├── robopianist ├── py.typed ├── suite │ ├── robopianist.png │ ├── README.md │ ├── tasks │ │ ├── __init__.py │ │ ├── self_actuated_piano_test.py │ │ ├── base.py │ │ └── self_actuated_piano.py │ ├── composite_reward.py │ ├── suite_test.py │ ├── __init__.py │ ├── variations_test.py │ └── variations.py ├── music │ ├── data │ │ ├── rousseau │ │ │ ├── nocturne-trimmed.mid │ │ │ └── twinkle-twinkle-trimmed.mid │ │ └── README.md │ ├── audio.py │ ├── constants.py │ ├── music_test.py │ ├── midi_message.py │ ├── __init__.py │ ├── synthesizer.py │ ├── midi_file_test.py │ └── piano_roll.py ├── viewer │ ├── README.md │ ├── __init__.py │ ├── gui │ │ ├── __init__.py │ │ ├── base.py │ │ └── fullscreen_quad.py │ ├── figures.py │ ├── user_input.py │ └── util.py ├── models │ ├── __init__.py │ ├── piano │ │ ├── __init__.py │ │ ├── piano_test.py │ │ ├── piano_constants.py │ │ └── midi_module.py │ ├── arenas │ │ ├── __init__.py │ │ ├── stage_test.py │ │ └── stage.py │ └── hands │ │ ├── __init__.py │ │ ├── shadow_hand_constants.py │ │ ├── base.py │ │ └── shadow_hand_test.py ├── wrappers │ ├── __init__.py │ ├── pixels.py │ ├── sound.py │ └── evaluation.py └── __init__.py ├── docs ├── index.md ├── cli.md ├── teaser1x3.jpeg ├── soundfonts.md ├── faq.md ├── dataset.md └── contributing.md ├── mkdocs.yml ├── .gitignore ├── examples ├── twinkle_twinkle_actions.npy ├── midi_data_to_file.py ├── play_midi_file.py ├── http_player.py ├── self_actuated_piano_env.py └── piano_with_shadow_hands_env.py ├── .gitmodules ├── MANIFEST.in ├── pyproject.toml ├── Makefile ├── .github └── workflows │ ├── docs.yml │ ├── publish.yml │ └── ci.yml ├── CITATION.cff ├── scripts ├── get_soundfonts.sh └── install_deps.sh ├── setup.py └── README.md /robopianist/py.typed: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | # Welcome to RoboPianist 2 | -------------------------------------------------------------------------------- /docs/cli.md: -------------------------------------------------------------------------------- 1 | # RoboPianist Command Line Interface 2 | -------------------------------------------------------------------------------- /mkdocs.yml: -------------------------------------------------------------------------------- 1 | site_name: RoboPianist 2 | theme: 3 | name: material 4 | -------------------------------------------------------------------------------- /docs/teaser1x3.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-research/robopianist/HEAD/docs/teaser1x3.jpeg -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | third_party/ 3 | soundfonts/ 4 | *.egg-info/ 5 | .eggs/ 6 | build/ 7 | dist/ 8 | pig_single_finger/ 9 | -------------------------------------------------------------------------------- /robopianist/suite/robopianist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-research/robopianist/HEAD/robopianist/suite/robopianist.png -------------------------------------------------------------------------------- /examples/twinkle_twinkle_actions.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-research/robopianist/HEAD/examples/twinkle_twinkle_actions.npy -------------------------------------------------------------------------------- /robopianist/music/data/rousseau/nocturne-trimmed.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-research/robopianist/HEAD/robopianist/music/data/rousseau/nocturne-trimmed.mid -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "third_party/mujoco_menagerie"] 2 | path = third_party/mujoco_menagerie 3 | url = https://github.com/deepmind/mujoco_menagerie.git 4 | branch = main 5 | -------------------------------------------------------------------------------- /robopianist/music/data/rousseau/twinkle-twinkle-trimmed.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-research/robopianist/HEAD/robopianist/music/data/rousseau/twinkle-twinkle-trimmed.mid -------------------------------------------------------------------------------- /robopianist/viewer/README.md: -------------------------------------------------------------------------------- 1 | # RoboPianist Viewer 2 | 3 | This is a fork of [`dm_control.viewer`](https://github.com/deepmind/dm_control/tree/main/dm_control/viewer) with customizations specific to the `robopianist` project. 4 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | recursive-include robopianist/models/hands/third_party * 2 | recursive-include robopianist/soundfonts * 3 | recursive-include robopianist/music/data * 4 | include robopianist/py.typed 5 | include README.md 6 | include LICENSE 7 | include CITATION.cff 8 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | target-version = ["py310"] 3 | 4 | [tool.mypy] 5 | ignore_missing_imports = true 6 | exclude = ["third_party", "dist", "build"] 7 | 8 | [tool.pytest.ini_options] 9 | norecursedirs = ["third_party", "dist", "build"] 10 | 11 | [tool.ruff] 12 | extend-exclude = ["third_party", "dist", "build"] 13 | select = ["E", "F", "I", "W"] 14 | ignore = ["E501"] 15 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | SHELL := /bin/bash 2 | 3 | .PHONY: help check format test 4 | .DEFAULT: help 5 | 6 | help: 7 | @echo "Usage: make " 8 | @echo 9 | @echo "Available targets:" 10 | @echo " help: Show this help" 11 | @echo " format: Run type checking and code styling inplace" 12 | @echo " test: Run all tests" 13 | 14 | format: 15 | black . 16 | ruff --fix . 17 | mypy . 18 | 19 | test: 20 | pytest -n auto 21 | 22 | server: 23 | python examples/http_player.py 24 | -------------------------------------------------------------------------------- /.github/workflows/docs.yml: -------------------------------------------------------------------------------- 1 | name: docs 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | 8 | permissions: 9 | contents: write 10 | 11 | jobs: 12 | deploy: 13 | runs-on: ubuntu-latest 14 | steps: 15 | - uses: actions/checkout@v3 16 | - uses: actions/setup-python@v4 17 | with: 18 | python-version: 3.x 19 | - uses: actions/cache@v2 20 | with: 21 | key: ${{ github.ref }} 22 | path: .cache 23 | - run: pip install mkdocs-material 24 | - run: mkdocs gh-deploy --force 25 | -------------------------------------------------------------------------------- /robopianist/music/data/README.md: -------------------------------------------------------------------------------- 1 | # MIDI Data 2 | 3 | ## `rousseau` 4 | 5 | There are 2 MIDI files in this directory containing 2 short snippets of music taken from [Rousseau](https://www.youtube.com/channel/UCPZUQqtVDmcjm4NY5FkzqLA) performances. We have gotten explicit permission from Rousseau to use these MIDI files in this benchmark. 6 | 7 | ## `pig_single_finger` 8 | 9 | This directory will be created after the PIG dataset has been downloaded from the official website and upon running `robopianist preprocess`. These MIDI files (saved as `.proto` files) retain their original license. 10 | -------------------------------------------------------------------------------- /robopianist/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The RoboPianist Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /robopianist/models/piano/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The RoboPianist Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from robopianist.models.piano.piano import Piano 16 | 17 | __all__ = [ 18 | "Piano", 19 | ] 20 | -------------------------------------------------------------------------------- /robopianist/models/arenas/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The RoboPianist Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Arenas for piano playing tasks.""" 16 | 17 | from robopianist.models.arenas.stage import Stage 18 | 19 | __all__ = ["Stage"] 20 | -------------------------------------------------------------------------------- /robopianist/models/hands/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The RoboPianist Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from robopianist.models.hands.base import Hand, HandSide 16 | from robopianist.models.hands.shadow_hand import ShadowHand 17 | 18 | __all__ = ["Hand", "HandSide", "ShadowHand"] 19 | -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.2.0 2 | authors: 3 | - family-names: Zakka 4 | given-names: Kevin 5 | - family-names: Philipp 6 | given-names: Wu 7 | - family-names: Smith 8 | given-names: Laura 9 | - family-names: Gileadi 10 | given-names: Nimrod 11 | - family-names: Howell 12 | given-names: Taylor 13 | - family-names: Peng 14 | given-names: Xue Bin 15 | - family-names: Singh 16 | given-names: Sumeet 17 | - family-names: Tassa 18 | given-names: Yuval 19 | - family-names: Florence 20 | given-names: Pete 21 | - family-names: Zeng 22 | given-names: Andy 23 | - family-names: Abbeel 24 | given-names: Pieter 25 | date-released: "2023-04-01" 26 | message: "If you use this software, please cite it as below." 27 | title: "RoboPianist: Dexterous Piano Playing with Deep Reinforcement Learning" 28 | url: "https://github.com/google-research/robopianist" 29 | -------------------------------------------------------------------------------- /robopianist/suite/README.md: -------------------------------------------------------------------------------- 1 | # RoboPianist Suite 2 | 3 | ![Task Suite](robopianist.png) 4 | 5 | ## Quickstart 6 | 7 | ```python 8 | import numpy as np 9 | from robopianist import suite 10 | 11 | # Print out all available tasks. 12 | print(suite.ALL) 13 | 14 | # Print out robopianist-etude-12 task subset. 15 | print(suite.ETUDE_12) 16 | 17 | # Load an environment from the debug subset. 18 | env = suite.load("RoboPianist-debug-TwinkleTwinkleLittleStar-v0") 19 | action_spec = env.action_spec() 20 | 21 | # Step through an episode and print out the reward, discount and observation. 22 | timestep = env.reset() 23 | while not timestep.last(): 24 | action = np.random.uniform( 25 | action_spec.minimum, action_spec.maximum, size=action_spec.shape 26 | ).astype(action_spec.dtype) 27 | timestep = env.step(action) 28 | print(timestep.reward, timestep.discount, timestep.observation) 29 | ``` 30 | -------------------------------------------------------------------------------- /docs/soundfonts.md: -------------------------------------------------------------------------------- 1 | # Soundfonts 2 | 3 | ## What is a soundfont? 4 | 5 | ## Soundfonts in RoboPianist 6 | 7 | 1. `TimGM6mb.sf2` 8 | * [Download Link](https://sourceforge.net/p/mscore/code/HEAD/tree/trunk/mscore/share/sound/TimGM6mb.sf2?format=raw) 9 | * Creator: [Tim Brechbill](https://timbrechbill.com/saxguru/) 10 | * License: GNU General Public License 11 | * Size: 6 MB 12 | * Note: Musescore 1.0 came with this soundfont. 13 | 2. `SalamanderGrandPiano-V3+20200602.sf2` 14 | * [Download Link](https://freepats.zenvoid.org/Piano/acoustic-grand-piano.html) 15 | * Creator: Alexander Holm 16 | * License: Public Domain (as of 4/3/2022) 17 | * Size: 1.27 GB 18 | 19 | ## Using custom soundfonts 20 | 21 | ## Resources 22 | 23 | Check out these links for more soundfonts: 24 | 25 | * [Musescore](https://musescore.org/en/handbook/3/soundfonts-and-sfz-files) 26 | * [Soundfonts4u](https://sites.google.com/site/soundfonts4u/) 27 | -------------------------------------------------------------------------------- /docs/faq.md: -------------------------------------------------------------------------------- 1 | # Frequently Asked Questions 2 | 3 | - [Can I use my own custom soundfont?](#can-i-use-my-own-custom-soundfont) 4 | - [I get a `ImportError("Couldn't find the FluidSynth library.")`](#i-get-a-importerrorcouldnt-find-the-fluidsynth-library) 5 | - [OSError: undefined symbol ffi\_type\_uint32](#oserror-undefined-symbol-ffi_type_uint32) 6 | 7 | ## Can I use my own custom soundfont? 8 | 9 | You are free to use a soundfont of your choosing, just make sure to update `SF2_PATH` in [`robopianist/__init__.py`](robopianist/__init__.py) to point to its location. Note only `.sf2` soundfonts are supported. 10 | 11 | ## I get a `ImportError("Couldn't find the FluidSynth library.")` 12 | 13 | See [this stackoverflow answer](https://stackoverflow.com/a/75339618) for a solution. 14 | 15 | ## OSError: undefined symbol ffi_type_uint32 16 | 17 | Add the following to your `~/.bashrc`: 18 | 19 | ```bash 20 | export LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libffi.so.7 21 | ``` 22 | -------------------------------------------------------------------------------- /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: Upload Python Package 2 | 3 | on: 4 | release: 5 | types: [published] 6 | 7 | jobs: 8 | deploy: 9 | runs-on: ubuntu-latest 10 | steps: 11 | - uses: actions/checkout@v3 12 | - name: Set up Python 13 | uses: actions/setup-python@v4 14 | with: 15 | python-version: '3.x' 16 | - name: Add repository sub-modules 17 | run: | 18 | git submodule init 19 | git submodule update 20 | - name: Install dependencies 21 | shell: bash 22 | run: | 23 | bash scripts/install_deps.sh 24 | - name: Prepare Python 25 | run: | 26 | python -m pip install --upgrade pip 27 | pip install setuptools wheel twine 28 | - name: Build and publish 29 | env: 30 | TWINE_USERNAME: __token__ 31 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} 32 | run: | 33 | python setup.py sdist bdist_wheel 34 | twine upload dist/* 35 | -------------------------------------------------------------------------------- /robopianist/wrappers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The RoboPianist Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from robopianist.wrappers.evaluation import MidiEvaluationWrapper 16 | from robopianist.wrappers.pixels import PixelWrapper 17 | from robopianist.wrappers.sound import PianoSoundVideoWrapper 18 | 19 | __all__ = [ 20 | "MidiEvaluationWrapper", 21 | "PianoSoundVideoWrapper", 22 | "PixelWrapper", 23 | ] 24 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: build 2 | 3 | on: 4 | push: 5 | paths-ignore: 6 | - "tutorial.ipynb" 7 | - "docs/**" 8 | - "**/README.md" 9 | pull_request: 10 | paths-ignore: 11 | - "tutorial.ipynb" 12 | - "docs/**" 13 | - "**/README.md" 14 | 15 | jobs: 16 | run-robopianist-tests: 17 | runs-on: ubuntu-latest 18 | strategy: 19 | matrix: 20 | python-version: ["python:3.11"] 21 | steps: 22 | - name: Checkout robopianist 23 | uses: actions/checkout@v3 24 | - name: Checkout submodules 25 | run: | 26 | git submodule init 27 | git submodule update 28 | - name: Install dependencies 29 | shell: bash 30 | run: | 31 | bash scripts/install_deps.sh 32 | - name: Prepare Python 33 | run: | 34 | python -m pip install --upgrade pip wheel 35 | pip install -e ".[test]" 36 | - name: Run tests 37 | run: | 38 | make test 39 | -------------------------------------------------------------------------------- /docs/dataset.md: -------------------------------------------------------------------------------- 1 | # Piano Fingering Dataset 2 | 3 | Due to licensing restrictions, we are unable to redistribute the PIG dataset inside this repository. You will need to go to the [PIG website](https://beam.kisarazu.ac.jp/~saito/research/PianoFingeringDataset/) and download it by [registering for an account](https://beam.kisarazu.ac.jp/~saito/research/PianoFingeringDataset/register.php). 4 | 5 | **Note: We are working with the authors of PIG to make the dataset available for download in a more convenient way.** 6 | 7 | The download will contain a zip file called `PianoFingeringDataset_v1.2.zip`. Extract the folder `PianoFingeringDataset_v1.2` from the zip file, then use the CLI as follows: 8 | 9 | ```bash 10 | robopianist preprocess --dataset-dir /PATH/TO/PianoFingeringDataset_v1.2 11 | ``` 12 | 13 | This will create a directory called `pig_single_finger` in `robopianist/music/data`. 14 | 15 | To double check that the dataset was successfully preprocessed, run the following command: 16 | 17 | ```bash 18 | robopianist --check-pig-exists 19 | ``` 20 | 21 | If successful, it will print `PIG dataset is ready to use!`. 22 | -------------------------------------------------------------------------------- /robopianist/models/arenas/stage_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The RoboPianist Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for stage.py.""" 16 | 17 | from absl.testing import absltest 18 | from dm_control import mjcf 19 | 20 | from robopianist.models.arenas import stage 21 | 22 | 23 | class StageTest(absltest.TestCase): 24 | def test_compiles_and_steps(self) -> None: 25 | arena = stage.Stage() 26 | physics = mjcf.Physics.from_mjcf_model(arena.mjcf_model) 27 | physics.step() 28 | 29 | 30 | if __name__ == "__main__": 31 | absltest.main() 32 | -------------------------------------------------------------------------------- /robopianist/suite/tasks/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The RoboPianist Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from robopianist.suite.tasks.base import PianoOnlyTask, PianoTask 16 | from robopianist.suite.tasks.piano_with_one_shadow_hand import PianoWithOneShadowHand 17 | from robopianist.suite.tasks.piano_with_shadow_hands import PianoWithShadowHands 18 | from robopianist.suite.tasks.self_actuated_piano import SelfActuatedPiano 19 | 20 | __all__ = [ 21 | "PianoTask", 22 | "PianoOnlyTask", 23 | "SelfActuatedPiano", 24 | "PianoWithShadowHands", 25 | "PianoWithOneShadowHand", 26 | ] 27 | -------------------------------------------------------------------------------- /robopianist/viewer/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The dm_control Authors. 2 | # Copyright 2023 The RoboPianist Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | 17 | """Suite environments viewer package.""" 18 | 19 | 20 | from robopianist.viewer import application 21 | 22 | 23 | def launch( 24 | environment_loader, policy=None, title="RoboPianist", width=1024, height=768 25 | ) -> None: 26 | app = application.Application(title=title, width=width, height=height) 27 | app.launch(environment_loader=environment_loader, policy=policy) 28 | -------------------------------------------------------------------------------- /docs/contributing.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | We'd love to accept your patches and contributions to this project. 4 | 5 | ## Before you begin 6 | 7 | ### Sign our Contributor License Agreement 8 | 9 | Contributions to this project must be accompanied by a 10 | [Contributor License Agreement](https://cla.developers.google.com/about) (CLA). 11 | You (or your employer) retain the copyright to your contribution; this simply 12 | gives us permission to use and redistribute your contributions as part of the 13 | project. 14 | 15 | If you or your current employer have already signed the Google CLA (even if it 16 | was for a different project), you probably don't need to do it again. 17 | 18 | Visit to see your current agreements or to 19 | sign a new one. 20 | 21 | ### Review our Community Guidelines 22 | 23 | This project follows 24 | [Google's Open Source Community Guidelines](https://opensource.google/conduct/). 25 | 26 | ## Contribution process 27 | 28 | ### Code Reviews 29 | 30 | All submissions, including submissions by project members, require review. We 31 | use GitHub pull requests for this purpose. Consult 32 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 33 | information on using pull requests. 34 | -------------------------------------------------------------------------------- /examples/midi_data_to_file.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The RoboPianist Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Save a programmatically generated NoteSequence to a MIDI file. 16 | 17 | Example usage: 18 | python examples/midi_data_to_file.py --name CMajorScaleTwoHands --save_path /tmp/c_major_scale_two_hands.mid 19 | """ 20 | 21 | from absl import app, flags 22 | 23 | from robopianist import music 24 | 25 | _NAME = flags.DEFINE_string("name", None, "") 26 | _SAVE_PATH = flags.DEFINE_string("save_path", None, "") 27 | 28 | 29 | def main(_) -> None: 30 | music.load(_NAME.value).save(_SAVE_PATH.value) 31 | 32 | 33 | if __name__ == "__main__": 34 | flags.mark_flags_as_required(["name", "save_path"]) 35 | app.run(main) 36 | -------------------------------------------------------------------------------- /scripts/get_soundfonts.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2023 The RoboPianist Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | # Install additional soundfonts. 17 | 18 | set -ex 19 | 20 | mkdir -p third_party/soundfonts 21 | 22 | # Salamander Grand Piano. 23 | LINK=https://freepats.zenvoid.org/Piano/SalamanderGrandPiano/SalamanderGrandPiano-SF2-V3+20200602.tar.xz 24 | wget $LINK 25 | tar -xvf SalamanderGrandPiano-SF2-V3+20200602.tar.xz 26 | mv SalamanderGrandPiano-SF2-V3+20200602/SalamanderGrandPiano-V3+20200602.sf2 third_party/soundfonts/SalamanderGrandPiano.sf2 27 | rm -r SalamanderGrandPiano-SF2-V3+20200602.tar.xz SalamanderGrandPiano-SF2-V3+20200602 28 | 29 | # Make a copy of the soundfonts in robopianist/soundfonts. 30 | mkdir -p robopianist/soundfonts 31 | cp third_party/soundfonts/* robopianist/soundfonts 32 | -------------------------------------------------------------------------------- /examples/play_midi_file.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The RoboPianist Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Play a MIDI file using FluidSynth and PyAudio. 16 | 17 | Example usage: 18 | python examples/play_midi_file.py --file robopianist/music/data/rousseau/nocturne-trimmed.mid 19 | """ 20 | 21 | from absl import app, flags 22 | 23 | from robopianist import music 24 | 25 | _FILE = flags.DEFINE_string("file", None, "Path to the MIDI file.") 26 | _STRETCH = flags.DEFINE_float("stretch", 1.0, "Stretch the MIDI file by this factor.") 27 | _SHIFT = flags.DEFINE_integer("shift", 0, "Shift the MIDI file by this many semitones.") 28 | 29 | 30 | def main(_) -> None: 31 | music.load( 32 | _FILE.value, stretch=_STRETCH.value, shift=_SHIFT.value 33 | ).trim_silence().play() 34 | 35 | 36 | if __name__ == "__main__": 37 | flags.mark_flag_as_required("file") 38 | app.run(main) 39 | -------------------------------------------------------------------------------- /robopianist/viewer/gui/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The dm_control Authors. 2 | # Copyright 2023 The RoboPianist Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """Viewer's windowing systems.""" 17 | 18 | from dm_control import _render 19 | 20 | # pylint: disable=g-import-not-at-top 21 | # pylint: disable=invalid-name 22 | 23 | RenderWindow = None 24 | 25 | try: 26 | from robopianist.viewer.gui import glfw_gui 27 | 28 | RenderWindow = glfw_gui.GlfwWindow 29 | except ImportError: 30 | pass 31 | 32 | if RenderWindow is None: 33 | 34 | def ErrorRenderWindow(*args, **kwargs): 35 | del args, kwargs 36 | raise ImportError( 37 | "Cannot create a window because no windowing system could be imported" 38 | ) 39 | 40 | RenderWindow = ErrorRenderWindow # type: ignore 41 | 42 | del _render 43 | -------------------------------------------------------------------------------- /robopianist/models/hands/shadow_hand_constants.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The RoboPianist Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from pathlib import Path 16 | from typing import Dict, Tuple 17 | 18 | _HERE = Path(__file__).resolve().parent 19 | _SHADOW_HAND_DIR = _HERE / "third_party" / "shadow_hand" 20 | 21 | NQ = 24 # Number of joints. 22 | NU = 20 # Number of actuators. 23 | 24 | JOINT_GROUP: Dict[str, Tuple[str, ...]] = { 25 | "wrist": ("WRJ1", "WRJ0"), 26 | "thumb": ("THJ4", "THJ3", "THJ2", "THJ1", "THJ0"), 27 | "first": ("FFJ3", "FFJ2", "FFJ1", "FFJ0"), 28 | "middle": ("MFJ3", "MFJ2", "MFJ1", "MFJ0"), 29 | "ring": ("RFJ3", "RFJ2", "RFJ1", "RFJ0"), 30 | "little": ("LFJ4", "LFJ3", "LFJ2", "LFJ1", "LFJ0"), 31 | } 32 | 33 | FINGERTIP_BODIES: Tuple[str, ...] = ( 34 | # Important: the order of these names should not be changed. 35 | "thdistal", 36 | "ffdistal", 37 | "mfdistal", 38 | "rfdistal", 39 | "lfdistal", 40 | ) 41 | 42 | FINGERTIP_COLORS: Tuple[Tuple[float, float, float], ...] = ( 43 | # Important: the order of these colors should not be changed. 44 | (0.8, 0.2, 0.8), # Purple. 45 | (0.8, 0.2, 0.2), # Red. 46 | (0.2, 0.8, 0.8), # Cyan. 47 | (0.2, 0.2, 0.8), # Blue. 48 | (0.8, 0.8, 0.2), # Yellow. 49 | ) 50 | 51 | # Path to the shadow hand E3M5 XML file. 52 | RIGHT_SHADOW_HAND_XML = _SHADOW_HAND_DIR / "right_hand.xml" 53 | LEFT_SHADOW_HAND_XML = _SHADOW_HAND_DIR / "left_hand.xml" 54 | -------------------------------------------------------------------------------- /robopianist/music/audio.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The RoboPianist Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Audio playback utils.""" 16 | 17 | import time 18 | 19 | import numpy as np 20 | import pyaudio 21 | 22 | from robopianist.music import constants as consts 23 | 24 | 25 | def play_sound( 26 | waveform: np.ndarray, sampling_rate: int = consts.SAMPLING_RATE, chunk: int = 1024 27 | ) -> None: 28 | """Play a waveform using PyAudio.""" 29 | if waveform.dtype != np.int16: 30 | raise ValueError("waveform must be an np.int16 array.") 31 | 32 | # An iterator that yields chunks of audio data. 33 | def chunkifier(): 34 | for i in range(0, len(waveform), chunk): 35 | yield waveform[i : i + chunk] 36 | 37 | audio_generator = chunkifier() 38 | 39 | def callback(in_data, frame_count, time_info, status): 40 | del in_data, frame_count, time_info, status 41 | return (next(audio_generator), pyaudio.paContinue) 42 | 43 | p = pyaudio.PyAudio() 44 | 45 | stream = p.open( 46 | format=pyaudio.paInt16, 47 | channels=1, 48 | rate=sampling_rate, 49 | output=True, 50 | frames_per_buffer=chunk, 51 | stream_callback=callback, 52 | ) 53 | 54 | try: 55 | stream.start_stream() 56 | while stream.is_active(): 57 | time.sleep(0.1) 58 | except KeyboardInterrupt: 59 | print("Ctrl-C detected. Stopping playback.") 60 | finally: 61 | stream.stop_stream() 62 | stream.close() 63 | p.terminate() 64 | -------------------------------------------------------------------------------- /robopianist/music/constants.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The RoboPianist Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Music constants.""" 16 | 17 | MIN_MIDI_PITCH = 0 18 | MAX_MIDI_PITCH = 127 19 | 20 | # MIDI pitch number of the lowest note on the piano (A0). 21 | MIN_MIDI_PITCH_PIANO = 21 22 | # MIDI pitch number of the highest note on the piano (C8). 23 | MAX_MIDI_PITCH_PIANO = 108 24 | 25 | # Min and max key numbers on the piano. 26 | MIN_KEY_NUMBER = 0 27 | MAX_KEY_NUMBER = 87 28 | NUM_KEYS = MAX_KEY_NUMBER - MIN_KEY_NUMBER + 1 29 | 30 | # Notes in an octave. 31 | NOTES_IN_OCTAVE = ["C", "C#", "D", "D#", "E", "F", "F#", "G", "G#", "A", "A#", "B"] 32 | assert len(NOTES_IN_OCTAVE) == 12 33 | 34 | # Notes on an 88-key piano from left (A0) to right (C8). 35 | NOTES = [] 36 | NOTES.extend(["A0", "A#0", "B0"]) 37 | for octave in range(1, 8): 38 | for note in NOTES_IN_OCTAVE: 39 | NOTES.append(note + str(octave)) 40 | NOTES.append("C8") 41 | assert len(NOTES) == NUM_KEYS 42 | 43 | # Mapping for converting between a key number and its note name. 44 | KEY_NUMBER_TO_NOTE_NAME = {i: note for i, note in enumerate(NOTES)} 45 | NOTE_NAME_TO_KEY_NUMBER = {note: i for i, note in enumerate(NOTES)} 46 | 47 | # Mapping for converting between a note name and its MIDI pitch number. 48 | MIDI_NUMBER_TO_NOTE_NAME = {i + 21: name for i, name in enumerate(NOTES)} 49 | NOTE_NAME_TO_MIDI_NUMBER = {v: k for k, v in MIDI_NUMBER_TO_NOTE_NAME.items()} 50 | 51 | # Sampling frequency of the audio, in Hz. 52 | SAMPLING_RATE = 44100 53 | 54 | SUSTAIN_PEDAL_CC_NUMBER = 64 55 | MIN_CC_VALUE = 0 56 | MAX_CC_VALUE = 127 57 | 58 | MIN_VELOCITY = 0 59 | MAX_VELOCITY = 127 60 | -------------------------------------------------------------------------------- /robopianist/suite/composite_reward.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The RoboPianist Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Utility class for composite reward functions.""" 16 | 17 | from typing import Callable, Dict 18 | 19 | from dm_control import mjcf 20 | 21 | Reward = float 22 | RewardFn = Callable[[mjcf.Physics], Reward] 23 | 24 | 25 | class CompositeReward: 26 | """A reward function composed of individual reward terms. 27 | 28 | Useful for grouping sub-rewards of a task into a single reward function, computing 29 | their sum, and logging the individual terms. 30 | """ 31 | 32 | def __init__(self, **kwargs) -> None: 33 | self._reward_fns: Dict[str, RewardFn] = {} 34 | for name, reward_fn in kwargs.items(): 35 | self.add(name, reward_fn) 36 | self._reward_terms: Dict[str, Reward] = {} 37 | 38 | def add(self, name: str, reward_fn: RewardFn) -> None: 39 | """Adds a reward term to the reward terms.""" 40 | self._reward_fns[name] = reward_fn 41 | 42 | def remove(self, name: str) -> None: 43 | """Removes a reward term from the reward terms.""" 44 | del self._reward_fns[name] 45 | 46 | def compute(self, physics: mjcf.Physics) -> float: 47 | """Computes the reward terms sequentially and returns their sum. 48 | 49 | Note that the reward terms are computed in the order they were added. 50 | """ 51 | sum_of_rewards = 0.0 52 | for name, reward_fn in self._reward_fns.items(): 53 | rew = reward_fn(physics) 54 | sum_of_rewards += rew 55 | self._reward_terms[name] = rew 56 | return sum_of_rewards 57 | 58 | @property 59 | def reward_fns(self) -> Dict[str, RewardFn]: 60 | return self._reward_fns 61 | 62 | @property 63 | def reward_terms(self) -> Dict[str, Reward]: 64 | return self._reward_terms 65 | -------------------------------------------------------------------------------- /robopianist/suite/suite_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The RoboPianist Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for robopianist.suite.""" 16 | 17 | import numpy as np 18 | from absl.testing import absltest, parameterized 19 | 20 | from robopianist import suite 21 | 22 | _SEED = 12345 23 | _NUM_EPISODES = 1 24 | _NUM_STEPS_PER_EPISODE = 10 25 | 26 | 27 | class RoboPianistSuiteTest(parameterized.TestCase): 28 | """Tests for all registered tasks in robopianist.suite.""" 29 | 30 | def _validate_observation(self, observation, observation_spec): 31 | self.assertEqual(list(observation.keys()), list(observation_spec.keys())) 32 | for name, array_spec in observation_spec.items(): 33 | array_spec.validate(observation[name]) 34 | 35 | @parameterized.parameters(*suite.DEBUG) 36 | def test_task_runs(self, environment_name: str) -> None: 37 | """Tests task loading and observation spec validity.""" 38 | env = suite.load(environment_name, seed=_SEED) 39 | random_state = np.random.RandomState(_SEED) 40 | 41 | observation_spec = env.observation_spec() 42 | action_spec = env.action_spec() 43 | self.assertTrue(np.all(np.isfinite(action_spec.minimum))) 44 | self.assertTrue(np.all(np.isfinite(action_spec.maximum))) 45 | 46 | for _ in range(_NUM_EPISODES): 47 | timestep = env.reset() 48 | for _ in range(_NUM_STEPS_PER_EPISODE): 49 | self._validate_observation(timestep.observation, observation_spec) 50 | if timestep.first(): 51 | self.assertIsNone(timestep.reward) 52 | self.assertIsNone(timestep.discount) 53 | action = random_state.uniform( 54 | action_spec.minimum, action_spec.maximum, size=action_spec.shape 55 | ).astype(action_spec.dtype) 56 | timestep = env.step(action) 57 | 58 | 59 | if __name__ == "__main__": 60 | absltest.main() 61 | -------------------------------------------------------------------------------- /robopianist/models/piano/piano_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The RoboPianist Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for piano.py.""" 16 | 17 | from absl.testing import absltest 18 | from dm_control import mjcf 19 | 20 | from robopianist.models.piano import piano 21 | from robopianist.models.piano import piano_constants as consts 22 | 23 | 24 | class PianoTest(absltest.TestCase): 25 | def test_compiles_and_steps(self) -> None: 26 | robot = piano.Piano() 27 | physics = mjcf.Physics.from_mjcf_model(robot.mjcf_model) 28 | for _ in range(100): 29 | physics.step() 30 | 31 | def test_set_name(self) -> None: 32 | robot = piano.Piano(name="mozart") 33 | self.assertEqual(robot.mjcf_model.model, "mozart") 34 | 35 | def test_joints(self) -> None: 36 | robot = piano.Piano() 37 | self.assertEqual(len(robot.joints), consts.NUM_KEYS) 38 | for joint in robot.joints: 39 | self.assertEqual(joint.tag, "joint") 40 | 41 | def test_keys(self) -> None: 42 | robot = piano.Piano() 43 | self.assertEqual(len(robot.keys), consts.NUM_KEYS) 44 | for key in robot.keys: 45 | self.assertEqual(key.tag, "body") 46 | 47 | def test_sorted(self) -> None: 48 | robot = piano.Piano() 49 | for i in range(consts.NUM_KEYS - 1): 50 | self.assertLess( 51 | int(robot.keys[i].name.split("_")[-1]), 52 | int(robot.keys[i + 1].name.split("_")[-1]), 53 | ) 54 | self.assertLess( 55 | int(robot._sites[i].name.split("_")[-1]), 56 | int(robot._sites[i + 1].name.split("_")[-1]), 57 | ) 58 | self.assertLess( 59 | int(robot._key_geoms[i].name.split("_")[-1]), 60 | int(robot._key_geoms[i + 1].name.split("_")[-1]), 61 | ) 62 | self.assertLess( 63 | int(robot.joints[i].name.split("_")[-1]), 64 | int(robot.joints[i + 1].name.split("_")[-1]), 65 | ) 66 | 67 | 68 | if __name__ == "__main__": 69 | absltest.main() 70 | -------------------------------------------------------------------------------- /robopianist/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The RoboPianist Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from pathlib import Path 16 | 17 | __version__ = "1.0.10" 18 | 19 | # Path to the root of the project. 20 | _PROJECT_ROOT = Path(__file__).parent.parent 21 | 22 | # Path to the soundfont directory. 23 | _SOUNDFONT_PATH = _PROJECT_ROOT / "robopianist" / "soundfonts" 24 | 25 | # TimGM6mb.sf2 is the default soundfont file that is packaged with the pip install or 26 | # available when installing from source using `bash scripts/install_deps.sh`. 27 | _DEFAULT_SF2_PATH = _SOUNDFONT_PATH / "TimGM6mb.sf2" 28 | 29 | # We first check if the user has a .robopianistrc file in their home directory. If so, 30 | # we check if it specifies a soundfont file. If so, we use that. 31 | _RC_FILE = Path.home() / ".robopianistrc" 32 | if _RC_FILE.exists(): 33 | found = False 34 | with open(_RC_FILE, "r") as f: 35 | for line in f: 36 | if line.startswith("DEFAULT_SOUNDFONT="): 37 | soundfont_path = line.split("=")[1].strip() 38 | SF2_PATH = _SOUNDFONT_PATH / f"{soundfont_path}.sf2" 39 | if not SF2_PATH.exists(): 40 | SF2_PATH = _DEFAULT_SF2_PATH 41 | found = True 42 | break 43 | if not found: 44 | SF2_PATH = _DEFAULT_SF2_PATH 45 | # Otherwise, we look in the soundfont directory. Our preference is for the higher 46 | # quality SalamanderGrandPiano.sf2, but if that is not found, we fall back to the 47 | # default soundfont file. 48 | else: 49 | _SALAMANDER_SF2_PATH = _SOUNDFONT_PATH / "SalamanderGrandPiano.sf2" 50 | if _SALAMANDER_SF2_PATH.exists(): 51 | SF2_PATH = _SALAMANDER_SF2_PATH 52 | else: 53 | if not _DEFAULT_SF2_PATH.exists(): 54 | raise FileNotFoundError( 55 | f"The default soundfont file {_DEFAULT_SF2_PATH} does not exist. Make " 56 | "sure you have first run `bash scripts/install_deps.sh` in the root of " 57 | "the project directory." 58 | ) 59 | SF2_PATH = _DEFAULT_SF2_PATH 60 | 61 | 62 | __all__ = [ 63 | "__version__", 64 | "SF2_PATH", 65 | ] 66 | -------------------------------------------------------------------------------- /robopianist/wrappers/pixels.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The RoboPianist Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """A wrapper for adding pixels to the observation.""" 16 | 17 | 18 | import collections 19 | from typing import Any, Dict, Optional 20 | 21 | import dm_env 22 | import numpy as np 23 | from dm_env import specs 24 | from dm_env_wrappers import EnvironmentWrapper 25 | 26 | 27 | class PixelWrapper(EnvironmentWrapper): 28 | """Adds pixel observations to the observation spec.""" 29 | 30 | def __init__( 31 | self, 32 | environment: dm_env.Environment, 33 | render_kwargs: Optional[Dict[str, Any]] = None, 34 | observation_key: str = "pixels", 35 | ) -> None: 36 | super().__init__(environment) 37 | 38 | self._render_kwargs = render_kwargs or {} 39 | self._observation_key = observation_key 40 | 41 | # Update the observation spec. 42 | self._wrapped_observation_spec = self._environment.observation_spec() 43 | self._observation_spec = collections.OrderedDict() 44 | self._observation_spec.update(self._wrapped_observation_spec) 45 | pixels = self._environment.physics.render(**self._render_kwargs) 46 | pixels_spec = specs.Array( 47 | shape=pixels.shape, dtype=pixels.dtype, name=self._observation_key 48 | ) 49 | self._observation_spec[observation_key] = pixels_spec 50 | 51 | def observation_spec(self): 52 | return self._observation_spec 53 | 54 | def reset(self) -> dm_env.TimeStep: 55 | timestep = self._environment.reset() 56 | return self._add_pixel_observation(timestep) 57 | 58 | def step(self, action: np.ndarray) -> dm_env.TimeStep: 59 | timestep = self._environment.step(action) 60 | return self._add_pixel_observation(timestep) 61 | 62 | def _add_pixel_observation(self, timestep: dm_env.TimeStep) -> dm_env.TimeStep: 63 | pixels = self._environment.physics.render(**self._render_kwargs) 64 | return timestep._replace( 65 | observation=collections.OrderedDict( 66 | timestep.observation, **{self._observation_key: pixels} 67 | ) 68 | ) 69 | -------------------------------------------------------------------------------- /robopianist/models/arenas/stage.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The RoboPianist Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Suite arenas.""" 16 | 17 | from mujoco_utils import composer_utils 18 | 19 | 20 | class Stage(composer_utils.Arena): 21 | """A custom arena with a ground plane, lights and a starry night sky.""" 22 | 23 | def _build(self, name: str = "stage") -> None: 24 | super()._build(name=name) 25 | 26 | # Change free camera settings. 27 | self._mjcf_root.statistic.extent = 0.6 28 | self._mjcf_root.statistic.center = (0.2, 0, 0.3) 29 | getattr(self._mjcf_root.visual, "global").azimuth = 180 30 | getattr(self._mjcf_root.visual, "global").elevation = -50 31 | 32 | self._mjcf_root.visual.scale.forcewidth = 0.04 33 | self._mjcf_root.visual.scale.contactwidth = 0.2 34 | self._mjcf_root.visual.scale.contactheight = 0.03 35 | 36 | # Lights. 37 | self._mjcf_root.worldbody.add("light", pos=(0, 0, 1)) 38 | self._mjcf_root.worldbody.add( 39 | "light", pos=(0.3, 0, 1), dir=(0, 0, -1), directional=False 40 | ) 41 | 42 | # Dark checkerboard floor. 43 | self._mjcf_root.asset.add( 44 | "texture", 45 | name="grid", 46 | type="2d", 47 | builtin="checker", 48 | width=512, 49 | height=512, 50 | rgb1=[0.1, 0.1, 0.1], 51 | rgb2=[0.2, 0.2, 0.2], 52 | ) 53 | self._mjcf_root.asset.add( 54 | "material", 55 | name="grid", 56 | texture="grid", 57 | texrepeat=(1, 1), 58 | texuniform=True, 59 | reflectance=0.2, 60 | ) 61 | self._ground_geom = self._mjcf_root.worldbody.add( 62 | "geom", 63 | type="plane", 64 | size=(1, 1, 0.05), 65 | material="grid", 66 | contype=0, 67 | conaffinity=0, 68 | ) 69 | 70 | # Starry night sky. 71 | self._mjcf_root.asset.add( 72 | "texture", 73 | name="skybox", 74 | type="skybox", 75 | builtin="gradient", 76 | rgb1=[0.2, 0.2, 0.2], 77 | rgb2=[0.0, 0.0, 0.0], 78 | width=800, 79 | height=800, 80 | mark="random", 81 | markrgb=[1, 1, 1], 82 | ) 83 | -------------------------------------------------------------------------------- /robopianist/music/music_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The RoboPianist Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for library.py.""" 16 | 17 | from absl.testing import absltest, parameterized 18 | 19 | from robopianist import _PROJECT_ROOT, music 20 | from robopianist.music import midi_file 21 | 22 | _PIG_DIR = _PROJECT_ROOT / "robopianist" / "music" / "data" / "pig_single_finger" 23 | 24 | 25 | @absltest.skipIf(not _PIG_DIR.exists(), "PIG dataset not found.") 26 | class ConstantsTest(parameterized.TestCase): 27 | def test_constants(self) -> None: 28 | # Check that all constants are non-empty. 29 | self.assertNotEmpty(music.ALL) 30 | self.assertNotEmpty(music.DEBUG_MIDIS) 31 | self.assertNotEmpty(music.PIG_MIDIS) 32 | self.assertNotEmpty(music.ETUDE_MIDIS) 33 | 34 | # Check that all = debug + pig. 35 | self.assertEqual(music.ALL, music.DEBUG_MIDIS + music.PIG_MIDIS) 36 | 37 | # Check that etude is a subset of pig. 38 | self.assertTrue(set(music.ETUDE_MIDIS).issubset(set(music.PIG_MIDIS))) 39 | 40 | 41 | @absltest.skipIf(not _PIG_DIR.exists(), "PIG dataset not found.") 42 | class LoadTest(parameterized.TestCase): 43 | def test_raises_key_error_on_invalid_midi(self) -> None: 44 | """Test that loading an invalid string MIDI raises a KeyError.""" 45 | with self.assertRaises(KeyError): 46 | music.load("invalid_midi") 47 | 48 | @parameterized.parameters(*music.ALL) 49 | def test_midis_in_library(self, midi_name: str) -> None: 50 | """Test that all midis in the library can be loaded.""" 51 | self.assertIsInstance(music.load(midi_name), midi_file.MidiFile) 52 | 53 | @parameterized.parameters(*music.ALL) 54 | def test_fingering_available_for_all_timesteps(self, midi_name: str) -> None: 55 | """Test that all midis in the library have fingering annotations for all 56 | timesteps.""" 57 | midi = music.load(midi_name).trim_silence() 58 | traj = midi_file.NoteTrajectory.from_midi(midi, dt=0.05) 59 | for timestep in traj.notes: 60 | for note in timestep: 61 | # -1 indicates no fingering annotation. Valid fingering lies in [0, 9]. 62 | self.assertGreater(note.fingering, -1) 63 | self.assertLess(note.fingering, 10) 64 | 65 | 66 | if __name__ == "__main__": 67 | absltest.main() 68 | -------------------------------------------------------------------------------- /scripts/install_deps.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2023 The RoboPianist Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | # Install dependencies (macOS and Linux). 17 | # Command line arguments: 18 | # --no-soundfonts: Skip downloading soundfonts. 19 | # --no-menagerie: Skip copying shadow_hand menagerie model. 20 | 21 | set -x 22 | 23 | SKIP_SOUNDFONTS=false 24 | SKIP_MENAGERIE=false 25 | while [[ $# -gt 0 ]]; do 26 | key="$1" 27 | case $key in 28 | --no-soundfonts) 29 | SKIP_SOUNDFONTS=true 30 | shift 31 | ;; 32 | --no-menagerie) 33 | SKIP_MENAGERIE=true 34 | shift 35 | ;; 36 | *) 37 | echo "Unknown argument: $key" 38 | exit 1 39 | ;; 40 | esac 41 | done 42 | 43 | # Install fluidsynth and portaudio. 44 | if [[ $OSTYPE == darwin* ]]; then 45 | # Install homebrew if not installed. 46 | if ! command -v brew; then 47 | ruby -e "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/master/install)" 48 | else 49 | brew update 50 | fi 51 | brew install portaudio fluid-synth ffmpeg 52 | elif [[ $OSTYPE == linux* ]]; then 53 | sudo apt update 54 | sudo apt install -y build-essential wget 55 | sudo apt install -y fluidsynth portaudio19-dev ffmpeg 56 | else 57 | echo "Unsupported OS" 58 | fi 59 | 60 | # Install soundfonts. 61 | if [ "$SKIP_SOUNDFONTS" = false ]; then 62 | # Download TimGM6mb.sf2 soundfont. 63 | mkdir -p third_party/soundfonts 64 | LINK=https://sourceforge.net/p/mscore/code/HEAD/tree/trunk/mscore/share/sound/TimGM6mb.sf2?format=raw 65 | if [ ! -f third_party/soundfonts/TimGM6mb.sf2 ]; then 66 | wget $LINK -O third_party/soundfonts/TimGM6mb.sf2 67 | fi 68 | 69 | # Copy soundfonts to robopianist. 70 | mkdir -p robopianist/soundfonts 71 | if [ ! -d "third_party/soundfonts" ]; then 72 | echo "third_party/soundfonts does not exist. Run scripts/get_soundfonts.sh first." 73 | exit 1 74 | fi 75 | cp -r third_party/soundfonts/* robopianist/soundfonts 76 | fi 77 | 78 | # Copy shadow_hand menagerie model to robopianist. 79 | if [ "$SKIP_MENAGERIE" = false ]; then 80 | cd third_party/mujoco_menagerie 81 | git checkout 1afc8be64233dcfe943b2fe0c505ec1e87a0a13e 82 | cd ../.. 83 | mkdir -p robopianist/models/hands/third_party/shadow_hand 84 | if [ ! -d "third_party/mujoco_menagerie/shadow_hand" ]; then 85 | echo "third_party/mujoco_menagerie/shadow_hand does not exist. Run git submodule init && git submodule update first." 86 | exit 1 87 | fi 88 | cp -r third_party/mujoco_menagerie/shadow_hand/* robopianist/models/hands/third_party/shadow_hand 89 | fi 90 | -------------------------------------------------------------------------------- /robopianist/music/midi_message.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The RoboPianist Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Stripped down MIDI message class.""" 16 | 17 | import enum 18 | from dataclasses import dataclass 19 | from typing import Union 20 | 21 | from robopianist.music import constants as consts 22 | 23 | 24 | @enum.unique 25 | class EventType(enum.Enum): 26 | """The type of a MIDI event.""" 27 | 28 | NOTE_ON = enum.auto() 29 | NOTE_OFF = enum.auto() 30 | SUSTAIN_ON = enum.auto() 31 | SUSTAIN_OFF = enum.auto() 32 | 33 | 34 | @dataclass 35 | class NoteOn: 36 | """A note-on MIDI message.""" 37 | 38 | note: int 39 | velocity: int 40 | time: float 41 | 42 | def __post_init__(self) -> None: 43 | assert consts.MIN_MIDI_PITCH <= self.note <= consts.MAX_MIDI_PITCH 44 | assert consts.MIN_VELOCITY <= self.velocity <= consts.MAX_VELOCITY 45 | assert self.time >= 0 46 | 47 | @property 48 | def event_type(self) -> EventType: 49 | return EventType.NOTE_ON 50 | 51 | 52 | @dataclass 53 | class NoteOff: 54 | """A note-off MIDI message.""" 55 | 56 | note: int 57 | time: float 58 | 59 | def __post_init__(self) -> None: 60 | assert consts.MIN_MIDI_PITCH <= self.note <= consts.MAX_MIDI_PITCH 61 | assert self.time >= 0 62 | 63 | @property 64 | def event_type(self) -> EventType: 65 | return EventType.NOTE_OFF 66 | 67 | 68 | @dataclass 69 | class _ControlChange: 70 | """A control-change MIDI message.""" 71 | 72 | control: int 73 | value: int 74 | time: float 75 | 76 | def __post_init__(self) -> None: 77 | assert consts.MIN_CC_VALUE <= self.control <= consts.MAX_CC_VALUE 78 | assert consts.MIN_CC_VALUE <= self.value <= consts.MAX_CC_VALUE 79 | assert self.time >= 0 80 | 81 | 82 | class SustainOn(_ControlChange): 83 | """A sustain-on MIDI message.""" 84 | 85 | def __init__(self, time: float) -> None: 86 | super().__init__( 87 | consts.SUSTAIN_PEDAL_CC_NUMBER, 88 | consts.SUSTAIN_PEDAL_CC_NUMBER, 89 | time, 90 | ) 91 | 92 | @property 93 | def event_type(self) -> EventType: 94 | return EventType.SUSTAIN_ON 95 | 96 | 97 | class SustainOff(_ControlChange): 98 | """A sustain-off MIDI message.""" 99 | 100 | def __init__(self, time: float) -> None: 101 | super().__init__( 102 | consts.SUSTAIN_PEDAL_CC_NUMBER, 103 | 0, 104 | time, 105 | ) 106 | 107 | @property 108 | def event_type(self) -> EventType: 109 | return EventType.SUSTAIN_OFF 110 | 111 | 112 | MidiMessage = Union[NoteOn, NoteOff, SustainOn, SustainOff] 113 | -------------------------------------------------------------------------------- /robopianist/viewer/gui/base.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The dm_control Authors. 2 | # Copyright 2023 The RoboPianist Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """Utilities and base classes used exclusively in the gui package.""" 17 | 18 | import abc 19 | import threading 20 | import time 21 | 22 | from robopianist.viewer import user_input 23 | 24 | _DOUBLE_CLICK_INTERVAL = 0.25 # seconds 25 | 26 | 27 | class InputEventsProcessor(metaclass=abc.ABCMeta): 28 | """Thread safe input events processor.""" 29 | 30 | def __init__(self): 31 | """Instance initializer.""" 32 | self._lock = threading.RLock() 33 | self._events = [] 34 | 35 | def add_event(self, receivers, *args): 36 | """Adds a new event to the processing queue.""" 37 | if not all(callable(receiver) for receiver in receivers): 38 | raise TypeError("Receivers are expected to be callables.") 39 | 40 | def event(): 41 | for receiver in list(receivers): 42 | receiver(*args) 43 | 44 | with self._lock: 45 | self._events.append(event) 46 | 47 | def process_events(self): 48 | """Invokes each of the events in the queue. 49 | 50 | Thread safe for queue access but not during event invocations. 51 | 52 | This method must be called regularly on the main thread. 53 | """ 54 | with self._lock: 55 | # Swap event buffers quickly so that we don't block the input thread for 56 | # too long. 57 | events_to_process = self._events 58 | self._events = [] 59 | 60 | # Now that we made the swap, process the received events in our own time. 61 | for event in events_to_process: 62 | event() 63 | 64 | 65 | class DoubleClickDetector: 66 | """Detects double click events.""" 67 | 68 | def __init__(self): 69 | self._double_clicks = {} 70 | 71 | def process(self, button, action): 72 | """Attempts to identify a mouse button click as a double click event.""" 73 | if action != user_input.PRESS: 74 | return False 75 | 76 | curr_time = time.time() 77 | timestamp = self._double_clicks.get(button, None) 78 | if timestamp is None: 79 | # No previous click registered. 80 | self._double_clicks[button] = curr_time 81 | return False 82 | else: 83 | time_elapsed = curr_time - timestamp 84 | if time_elapsed < _DOUBLE_CLICK_INTERVAL: 85 | # Double click discovered. 86 | self._double_clicks[button] = None 87 | return True 88 | else: 89 | # The previous click was too long ago, so discard it and start a fresh 90 | # timer. 91 | self._double_clicks[button] = curr_time 92 | return False 93 | -------------------------------------------------------------------------------- /robopianist/models/piano/piano_constants.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The RoboPianist Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Piano modeling constants. 16 | 17 | Inspired by: https://kawaius.com/wp-content/uploads/2019/04/Kawai-Upright-Piano-Regulation-Manual.pdf 18 | """ 19 | 20 | from math import atan 21 | 22 | NUM_KEYS = 88 23 | NUM_WHITE_KEYS = 52 24 | WHITE_KEY_WIDTH = 0.0225 25 | WHITE_KEY_LENGTH = 0.15 26 | WHITE_KEY_HEIGHT = WHITE_KEY_WIDTH 27 | SPACING_BETWEEN_WHITE_KEYS = 0.001 28 | N_SPACES_BETWEEN_WHITE_KEYS = NUM_WHITE_KEYS - 1 29 | BLACK_KEY_WIDTH = 0.01 30 | BLACK_KEY_LENGTH = 0.09 31 | # Unlike the other dimensions, the height of the black key was roughly set such that 32 | # when a white key is fully depressed, the bottom of the black key is barely visible. 33 | BLACK_KEY_HEIGHT = 0.018 34 | PIANO_LENGTH = (NUM_WHITE_KEYS * WHITE_KEY_WIDTH) + ( 35 | N_SPACES_BETWEEN_WHITE_KEYS * SPACING_BETWEEN_WHITE_KEYS 36 | ) 37 | 38 | WHITE_KEY_X_OFFSET = 0 39 | WHITE_KEY_Z_OFFSET = WHITE_KEY_HEIGHT / 2 40 | BLACK_KEY_X_OFFSET = -WHITE_KEY_LENGTH / 2 + BLACK_KEY_LENGTH / 2 41 | # The top of the black key should be 12.5 mm above the top of the white key. 42 | BLACK_OFFSET_FROM_WHITE = 0.0125 43 | BLACK_KEY_Z_OFFSET = WHITE_KEY_HEIGHT + BLACK_OFFSET_FROM_WHITE - BLACK_KEY_HEIGHT / 2 44 | 45 | BASE_HEIGHT = 0.04 46 | BASE_LENGTH = 0.1 47 | BASE_WIDTH = PIANO_LENGTH 48 | BASE_SIZE = [BASE_LENGTH / 2, BASE_WIDTH / 2, BASE_HEIGHT / 2] 49 | BASE_X_OFFSET = -WHITE_KEY_LENGTH / 2 - 0.5 * BASE_LENGTH - 0.002 50 | BASE_POS = [BASE_X_OFFSET, 0, BASE_HEIGHT / 2] 51 | 52 | # A key is designed to travel downward 3/8 of an inch (roughly 10mm). 53 | # Assuming the joint is positioned at the back of the key, we can write: 54 | # tan(θ) = d / l, where d is the distance the key travels and l is the length of the 55 | # key. Solving for θ, we get: θ = arctan(d / l). 56 | WHITE_KEY_TRAVEL_DISTANCE = 0.01 57 | WHITE_KEY_JOINT_MAX_ANGLE = atan(WHITE_KEY_TRAVEL_DISTANCE / WHITE_KEY_LENGTH) 58 | # TODO(kevin): Figure out black key travel distance. 59 | BLACK_KEY_TRAVEL_DISTANCE = 0.008 60 | BLACK_KEY_JOINT_MAX_ANGLE = atan(BLACK_KEY_TRAVEL_DISTANCE / BLACK_KEY_LENGTH) 61 | # Mass in kg. 62 | WHITE_KEY_MASS = 0.04 63 | BLACK_KEY_MASS = 0.02 64 | # Joint spring reference, in degrees. 65 | # At equilibrium, the joint should be at 0 degrees. 66 | WHITE_KEY_SPRINGREF = -1 67 | BLACK_KEY_SPRINGREF = -1 68 | # Joint spring stiffness, in Nm/rad. 69 | # The spring should be stiff enough to support the weight of the key at equilibrium. 70 | WHITE_KEY_STIFFNESS = 2 71 | BLACK_KEY_STIFFNESS = 2 72 | # Joint damping and armature for smoothing key motion. 73 | WHITE_JOINT_DAMPING = 0.05 74 | BLACK_JOINT_DAMPING = 0.05 75 | WHITE_JOINT_ARMATURE = 0.001 76 | BLACK_JOINT_ARMATURE = 0.001 77 | 78 | # Actuator parameters (for self-actuated only). 79 | ACTUATOR_DYNPRM = 1 80 | ACTUATOR_GAINPRM = 1 81 | 82 | # Colors. 83 | WHITE_KEY_COLOR = [0.9, 0.9, 0.9, 1] 84 | BLACK_KEY_COLOR = [0.1, 0.1, 0.1, 1] 85 | BASE_COLOR = [0.15, 0.15, 0.15, 1] 86 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The RoboPianist Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import re 16 | from pathlib import Path 17 | 18 | from setuptools import find_packages, setup 19 | 20 | _here = Path(__file__).resolve().parent 21 | 22 | name = "robopianist" 23 | 24 | # Reference: https://github.com/patrick-kidger/equinox/blob/main/setup.py 25 | with open(_here / name / "__init__.py") as f: 26 | meta_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", f.read(), re.M) 27 | if meta_match: 28 | version = meta_match.group(1) 29 | else: 30 | raise RuntimeError("Unable to find __version__ string.") 31 | 32 | 33 | with open(_here / "README.md", "r") as f: 34 | readme = f.read() 35 | 36 | core_requirements = [ 37 | "dm_control>=1.0.16", 38 | "dm_env_wrappers>=0.0.11", 39 | "mujoco>=3.1.1", 40 | "mujoco_utils>=0.0.6", 41 | "note_seq>=0.0.5", 42 | "pretty_midi>=0.2.10", 43 | "pyaudio>=0.2.12", 44 | "pyfluidsynth>=1.3.2", 45 | "scikit-learn==1.4.2", 46 | "termcolor", 47 | "tqdm", 48 | ] 49 | 50 | test_requirements = [ 51 | "absl-py", 52 | "pytest-xdist", 53 | ] 54 | 55 | dev_requirements = [ 56 | "black", 57 | "ruff", 58 | "mypy", 59 | ] + test_requirements 60 | 61 | classifiers = [ 62 | "Development Status :: 5 - Production/Stable", 63 | "Intended Audience :: Developers", 64 | "Intended Audience :: Science/Research", 65 | "License :: OSI Approved :: Apache Software License", 66 | "Natural Language :: English", 67 | "Programming Language :: Python :: 3", 68 | "Programming Language :: Python :: 3.8", 69 | "Programming Language :: Python :: 3.9", 70 | "Programming Language :: Python :: 3.10", 71 | "Programming Language :: Python :: 3.11", 72 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 73 | ] 74 | 75 | author = "Kevin Zakka" 76 | 77 | author_email = "kevinarmandzakka@gmail.com" 78 | 79 | description = "A benchmark for high-dimensional robot control" 80 | 81 | keywords = "reinforcement-learning mujoco bimanual dexterous-manipulation piano" 82 | 83 | setup( 84 | name=name, 85 | version=version, 86 | author=author, 87 | author_email=author_email, 88 | maintainer=author, 89 | maintainer_email=author_email, 90 | description=description, 91 | long_description=readme, 92 | long_description_content_type="text/markdown", 93 | keywords=keywords, 94 | url=f"https://github.com/google-research/{name}", 95 | license="Apache License 2.0", 96 | license_files=("LICENSE",), 97 | packages=find_packages(exclude=["examples"]), 98 | python_requires=">=3.8", 99 | install_requires=core_requirements, 100 | include_package_data=True, 101 | classifiers=classifiers, 102 | extras_require={ 103 | "test": test_requirements, 104 | "dev": dev_requirements, 105 | }, 106 | zip_safe=False, 107 | entry_points={"console_scripts": [f"{name}={name}.cli:main"]}, 108 | ) 109 | -------------------------------------------------------------------------------- /robopianist/viewer/figures.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The RoboPianist Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import Optional 16 | 17 | from robopianist.viewer import views 18 | 19 | 20 | class MujocoFigureModelWithRuntime(views.MujocoFigureModel): 21 | """Base class for figures that need access to the runtime.""" 22 | 23 | def __init__(self, **kwargs) -> None: 24 | super().__init__(**kwargs) 25 | 26 | self._runtime = None 27 | self._on_episode_begin_callbacks = [self.reset] 28 | 29 | def set_runtime(self, instance) -> None: 30 | if self._runtime is not None: 31 | for callback in self._on_episode_begin_callbacks: 32 | self._runtime.on_episode_begin -= callback 33 | self._runtime = instance 34 | if self._runtime: 35 | for callback in self._on_episode_begin_callbacks: 36 | self._runtime.on_episode_begin += callback 37 | 38 | 39 | class RewardFigure(MujocoFigureModelWithRuntime): 40 | """Plot the total reward over time.""" 41 | 42 | def __init__(self, pause, **kwargs) -> None: 43 | super().__init__(**kwargs) 44 | 45 | self._pause = pause 46 | self._series = views.TimeSeries() 47 | 48 | self._on_episode_begin_callbacks.append(self.reset_series) 49 | 50 | def configure_figure(self) -> None: 51 | self._figure.title = "Reward" 52 | self._figure.xlabel = "Timestep" 53 | 54 | def get_time_series(self) -> Optional[views.TimeSeries]: 55 | if self._runtime is None or self._pause.value: 56 | return None 57 | 58 | reward = self._runtime._time_step.reward 59 | self._series.add(reward) 60 | 61 | return self._series 62 | 63 | def reset_series(self) -> None: 64 | self._series.clear() 65 | 66 | 67 | class RewardTermsFigure(MujocoFigureModelWithRuntime): 68 | """Plot the different reward terms over time.""" 69 | 70 | def __init__(self, pause, **kwargs) -> None: 71 | super().__init__(**kwargs) 72 | 73 | self._pause = pause 74 | self._series = views.TimeSeries() 75 | 76 | self._on_episode_begin_callbacks.append(self.reset_series) 77 | 78 | def configure_figure(self) -> None: 79 | self._figure.title = "Reward" 80 | self._figure.xlabel = "Timestep" 81 | 82 | def get_time_series(self) -> Optional[views.TimeSeries]: 83 | if self._runtime is None or self._pause.value: 84 | return None 85 | 86 | reward = self._runtime._time_step.reward 87 | 88 | if not hasattr(self._runtime._environment.task, "reward_fn"): 89 | self._series.add(reward) 90 | else: 91 | reward_fn = self._runtime._environment.task.reward_fn 92 | reward_dict = {k: v for k, v in reward_fn.reward_terms.items()} 93 | reward_dict["total"] = reward # Also log the total reward. 94 | self._series.add_dict(reward_dict) 95 | 96 | return self._series 97 | 98 | def reset_series(self) -> None: 99 | self._series.clear() 100 | -------------------------------------------------------------------------------- /examples/http_player.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The RoboPianist Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """An HTTP server that plays notes received through POST requests. 16 | 17 | To try it out, start the server: 18 | python examples/http_player.py 19 | 20 | Then send it a post request like: 21 | curl -X POST localhost:8080 -d 'ACTIVATION=[40,44]' 22 | """ 23 | 24 | import re 25 | from http.server import BaseHTTPRequestHandler, HTTPServer 26 | from typing import Optional 27 | 28 | import numpy as np 29 | 30 | from robopianist.music import midi_file, synthesizer 31 | from robopianist.music.constants import NUM_KEYS 32 | 33 | hostname = "localhost" 34 | serverport = 8080 35 | 36 | _ACTIVATION_RE = re.compile(r"^ACTIVATION=\[((\d+,)*(\d+)?)\]$") 37 | 38 | 39 | class PianoServer(BaseHTTPRequestHandler): 40 | """An HTTP server that plays notes received through POST requests.""" 41 | 42 | def __init__(self, *args, **kwargs): 43 | self._prev_activation = np.zeros(NUM_KEYS, dtype=bool) 44 | 45 | super().__init__(*args, **kwargs) 46 | 47 | def do_POST(self) -> None: 48 | global _synth 49 | assert _synth is not None 50 | 51 | self.send_response(200) 52 | self.send_header("Content-Length", "0") 53 | self.end_headers() 54 | 55 | for line in self.rfile: 56 | m = _ACTIVATION_RE.match(str(line, "utf-8")) 57 | if m: 58 | activation = np.zeros(NUM_KEYS, dtype=bool) 59 | if m.group(1): 60 | active = [int(d) for d in m.group(1).split(",")] 61 | for key_id in active: 62 | if key_id < NUM_KEYS: 63 | activation[key_id] = True 64 | else: 65 | print(f"Invalid key id: {key_id}") 66 | 67 | state_change = activation ^ self._prev_activation 68 | 69 | # Note on events. 70 | for key_id in np.flatnonzero(state_change & ~self._prev_activation): 71 | _synth.note_on( 72 | midi_file.key_number_to_midi_number(key_id), 73 | 127, 74 | ) 75 | 76 | # Note off events. 77 | for key_id in np.flatnonzero(state_change & ~activation): 78 | _synth.note_off(midi_file.key_number_to_midi_number(key_id)) 79 | 80 | # Update state. 81 | self._prev_activation = activation.copy() 82 | 83 | break 84 | 85 | def log_request(self, request): 86 | del request # Unused. 87 | 88 | 89 | _synth: Optional[synthesizer.Synthesizer] = None 90 | 91 | if __name__ == "__main__": 92 | _synth = synthesizer.Synthesizer() 93 | _synth.start() 94 | 95 | webServer = HTTPServer((hostname, serverport), PianoServer) 96 | print(f"Server started http://{hostname}:{serverport}") 97 | 98 | try: 99 | webServer.serve_forever() 100 | except KeyboardInterrupt: 101 | pass 102 | 103 | webServer.server_close() 104 | _synth.stop() 105 | print("Server stopped.") 106 | -------------------------------------------------------------------------------- /robopianist/music/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The RoboPianist Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Music module.""" 16 | 17 | from pathlib import Path 18 | from typing import Union 19 | 20 | from robopianist import _PROJECT_ROOT 21 | from robopianist.music import library, midi_file 22 | 23 | 24 | def _camel_case(name: str) -> str: 25 | new_name = name.replace("'", "") # Remove apostrophes. 26 | new_name = new_name.replace("_", " ").title().replace(" ", "") 27 | # We have a -{number} suffix which originally came from the different fingering 28 | # annotations per file in the PIG dataset. We remove it here. 29 | if "-" in new_name: 30 | new_name = new_name[: new_name.index("-")] 31 | return new_name 32 | 33 | 34 | _PIG_DIR = _PROJECT_ROOT / "robopianist" / "music" / "data" / "pig_single_finger" 35 | _PIG_FILES = sorted(_PIG_DIR.glob("*.proto")) 36 | PIG_MIDIS = [_camel_case(Path(f).stem) for f in _PIG_FILES] 37 | _ETUDE_SUBSET = ( 38 | "french_suite_no_1_allemande-1", 39 | "french_suite_no_5_sarabande-1", 40 | "piano_sonata_d_845_1st_mov-1", 41 | "partita_no_2_6-1", 42 | "waltz_op_64_no_1-1", 43 | "bagatelle_op_3_no_4-1", 44 | "kreisleriana_op_16_no_8-1", 45 | "french_suite_no_5_gavotte-1", 46 | "piano_sonata_no_23_2nd_mov-1", 47 | "golliwogg's_cakewalk-1", 48 | "piano_sonata_no_2_1st_mov-1", 49 | "piano_sonata_k_279_in_c_major_1st_mov-1", 50 | ) 51 | ETUDE_MIDIS = [_camel_case(name) for name in _ETUDE_SUBSET] 52 | _PIG_NAME_TO_FILE = dict(zip(PIG_MIDIS, _PIG_FILES)) 53 | DEBUG_MIDIS = list(library.MIDI_NAME_TO_CALLABLE.keys()) 54 | ALL = DEBUG_MIDIS + PIG_MIDIS 55 | 56 | 57 | def load( 58 | path_or_name: Union[str, Path], 59 | stretch: float = 1.0, 60 | shift: int = 0, 61 | ) -> midi_file.MidiFile: 62 | """Make a MidiFile object from a path or name. 63 | 64 | Args: 65 | path_or_name: Path or name of the midi file. 66 | stretch: Temporal stretch factor. Values greater than 1.0 slow down a song, and 67 | values less than 1.0 speed it up. 68 | shift: Number of semitones to transpose the song by. 69 | 70 | Returns: 71 | A MidiFile object. 72 | 73 | Raises: 74 | ValueError if the path extension is not supported or the MIDI file is invalid. 75 | KeyError if the name is not found in the library. 76 | """ 77 | path = Path(path_or_name) 78 | 79 | if path.suffix: # Note strings will have an empty string suffix. 80 | midi = midi_file.MidiFile.from_file(path) 81 | else: 82 | # Debug midis are generated programmatically and thus should not be loaded from 83 | # file. 84 | if path.stem in DEBUG_MIDIS: 85 | midi = library.MIDI_NAME_TO_CALLABLE[path.stem]() 86 | # PIG midis are stored as proto files and should be loaded from file. 87 | elif path.stem in PIG_MIDIS: 88 | midi = midi_file.MidiFile.from_file(_PIG_NAME_TO_FILE[path.stem]) 89 | else: 90 | raise KeyError(f"Unknown name: {path.stem}. Available names: {ALL}.") 91 | 92 | return midi.stretch(stretch).transpose(shift) 93 | 94 | 95 | __all__ = [ 96 | "ALL", 97 | "DEBUG_MIDIS", 98 | "PIG_MIDIS", 99 | "ETUDE_MIDIS", 100 | "load", 101 | ] 102 | -------------------------------------------------------------------------------- /robopianist/suite/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The RoboPianist Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """RoboPianist suite.""" 16 | 17 | from pathlib import Path 18 | from typing import Any, Dict, Mapping, Optional, Union 19 | 20 | from dm_control import composer 21 | from mujoco_utils import composer_utils 22 | 23 | from robopianist import music 24 | from robopianist.suite.tasks import piano_with_shadow_hands 25 | 26 | # RoboPianist-repertoire-150. 27 | _BASE_REPERTOIRE_NAME = "RoboPianist-repertoire-150-{}-v0" 28 | REPERTOIRE_150 = [_BASE_REPERTOIRE_NAME.format(name) for name in music.PIG_MIDIS] 29 | _REPERTOIRE_150_DICT = dict(zip(REPERTOIRE_150, music.PIG_MIDIS)) 30 | 31 | # RoboPianist-etude-12. 32 | _BASE_ETUDE_NAME = "RoboPianist-etude-12-{}-v0" 33 | ETUDE_12 = [_BASE_ETUDE_NAME.format(name) for name in music.ETUDE_MIDIS] 34 | _ETUDE_12_DICT = dict(zip(ETUDE_12, music.ETUDE_MIDIS)) 35 | 36 | # RoboPianist-debug. 37 | _DEBUG_BASE_NAME = "RoboPianist-debug-{}-v0" 38 | DEBUG = [_DEBUG_BASE_NAME.format(name) for name in music.DEBUG_MIDIS] 39 | _DEBUG_DICT = dict(zip(DEBUG, music.DEBUG_MIDIS)) 40 | 41 | # All valid environment names. 42 | ALL = REPERTOIRE_150 + ETUDE_12 + DEBUG 43 | _ALL_DICT: Dict[str, Union[Path, str]] = { 44 | **_REPERTOIRE_150_DICT, 45 | **_ETUDE_12_DICT, 46 | **_DEBUG_DICT, 47 | } 48 | 49 | 50 | def load( 51 | environment_name: str, 52 | midi_file: Optional[Path] = None, 53 | seed: Optional[int] = None, 54 | stretch: float = 1.0, 55 | shift: int = 0, 56 | recompile_physics: bool = False, 57 | legacy_step: bool = True, 58 | task_kwargs: Optional[Mapping[str, Any]] = None, 59 | ) -> composer.Environment: 60 | """Loads a RoboPianist environment. 61 | 62 | Args: 63 | environment_name: Name of the environment to load. Must be of the form 64 | "RoboPianist-repertoire-150--v0", where is the name of a 65 | PIG dataset MIDI file in camel case notation. 66 | midi_file: Path to a MIDI file to load. If provided, this will override 67 | `environment_name`. 68 | seed: Optional random seed. 69 | stretch: Stretch factor for the MIDI file. 70 | shift: Shift factor for the MIDI file. 71 | recompile_physics: Whether to recompile the physics. 72 | legacy_step: Whether to use the legacy step function. 73 | task_kwargs: Additional keyword arguments to pass to the task. 74 | """ 75 | if midi_file is not None: 76 | midi = music.load(midi_file, stretch=stretch, shift=shift) 77 | else: 78 | if environment_name not in ALL: 79 | raise ValueError( 80 | f"Unknown environment {environment_name}. " 81 | f"Available environments: {ALL}" 82 | ) 83 | midi = music.load(_ALL_DICT[environment_name], stretch=stretch, shift=shift) 84 | 85 | task_kwargs = task_kwargs or {} 86 | 87 | return composer_utils.Environment( 88 | task=piano_with_shadow_hands.PianoWithShadowHands(midi=midi, **task_kwargs), 89 | random_state=seed, 90 | strip_singleton_obs_buffer_dim=True, 91 | recompile_physics=recompile_physics, 92 | legacy_step=legacy_step, 93 | ) 94 | 95 | 96 | __all__ = [ 97 | "ALL", 98 | "DEBUG", 99 | "ETUDE_12", 100 | "REPERTOIRE_150", 101 | "load", 102 | ] 103 | -------------------------------------------------------------------------------- /robopianist/models/hands/base.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The RoboPianist Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import abc 16 | import enum 17 | from typing import Sequence 18 | 19 | import numpy as np 20 | from dm_control import composer, mjcf 21 | from dm_control.composer.observation import observable 22 | from mujoco_utils import types 23 | 24 | 25 | @enum.unique 26 | class HandSide(enum.Enum): 27 | """Which hand side is being modeled.""" 28 | 29 | LEFT = enum.auto() 30 | RIGHT = enum.auto() 31 | 32 | 33 | class Hand(composer.Entity, abc.ABC): 34 | """Base composer class for dexterous hands.""" 35 | 36 | def _build_observables(self) -> "HandObservables": 37 | return HandObservables(self) 38 | 39 | @property 40 | @abc.abstractmethod 41 | def name(self) -> str: 42 | ... 43 | 44 | @property 45 | @abc.abstractmethod 46 | def hand_side(self) -> HandSide: 47 | ... 48 | 49 | @property 50 | @abc.abstractmethod 51 | def root_body(self) -> types.MjcfElement: 52 | ... 53 | 54 | @property 55 | @abc.abstractmethod 56 | def joints(self) -> Sequence[types.MjcfElement]: 57 | ... 58 | 59 | @property 60 | @abc.abstractmethod 61 | def actuators(self) -> Sequence[types.MjcfElement]: 62 | ... 63 | 64 | @property 65 | @abc.abstractmethod 66 | def fingertip_sites(self) -> Sequence[types.MjcfElement]: 67 | ... 68 | 69 | 70 | class HandObservables(composer.Observables): 71 | """Base class for dexterous hand observables.""" 72 | 73 | _entity: Hand 74 | 75 | @composer.observable 76 | def joints_pos(self): 77 | """Returns the joint positions.""" 78 | return observable.MJCFFeature("qpos", self._entity.joints) 79 | 80 | @composer.observable 81 | def joints_pos_cos_sin(self): 82 | """Returns the joint positions encoded as (cos, sin) pairs. 83 | 84 | This has twice as many dimensions as the raw joint positions. 85 | """ 86 | 87 | def _get_joint_angles(physics: mjcf.Physics) -> np.ndarray: 88 | qpos = physics.bind(self._entity.joints).qpos 89 | return np.hstack([np.cos(qpos), np.sin(qpos)]) 90 | 91 | return observable.Generic(raw_observation_callable=_get_joint_angles) 92 | 93 | @composer.observable 94 | def joints_vel(self): 95 | """Returns the joint velocities.""" 96 | return observable.MJCFFeature("qvel", self._entity.joints) 97 | 98 | @composer.observable 99 | def joints_torque(self) -> observable.Generic: 100 | """Returns the joint torques.""" 101 | 102 | def _get_joint_torques(physics: mjcf.Physics) -> np.ndarray: 103 | # We only care about torques acting on each joint's axis of rotation, so we 104 | # project them. 105 | torques = physics.bind(self._entity.joint_torque_sensors).sensordata 106 | joint_axes = physics.bind(self._entity.joints).axis 107 | return np.einsum("ij,ij->i", torques.reshape(-1, 3), joint_axes) 108 | 109 | return observable.Generic(raw_observation_callable=_get_joint_torques) 110 | 111 | @composer.observable 112 | def position(self): 113 | """Returns the position of the hand's root body in the world frame.""" 114 | return observable.MJCFFeature("xpos", self._entity.root_body) 115 | -------------------------------------------------------------------------------- /robopianist/wrappers/sound.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The RoboPianist Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """A wrapper for rendering videos with sound.""" 16 | 17 | import shutil 18 | import subprocess 19 | import wave 20 | from pathlib import Path 21 | 22 | import dm_env 23 | from dm_env_wrappers import DmControlVideoWrapper 24 | 25 | from robopianist import SF2_PATH 26 | from robopianist.models.piano import midi_module 27 | from robopianist.music import constants as consts 28 | from robopianist.music import midi_message, synthesizer 29 | 30 | 31 | class PianoSoundVideoWrapper(DmControlVideoWrapper): 32 | """Video rendering with sound from the piano keys.""" 33 | 34 | def __init__( 35 | self, 36 | environment: dm_env.Environment, 37 | sf2_path: Path = SF2_PATH, 38 | sample_rate: int = consts.SAMPLING_RATE, 39 | **kwargs, 40 | ) -> None: 41 | # Check that this is an environment with a piano. 42 | if not hasattr(environment.task, "piano"): 43 | raise ValueError("PianoVideoWrapper only works with piano environments.") 44 | 45 | super().__init__(environment, **kwargs) 46 | 47 | self._midi_module: midi_module.MidiModule = environment.task.piano.midi_module 48 | self._sample_rate = sample_rate 49 | self._synth = synthesizer.Synthesizer(sf2_path, sample_rate) 50 | 51 | def _write_frames(self) -> None: 52 | super()._write_frames() 53 | 54 | midi_events = self._midi_module.get_all_midi_messages() 55 | 56 | # Exit if there are no MIDI events or if all events are sustain events. 57 | # Sustain only events cause white noise in the audio (which has shattered my 58 | # eardrums on more than one occasion). 59 | no_events = len(midi_events) == 0 60 | are_events_sustains = [ 61 | isinstance(event, (midi_message.SustainOn, midi_message.SustainOff)) 62 | for event in midi_events 63 | ] 64 | only_sustain = all(are_events_sustains) and len(midi_events) > 0 65 | if no_events or only_sustain: 66 | return 67 | 68 | # Synthesize waveform. 69 | waveform = self._synth.get_samples(midi_events) 70 | 71 | # Save waveform as mp3. 72 | waveform_name = self._record_dir / f"{self._counter:05d}.mp3" 73 | wf = wave.open(str(waveform_name), "wb") 74 | wf.setnchannels(1) 75 | wf.setsampwidth(2) 76 | wf.setframerate(self._sample_rate * self._playback_speed) 77 | wf.writeframes(waveform) # type: ignore 78 | wf.close() 79 | 80 | # Make a copy of the MP4 so that FFMPEG can overwrite it. 81 | filename = self._record_dir / f"{self._counter:05d}.mp4" 82 | temp_filename = self._record_dir / "temp.mp4" 83 | shutil.copyfile(filename, temp_filename) 84 | filename.unlink() 85 | 86 | # Add the sound to the MP4 using FFMPEG, suppressing the output. 87 | # Reference: https://stackoverflow.com/a/11783474 88 | ret = subprocess.run( 89 | [ 90 | "ffmpeg", 91 | "-nostdin", 92 | "-y", 93 | "-i", 94 | str(temp_filename), 95 | "-i", 96 | str(waveform_name), 97 | "-map", 98 | "0", 99 | "-map", 100 | "1:a", 101 | "-c:v", 102 | "copy", 103 | "-shortest", 104 | str(filename), 105 | ], 106 | stdout=subprocess.DEVNULL, 107 | stderr=subprocess.STDOUT, 108 | check=True, 109 | ) 110 | if ret.returncode != 0: 111 | print(f"FFMPEG failed to add sound to video {temp_filename}.") 112 | 113 | # Remove temporary files. 114 | temp_filename.unlink() 115 | waveform_name.unlink() 116 | 117 | def __del__(self) -> None: 118 | self._synth.stop() 119 | -------------------------------------------------------------------------------- /examples/self_actuated_piano_env.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The RoboPianist Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Self-actuated piano environment.""" 16 | 17 | import dm_env 18 | import numpy as np 19 | from absl import app, flags 20 | from dm_control.mjcf import export_with_assets 21 | from mujoco import viewer as mujoco_viewer 22 | from mujoco_utils import composer_utils 23 | 24 | from robopianist import music, viewer 25 | from robopianist.suite.tasks import self_actuated_piano 26 | from robopianist.wrappers import MidiEvaluationWrapper, PianoSoundVideoWrapper 27 | 28 | _FILE = flags.DEFINE_string("file", "TwinkleTwinkleRousseau", "") 29 | _RECORD = flags.DEFINE_bool("record", False, "") 30 | _EXPORT = flags.DEFINE_bool("export", False, "") 31 | _TRIM_SILENCE = flags.DEFINE_bool("trim_silence", False, "") 32 | _CONTROL_TIMESTEP = flags.DEFINE_float("control_timestep", 0.01, "") 33 | _STRETCH = flags.DEFINE_float("stretch", 1.0, "") 34 | _SHIFT = flags.DEFINE_integer("shift", 0, "") 35 | _PLAYBACK_SPEED = flags.DEFINE_float("playback_speed", 1.0, "") 36 | 37 | 38 | def main(_) -> None: 39 | task = self_actuated_piano.SelfActuatedPiano( 40 | midi=music.load(_FILE.value, stretch=_STRETCH.value, shift=_SHIFT.value), 41 | change_color_on_activation=True, 42 | trim_silence=_TRIM_SILENCE.value, 43 | control_timestep=_CONTROL_TIMESTEP.value, 44 | ) 45 | if _EXPORT.value: 46 | export_with_assets( 47 | task.root_entity.mjcf_model, 48 | out_dir="/tmp/robopianist/self_actuated_piano", 49 | out_file_name="scene.xml", 50 | ) 51 | mujoco_viewer.launch_from_path("/tmp/robopianist/self_actuated_piano/scene.xml") 52 | return 53 | 54 | env = composer_utils.Environment( 55 | recompile_physics=False, task=task, strip_singleton_obs_buffer_dim=True 56 | ) 57 | env = MidiEvaluationWrapper(env) 58 | if _RECORD.value: 59 | env = PianoSoundVideoWrapper( 60 | env, 61 | record_every=1, 62 | camera_id="piano/topdown", 63 | playback_speed=_PLAYBACK_SPEED.value, 64 | ) 65 | 66 | action_spec = env.action_spec() 67 | min_ctrl = action_spec.minimum 68 | max_ctrl = action_spec.maximum 69 | print(f"Action dimension: {action_spec.shape}") 70 | 71 | # Sanity check observables. 72 | print("Observables:") 73 | timestep = env.reset() 74 | dim = 0 75 | for k, v in timestep.observation.items(): 76 | print(f"\t{k}: {v.shape} {v.dtype}") 77 | dim += np.prod(v.shape) 78 | print(f"Observation dimension: {dim}") 79 | 80 | print(f"Control frequency: {1 / _CONTROL_TIMESTEP.value} Hz") 81 | 82 | class Oracle: 83 | def __call__(self, timestep: dm_env.TimeStep) -> np.ndarray: 84 | if timestep.reward is not None: 85 | assert timestep.reward == 0 86 | # Only grab the next timestep's goal state. 87 | goal = timestep.observation["goal"][: task.piano.n_keys] 88 | key_idxs = np.flatnonzero(goal) 89 | # For goal keys that should be pressed, set the action to the maximum 90 | # actuator value. For goal keys that should be released, set the action to 91 | # the minimum actuator value. 92 | action = min_ctrl.copy() 93 | action[key_idxs] = max_ctrl[key_idxs] 94 | # Grab the sustain pedal action. 95 | action[-1] = timestep.observation["goal"][-1] 96 | return action 97 | 98 | policy = Oracle() 99 | 100 | if not _RECORD.value: 101 | viewer.launch(env, policy=policy) 102 | else: 103 | timestep = env.reset() 104 | while not timestep.last(): 105 | action = policy(timestep) 106 | timestep = env.step(action) 107 | 108 | for k, v in env.get_musical_metrics().items(): 109 | np.testing.assert_equal(v, 1.0) 110 | 111 | 112 | if __name__ == "__main__": 113 | app.run(main) 114 | -------------------------------------------------------------------------------- /robopianist/viewer/gui/fullscreen_quad.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The dm_control Authors. 2 | # Copyright 2023 The RoboPianist Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | 17 | """OpenGL utility for rendering numpy arrays as images on a quad surface.""" 18 | 19 | import ctypes 20 | 21 | import numpy as np 22 | from OpenGL import GL 23 | from OpenGL.GL import shaders 24 | 25 | # This array contains packed position and texture cooridnates of a fullscreen 26 | # quad. 27 | # It contains definition of 4 vertices that will be rendered as a triangle 28 | # strip. Each vertex is described by a tuple: 29 | # (VertexPosition.X, VertexPosition.Y, TextureCoord.U, TextureCoord.V) 30 | _FULLSCREEN_QUAD_VERTEX_POSITONS_AND_TEXTURE_COORDS = np.array( 31 | [-1, -1, 0, 1, -1, 1, 0, 0, 1, -1, 1, 1, 1, 1, 1, 0], dtype=np.float32 32 | ) 33 | _FLOATS_PER_XY = 2 34 | _FLOATS_PER_VERTEX = 4 35 | _SIZE_OF_FLOAT = ctypes.sizeof(ctypes.c_float) 36 | 37 | _VERTEX_SHADER = """ 38 | #version 120 39 | attribute vec2 position; 40 | attribute vec2 uv; 41 | void main() { 42 | gl_Position = vec4(position, 0, 1); 43 | gl_TexCoord[0].st = uv; 44 | } 45 | """ 46 | _FRAGMENT_SHADER = """ 47 | #version 120 48 | uniform sampler2D tex; 49 | void main() { 50 | gl_FragColor = texture2D(tex, gl_TexCoord[0].st); 51 | } 52 | """ 53 | _VAR_POSITION = "position" 54 | _VAR_UV = "uv" 55 | _VAR_TEXTURE_SAMPLER = "tex" 56 | 57 | 58 | class FullscreenQuadRenderer: 59 | """Renders pixmaps on a fullscreen quad using OpenGL.""" 60 | 61 | def __init__(self): 62 | """Initializes the fullscreen quad renderer.""" 63 | GL.glClearColor(0, 0, 0, 0) 64 | self._init_geometry() 65 | self._init_texture() 66 | self._init_shaders() 67 | 68 | def _init_geometry(self): 69 | """Initializes the fullscreen quad geometry.""" 70 | vertex_buffer = GL.glGenBuffers(1) 71 | GL.glBindBuffer(GL.GL_ARRAY_BUFFER, vertex_buffer) 72 | GL.glBufferData( 73 | GL.GL_ARRAY_BUFFER, 74 | _FULLSCREEN_QUAD_VERTEX_POSITONS_AND_TEXTURE_COORDS.nbytes, 75 | _FULLSCREEN_QUAD_VERTEX_POSITONS_AND_TEXTURE_COORDS, 76 | GL.GL_STATIC_DRAW, 77 | ) 78 | 79 | def _init_texture(self): 80 | """Initializes the texture storage.""" 81 | self._texture = GL.glGenTextures(1) 82 | GL.glBindTexture(GL.GL_TEXTURE_2D, self._texture) 83 | GL.glTexParameteri(GL.GL_TEXTURE_2D, GL.GL_TEXTURE_MAG_FILTER, GL.GL_NEAREST) 84 | GL.glTexParameteri(GL.GL_TEXTURE_2D, GL.GL_TEXTURE_MIN_FILTER, GL.GL_NEAREST) 85 | 86 | def _init_shaders(self): 87 | """Initializes the shaders used to render the textures fullscreen quad.""" 88 | vs = shaders.compileShader(_VERTEX_SHADER, GL.GL_VERTEX_SHADER) 89 | fs = shaders.compileShader(_FRAGMENT_SHADER, GL.GL_FRAGMENT_SHADER) 90 | self._shader = shaders.compileProgram(vs, fs) 91 | 92 | stride = _FLOATS_PER_VERTEX * _SIZE_OF_FLOAT 93 | var_position = GL.glGetAttribLocation(self._shader, _VAR_POSITION) 94 | GL.glVertexAttribPointer( 95 | var_position, 2, GL.GL_FLOAT, GL.GL_FALSE, stride, None 96 | ) 97 | GL.glEnableVertexAttribArray(var_position) 98 | 99 | var_uv = GL.glGetAttribLocation(self._shader, _VAR_UV) 100 | uv_offset = ctypes.c_void_p(_FLOATS_PER_XY * _SIZE_OF_FLOAT) 101 | GL.glVertexAttribPointer(var_uv, 2, GL.GL_FLOAT, GL.GL_FALSE, stride, uv_offset) 102 | GL.glEnableVertexAttribArray(var_uv) 103 | 104 | self._var_texture_sampler = GL.glGetUniformLocation( 105 | self._shader, _VAR_TEXTURE_SAMPLER 106 | ) 107 | 108 | def render(self, pixmap, viewport_shape): 109 | """Renders the pixmap on a fullscreen quad. 110 | 111 | Args: 112 | pixmap: A 3D numpy array of bytes (np.uint8), with dimensions 113 | (width, height, 3). 114 | viewport_shape: A tuple of two elements, (width, height). 115 | """ 116 | GL.glClear(GL.GL_COLOR_BUFFER_BIT) 117 | GL.glViewport(0, 0, *viewport_shape) 118 | GL.glUseProgram(self._shader) 119 | GL.glActiveTexture(GL.GL_TEXTURE0) 120 | GL.glBindTexture(GL.GL_TEXTURE_2D, self._texture) 121 | GL.glPixelStorei(GL.GL_UNPACK_ALIGNMENT, 1) 122 | GL.glTexImage2D( 123 | GL.GL_TEXTURE_2D, 124 | 0, 125 | GL.GL_RGB, 126 | pixmap.shape[1], 127 | pixmap.shape[0], 128 | 0, 129 | GL.GL_RGB, 130 | GL.GL_UNSIGNED_BYTE, 131 | pixmap, 132 | ) 133 | GL.glUniform1i(self._var_texture_sampler, 0) 134 | GL.glDrawArrays(GL.GL_TRIANGLE_STRIP, 0, 4) 135 | -------------------------------------------------------------------------------- /robopianist/music/synthesizer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The RoboPianist Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Library for synthesizing music from MIDI files.""" 16 | 17 | from pathlib import Path 18 | from typing import Sequence 19 | 20 | import fluidsynth 21 | import numpy as np 22 | 23 | from robopianist import SF2_PATH 24 | from robopianist.music import constants as consts 25 | from robopianist.music import midi_message 26 | from robopianist.music.constants import SAMPLING_RATE 27 | 28 | _PROGRAM = 0 # Acoustic Grand Piano 29 | _CHANNEL = 0 30 | _BANK = 0 31 | 32 | 33 | def _validate_note(note: int) -> None: 34 | assert consts.MIN_MIDI_PITCH <= note <= consts.MAX_MIDI_PITCH 35 | 36 | 37 | def _validate_velocity(velocity: int) -> None: 38 | assert consts.MIN_VELOCITY <= velocity <= consts.MAX_VELOCITY 39 | 40 | 41 | class Synthesizer: 42 | """FluidSynth-based synthesizer.""" 43 | 44 | def __init__( 45 | self, 46 | soundfont_path: Path = SF2_PATH, 47 | sample_rate: int = SAMPLING_RATE, 48 | ) -> None: 49 | self._soundfont_path = soundfont_path 50 | self._sample_rate = sample_rate 51 | self._muted: bool = False 52 | self._sustained: bool = False 53 | 54 | # Initialize FluidSynth. 55 | self._synth = fluidsynth.Synth(samplerate=float(sample_rate)) 56 | soundfont_id = self._synth.sfload(str(soundfont_path)) 57 | self._synth.program_select(_CHANNEL, soundfont_id, _BANK, _PROGRAM) 58 | 59 | def start(self) -> None: 60 | self._synth.start() 61 | 62 | def stop(self) -> None: 63 | self._synth.delete() 64 | 65 | def mute(self, value: bool) -> None: 66 | self._muted = value 67 | if value: 68 | self.all_sounds_off() 69 | 70 | def all_sounds_off(self) -> None: 71 | self._synth.all_sounds_off(_CHANNEL) 72 | 73 | def all_notes_off(self) -> None: 74 | self._synth.all_notes_off(_CHANNEL) 75 | 76 | def note_on(self, note: int, velocity: int) -> None: 77 | if not self._muted: 78 | _validate_note(note) 79 | _validate_velocity(velocity) 80 | self._synth.noteon(_CHANNEL, note, velocity) 81 | 82 | def note_off(self, note: int) -> None: 83 | if not self._muted: 84 | _validate_note(note) 85 | self._synth.noteoff(_CHANNEL, note) 86 | 87 | def sustain_on(self) -> None: 88 | if not self._muted: 89 | self._synth.cc( 90 | _CHANNEL, consts.SUSTAIN_PEDAL_CC_NUMBER, consts.MAX_CC_VALUE 91 | ) 92 | self._sustained = True 93 | 94 | def sustain_off(self) -> None: 95 | if not self._muted: 96 | self._synth.cc( 97 | _CHANNEL, consts.SUSTAIN_PEDAL_CC_NUMBER, consts.MIN_CC_VALUE 98 | ) 99 | self._sustained = False 100 | 101 | @property 102 | def muted(self) -> bool: 103 | return self._muted 104 | 105 | @property 106 | def sustained(self) -> bool: 107 | return self._sustained 108 | 109 | def get_samples( 110 | self, 111 | event_list: Sequence[midi_message.MidiMessage], 112 | ) -> np.ndarray: 113 | """Synthesize a list of MIDI events into a waveform.""" 114 | current_time = event_list[0].time 115 | 116 | # Convert absolute seconds to relative seconds. 117 | next_event_times = [e.time for e in event_list[1:]] 118 | for event, end in zip(event_list[:-1], next_event_times): 119 | event.time = end - event.time 120 | 121 | # Include 1 second of silence at the end. 122 | event_list[-1].time = 1.0 123 | 124 | total_time = current_time + np.sum([e.time for e in event_list]) 125 | synthesized = np.zeros(int(np.ceil(self._sample_rate * total_time))) 126 | for event in event_list: 127 | if isinstance(event, midi_message.NoteOn): 128 | self.note_on(event.note, event.velocity) 129 | elif isinstance(event, midi_message.NoteOff): 130 | self.note_off(event.note) 131 | elif isinstance(event, midi_message.SustainOn): 132 | self.sustain_on() 133 | elif isinstance(event, midi_message.SustainOff): 134 | self.sustain_off() 135 | else: 136 | raise ValueError(f"Unknown event type: {event}") 137 | current_sample = int(self._sample_rate * current_time) 138 | end = int(self._sample_rate * (current_time + event.time)) 139 | samples = self._synth.get_samples(end - current_sample)[::2] 140 | synthesized[current_sample:end] += samples 141 | current_time += event.time 142 | waveform_float = synthesized / np.abs(synthesized).max() 143 | 144 | # Convert to 16-bit ints. 145 | normalizer = float(np.iinfo(np.int16).max) 146 | return np.array(np.asarray(waveform_float) * normalizer, dtype=np.int16) 147 | -------------------------------------------------------------------------------- /robopianist/models/piano/midi_module.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The RoboPianist Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Piano sound module.""" 16 | 17 | from typing import Callable, List, Optional 18 | 19 | import numpy as np 20 | from dm_control import mjcf 21 | 22 | from robopianist.models.piano import piano_constants 23 | from robopianist.music import midi_file, midi_message 24 | 25 | 26 | class MidiModule: 27 | """The piano sound module. 28 | 29 | It is responsible for tracking the state of the piano keys and generating 30 | corresponding MIDI messages. The MIDI messages can be used with a synthesizer 31 | to produce sound. 32 | """ 33 | 34 | def __init__(self) -> None: 35 | self._note_on_callback: Optional[Callable[[int, int], None]] = None 36 | self._note_off_callback: Optional[Callable[[int], None]] = None 37 | self._sustain_on_callback: Optional[Callable[[], None]] = None 38 | self._sustain_off_callback: Optional[Callable[[], None]] = None 39 | 40 | def initialize_episode(self, physics: mjcf.Physics) -> None: 41 | del physics # Unused. 42 | 43 | self._prev_activation = np.zeros(piano_constants.NUM_KEYS, dtype=bool) 44 | self._prev_sustain_activation = np.zeros(1, dtype=bool) 45 | self._midi_messages: List[List[midi_message.MidiMessage]] = [] 46 | 47 | def after_substep( 48 | self, 49 | physics: mjcf.Physics, 50 | activation: np.ndarray, 51 | sustain_activation: np.ndarray, 52 | ) -> None: 53 | # Sanity check dtype since we use bitwise operators. 54 | assert activation.dtype == bool 55 | assert sustain_activation.dtype == bool 56 | 57 | timestep_events: List[midi_message.MidiMessage] = [] 58 | message: midi_message.MidiMessage 59 | 60 | state_change = activation ^ self._prev_activation 61 | sustain_change = sustain_activation ^ self._prev_sustain_activation 62 | 63 | # Note on events. 64 | for key_id in np.flatnonzero(state_change & ~self._prev_activation): 65 | message = midi_message.NoteOn( 66 | note=midi_file.key_number_to_midi_number(key_id), 67 | # TODO(kevin): In the future, we will replace this with the actual 68 | # key velocity. For now, we hardcode it to the maximum velocity. 69 | velocity=127, 70 | time=physics.data.time, 71 | ) 72 | timestep_events.append(message) 73 | if self._note_on_callback is not None: 74 | self._note_on_callback(message.note, message.velocity) 75 | 76 | # Note off events. 77 | for key_id in np.flatnonzero(state_change & ~activation): 78 | message = midi_message.NoteOff( 79 | note=midi_file.key_number_to_midi_number(key_id), 80 | time=physics.data.time, 81 | ) 82 | timestep_events.append(message) 83 | if self._note_off_callback is not None: 84 | self._note_off_callback(message.note) 85 | 86 | # Sustain pedal events. 87 | if sustain_change & ~self._prev_sustain_activation: 88 | timestep_events.append(midi_message.SustainOn(time=physics.data.time)) 89 | if self._sustain_on_callback is not None: 90 | self._sustain_on_callback() 91 | if sustain_change & ~sustain_activation: 92 | timestep_events.append(midi_message.SustainOff(time=physics.data.time)) 93 | if self._sustain_off_callback is not None: 94 | self._sustain_off_callback() 95 | 96 | self._midi_messages.append(timestep_events) 97 | self._prev_activation = activation.copy() 98 | self._prev_sustain_activation = sustain_activation.copy() 99 | 100 | def get_latest_midi_messages(self) -> List[midi_message.MidiMessage]: 101 | """Returns the MIDI messages generated in the last substep.""" 102 | return self._midi_messages[-1] 103 | 104 | def get_all_midi_messages(self) -> List[midi_message.MidiMessage]: 105 | """Returns a list of all MIDI messages generated during the episode.""" 106 | return [message for timestep in self._midi_messages for message in timestep] 107 | 108 | # Callbacks for synthesizer events. 109 | 110 | def register_synth_note_on_callback( 111 | self, 112 | callback: Callable[[int, int], None], 113 | ) -> None: 114 | """Registers a callback for note on events.""" 115 | self._note_on_callback = callback 116 | 117 | def register_synth_note_off_callback( 118 | self, 119 | callback: Callable[[int], None], 120 | ) -> None: 121 | """Registers a callback for note off events.""" 122 | self._note_off_callback = callback 123 | 124 | def register_synth_sustain_on_callback( 125 | self, 126 | callback: Callable[[], None], 127 | ) -> None: 128 | """Registers a callback for sustain pedal on events.""" 129 | self._sustain_on_callback = callback 130 | 131 | def register_synth_sustain_off_callback( 132 | self, 133 | callback: Callable[[], None], 134 | ) -> None: 135 | """Registers a callback for sustain pedal off events.""" 136 | self._sustain_off_callback = callback 137 | -------------------------------------------------------------------------------- /examples/piano_with_shadow_hands_env.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The RoboPianist Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Piano with shadow hands environment.""" 16 | 17 | import dm_env 18 | import numpy as np 19 | from absl import app, flags 20 | from dm_control.mjcf import export_with_assets 21 | from dm_env_wrappers import CanonicalSpecWrapper 22 | from mujoco import viewer as mujoco_viewer 23 | 24 | from robopianist import suite, viewer 25 | from robopianist.wrappers import PianoSoundVideoWrapper 26 | 27 | _ENV_NAME = flags.DEFINE_string( 28 | "env_name", "RoboPianist-debug-TwinkleTwinkleLittleStar-v0", "" 29 | ) 30 | _MIDI_FILE = flags.DEFINE_string("midi_file", None, "") 31 | _CONTROL_TIMESTEP = flags.DEFINE_float("control_timestep", 0.05, "") 32 | _STRETCH = flags.DEFINE_float("stretch", 1.0, "") 33 | _SHIFT = flags.DEFINE_integer("shift", 0, "") 34 | _RECORD = flags.DEFINE_bool("record", False, "") 35 | _EXPORT = flags.DEFINE_bool("export", False, "") 36 | _GRAVITY_COMPENSATION = flags.DEFINE_bool("gravity_compensation", False, "") 37 | _HEADLESS = flags.DEFINE_bool("headless", False, "") 38 | _TRIM_SILENCE = flags.DEFINE_bool("trim_silence", False, "") 39 | _PRIMITIVE_FINGERTIP_COLLISIONS = flags.DEFINE_bool( 40 | "primitive_fingertip_collisions", False, "" 41 | ) 42 | _REDUCED_ACTION_SPACE = flags.DEFINE_bool("reduced_action_space", False, "") 43 | _DISABLE_FINGERING_REWARD = flags.DEFINE_bool("disable_fingering_reward", False, "") 44 | _DISABLE_FOREARM_REWARD = flags.DEFINE_bool("disable_forearm_reward", False, "") 45 | _DISABLE_COLORIZATION = flags.DEFINE_bool("disable_colorization", False, "") 46 | _DISABLE_HAND_COLLISIONS = flags.DEFINE_bool("disable_hand_collisions", False, "") 47 | _CANONICALIZE = flags.DEFINE_bool("canonicalize", False, "") 48 | _N_STEPS_LOOKAHEAD = flags.DEFINE_integer("n_steps_lookahead", 1, "") 49 | _ATTACHMENT_YAW = flags.DEFINE_float("attachment_yaw", 0.0, "") 50 | _ACTION_SEQUENCE = flags.DEFINE_string( 51 | "action_sequence", 52 | None, 53 | "Path to an npy file containing a sequence of actions to replay.", 54 | ) 55 | 56 | 57 | def main(_) -> None: 58 | env = suite.load( 59 | environment_name=_ENV_NAME.value, 60 | midi_file=_MIDI_FILE.value, 61 | stretch=_STRETCH.value, 62 | shift=_SHIFT.value, 63 | task_kwargs=dict( 64 | change_color_on_activation=True, 65 | trim_silence=_TRIM_SILENCE.value, 66 | control_timestep=_CONTROL_TIMESTEP.value, 67 | gravity_compensation=_GRAVITY_COMPENSATION.value, 68 | primitive_fingertip_collisions=_PRIMITIVE_FINGERTIP_COLLISIONS.value, 69 | reduced_action_space=_REDUCED_ACTION_SPACE.value, 70 | n_steps_lookahead=_N_STEPS_LOOKAHEAD.value, 71 | disable_fingering_reward=_DISABLE_FINGERING_REWARD.value, 72 | disable_forearm_reward=_DISABLE_FOREARM_REWARD.value, 73 | disable_colorization=_DISABLE_COLORIZATION.value, 74 | disable_hand_collisions=_DISABLE_HAND_COLLISIONS.value, 75 | attachment_yaw=_ATTACHMENT_YAW.value, 76 | ), 77 | ) 78 | 79 | if _EXPORT.value: 80 | export_with_assets( 81 | env.task.root_entity.mjcf_model, 82 | out_dir="/tmp/robopianist/piano_with_shadow_hands", 83 | out_file_name="scene.xml", 84 | ) 85 | mujoco_viewer.launch_from_path( 86 | "/tmp/robopianist/piano_with_shadow_hands/scene.xml" 87 | ) 88 | return 89 | 90 | if _RECORD.value: 91 | env = PianoSoundVideoWrapper(env, record_every=1) 92 | if _CANONICALIZE.value: 93 | env = CanonicalSpecWrapper(env) 94 | 95 | action_spec = env.action_spec() 96 | zeros = np.zeros(action_spec.shape, dtype=action_spec.dtype) 97 | zeros[-1] = -1.0 # Disable sustain pedal. 98 | print(f"Action dimension: {action_spec.shape}") 99 | 100 | # Sanity check observables. 101 | timestep = env.reset() 102 | dim = 0 103 | for k, v in timestep.observation.items(): 104 | print(f"\t{k}: {v.shape} {v.dtype}") 105 | dim += int(np.prod(v.shape)) 106 | print(f"Observation dimension: {dim}") 107 | 108 | print(f"Control frequency: {1 / _CONTROL_TIMESTEP.value} Hz") 109 | 110 | class Policy: 111 | def __init__(self) -> None: 112 | self.reset() 113 | 114 | def reset(self) -> None: 115 | if _ACTION_SEQUENCE.value is not None: 116 | self._idx = 0 117 | self._actions = np.load(_ACTION_SEQUENCE.value) 118 | 119 | def __call__(self, timestep: dm_env.TimeStep) -> np.ndarray: 120 | del timestep # Unused. 121 | if _ACTION_SEQUENCE.value is not None: 122 | actions = self._actions[self._idx] 123 | self._idx += 1 124 | return actions 125 | return zeros 126 | 127 | policy = Policy() 128 | 129 | if not _RECORD.value: 130 | if _HEADLESS.value: 131 | timestep = env.reset() 132 | while not timestep.last(): 133 | action = policy(timestep) 134 | timestep = env.step(action) 135 | else: 136 | viewer.launch(env, policy=policy) 137 | else: 138 | timestep = env.reset() 139 | while not timestep.last(): 140 | action = policy(timestep) 141 | timestep = env.step(action) 142 | 143 | 144 | if __name__ == "__main__": 145 | app.run(main) 146 | -------------------------------------------------------------------------------- /robopianist/suite/variations_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The RoboPianist Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for variations.py.""" 16 | 17 | import numpy as np 18 | from absl.testing import absltest 19 | from note_seq.protobuf import compare 20 | 21 | from robopianist.music import ALL, library, midi_file 22 | from robopianist.suite import variations 23 | 24 | _SEED = 12345 25 | _NUM_SAMPLES = 100 26 | 27 | 28 | class MidiSelectTest(absltest.TestCase): 29 | def test_output_is_midi_file(self) -> None: 30 | var = variations.MidiSelect(midi_names=ALL) 31 | random_state = np.random.RandomState(_SEED) 32 | for _ in range(_NUM_SAMPLES): 33 | midi = var(random_state=random_state) 34 | self.assertIsInstance(midi, midi_file.MidiFile) 35 | 36 | 37 | class MidiTemporalStretchTest(absltest.TestCase): 38 | def assertProtoEquals(self, a, b, msg=None): 39 | if not compare.ProtoEq(a, b): 40 | compare.assertProtoEqual(self, a, b, normalize_numbers=True, msg=msg) 41 | 42 | def test_output_is_midi_file(self) -> None: 43 | original_midi = library.toy() 44 | var = variations.MidiTemporalStretch(prob=0.1, stretch_range=0.5) 45 | random_state = np.random.RandomState(_SEED) 46 | for _ in range(_NUM_SAMPLES): 47 | midi = var(initial_value=original_midi, random_state=random_state) 48 | self.assertIsInstance(midi, midi_file.MidiFile) 49 | 50 | def test_output_same_if_prob_zero(self) -> None: 51 | original_midi = library.toy() 52 | var = variations.MidiTemporalStretch(prob=0.0, stretch_range=0.5) 53 | random_state = np.random.RandomState(_SEED) 54 | for _ in range(_NUM_SAMPLES): 55 | new_midi = var(initial_value=original_midi, random_state=random_state) 56 | self.assertIs(original_midi, new_midi) 57 | 58 | def test_output_different_if_prob_one(self) -> None: 59 | original_midi = library.toy() 60 | var = variations.MidiTemporalStretch(prob=1.0, stretch_range=0.5) 61 | random_state = np.random.RandomState(_SEED) 62 | for _ in range(_NUM_SAMPLES): 63 | new_midi = var(initial_value=original_midi, random_state=random_state) 64 | self.assertIsNot(original_midi, new_midi) 65 | 66 | def test_raises_value_error_if_no_initial_value(self) -> None: 67 | var = variations.MidiTemporalStretch(prob=0.1, stretch_range=0.5) 68 | random_state = np.random.RandomState(_SEED) 69 | with self.assertRaises(ValueError): 70 | var(random_state=random_state) 71 | 72 | def test_raises_value_error_if_wrong_type(self) -> None: 73 | var = variations.MidiTemporalStretch(prob=0.1, stretch_range=0.5) 74 | random_state = np.random.RandomState(_SEED) 75 | with self.assertRaises(ValueError): 76 | var(initial_value=1, random_state=random_state) 77 | 78 | def test_output_same_if_stretch_range_zero(self) -> None: 79 | original_midi = library.toy() 80 | var = variations.MidiTemporalStretch(prob=0.1, stretch_range=0.0) 81 | random_state = np.random.RandomState(_SEED) 82 | for _ in range(_NUM_SAMPLES): 83 | new_midi = var(initial_value=original_midi, random_state=random_state) 84 | self.assertProtoEquals(original_midi.seq, new_midi.seq) 85 | self.assertEqual(original_midi.duration, new_midi.duration) 86 | 87 | 88 | class MidiPitchShiftTest(absltest.TestCase): 89 | def assertProtoEquals(self, a, b, msg=None): 90 | if not compare.ProtoEq(a, b): 91 | compare.assertProtoEqual(self, a, b, normalize_numbers=True, msg=msg) 92 | 93 | def test_output_is_midi_file(self) -> None: 94 | original_midi = library.toy() 95 | var = variations.MidiPitchShift(prob=0.1, shift_range=1) 96 | random_state = np.random.RandomState(_SEED) 97 | for _ in range(_NUM_SAMPLES): 98 | midi = var(initial_value=original_midi, random_state=random_state) 99 | self.assertIsInstance(midi, midi_file.MidiFile) 100 | 101 | def test_output_same_if_prob_zero(self) -> None: 102 | original_midi = library.toy() 103 | var = variations.MidiPitchShift(prob=0.0, shift_range=1) 104 | random_state = np.random.RandomState(_SEED) 105 | for _ in range(_NUM_SAMPLES): 106 | new_midi = var(initial_value=original_midi, random_state=random_state) 107 | self.assertIs(original_midi, new_midi) 108 | 109 | def test_raises_value_error_if_no_initial_value(self) -> None: 110 | var = variations.MidiPitchShift(prob=0.1, shift_range=1) 111 | random_state = np.random.RandomState(_SEED) 112 | with self.assertRaises(ValueError): 113 | var(random_state=random_state) 114 | 115 | def test_raises_value_error_if_wrong_type(self) -> None: 116 | var = variations.MidiPitchShift(prob=0.1, shift_range=1) 117 | random_state = np.random.RandomState(_SEED) 118 | with self.assertRaises(ValueError): 119 | var(initial_value=1, random_state=random_state) 120 | 121 | def test_output_same_if_range_zero(self) -> None: 122 | original_midi = library.toy() 123 | var = variations.MidiPitchShift(prob=0.1, shift_range=0) 124 | random_state = np.random.RandomState(_SEED) 125 | for _ in range(_NUM_SAMPLES): 126 | new_midi = var(initial_value=original_midi, random_state=random_state) 127 | self.assertIs(original_midi, new_midi) 128 | 129 | 130 | if __name__ == "__main__": 131 | absltest.main() 132 | -------------------------------------------------------------------------------- /robopianist/suite/tasks/self_actuated_piano_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The RoboPianist Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for self_actuated_piano.py.""" 16 | 17 | import numpy as np 18 | from absl.testing import absltest, parameterized 19 | from dm_control import composer 20 | from note_seq.protobuf import music_pb2 21 | 22 | from robopianist.music import midi_file 23 | from robopianist.suite.tasks import self_actuated_piano 24 | 25 | 26 | def _get_test_midi(dt: float = 0.01) -> midi_file.MidiFile: 27 | seq = music_pb2.NoteSequence() 28 | 29 | # C6 for 2 dts. 30 | seq.notes.add( 31 | start_time=0.0, 32 | end_time=2 * dt, 33 | velocity=80, 34 | pitch=midi_file.note_name_to_midi_number("C6"), 35 | part=-1, 36 | ) 37 | # G5 for 1 dt. 38 | seq.notes.add( 39 | start_time=2 * dt, 40 | end_time=3 * dt, 41 | velocity=80, 42 | pitch=midi_file.note_name_to_midi_number("G5"), 43 | part=-1, 44 | ) 45 | 46 | seq.total_time = 3 * dt 47 | seq.tempos.add(qpm=60) 48 | return midi_file.MidiFile(seq=seq) 49 | 50 | 51 | def _get_env( 52 | control_timestep: float = 0.01, 53 | n_steps_lookahead: int = 0, 54 | reward_type: self_actuated_piano.RewardType = self_actuated_piano.RewardType.NEGATIVE_L2, 55 | ) -> composer.Environment: 56 | task = self_actuated_piano.SelfActuatedPiano( 57 | midi=_get_test_midi(dt=control_timestep), 58 | n_steps_lookahead=n_steps_lookahead, 59 | reward_type=reward_type, 60 | control_timestep=control_timestep, 61 | ) 62 | return composer.Environment(task, strip_singleton_obs_buffer_dim=True) 63 | 64 | 65 | class SelfActuatedPianoTest(parameterized.TestCase): 66 | def test_observables(self) -> None: 67 | env = _get_env() 68 | timestep = env.reset() 69 | 70 | self.assertIn("piano/activation", timestep.observation) 71 | self.assertIn("piano/sustain_activation", timestep.observation) 72 | self.assertIn("goal", timestep.observation) 73 | 74 | def test_action_spec(self) -> None: 75 | env = _get_env() 76 | self.assertEqual(env.action_spec().shape, (env.task.piano.n_keys + 1,)) 77 | 78 | def test_termination_and_discount(self) -> None: 79 | env = _get_env() 80 | action_spec = env.action_spec() 81 | zero_action = np.zeros(action_spec.shape) 82 | env.reset() 83 | 84 | # With a dt of 0.01 and a 3 dt long midi, the episode should end after 4 steps. 85 | for _ in range(3): 86 | timestep = env.step(zero_action) 87 | self.assertFalse(env.task.should_terminate_episode(env.physics)) 88 | np.testing.assert_array_equal(env.task.get_discount(env.physics), 1.0) 89 | 90 | # 1 more step to terminate. 91 | timestep = env.step(zero_action) 92 | self.assertTrue(timestep.last()) 93 | self.assertTrue(env.task.should_terminate_episode(env.physics)) 94 | # No failure, so discount should be 1.0. 95 | np.testing.assert_array_equal(env.task.get_discount(env.physics), 1.0) 96 | 97 | @parameterized.parameters(0, 1, 2, 5) 98 | def test_goal_observable_lookahead(self, n_steps_lookahead: int) -> None: 99 | env = _get_env(control_timestep=0.01, n_steps_lookahead=n_steps_lookahead) 100 | action_spec = env.action_spec() 101 | zero_action = np.zeros(action_spec.shape) 102 | timestep = env.reset() 103 | 104 | midi = _get_test_midi(dt=0.01) 105 | note_traj = midi_file.NoteTrajectory.from_midi( 106 | midi, dt=env.task.control_timestep 107 | ) 108 | notes = note_traj.notes 109 | sustains = note_traj.sustains 110 | self.assertLen(notes, 4) 111 | 112 | for i in range(len(notes)): 113 | expected_goal = np.zeros((n_steps_lookahead + 1, env.task.piano.n_keys + 1)) 114 | 115 | t_start = i 116 | t_end = min(i + n_steps_lookahead + 1, len(notes)) 117 | for j, t in enumerate(range(t_start, t_end)): 118 | keys = [note.key for note in notes[t]] 119 | expected_goal[j, keys] = 1.0 120 | expected_goal[j, -1] = sustains[t] 121 | 122 | actual_goal = timestep.observation["goal"] 123 | np.testing.assert_array_equal(actual_goal, expected_goal.ravel()) 124 | 125 | # Check that the 0th goal is always the goal at the current timestep. 126 | expected_current = np.zeros((env.task.piano.n_keys + 1,)) 127 | keys = [note.key for note in notes[i]] 128 | expected_current[keys] = 1.0 129 | expected_current[-1] = sustains[i] 130 | actual_current = timestep.observation["goal"][0 : env.task.piano.n_keys + 1] 131 | np.testing.assert_array_equal(actual_current, expected_current) 132 | 133 | timestep = env.step(zero_action) 134 | 135 | # In the `after_step` method, we cache the goal for the current timestep 136 | # to compute the reward. Let's check that it matches the expected goal. 137 | np.testing.assert_array_equal(expected_current, env.task._goal_current) 138 | 139 | @parameterized.parameters( 140 | self_actuated_piano.RewardType.NEGATIVE_L2, 141 | self_actuated_piano.RewardType.NEGATIVE_XENT, 142 | ) 143 | def test_reward(self, reward_type: self_actuated_piano.RewardType) -> None: 144 | env = _get_env(reward_type=reward_type) 145 | action_spec = env.action_spec() 146 | timestep = env.reset() 147 | 148 | # The first timestep should have a None reward. 149 | self.assertIsNone(timestep.reward) 150 | 151 | while not timestep.last(): 152 | random_ctrl = np.random.uniform( 153 | low=action_spec.minimum, 154 | high=action_spec.maximum, 155 | size=action_spec.shape, 156 | ).astype(action_spec.dtype) 157 | timestep = env.step(random_ctrl) 158 | 159 | actual_reward = timestep.reward 160 | expected_reward = reward_type.get()( 161 | np.concatenate( 162 | [env.task.piano.activation, env.task.piano.sustain_activation] 163 | ), 164 | env.task._goal_current, 165 | ) 166 | self.assertEqual(actual_reward, expected_reward) 167 | 168 | 169 | if __name__ == "__main__": 170 | absltest.main() 171 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # RoboPianist: Dexterous Piano Playing with Deep Reinforcement Learning 2 | 3 | [![build][tests-badge]][tests] 4 | [![docs][docs-badge]][docs] 5 | [![PyPI Python Version][pypi-versions-badge]][pypi] 6 | [![PyPI version][pypi-badge]][pypi] 7 | 8 | [tests-badge]: https://github.com/google-research/robopianist/actions/workflows/ci.yml/badge.svg 9 | [docs-badge]: https://github.com/google-research/robopianist/actions/workflows/docs.yml/badge.svg 10 | [tests]: https://github.com/google-research/robopianist/actions/workflows/ci.yml 11 | [docs]: https://google-research.github.io/robopianist/ 12 | [pypi-versions-badge]: https://img.shields.io/pypi/pyversions/robopianist 13 | [pypi-badge]: https://badge.fury.io/py/robopianist.svg 14 | [pypi]: https://pypi.org/project/robopianist/ 15 | 16 | [![Video](http://img.youtube.com/vi/VBFn_Gg0yD8/hqdefault.jpg)](https://youtu.be/VBFn_Gg0yD8) 17 | 18 | RoboPianist is a new benchmarking suite for high-dimensional control, targeted at testing high spatial and temporal precision, coordination, and planning, all with an underactuated system frequently making-and-breaking contacts. The proposed challenge is *mastering the piano* through bi-manual dexterity, using a pair of simulated anthropomorphic robot hands. 19 | 20 | This codebase contains software and tasks for the benchmark, and is powered by [MuJoCo](https://mujoco.org/). 21 | 22 | - [Latest Updates](#latest-updates) 23 | - [Getting Started](#getting-started) 24 | - [Installation](#installation) 25 | - [Install from source](#install-from-source) 26 | - [Install from PyPI](#install-from-pypi) 27 | - [Optional: Download additional soundfonts](#optional-download-additional-soundfonts) 28 | - [MIDI Dataset](#midi-dataset) 29 | - [CLI](#cli) 30 | - [Contributing](#contributing) 31 | - [FAQ](#faq) 32 | - [Citing RoboPianist](#citing-robopianist) 33 | - [Acknowledgements](#acknowledgements) 34 | - [Works that have used RoboPianist](#works-that-have-used-robopianist) 35 | - [License and Disclaimer](#license-and-disclaimer) 36 | 37 | ------- 38 | 39 | ## Latest Updates 40 | 41 | - [24/12/2023] Updated install script so that it checks out the correct Menagerie commit. Please re-run `bash scripts/install_deps.sh` to update your installation. 42 | - [17/08/2023] Added a [pixel wrapper](robopianist/wrappers/pixels.py) for augmenting the observation space with RGB images. 43 | - [11/08/2023] Code to train the model-free RL policies is now public, see [robopianist-rl](https://github.com/kevinzakka/robopianist-rl). 44 | 45 | ------- 46 | 47 | ## Getting Started 48 | 49 | We've created an introductory [Colab](https://colab.research.google.com/github/google-research/robopianist/blob/main/tutorial.ipynb) notebook that demonstrates how to use RoboPianist. It includes code for loading and customizing a piano playing task, and a demonstration of a pretrained policy playing a short snippet of *Twinkle Twinkle Little Star*. Click the button below to get started! 50 | 51 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google-research/robopianist/blob/main/tutorial.ipynb) 52 | 53 | ## Installation 54 | 55 | RoboPianist is supported on both Linux and macOS and can be installed with Python >= 3.8. We recommend using [Miniconda](https://docs.conda.io/en/latest/miniconda.html) to manage your Python environment. 56 | 57 | ### Install from source 58 | 59 | The recommended way to install this package is from source. Start by cloning the repository: 60 | 61 | ```bash 62 | git clone https://github.com/google-research/robopianist.git && cd robopianist 63 | ``` 64 | 65 | Next, install the prerequisite dependencies: 66 | 67 | ```bash 68 | git submodule init && git submodule update 69 | bash scripts/install_deps.sh 70 | ``` 71 | 72 | Finally, create a new conda environment and install RoboPianist in editable mode: 73 | 74 | ```bash 75 | conda create -n pianist python=3.10 76 | conda activate pianist 77 | 78 | pip install -e ".[dev]" 79 | ``` 80 | 81 | To test your installation, run `make test` and verify that all tests pass. 82 | 83 | ### Install from PyPI 84 | 85 | First, install the prerequisite dependencies: 86 | 87 | ```bash 88 | bash <(curl -s https://raw.githubusercontent.com/google-research/robopianist/main/scripts/install_deps.sh) --no-soundfonts 89 | ``` 90 | 91 | Next, create a new conda environment and install RoboPianist: 92 | 93 | ```bash 94 | conda create -n pianist python=3.10 95 | conda activate pianist 96 | 97 | pip install --upgrade robopianist 98 | ``` 99 | 100 | ### Optional: Download additional soundfonts 101 | 102 | We recommend installing additional soundfonts to improve the quality of the synthesized audio. You can easily do this using the RoboPianist CLI: 103 | 104 | ```bash 105 | robopianist soundfont --download 106 | ``` 107 | 108 | For more soundfont-related commands, see [docs/soundfonts.md](docs/soundfonts.md). 109 | 110 | ## MIDI Dataset 111 | 112 | The PIG dataset cannot be redistributed on GitHub due to licensing restrictions. See [docs/dataset](docs/dataset.md) for instructions on where to download it and how to preprocess it. 113 | 114 | ## CLI 115 | 116 | RoboPianist comes with a command line interface (CLI) that can be used to download additional soundfonts, play MIDI files, preprocess the PIG dataset, and more. For more information, see [docs/cli.md](docs/cli.md). 117 | 118 | ## Contributing 119 | 120 | We welcome contributions to RoboPianist. Please see [docs/contributing.md](docs/contributing.md) for more information. 121 | 122 | ## FAQ 123 | 124 | See [docs/faq.md](docs/faq.md) for a list of frequently asked questions. 125 | 126 | ## Citing RoboPianist 127 | 128 | If you use RoboPianist in your work, please use the following citation: 129 | 130 | ```bibtex 131 | @inproceedings{robopianist2023, 132 | author = {Zakka, Kevin and Wu, Philipp and Smith, Laura and Gileadi, Nimrod and Howell, Taylor and Peng, Xue Bin and Singh, Sumeet and Tassa, Yuval and Florence, Pete and Zeng, Andy and Abbeel, Pieter}, 133 | title = {RoboPianist: Dexterous Piano Playing with Deep Reinforcement Learning}, 134 | booktitle = {Conference on Robot Learning (CoRL)}, 135 | year = {2023}, 136 | } 137 | ``` 138 | 139 | ## Acknowledgements 140 | 141 | We would like to thank the following people for making this project possible: 142 | 143 | - [Philipp Wu](https://www.linkedin.com/in/wuphilipp/) and [Mohit Shridhar](https://mohitshridhar.com/) for being a constant source of inspiration and support. 144 | - [Ilya Kostrikov](https://www.kostrikov.xyz/) for constantly raising the bar for RL engineering and for invaluable debugging help. 145 | - The [Magenta](https://magenta.tensorflow.org/) team for helpful pointers and feedback. 146 | - The [MuJoCo](https://mujoco.org/) team for the development of the MuJoCo physics engine and their support throughout the project. 147 | 148 | ## Works that have used RoboPianist 149 | 150 | - *Privileged Sensing Scaffolds Reinforcement Learning*, Hu et. al. ([paper](https://openreview.net/forum?id=EpVe8jAjdx), [website](https://penn-pal-lab.github.io/scaffolder/)) 151 | 152 | ## License and Disclaimer 153 | 154 | [MuJoco Menagerie](https://github.com/deepmind/mujoco_menagerie)'s license can be found [here](https://github.com/deepmind/mujoco_menagerie/blob/main/LICENSE). Soundfont licensing information can be found [here](docs/soundfonts.md). MIDI licensing information can be found [here](docs/dataset.md). All other code is licensed under an [Apache-2.0 License](LICENSE). 155 | 156 | This is not an officially supported Google product. 157 | -------------------------------------------------------------------------------- /robopianist/models/hands/shadow_hand_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The RoboPianist Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for shadow_hand.py.""" 16 | 17 | import numpy as np 18 | from absl.testing import absltest, parameterized 19 | from dm_control import composer, mjcf 20 | 21 | from robopianist.models.arenas import stage 22 | from robopianist.models.hands import base as base_hand 23 | from robopianist.models.hands import shadow_hand 24 | from robopianist.models.hands import shadow_hand_constants as consts 25 | from robopianist.models.hands.base import HandSide 26 | from robopianist.suite.tasks import base as base_task 27 | 28 | 29 | def _get_env(): 30 | task = base_task.PianoTask(arena=stage.Stage()) 31 | env = composer.Environment( 32 | task=task, time_limit=1.0, strip_singleton_obs_buffer_dim=True 33 | ) 34 | return env 35 | 36 | 37 | class ShadowHandConstantsTest(absltest.TestCase): 38 | def test_fingertip_bodies_order(self) -> None: 39 | expected_order = ["thdistal", "ffdistal", "mfdistal", "rfdistal", "lfdistal"] 40 | self.assertEqual(consts.FINGERTIP_BODIES, tuple(expected_order)) 41 | 42 | 43 | class ShadowHandTest(parameterized.TestCase): 44 | @parameterized.product( 45 | side=[base_hand.HandSide.RIGHT, base_hand.HandSide.LEFT], 46 | primitive_fingertip_collisions=[False, True], 47 | restrict_yaw_range=[False, True], 48 | reduced_action_space=[False, True], 49 | ) 50 | def test_compiles_and_steps( 51 | self, 52 | side: base_hand.HandSide, 53 | primitive_fingertip_collisions: bool, 54 | restrict_yaw_range: bool, 55 | reduced_action_space: bool, 56 | ) -> None: 57 | robot = shadow_hand.ShadowHand( 58 | side=side, 59 | primitive_fingertip_collisions=primitive_fingertip_collisions, 60 | restrict_wrist_yaw_range=restrict_yaw_range, 61 | reduced_action_space=reduced_action_space, 62 | ) 63 | physics = mjcf.Physics.from_mjcf_model(robot.mjcf_model) 64 | physics.step() 65 | 66 | def test_set_name(self) -> None: 67 | robot = shadow_hand.ShadowHand(name="larry") 68 | self.assertEqual(robot.name, "larry") 69 | self.assertEqual(robot.mjcf_model.model, "larry") 70 | 71 | def test_default_name(self) -> None: 72 | robot = shadow_hand.ShadowHand(side=HandSide.RIGHT) 73 | self.assertEqual(robot.name, "rh_shadow_hand") 74 | robot = shadow_hand.ShadowHand(side=HandSide.LEFT) 75 | self.assertEqual(robot.name, "lh_shadow_hand") 76 | 77 | def test_raises_value_error_on_invalid_forearm_dofs(self) -> None: 78 | with self.assertRaises(ValueError): 79 | shadow_hand.ShadowHand(forearm_dofs=("invalid",)) 80 | 81 | def test_joints(self) -> None: 82 | robot = shadow_hand.ShadowHand() 83 | for joint in robot.joints: 84 | self.assertEqual(joint.tag, "joint") 85 | expected_dofs = consts.NQ + robot.n_forearm_dofs 86 | self.assertLen(robot.joints, expected_dofs) 87 | 88 | @parameterized.named_parameters( 89 | {"testcase_name": "full_action_space", "reduced_action_space": False}, 90 | {"testcase_name": "reduced_action_space", "reduced_action_space": True}, 91 | ) 92 | def test_actuators(self, reduced_action_space: bool) -> None: 93 | robot = shadow_hand.ShadowHand(reduced_action_space=reduced_action_space) 94 | for actuator in robot.actuators: 95 | self.assertEqual(actuator.tag, "position") 96 | expected_acts = consts.NU + robot.n_forearm_dofs 97 | if reduced_action_space: 98 | expected_acts -= len(shadow_hand._REDUCED_ACTION_SPACE_EXCLUDED_DOFS) 99 | self.assertLen(robot.actuators, expected_acts) 100 | 101 | def test_restrict_wrist_yaw_range(self) -> None: 102 | robot = shadow_hand.ShadowHand(restrict_wrist_yaw_range=True) 103 | physics = mjcf.Physics.from_mjcf_model(robot.mjcf_model) 104 | jnt_range = physics.bind(robot.joints[0]).range # W2 is the first joint. 105 | self.assertEqual(jnt_range[0], -0.174533) 106 | self.assertEqual(jnt_range[1], 0.174533) 107 | 108 | def test_fingertip_sites_order(self) -> None: 109 | expected_order = ["thdistal", "ffdistal", "mfdistal", "rfdistal", "lfdistal"] 110 | robot = shadow_hand.ShadowHand() 111 | for i, site in enumerate(robot.fingertip_sites): 112 | self.assertEqual(site.tag, "site") 113 | self.assertEqual(site.name, f"{expected_order[i]}_site") 114 | 115 | @parameterized.named_parameters( 116 | {"testcase_name": "left_hand", "side": HandSide.LEFT}, 117 | {"testcase_name": "right_hand", "side": HandSide.RIGHT}, 118 | ) 119 | def test_action_spec(self, side: HandSide) -> None: 120 | robot = shadow_hand.ShadowHand(side=side) 121 | physics = mjcf.Physics.from_mjcf_model(robot.mjcf_model) 122 | action_spec = robot.action_spec(physics) 123 | expected_shape = (consts.NU + robot.n_forearm_dofs,) 124 | self.assertEqual(action_spec.shape, expected_shape) 125 | 126 | 127 | class ShadowHandObservableTest(parameterized.TestCase): 128 | @parameterized.parameters( 129 | [ 130 | "root_body", 131 | ] 132 | ) 133 | def test_get_element_property(self, name: str) -> None: 134 | attribute_value = getattr(shadow_hand.ShadowHand(), name) 135 | self.assertIsInstance(attribute_value, mjcf.Element) 136 | 137 | @parameterized.parameters( 138 | [ 139 | "actuators", 140 | "joints", 141 | "joint_torque_sensors", 142 | "actuator_velocity_sensors", 143 | "actuator_force_sensors", 144 | "fingertip_sites", 145 | "fingertip_touch_sensors", 146 | ] 147 | ) 148 | def test_get_element_tuple_property(self, name: str) -> None: 149 | attribute_value = getattr(shadow_hand.ShadowHand(), name) 150 | self.assertNotEmpty(attribute_value) 151 | for element in attribute_value: 152 | self.assertIsInstance(element, mjcf.Element) 153 | 154 | @parameterized.parameters( 155 | [ 156 | "joints_pos", 157 | "joints_pos_cos_sin", 158 | "joints_vel", 159 | "joints_torque", 160 | "actuators_force", 161 | "actuators_velocity", 162 | "actuators_power", 163 | "position", 164 | "fingertip_force", 165 | ] 166 | ) 167 | def test_evaluate_observable(self, name: str) -> None: 168 | env = _get_env() 169 | physics = env.physics 170 | for hand in [env.task.right_hand, env.task.left_hand]: 171 | observable = getattr(hand.observables, name) 172 | observation = observable(physics) 173 | self.assertIsInstance(observation, (float, np.ndarray)) 174 | 175 | 176 | if __name__ == "__main__": 177 | absltest.main() 178 | -------------------------------------------------------------------------------- /robopianist/suite/tasks/base.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The RoboPianist Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Base piano composer task.""" 16 | 17 | from typing import Sequence 18 | 19 | import mujoco 20 | import numpy as np 21 | from dm_control import composer 22 | from mujoco_utils import composer_utils, physics_utils 23 | 24 | from robopianist.models.hands import HandSide, shadow_hand 25 | from robopianist.models.piano import piano 26 | 27 | # Timestep of the physics simulation, in seconds. 28 | _PHYSICS_TIMESTEP = 0.005 29 | 30 | # Interval between agent actions, in seconds. 31 | _CONTROL_TIMESTEP = 0.05 # 20 Hz. 32 | 33 | # Default position and orientation of the hands. 34 | _LEFT_HAND_POSITION = (0.4, -0.15, 0.13) 35 | _LEFT_HAND_QUATERNION = (-1, -1, 1, 1) 36 | _RIGHT_HAND_POSITION = (0.4, 0.15, 0.13) 37 | _RIGHT_HAND_QUATERNION = (-1, -1, 1, 1) 38 | 39 | _ATTACHMENT_YAW = 0 # Degrees. 40 | 41 | 42 | class PianoOnlyTask(composer.Task): 43 | """Piano task with no hands.""" 44 | 45 | def __init__( 46 | self, 47 | arena: composer_utils.Arena, 48 | change_color_on_activation: bool = False, 49 | add_piano_actuators: bool = False, 50 | physics_timestep: float = _PHYSICS_TIMESTEP, 51 | control_timestep: float = _CONTROL_TIMESTEP, 52 | ) -> None: 53 | self._arena = arena 54 | self._piano = piano.Piano( 55 | change_color_on_activation=change_color_on_activation, 56 | add_actuators=add_piano_actuators, 57 | ) 58 | arena.attach(self._piano) 59 | 60 | # Harden the piano keys. 61 | # The default solref parameters are (0.02, 1). In particular, the first 62 | # parameter specifies -stiffness, and so decreasing it makes the contacts 63 | # harder. The documentation recommends keeping the stiffness at least 2x larger 64 | # than the physics timestep, see: 65 | # https://mujoco.readthedocs.io/en/latest/modeling.html?highlight=stiffness#solver-parameters 66 | self._piano.mjcf_model.default.geom.solref = (physics_timestep * 2, 1) 67 | 68 | self.set_timesteps( 69 | control_timestep=control_timestep, physics_timestep=physics_timestep 70 | ) 71 | 72 | # Accessors. 73 | 74 | @property 75 | def root_entity(self): 76 | return self._arena 77 | 78 | @property 79 | def arena(self): 80 | return self._arena 81 | 82 | @property 83 | def piano(self) -> piano.Piano: 84 | return self._piano 85 | 86 | # Composer methods. 87 | 88 | def get_reward(self, physics) -> float: 89 | del physics # Unused. 90 | return 0.0 91 | 92 | 93 | class PianoTask(PianoOnlyTask): 94 | """Base class for piano tasks.""" 95 | 96 | def __init__( 97 | self, 98 | arena: composer_utils.Arena, 99 | gravity_compensation: bool = False, 100 | change_color_on_activation: bool = False, 101 | primitive_fingertip_collisions: bool = False, 102 | reduced_action_space: bool = False, 103 | attachment_yaw: float = _ATTACHMENT_YAW, 104 | forearm_dofs: Sequence[str] = shadow_hand._DEFAULT_FOREARM_DOFS, 105 | physics_timestep: float = _PHYSICS_TIMESTEP, 106 | control_timestep: float = _CONTROL_TIMESTEP, 107 | ) -> None: 108 | super().__init__( 109 | arena=arena, 110 | change_color_on_activation=change_color_on_activation, 111 | add_piano_actuators=False, 112 | physics_timestep=physics_timestep, 113 | control_timestep=control_timestep, 114 | ) 115 | 116 | self._right_hand = self._add_hand( 117 | hand_side=HandSide.RIGHT, 118 | position=_RIGHT_HAND_POSITION, 119 | quaternion=_RIGHT_HAND_QUATERNION, 120 | gravity_compensation=gravity_compensation, 121 | primitive_fingertip_collisions=primitive_fingertip_collisions, 122 | reduced_action_space=reduced_action_space, 123 | attachment_yaw=attachment_yaw, 124 | forearm_dofs=forearm_dofs, 125 | ) 126 | self._left_hand = self._add_hand( 127 | hand_side=HandSide.LEFT, 128 | position=_LEFT_HAND_POSITION, 129 | quaternion=_LEFT_HAND_QUATERNION, 130 | gravity_compensation=gravity_compensation, 131 | primitive_fingertip_collisions=primitive_fingertip_collisions, 132 | reduced_action_space=reduced_action_space, 133 | attachment_yaw=attachment_yaw, 134 | forearm_dofs=forearm_dofs, 135 | ) 136 | 137 | # Accessors. 138 | 139 | @property 140 | def left_hand(self) -> shadow_hand.ShadowHand: 141 | return self._left_hand 142 | 143 | @property 144 | def right_hand(self) -> shadow_hand.ShadowHand: 145 | return self._right_hand 146 | 147 | # Helper methods. 148 | 149 | def _add_hand( 150 | self, 151 | hand_side: HandSide, 152 | position, 153 | quaternion, 154 | gravity_compensation: bool, 155 | primitive_fingertip_collisions: bool, 156 | reduced_action_space: bool, 157 | attachment_yaw: float, 158 | forearm_dofs: Sequence[str], 159 | ) -> shadow_hand.ShadowHand: 160 | joint_range = [-self._piano.size[1], self._piano.size[1]] 161 | 162 | # Offset the joint range by the hand's initial position. 163 | joint_range[0] -= position[1] 164 | joint_range[1] -= position[1] 165 | 166 | hand = shadow_hand.ShadowHand( 167 | side=hand_side, 168 | primitive_fingertip_collisions=primitive_fingertip_collisions, 169 | restrict_wrist_yaw_range=False, 170 | reduced_action_space=reduced_action_space, 171 | forearm_dofs=forearm_dofs, 172 | ) 173 | hand.root_body.pos = position 174 | 175 | # Slightly rotate the forearms inwards (Z-axis) to mimic human posture. 176 | rotate_axis = np.asarray([0, 0, 1], dtype=np.float64) 177 | rotate_by = np.zeros(4, dtype=np.float64) 178 | sign = -1 if hand_side == HandSide.LEFT else 1 179 | angle = np.radians(sign * attachment_yaw) 180 | mujoco.mju_axisAngle2Quat(rotate_by, rotate_axis, angle) 181 | final_quaternion = np.zeros(4, dtype=np.float64) 182 | mujoco.mju_mulQuat(final_quaternion, rotate_by, quaternion) 183 | hand.root_body.quat = final_quaternion 184 | 185 | if gravity_compensation: 186 | physics_utils.compensate_gravity(hand.mjcf_model) 187 | 188 | # Override forearm translation joint range. 189 | forearm_tx_joint = hand.mjcf_model.find("joint", "forearm_tx") 190 | if forearm_tx_joint is not None: 191 | forearm_tx_joint.range = joint_range 192 | forearm_tx_actuator = hand.mjcf_model.find("actuator", "forearm_tx") 193 | if forearm_tx_actuator is not None: 194 | forearm_tx_actuator.ctrlrange = joint_range 195 | 196 | self._arena.attach(hand) 197 | return hand 198 | -------------------------------------------------------------------------------- /robopianist/suite/variations.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The RoboPianist Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Common variations for the suite.""" 16 | 17 | from typing import Sequence 18 | 19 | import numpy as np 20 | from dm_control.composer import variation 21 | from dm_control.composer.variation import distributions 22 | 23 | from robopianist import music 24 | from robopianist.music import constants, midi_file 25 | 26 | 27 | class MidiSelect(variation.Variation): 28 | """Randomly select a MIDI file from the registry.""" 29 | 30 | def __init__(self, midi_names: Sequence[str] = []) -> None: 31 | """Initializes the variation. 32 | 33 | Args: 34 | midi_names: A sequence of MIDI names to select from. Must be valid keys that 35 | can be loaded by `robopianist.music.load`. 36 | """ 37 | self._midi_names = midi_names 38 | self._dist = distributions.UniformChoice(midi_names) 39 | 40 | def __call__( 41 | self, initial_value=None, current_value=None, random_state=None 42 | ) -> midi_file.MidiFile: 43 | del initial_value, current_value # Unused. 44 | random = random_state or np.random 45 | midi_key: str = self._dist(random_state=random) 46 | return music.load(midi_key) 47 | 48 | 49 | class MidiTemporalStretch(variation.Variation): 50 | """Randomly apply a temporal stretch to a MIDI file.""" 51 | 52 | def __init__( 53 | self, 54 | prob: float, 55 | stretch_range: float, 56 | ) -> None: 57 | """Initializes the variation. 58 | 59 | Args: 60 | prob: A float specifying the probability of applying a temporal stretch. 61 | stretch_range: Range specifying the bounds of the uniform distribution 62 | from which the multiplicative stretch factor is sampled from (i.e., 63 | [1 - stretch_range, 1 + stretch_range]). 64 | """ 65 | self._prob = prob 66 | self._dist = distributions.Uniform(-stretch_range, stretch_range) 67 | 68 | def __call__( 69 | self, initial_value=None, current_value=None, random_state=None 70 | ) -> midi_file.MidiFile: 71 | del current_value # Unused. 72 | random = random_state or np.random 73 | if random.uniform(0.0, 1.0) > self._prob: 74 | if initial_value is None or not isinstance( 75 | initial_value, midi_file.MidiFile 76 | ): 77 | raise ValueError( 78 | "Expected `initial_value` to be provided and be a midi_file.MidiFile." 79 | ) 80 | return initial_value 81 | stretch_factor = 1.0 + self._dist(random_state=random) 82 | return initial_value.stretch(stretch_factor) 83 | 84 | 85 | class MidiPitchShift(variation.Variation): 86 | """Randomly apply a pitch shift to a MIDI file.""" 87 | 88 | def __init__( 89 | self, 90 | prob: float, 91 | shift_range: int, 92 | ) -> None: 93 | """Initializes the variation. 94 | 95 | Args: 96 | prob: A float specifying the probability of applying a pitch shift. 97 | shift_range: Range specifying the maximum absolute value of the uniform 98 | distribution from which the pitch shift, in semitones, is sampled from. 99 | This value will get truncated to the maximum number of semitones that 100 | can be shifted without exceeding the piano's range. 101 | """ 102 | self._prob = prob 103 | if not isinstance(shift_range, int): 104 | raise ValueError("`shift_range` must be an integer.") 105 | self._shift_range = shift_range 106 | 107 | def __call__( 108 | self, initial_value=None, current_value=None, random_state=None 109 | ) -> midi_file.MidiFile: 110 | del current_value # Unused. 111 | random = random_state or np.random 112 | if random.uniform(0.0, 1.0) > self._prob: 113 | if initial_value is None or not isinstance( 114 | initial_value, midi_file.MidiFile 115 | ): 116 | raise ValueError( 117 | "Expected `initial_value` to be provided and be a midi_file.MidiFile." 118 | ) 119 | return initial_value 120 | 121 | if self._shift_range == 0: 122 | return initial_value 123 | 124 | # Ensure that the pitch shift won't exceed the piano's range. 125 | pitches = [note.pitch for note in initial_value.seq.notes] 126 | min_pitch, max_pitch = min(pitches), max(pitches) 127 | low = max(constants.MIN_MIDI_PITCH_PIANO - min_pitch, -self._shift_range) 128 | high = min(constants.MAX_MIDI_PITCH_PIANO - max_pitch, self._shift_range) 129 | 130 | shift = random.randint(low, high + 1) 131 | if shift == 0: 132 | return initial_value 133 | return initial_value.transpose(shift) 134 | 135 | 136 | class MidiOctaveShift(variation.Variation): 137 | """Shift the pitch of a MIDI file in octaves.""" 138 | 139 | def __init__( 140 | self, 141 | prob: float, 142 | octave_range: int, 143 | ) -> None: 144 | """Initializes the variation. 145 | 146 | Args: 147 | prob: A float specifying the probability of applying a pitch shift. 148 | octave_range: Range specifying the maximum absolute value of the uniform 149 | distribution from which the octave shift is sampled from. This value 150 | will get truncated to the maximum number of octaves that can be 151 | shifted without exceeding the piano's range. 152 | """ 153 | self._prob = prob 154 | if not isinstance(octave_range, int): 155 | raise ValueError("`octave_range` must be an integer.") 156 | self._octave_range = octave_range 157 | 158 | def __call__( 159 | self, initial_value=None, current_value=None, random_state=None 160 | ) -> midi_file.MidiFile: 161 | del current_value # Unused. 162 | random = random_state or np.random 163 | if random.uniform(0.0, 1.0) > self._prob: 164 | if initial_value is None or not isinstance( 165 | initial_value, midi_file.MidiFile 166 | ): 167 | raise ValueError( 168 | "Expected `initial_value` to be provided and be a midi_file.MidiFile." 169 | ) 170 | return initial_value 171 | 172 | if self._octave_range == 0: 173 | return initial_value 174 | 175 | # Ensure that the octave shift won't exceed the piano's range. 176 | pitches = [note.pitch for note in initial_value.seq.notes] 177 | min_pitch, max_pitch = min(pitches), max(pitches) 178 | low = max(constants.MIN_MIDI_PITCH_PIANO - min_pitch, -self._octave_range * 12) 179 | high = min(constants.MAX_MIDI_PITCH_PIANO - max_pitch, self._octave_range * 12) 180 | 181 | shift = random.randint(low // 12, high // 12 + 1) 182 | if shift == 0: 183 | return initial_value 184 | return initial_value.transpose(shift * 12) 185 | -------------------------------------------------------------------------------- /robopianist/wrappers/evaluation.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The RoboPianist Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """A wrapper for tracking episode statistics pertaining to music performance. 16 | 17 | TODO(kevin): 18 | - Look into `mir_eval` for metrics. 19 | - Should sustain be a separate metric or should it just be applied to the note sequence 20 | as a whole? 21 | """ 22 | 23 | from collections import deque 24 | from typing import Deque, Dict, List, NamedTuple, Sequence 25 | 26 | import dm_env 27 | import numpy as np 28 | from dm_env_wrappers import EnvironmentWrapper 29 | from sklearn.metrics import precision_recall_fscore_support 30 | 31 | 32 | class EpisodeMetrics(NamedTuple): 33 | """A container for storing episode metrics.""" 34 | 35 | precision: float 36 | recall: float 37 | f1: float 38 | 39 | 40 | class MidiEvaluationWrapper(EnvironmentWrapper): 41 | """Track metrics related to musical performance. 42 | 43 | This wrapper calculates the precision, recall, and F1 score of the last `deque_size` 44 | episodes. The mean precision, recall and F1 score can be retrieved using 45 | `get_musical_metrics()`. 46 | 47 | By default, `deque_size` is set to 1 which means that only the current episode's 48 | statistics are tracked. 49 | """ 50 | 51 | def __init__(self, environment: dm_env.Environment, deque_size: int = 1) -> None: 52 | super().__init__(environment) 53 | 54 | self._key_presses: List[np.ndarray] = [] 55 | self._sustain_presses: List[np.ndarray] = [] 56 | 57 | # Key press metrics. 58 | self._key_press_precisions: Deque[float] = deque(maxlen=deque_size) 59 | self._key_press_recalls: Deque[float] = deque(maxlen=deque_size) 60 | self._key_press_f1s: Deque[float] = deque(maxlen=deque_size) 61 | 62 | # Sustain metrics. 63 | self._sustain_precisions: Deque[float] = deque(maxlen=deque_size) 64 | self._sustain_recalls: Deque[float] = deque(maxlen=deque_size) 65 | self._sustain_f1s: Deque[float] = deque(maxlen=deque_size) 66 | 67 | def step(self, action: np.ndarray) -> dm_env.TimeStep: 68 | timestep = self._environment.step(action) 69 | 70 | key_activation = self._environment.task.piano.activation 71 | self._key_presses.append(key_activation.astype(np.float64)) 72 | sustain_activation = self._environment.task.piano.sustain_activation 73 | self._sustain_presses.append(sustain_activation.astype(np.float64)) 74 | 75 | if timestep.last(): 76 | key_press_metrics = self._compute_key_press_metrics() 77 | self._key_press_precisions.append(key_press_metrics.precision) 78 | self._key_press_recalls.append(key_press_metrics.recall) 79 | self._key_press_f1s.append(key_press_metrics.f1) 80 | 81 | sustain_metrics = self._compute_sustain_metrics() 82 | self._sustain_precisions.append(sustain_metrics.precision) 83 | self._sustain_recalls.append(sustain_metrics.recall) 84 | self._sustain_f1s.append(sustain_metrics.f1) 85 | 86 | self._key_presses = [] 87 | self._sustain_presses = [] 88 | return timestep 89 | 90 | def reset(self) -> dm_env.TimeStep: 91 | self._key_presses = [] 92 | self._sustain_presses = [] 93 | return self._environment.reset() 94 | 95 | def get_musical_metrics(self) -> Dict[str, float]: 96 | """Returns the mean precision/recall/F1 over the last `deque_size` episodes.""" 97 | if not self._key_press_precisions: 98 | raise ValueError("No episode metrics available yet.") 99 | 100 | def _mean(seq: Sequence[float]) -> float: 101 | return sum(seq) / len(seq) 102 | 103 | return { 104 | "precision": _mean(self._key_press_precisions), 105 | "recall": _mean(self._key_press_recalls), 106 | "f1": _mean(self._key_press_f1s), 107 | "sustain_precision": _mean(self._sustain_precisions), 108 | "sustain_recall": _mean(self._sustain_recalls), 109 | "sustain_f1": _mean(self._sustain_f1s), 110 | } 111 | 112 | # Helper methods. 113 | 114 | def _compute_key_press_metrics(self) -> EpisodeMetrics: 115 | """Computes precision/recall/F1 for key presses over the episode.""" 116 | # Get the ground truth key presses. 117 | note_seq = self._environment.task._notes 118 | ground_truth = [] 119 | for notes in note_seq: 120 | presses = np.zeros((self._environment.task.piano.n_keys,), dtype=np.float64) 121 | keys = [note.key for note in notes] 122 | presses[keys] = 1.0 123 | ground_truth.append(presses) 124 | 125 | # Deal with the case where the episode gets truncated due to a failure. In this 126 | # case, the length of the key presses will be less than or equal to the length 127 | # of the ground truth. 128 | if hasattr(self._environment.task, "_wrong_press_termination"): 129 | failure_termination = self._environment.task._wrong_press_termination 130 | if failure_termination: 131 | ground_truth = ground_truth[: len(self._key_presses)] 132 | 133 | assert len(ground_truth) == len(self._key_presses) 134 | 135 | precisions = [] 136 | recalls = [] 137 | f1s = [] 138 | for y_true, y_pred in zip(ground_truth, self._key_presses): 139 | precision, recall, f1, _ = precision_recall_fscore_support( 140 | y_true=y_true, y_pred=y_pred, average="binary", zero_division=1 141 | ) 142 | precisions.append(precision) 143 | recalls.append(recall) 144 | f1s.append(f1) 145 | precision = np.mean(precisions) 146 | recall = np.mean(recalls) 147 | f1 = np.mean(f1s) 148 | 149 | return EpisodeMetrics(precision, recall, f1) 150 | 151 | def _compute_sustain_metrics(self) -> EpisodeMetrics: 152 | """Computes precision/recall/F1 for sustain presses over the episode.""" 153 | # Get the ground truth sustain presses. 154 | ground_truth = [ 155 | np.atleast_1d(v).astype(float) for v in self._environment.task._sustains 156 | ] 157 | 158 | if hasattr(self._environment.task, "_wrong_press_termination"): 159 | failure_termination = self._environment.task._wrong_press_termination 160 | if failure_termination: 161 | ground_truth = ground_truth[: len(self._sustain_presses)] 162 | 163 | precisions = [] 164 | recalls = [] 165 | f1s = [] 166 | for y_true, y_pred in zip(ground_truth, self._sustain_presses): 167 | precision, recall, f1, _ = precision_recall_fscore_support( 168 | y_true=y_true, y_pred=y_pred, average="binary", zero_division=1 169 | ) 170 | precisions.append(precision) 171 | recalls.append(recall) 172 | f1s.append(f1) 173 | precision = np.mean(precisions) 174 | recall = np.mean(recalls) 175 | f1 = np.mean(f1s) 176 | 177 | return EpisodeMetrics(precision, recall, f1) 178 | -------------------------------------------------------------------------------- /robopianist/music/midi_file_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The RoboPianist Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for midi_file.py.""" 16 | 17 | from absl.testing import absltest, parameterized 18 | from note_seq.protobuf import compare, music_pb2 19 | 20 | from robopianist import music 21 | from robopianist.music import midi_file 22 | 23 | 24 | class MidiFileTest(parameterized.TestCase): 25 | def assertProtoEquals(self, a, b, msg=None): 26 | if not compare.ProtoEq(a, b): 27 | compare.assertProtoEqual(self, a, b, normalize_numbers=True, msg=msg) 28 | 29 | @parameterized.parameters(0.5, 1.0, 2.0) 30 | def test_temporal_stretch(self, stretch_factor: float) -> None: 31 | midi = music.load("CMajorScaleTwoHands") 32 | stretched_midi = midi.stretch(stretch_factor) 33 | self.assertEqual(stretched_midi.n_notes, midi.n_notes) 34 | self.assertEqual(stretched_midi.duration, midi.duration * stretch_factor) 35 | 36 | @parameterized.parameters(-1, 0) 37 | def test_temporal_stretch_raises_value_error(self, stretch_factor: float) -> None: 38 | midi = music.load("CMajorScaleTwoHands") 39 | with self.assertRaises(ValueError): 40 | midi.stretch(stretch_factor) 41 | 42 | def test_temporal_stretch_no_op(self) -> None: 43 | midi = music.load("CMajorScaleTwoHands") 44 | stretched_midi = midi.stretch(1.0) 45 | self.assertProtoEquals(stretched_midi.seq, midi.seq) 46 | 47 | @parameterized.parameters(-2, -1, 0, 1, 2) 48 | def test_transpose(self, amount: int) -> None: 49 | midi = music.load("CMajorScaleTwoHands") 50 | stretched_midi = midi.transpose(amount) 51 | self.assertEqual(stretched_midi.n_notes, midi.n_notes) 52 | # TODO(kevin): Check that the notes are actually transposed. 53 | 54 | def test_transpose_no_op(self) -> None: 55 | midi = music.load("CMajorScaleTwoHands") 56 | transposed_midi = midi.transpose(0) 57 | self.assertProtoEquals(transposed_midi.seq, midi.seq) 58 | 59 | def test_trim_silence(self) -> None: 60 | midi = music.load("TwinkleTwinkleRousseau") 61 | midi_trimmed = midi.trim_silence() 62 | self.assertEqual(midi_trimmed.seq.notes[0].start_time, 0.0) 63 | 64 | 65 | class PianoNoteTest(absltest.TestCase): 66 | def test_constructor(self) -> None: 67 | name = "C4" 68 | number = midi_file.note_name_to_midi_number(name) 69 | velocity = 100 70 | note = midi_file.PianoNote.create(number=number, velocity=velocity) 71 | self.assertEqual(note.number, number) 72 | self.assertEqual(note.velocity, velocity) 73 | self.assertEqual(note.name, name) 74 | 75 | def test_raises_value_error_negative_number(self) -> None: 76 | with self.assertRaises(ValueError): 77 | midi_file.PianoNote.create(number=-1, velocity=0) 78 | 79 | def test_raises_value_error_large_number(self) -> None: 80 | with self.assertRaises(ValueError): 81 | midi_file.PianoNote.create(number=128, velocity=0) 82 | 83 | def test_raises_value_error_negative_velocity(self) -> None: 84 | with self.assertRaises(ValueError): 85 | midi_file.PianoNote.create(number=0, velocity=-1) 86 | 87 | def test_raises_value_error_large_velocity(self) -> None: 88 | with self.assertRaises(ValueError): 89 | midi_file.PianoNote.create(number=0, velocity=128) 90 | 91 | 92 | class ConversionMethodsTest(absltest.TestCase): 93 | def test_note_name_midi_number_consistency(self) -> None: 94 | name = "C4" 95 | number = midi_file.note_name_to_midi_number(name) 96 | self.assertEqual(midi_file.midi_number_to_note_name(number), name) 97 | 98 | def test_key_number_midi_number_consistency(self) -> None: 99 | key_number = 10 100 | number = midi_file.key_number_to_midi_number(key_number) 101 | self.assertEqual(midi_file.midi_number_to_key_number(number), key_number) 102 | 103 | def test_key_number_note_name_consistency(self) -> None: 104 | key_number = 39 105 | name = midi_file.key_number_to_note_name(key_number) 106 | self.assertEqual(midi_file.note_name_to_key_number(name), key_number) 107 | 108 | 109 | def _get_test_midi(dt: float = 0.01) -> midi_file.MidiFile: 110 | """A sequence constructed specifically to test hitting a note 2x in a row.""" 111 | seq = music_pb2.NoteSequence() 112 | 113 | # Silence for the first dt. 114 | 115 | # Hit C6 2 times in a row. First one for 1 dt, second one for 3 dt. 116 | seq.notes.add( 117 | start_time=1 * dt, 118 | end_time=2 * dt, 119 | velocity=80, 120 | pitch=midi_file.note_name_to_midi_number("C6"), 121 | part=-1, 122 | ) 123 | seq.notes.add( 124 | start_time=2.0 * dt, 125 | end_time=5 * dt, 126 | velocity=80, 127 | pitch=midi_file.note_name_to_midi_number("C6"), 128 | part=-1, 129 | ) 130 | 131 | seq.total_time = 5.0 * dt 132 | seq.tempos.add(qpm=60) 133 | return midi_file.MidiFile(seq=seq) 134 | 135 | 136 | def _get_test_midi_with_sustain(dt: float = 0.01) -> midi_file.MidiFile: 137 | seq = music_pb2.NoteSequence() 138 | 139 | # Hit C6 for 1 dt. 140 | seq.notes.add( 141 | start_time=0 * dt, 142 | end_time=1 * dt, 143 | velocity=80, 144 | pitch=midi_file.note_name_to_midi_number("C6"), 145 | part=-1, 146 | ) 147 | 148 | # Sustain it for 3 dt. 149 | seq.control_changes.add( 150 | time=0 * dt, 151 | control_number=64, 152 | control_value=64, 153 | instrument=0, 154 | ) 155 | seq.control_changes.add( 156 | time=3 * dt, 157 | control_number=64, 158 | control_value=0, 159 | instrument=0, 160 | ) 161 | 162 | seq.notes.add( 163 | start_time=5 * dt, 164 | end_time=6 * dt, 165 | velocity=80, 166 | pitch=midi_file.note_name_to_midi_number("C6"), 167 | part=-1, 168 | ) 169 | 170 | seq.total_time = 6.0 * dt 171 | seq.tempos.add(qpm=60) 172 | return midi_file.MidiFile(seq=seq) 173 | 174 | 175 | class NoteTrajectoryTest(absltest.TestCase): 176 | def test_same_not_pressed_consecutively(self) -> None: 177 | midi = _get_test_midi() 178 | 179 | note_traj = midi_file.NoteTrajectory.from_midi(midi, dt=0.01) 180 | self.assertEqual(len(note_traj), 6) 181 | 182 | self.assertEqual(note_traj.notes[0], []) # Silence. 183 | 184 | self.assertLen(note_traj.notes[1], 1) 185 | self.assertEqual(note_traj.notes[1][0].name, "C6") 186 | 187 | # To prevent the note from being sustained, the third timestep should be empty 188 | # even though the note is played anew at that timestep. 189 | self.assertEqual(note_traj.notes[2], []) 190 | 191 | # Now the note should be active for 2 timesteps. 192 | self.assertLen(note_traj.notes[3], 1) 193 | self.assertEqual(note_traj.notes[3][0].name, "C6") 194 | self.assertLen(note_traj.notes[4], 1) 195 | self.assertEqual(note_traj.notes[4][0].name, "C6") 196 | 197 | def test_sustain(self) -> None: 198 | midi = _get_test_midi_with_sustain() 199 | 200 | note_traj = midi_file.NoteTrajectory.from_midi(midi, dt=0.01) 201 | self.assertEqual(len(note_traj), 7) 202 | 203 | sustain = note_traj.sustains 204 | 205 | # Sustain should be active for the first 3 timesteps. 206 | for i in range(3): 207 | self.assertTrue(sustain[i]) 208 | 209 | # Sustain should be inactive for the last 3 timesteps. 210 | for i in range(3, 6): 211 | self.assertFalse(sustain[i]) 212 | 213 | 214 | if __name__ == "__main__": 215 | absltest.main() 216 | -------------------------------------------------------------------------------- /robopianist/suite/tasks/self_actuated_piano.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The RoboPianist Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """A self-actuated piano that must learn to play a MIDI file.""" 16 | 17 | import enum 18 | from typing import Callable, Optional, Sequence 19 | 20 | import numpy as np 21 | from dm_control import mjcf 22 | from dm_control.composer import variation as base_variation 23 | from dm_control.composer.observation import observable 24 | from dm_env import specs 25 | from mujoco_utils import spec_utils 26 | 27 | from robopianist.models.arenas import stage 28 | from robopianist.music import midi_file 29 | from robopianist.suite import composite_reward 30 | from robopianist.suite.tasks import base 31 | 32 | # For numerical stability. 33 | _EPS = 1e-6 34 | 35 | RewardFn = Callable[[np.ndarray, np.ndarray], float] 36 | 37 | 38 | def negative_binary_cross_entropy( 39 | predictions: np.ndarray, 40 | targets: np.ndarray, 41 | ) -> float: 42 | """Computes the negative binary cross entropy between predictions and targets.""" 43 | assert predictions.shape == targets.shape 44 | assert predictions.ndim >= 1 45 | log_p = np.log(predictions + _EPS) 46 | log_1_minus_p = np.log(1 - predictions + _EPS) 47 | return np.sum(targets * log_p + (1 - targets) * log_1_minus_p) 48 | 49 | 50 | def negative_l2_distance( 51 | predictions: np.ndarray, 52 | targets: np.ndarray, 53 | ) -> float: 54 | """Computes the negative L2 distance between predictions and targets.""" 55 | assert predictions.shape == targets.shape 56 | assert predictions.ndim >= 1 57 | return -np.sqrt(np.sum((predictions - targets) ** 2)) 58 | 59 | 60 | class RewardType(enum.Enum): 61 | NEGATIVE_XENT = "negative_xent" 62 | NEGATIVE_L2 = "negative_l2" 63 | 64 | def get(self) -> RewardFn: 65 | if self == RewardType.NEGATIVE_XENT: 66 | return negative_binary_cross_entropy 67 | elif self == RewardType.NEGATIVE_L2: 68 | return negative_l2_distance 69 | else: 70 | raise ValueError(f"Invalid reward type: {self}") 71 | 72 | 73 | class SelfActuatedPiano(base.PianoOnlyTask): 74 | """Task where a piano self-actuates to play a MIDI file.""" 75 | 76 | def __init__( 77 | self, 78 | midi: midi_file.MidiFile, 79 | n_steps_lookahead: int = 0, 80 | trim_silence: bool = False, 81 | reward_type: RewardType = RewardType.NEGATIVE_L2, 82 | augmentations: Optional[Sequence[base_variation.Variation]] = None, 83 | **kwargs, 84 | ) -> None: 85 | """Task constructor. 86 | 87 | Args: 88 | midi: A `MidiFile` object. 89 | n_steps_lookahead: Number of timesteps to look ahead when computing the 90 | goal state. 91 | trim_silence: If True, shifts the MIDI file so that the first note starts 92 | at time 0. 93 | reward_type: Reward function to use for the key press reward. 94 | augmentations: A list of `Variation` objects that will be applied to the 95 | MIDI file at the beginning of each episode. If None, no augmentations 96 | will be applied. 97 | """ 98 | super().__init__(arena=stage.Stage(), add_piano_actuators=True, **kwargs) 99 | 100 | if trim_silence: 101 | midi = midi.trim_silence() 102 | self._midi = midi 103 | self._initial_midi = midi 104 | self._n_steps_lookahead = n_steps_lookahead 105 | self._key_press_reward = reward_type.get() 106 | self._reward_fn = composite_reward.CompositeReward( 107 | key_press_reward=self._compute_key_press_reward, 108 | ) 109 | self._augmentations = augmentations 110 | 111 | self._reset_quantities_at_episode_init() 112 | self._reset_trajectory() # Important: call before adding observables. 113 | self._add_observables() 114 | 115 | def _reset_quantities_at_episode_init(self) -> None: 116 | self._t_idx: int = 0 117 | self._should_terminate: bool = False 118 | 119 | def _maybe_change_midi(self, random_state: np.random.RandomState) -> None: 120 | if self._augmentations is not None: 121 | midi = self._initial_midi 122 | for var in self._augmentations: 123 | midi = var(initial_value=midi, random_state=random_state) 124 | self._midi = midi 125 | self._reset_trajectory() 126 | 127 | def _reset_trajectory(self) -> None: 128 | note_traj = midi_file.NoteTrajectory.from_midi( 129 | self._midi, self.control_timestep 130 | ) 131 | self._notes = note_traj.notes 132 | self._sustains = note_traj.sustains 133 | 134 | # Composer methods. 135 | 136 | def initialize_episode( 137 | self, physics: mjcf.Physics, random_state: np.random.RandomState 138 | ) -> None: 139 | del physics # Unused. 140 | self._maybe_change_midi(random_state) 141 | self._reset_quantities_at_episode_init() 142 | 143 | def before_step( 144 | self, 145 | physics: mjcf.Physics, 146 | action: np.ndarray, 147 | random_state: np.random.RandomState, 148 | ) -> None: 149 | # Note that with a self-actuated piano, we don't need to separately apply the 150 | # sustain action. 151 | self.piano.apply_action(physics, action, random_state) 152 | 153 | def after_step( 154 | self, physics: mjcf.Physics, random_state: np.random.RandomState 155 | ) -> None: 156 | del physics, random_state # Unused. 157 | self._t_idx += 1 158 | self._should_terminate = (self._t_idx - 1) == len(self._notes) - 1 159 | 160 | # We need to save the goal state for the current timestep because observable 161 | # callables are called _before_ the reward is computed. Otherwise, we'd be off 162 | # by one timestep when computing the reward. 163 | # NOTE(kevin): The reason we don't need a `copy()` here is because 164 | # self._goal_state gets new memory allocated to it every time we call 165 | # `self._update_goal_state()`. For peace of mind, we have a unit test for this 166 | # in `play_midi_test.py`. 167 | self._goal_current = self._goal_state[0] 168 | 169 | def get_reward(self, physics: mjcf.Physics) -> float: 170 | return self._reward_fn.compute(physics) 171 | 172 | def should_terminate_episode(self, physics: mjcf.Physics) -> bool: 173 | del physics # Unused. 174 | return self._should_terminate 175 | 176 | @property 177 | def task_observables(self): 178 | return self._task_observables 179 | 180 | def action_spec(self, physics: mjcf.Physics) -> specs.BoundedArray: 181 | keys_spec = spec_utils.create_action_spec(physics, self.piano.actuators) 182 | sustain_spec = specs.BoundedArray( 183 | shape=(1,), 184 | dtype=keys_spec.dtype, 185 | minimum=[0.0], 186 | maximum=[1.0], 187 | name="sustain", 188 | ) 189 | return spec_utils.merge_specs([keys_spec, sustain_spec]) 190 | 191 | # Other. 192 | 193 | @property 194 | def midi(self) -> midi_file.MidiFile: 195 | return self._midi 196 | 197 | @property 198 | def reward_fn(self) -> composite_reward.CompositeReward: 199 | return self._reward_fn 200 | 201 | # Helper methods. 202 | 203 | def _compute_key_press_reward(self, physics: mjcf.Physics) -> float: 204 | del physics # Unused. 205 | return self._key_press_reward( 206 | np.concatenate([self.piano.activation, self.piano.sustain_activation]), 207 | self._goal_current, 208 | ) 209 | 210 | def _update_goal_state(self) -> None: 211 | # Observable callables get called after `after_step` but before 212 | # `should_terminate_episode`. Since we increment `self._t_idx` in `after_step`, 213 | # we need to guard against out of bounds indexing. Note that the goal state 214 | # does not matter at this point since we are terminating the episode and this 215 | # update is usually meant for the next timestep. 216 | if self._t_idx == len(self._notes): 217 | return 218 | 219 | self._goal_state = np.zeros( 220 | (self._n_steps_lookahead + 1, self.piano.n_keys + 1), 221 | dtype=np.float64, 222 | ) 223 | t_start = self._t_idx 224 | t_end = min(t_start + self._n_steps_lookahead + 1, len(self._notes)) 225 | for i, t in enumerate(range(t_start, t_end)): 226 | keys = [note.key for note in self._notes[t]] 227 | self._goal_state[i, keys] = 1.0 228 | self._goal_state[i, -1] = self._sustains[t] 229 | 230 | def _add_observables(self) -> None: 231 | # This returns the current state of the piano keys. 232 | self.piano.observables.activation.enabled = True 233 | self.piano.observables.sustain_activation.enabled = True 234 | 235 | # This returns the goal state for the current timestep and n steps ahead. 236 | def _get_goal_state(physics) -> np.ndarray: 237 | del physics # Unused. 238 | self._update_goal_state() 239 | return self._goal_state.ravel() 240 | 241 | goal_observable = observable.Generic(_get_goal_state) 242 | goal_observable.enabled = True 243 | self._task_observables = {"goal": goal_observable} 244 | -------------------------------------------------------------------------------- /robopianist/viewer/user_input.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The dm_control Authors. 2 | # Copyright 2023 The RoboPianist Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """Utilities for handling keyboard events.""" 17 | 18 | import collections 19 | 20 | # Mapped input values, so that we don't have to reference glfw everywhere. 21 | RELEASE = 0 22 | PRESS = 1 23 | REPEAT = 2 24 | 25 | KEY_UNKNOWN = -1 26 | KEY_SPACE = 32 27 | KEY_APOSTROPHE = 39 28 | KEY_COMMA = 44 29 | KEY_MINUS = 45 30 | KEY_PERIOD = 46 31 | KEY_SLASH = 47 32 | KEY_0 = 48 33 | KEY_1 = 49 34 | KEY_2 = 50 35 | KEY_3 = 51 36 | KEY_4 = 52 37 | KEY_5 = 53 38 | KEY_6 = 54 39 | KEY_7 = 55 40 | KEY_8 = 56 41 | KEY_9 = 57 42 | KEY_SEMICOLON = 59 43 | KEY_EQUAL = 61 44 | KEY_A = 65 45 | KEY_B = 66 46 | KEY_C = 67 47 | KEY_D = 68 48 | KEY_E = 69 49 | KEY_F = 70 50 | KEY_G = 71 51 | KEY_H = 72 52 | KEY_I = 73 53 | KEY_J = 74 54 | KEY_K = 75 55 | KEY_L = 76 56 | KEY_M = 77 57 | KEY_N = 78 58 | KEY_O = 79 59 | KEY_P = 80 60 | KEY_Q = 81 61 | KEY_R = 82 62 | KEY_S = 83 63 | KEY_T = 84 64 | KEY_U = 85 65 | KEY_V = 86 66 | KEY_W = 87 67 | KEY_X = 88 68 | KEY_Y = 89 69 | KEY_Z = 90 70 | KEY_LEFT_BRACKET = 91 71 | KEY_BACKSLASH = 92 72 | KEY_RIGHT_BRACKET = 93 73 | KEY_GRAVE_ACCENT = 96 74 | KEY_ESCAPE = 256 75 | KEY_ENTER = 257 76 | KEY_TAB = 258 77 | KEY_BACKSPACE = 259 78 | KEY_INSERT = 260 79 | KEY_DELETE = 261 80 | KEY_RIGHT = 262 81 | KEY_LEFT = 263 82 | KEY_DOWN = 264 83 | KEY_UP = 265 84 | KEY_PAGE_UP = 266 85 | KEY_PAGE_DOWN = 267 86 | KEY_HOME = 268 87 | KEY_END = 269 88 | KEY_CAPS_LOCK = 280 89 | KEY_SCROLL_LOCK = 281 90 | KEY_NUM_LOCK = 282 91 | KEY_PRINT_SCREEN = 283 92 | KEY_PAUSE = 284 93 | KEY_F1 = 290 94 | KEY_F2 = 291 95 | KEY_F3 = 292 96 | KEY_F4 = 293 97 | KEY_F5 = 294 98 | KEY_F6 = 295 99 | KEY_F7 = 296 100 | KEY_F8 = 297 101 | KEY_F9 = 298 102 | KEY_F10 = 299 103 | KEY_F11 = 300 104 | KEY_F12 = 301 105 | KEY_KP_0 = 320 106 | KEY_KP_1 = 321 107 | KEY_KP_2 = 322 108 | KEY_KP_3 = 323 109 | KEY_KP_4 = 324 110 | KEY_KP_5 = 325 111 | KEY_KP_6 = 326 112 | KEY_KP_7 = 327 113 | KEY_KP_8 = 328 114 | KEY_KP_9 = 329 115 | KEY_KP_DECIMAL = 330 116 | KEY_KP_DIVIDE = 331 117 | KEY_KP_MULTIPLY = 332 118 | KEY_KP_SUBTRACT = 333 119 | KEY_KP_ADD = 334 120 | KEY_KP_ENTER = 335 121 | KEY_KP_EQUAL = 336 122 | KEY_LEFT_SHIFT = 340 123 | KEY_LEFT_CONTROL = 341 124 | KEY_LEFT_ALT = 342 125 | KEY_LEFT_SUPER = 343 126 | KEY_RIGHT_SHIFT = 344 127 | KEY_RIGHT_CONTROL = 345 128 | KEY_RIGHT_ALT = 346 129 | KEY_RIGHT_SUPER = 347 130 | 131 | MOD_NONE = 0 132 | MOD_SHIFT = 0x0001 133 | MOD_CONTROL = 0x0002 134 | MOD_ALT = 0x0004 135 | MOD_SUPER = 0x0008 136 | MOD_SHIFT_CONTROL = MOD_SHIFT | MOD_CONTROL 137 | 138 | MOUSE_BUTTON_LEFT = 0 139 | MOUSE_BUTTON_RIGHT = 1 140 | MOUSE_BUTTON_MIDDLE = 2 141 | 142 | _NO_EXCLUSIVE_KEY = (None, lambda _: None) 143 | _NO_CALLBACK = (None, None) 144 | 145 | 146 | class Exclusive(collections.namedtuple("Exclusive", "combination")): 147 | """Defines an exclusive action. 148 | 149 | Exclusive actions can be invoked in response to single key clicks only. The 150 | callback will be called twice. The first time when the key combination is 151 | pressed, passing True as the argument to the callback. The second time when 152 | the key is released (the modifiers don't have to be present then), passing 153 | False as the callback argument. 154 | 155 | Attributes: 156 | combination: A list of integers interpreted as key codes, or tuples 157 | in format (keycode, modifier). 158 | """ 159 | 160 | pass 161 | 162 | 163 | class DoubleClick(collections.namedtuple("DoubleClick", "combination")): 164 | """Defines a mouse double click action. 165 | 166 | It will define a requirement to double click the mouse button specified in the 167 | combination in order to be triggered. 168 | 169 | Attributes: 170 | combination: A list of integers interpreted as key codes, or tuples 171 | in format (keycode, modifier). The keycodes are limited only to mouse 172 | button codes. 173 | """ 174 | 175 | pass 176 | 177 | 178 | class Range(collections.namedtuple("Range", "collection")): 179 | """Binds a number of key combinations to a callback. 180 | 181 | When triggered, the index of the triggering key combination will be passed 182 | as an argument to the callback. 183 | 184 | Attributes: 185 | callback: A callable accepting a single argument - an integer index of the 186 | triggered callback. 187 | collection: A collection of combinations. Combinations may either be raw key 188 | codes, tuples in format (keycode, modifier), or one of the Exclusive or 189 | DoubleClick instances. 190 | """ 191 | 192 | pass 193 | 194 | 195 | class InputMap: 196 | """Provides ability to alias key combinations and map them to actions.""" 197 | 198 | def __init__(self, mouse, keyboard): 199 | """Instance initializer. 200 | 201 | Args: 202 | mouse: GlfwMouse instance. 203 | keyboard: GlfwKeyboard instance. 204 | """ 205 | self._keyboard = keyboard 206 | self._mouse = mouse 207 | 208 | self._keyboard.on_key += self._handle_key 209 | self._mouse.on_click += self._handle_key 210 | self._mouse.on_double_click += self._handle_double_click 211 | self._mouse.on_move += self._handle_mouse_move 212 | self._mouse.on_scroll += self._handle_mouse_scroll 213 | 214 | self.clear_bindings() 215 | 216 | def __del__(self): 217 | """Instance deleter.""" 218 | self._keyboard.on_key -= self._handle_key 219 | self._mouse.on_click -= self._handle_key 220 | self._mouse.on_double_click -= self._handle_double_click 221 | self._mouse.on_move -= self._handle_mouse_move 222 | self._mouse.on_scroll -= self._handle_mouse_scroll 223 | 224 | def clear_bindings(self): 225 | """Clears registered action bindings, while keeping key aliases.""" 226 | self._action_callbacks = {} 227 | self._double_click_callbacks = {} 228 | self._plane_callback = [] 229 | self._z_axis_callback = [] 230 | self._active_exclusive = _NO_EXCLUSIVE_KEY 231 | 232 | def bind(self, callback, key_binding): 233 | """Binds a key combination to a callback. 234 | 235 | Args: 236 | callback: An argument-less callable. 237 | key_binding: A integer with a key code, a tuple (keycode, modifier) or one 238 | of the actions Exclusive|DoubleClick|Range carrying the key combination. 239 | """ 240 | 241 | def build_callback(index, callback): 242 | def indexed_callback(): 243 | callback(index) 244 | 245 | return indexed_callback 246 | 247 | if isinstance(key_binding, Range): 248 | for index, binding in enumerate(key_binding.collection): 249 | self._add_binding(build_callback(index, callback), binding) 250 | else: 251 | self._add_binding(callback, key_binding) 252 | 253 | def _add_binding(self, callback, key_binding): 254 | key_combination = self._extract_key_combination(key_binding) 255 | if isinstance(key_binding, Exclusive): 256 | self._action_callbacks[key_combination] = (True, callback) 257 | elif isinstance(key_binding, DoubleClick): 258 | self._double_click_callbacks[key_combination] = callback 259 | else: 260 | self._action_callbacks[key_combination] = (False, callback) 261 | 262 | def _extract_key_combination(self, key_binding): 263 | if isinstance(key_binding, Exclusive): 264 | key_binding = key_binding.combination 265 | elif isinstance(key_binding, DoubleClick): 266 | key_binding = key_binding.combination 267 | 268 | if not isinstance(key_binding, tuple): 269 | key_binding = (key_binding, MOD_NONE) 270 | return key_binding 271 | 272 | def bind_plane(self, callback): 273 | """Binds a callback to a planar motion action (mouse movement).""" 274 | self._plane_callback.append(callback) 275 | 276 | def bind_z_axis(self, callback): 277 | """Binds a callback to a z-axis motion action (mouse scroll).""" 278 | self._z_axis_callback.append(callback) 279 | 280 | def _handle_key(self, key, action, modifiers): 281 | """Handles a single key press (mouse and keyboard).""" 282 | alias_key = (key, modifiers) 283 | 284 | exclusive_key, exclusive_callback = self._active_exclusive 285 | if exclusive_key is not None: 286 | if action == RELEASE and key == exclusive_key: 287 | exclusive_callback(False) 288 | self._active_exclusive = _NO_EXCLUSIVE_KEY 289 | else: 290 | is_exclusive, callback = self._action_callbacks.get(alias_key, _NO_CALLBACK) 291 | if callback: 292 | if action == PRESS: 293 | if is_exclusive: 294 | callback(True) 295 | self._active_exclusive = (key, callback) 296 | else: 297 | callback() 298 | elif action == REPEAT: 299 | callback() 300 | 301 | def _handle_double_click(self, key, modifiers): 302 | """Handles a double mouse click.""" 303 | alias_key = (key, modifiers) 304 | callback = self._double_click_callbacks.get(alias_key, None) 305 | if callback is not None: 306 | callback() 307 | 308 | def _handle_mouse_move(self, position, translation): 309 | """Handles mouse move.""" 310 | for callback in self._plane_callback: 311 | callback(position, translation) 312 | 313 | def _handle_mouse_scroll(self, value): 314 | """Handles mouse wheel scroll.""" 315 | for callback in self._z_axis_callback: 316 | callback(value) 317 | -------------------------------------------------------------------------------- /robopianist/viewer/util.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The dm_control Authors. 2 | # Copyright 2023 The RoboPianist Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """Utility classes.""" 17 | 18 | import collections 19 | import contextlib 20 | import itertools 21 | import sys 22 | import time 23 | import traceback 24 | 25 | from absl import logging 26 | 27 | # Lower bound of the time multiplier set through TimeMultiplier class. 28 | _MIN_TIME_MULTIPLIER = 1.0 / 32.0 29 | # Upper bound of the time multiplier set through TimeMultiplier class. 30 | _MAX_TIME_MULTIPLIER = 1.0 31 | 32 | 33 | def is_scalar(value): 34 | """Checks if the supplied value can be converted to a scalar.""" 35 | try: 36 | float(value) 37 | except (TypeError, ValueError): 38 | return False 39 | else: 40 | return True 41 | 42 | 43 | def to_iterable(item): 44 | """Converts an item or iterable into an iterable.""" 45 | if isinstance(item, str): 46 | return [item] 47 | elif isinstance(item, collections.abc.Iterable): 48 | return item 49 | else: 50 | return [item] 51 | 52 | 53 | class QuietSet: 54 | """A set-like container that quietly processes removals of missing keys.""" 55 | 56 | def __init__(self): 57 | self._items = set() 58 | 59 | def __iadd__(self, items): 60 | """Adds `items`, avoiding duplicates. 61 | 62 | Args: 63 | items: An iterable of items to add, or a single item to add. 64 | 65 | Returns: 66 | This instance of `QuietSet`. 67 | """ 68 | self._items.update(to_iterable(items)) 69 | self._items.discard(None) 70 | return self 71 | 72 | def __isub__(self, items): 73 | """Detaches `items`. 74 | 75 | Args: 76 | items: An iterable of items to detach, or a single item to detach. 77 | 78 | Returns: 79 | This instance of `QuietSet`. 80 | """ 81 | for item in to_iterable(items): 82 | self._items.discard(item) 83 | return self 84 | 85 | def __len__(self): 86 | return len(self._items) 87 | 88 | def __iter__(self): 89 | return iter(self._items) 90 | 91 | 92 | def interleave(a, b): 93 | """Interleaves the contents of two iterables.""" 94 | return itertools.chain.from_iterable(zip(a, b)) 95 | 96 | 97 | class TimeMultiplier: 98 | """Controls the relative speed of the simulation compared to realtime.""" 99 | 100 | def __init__(self, initial_time_multiplier): 101 | """Instance initializer. 102 | 103 | Args: 104 | initial_time_multiplier: A float scalar specifying the initial speed of 105 | the simulation with 1.0 corresponding to realtime. 106 | """ 107 | self.set(initial_time_multiplier) 108 | 109 | def get(self): 110 | """Returns the current time factor value.""" 111 | return self._real_time_multiplier 112 | 113 | def set(self, value): 114 | """Modifies the time factor. 115 | 116 | Args: 117 | value: A float scalar, new value of the time factor. 118 | """ 119 | self._real_time_multiplier = max( 120 | _MIN_TIME_MULTIPLIER, min(_MAX_TIME_MULTIPLIER, value) 121 | ) 122 | 123 | def __str__(self): 124 | """Returns a formatted string containing the time factor.""" 125 | if self._real_time_multiplier >= 1.0: 126 | time_factor = "%d" % self._real_time_multiplier 127 | else: 128 | time_factor = "1/%d" % (1.0 // self._real_time_multiplier) 129 | return time_factor 130 | 131 | def increase(self): 132 | """Doubles the current time factor value.""" 133 | self.set(self._real_time_multiplier * 2.0) 134 | 135 | def decrease(self): 136 | """Halves the current time factor value.""" 137 | self.set(self._real_time_multiplier / 2.0) 138 | 139 | 140 | class Integrator: 141 | """Integrates a value and averages it for the specified period of time.""" 142 | 143 | def __init__(self, refresh_rate=0.5): 144 | """Instance initializer. 145 | 146 | Args: 147 | refresh_rate: How often, in seconds, is the integrated value averaged. 148 | """ 149 | self._value = 0 150 | self._value_acc = 0 151 | self._num_samples = 0 152 | self._sampling_timestamp = time.time() 153 | self._refresh_rate = refresh_rate 154 | 155 | @property 156 | def value(self): 157 | """Returns the averaged value.""" 158 | return self._value 159 | 160 | @value.setter 161 | def value(self, val): 162 | """Integrates the new value.""" 163 | self._value_acc += val 164 | self._num_samples += 1 165 | 166 | time_elapsed = time.time() - self._sampling_timestamp 167 | if time_elapsed >= self._refresh_rate: 168 | self._value = self._value_acc / self._num_samples 169 | self._value_acc = 0 170 | self._num_samples = 0 171 | self._sampling_timestamp = time.time() 172 | 173 | 174 | class AtomicAction: 175 | """An action that cannot be interrupted.""" 176 | 177 | def __init__(self, state_change_callback=None): 178 | """Instance initializer. 179 | 180 | Args: 181 | state_change_callback: Callable invoked when action changes its state. 182 | """ 183 | self._state_change_callback = state_change_callback 184 | self._watermark = None 185 | 186 | def begin(self, watermark): 187 | """Begins the action, signing it with the specified watermark.""" 188 | if self._watermark is None: 189 | self._watermark = watermark 190 | if self._state_change_callback is not None: 191 | self._state_change_callback(watermark) 192 | 193 | def end(self, watermark): 194 | """Ends a started action, provided the watermarks match.""" 195 | if self._watermark == watermark: 196 | self._watermark = None 197 | if self._state_change_callback is not None: 198 | self._state_change_callback(None) 199 | 200 | @property 201 | def in_progress(self): 202 | """Returns a boolean value to indicate if the being method was called.""" 203 | return self._watermark is not None 204 | 205 | @property 206 | def watermark(self): 207 | """Returns the watermark passed to begin() method call, or None. 208 | 209 | None will be returned if the action is not in progress. 210 | """ 211 | return self._watermark 212 | 213 | 214 | class ObservableFlag(QuietSet): 215 | """Observable boolean flag. 216 | 217 | The QuietState provides necessary functionality for managing listeners. 218 | 219 | A listener is a callable that takes one boolean parameter. 220 | """ 221 | 222 | def __init__(self, initial_value): 223 | """Instance initializer. 224 | 225 | Args: 226 | initial_value: A boolean value with the initial state of the flag. 227 | """ 228 | self._value = initial_value 229 | super().__init__() 230 | 231 | def toggle(self): 232 | """Toggles the value True/False.""" 233 | self._value = not self._value 234 | for listener in self._items: 235 | listener(self._value) 236 | 237 | def __iadd__(self, value): 238 | """Add new listeners and update them about the state.""" 239 | listeners = to_iterable(value) 240 | super().__iadd__(listeners) 241 | for listener in listeners: 242 | listener(self._value) 243 | return self 244 | 245 | @property 246 | def value(self): 247 | """Value of the flag.""" 248 | return self._value 249 | 250 | @value.setter 251 | def value(self, val): 252 | if self._value != val: 253 | for listener in self._items: 254 | listener(self._value) 255 | self._value = val 256 | 257 | 258 | class Timer: 259 | """Measures time elapsed between two ticks.""" 260 | 261 | def __init__(self): 262 | """Instance initializer.""" 263 | self._previous_time = time.time() 264 | self._measured_time = 0.0 265 | 266 | def tick(self): 267 | """Updates the timer. 268 | 269 | Returns: 270 | Time elapsed since the last call to this method. 271 | """ 272 | curr_time = time.time() 273 | self._measured_time = curr_time - self._previous_time 274 | self._previous_time = curr_time 275 | return self._measured_time 276 | 277 | @contextlib.contextmanager 278 | def measure_time(self): 279 | start_time = time.time() 280 | yield 281 | self._measured_time = time.time() - start_time 282 | 283 | @property 284 | def measured_time(self): 285 | return self._measured_time 286 | 287 | 288 | class ErrorLogger: 289 | """A context manager that catches and logs all errors.""" 290 | 291 | def __init__(self, listeners): 292 | """Instance initializer. 293 | 294 | Args: 295 | listeners: An iterable of callables, listeners to inform when an error 296 | is caught. Each callable should accept a single string argument. 297 | """ 298 | self._error_found = False 299 | self._listeners = listeners 300 | 301 | def __enter__(self, *args): 302 | self._error_found = False 303 | 304 | def __exit__(self, exception_type, exception_value, tb): 305 | if exception_type: 306 | self._error_found = True 307 | error_message = ( 308 | "dm_control viewer intercepted an environment error.\n" 309 | "Original message: {}".format(exception_value) 310 | ) 311 | logging.error(error_message) 312 | sys.stderr.write(error_message + "\nTraceback:\n") 313 | traceback.print_tb(tb) 314 | for listener in self._listeners: 315 | listener("{}".format(exception_value)) 316 | return True 317 | 318 | @property 319 | def errors_found(self): 320 | """Returns True if any errors were caught.""" 321 | return self._error_found 322 | 323 | 324 | class NullErrorLogger: 325 | """A context manager that replaces an ErrorLogger. 326 | 327 | This error logger will pass all thrown errors through. 328 | """ 329 | 330 | def __enter__(self, *args): 331 | pass 332 | 333 | def __exit__(self, error_type, value, tb): 334 | pass 335 | 336 | @property 337 | def errors_found(self): 338 | """Returns True if any errors were caught.""" 339 | return False 340 | -------------------------------------------------------------------------------- /robopianist/music/piano_roll.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The Magenta Authors. 2 | # Copyright 2023 The RoboPianist Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Piano roll utilities. 17 | 18 | This is a copy of note_seq's implementation, modified to also compute fingering 19 | information in the returned `Pianoroll` object. 20 | """ 21 | 22 | import collections 23 | import math 24 | 25 | import numpy as np 26 | import pretty_midi 27 | from note_seq import constants, music_pb2 28 | 29 | # The amount to upweight note-on events vs note-off events. 30 | ONSET_UPWEIGHT = 5.0 31 | 32 | # The size of the frame extension for onset event. 33 | # Frames in [onset_frame-ONSET_WINDOW, onset_frame+ONSET_WINDOW] 34 | # are considered to contain onset events. 35 | ONSET_WINDOW = 1 36 | 37 | Pianoroll = collections.namedtuple( # pylint:disable=invalid-name 38 | "Pianoroll", 39 | [ 40 | "active", 41 | "weights", 42 | "onsets", 43 | "onset_velocities", 44 | "active_velocities", 45 | "offsets", 46 | "control_changes", 47 | "fingerings", 48 | ], 49 | ) 50 | 51 | 52 | def _unscale_velocity(velocity, scale, bias): 53 | unscaled = max(min(velocity, 1.0), 0) * scale + bias 54 | if math.isnan(unscaled): 55 | return 0 56 | return int(unscaled) 57 | 58 | 59 | def sequence_to_pianoroll( 60 | sequence, 61 | frames_per_second, 62 | min_pitch, 63 | max_pitch, 64 | # pylint: disable=unused-argument 65 | min_velocity=constants.MIN_MIDI_VELOCITY, 66 | # pylint: enable=unused-argument 67 | max_velocity=constants.MAX_MIDI_VELOCITY, 68 | add_blank_frame_before_onset=False, 69 | onset_upweight=ONSET_UPWEIGHT, 70 | onset_window=ONSET_WINDOW, 71 | onset_length_ms=0, 72 | offset_length_ms=0, 73 | onset_mode="window", 74 | onset_delay_ms=0.0, 75 | min_frame_occupancy_for_label=0.0, 76 | onset_overlap=True, 77 | ): 78 | roll = np.zeros( 79 | (int(sequence.total_time * frames_per_second + 1), max_pitch - min_pitch + 1), 80 | dtype=np.float32, 81 | ) 82 | 83 | roll_weights = np.ones_like(roll) 84 | 85 | onsets = np.zeros_like(roll) 86 | offsets = np.zeros_like(roll) 87 | 88 | control_changes = np.zeros( 89 | (int(sequence.total_time * frames_per_second + 1), 128), dtype=np.int32 90 | ) 91 | 92 | fingerings = np.full_like(roll, -1) 93 | 94 | def frames_from_times(start_time, end_time): 95 | """Converts start/end times to start/end frames.""" 96 | # Will round down because note may start or end in the middle of the frame. 97 | start_frame = int(start_time * frames_per_second) 98 | start_frame_occupancy = start_frame + 1 - start_time * frames_per_second 99 | # check for > 0.0 to avoid possible numerical issues 100 | if ( 101 | min_frame_occupancy_for_label > 0.0 102 | and start_frame_occupancy < min_frame_occupancy_for_label 103 | ): 104 | start_frame += 1 105 | 106 | end_frame = int(math.ceil(end_time * frames_per_second)) 107 | end_frame_occupancy = end_time * frames_per_second - start_frame - 1 108 | if ( 109 | min_frame_occupancy_for_label > 0.0 110 | and end_frame_occupancy < min_frame_occupancy_for_label 111 | ): 112 | end_frame -= 1 113 | 114 | # Ensure that every note fills at least one frame. 115 | end_frame = max(start_frame + 1, end_frame) 116 | 117 | return start_frame, end_frame 118 | 119 | velocities_roll = np.zeros_like(roll, dtype=np.float32) 120 | 121 | for note in sorted(sequence.notes, key=lambda n: n.start_time): 122 | if note.pitch < min_pitch or note.pitch > max_pitch: 123 | continue 124 | start_frame, end_frame = frames_from_times(note.start_time, note.end_time) 125 | 126 | # label onset events. Use a window size of onset_window to account of 127 | # rounding issue in the start_frame computation. 128 | onset_start_time = note.start_time + onset_delay_ms / 1000.0 129 | onset_end_time = note.end_time + onset_delay_ms / 1000.0 130 | if onset_mode == "window": 131 | onset_start_frame_without_window, _ = frames_from_times( 132 | onset_start_time, onset_end_time 133 | ) 134 | 135 | onset_start_frame = max(0, onset_start_frame_without_window - onset_window) 136 | onset_end_frame = min( 137 | onsets.shape[0], onset_start_frame_without_window + onset_window + 1 138 | ) 139 | elif onset_mode == "length_ms": 140 | onset_end_time = min( 141 | onset_end_time, onset_start_time + onset_length_ms / 1000.0 142 | ) 143 | onset_start_frame, onset_end_frame = frames_from_times( 144 | onset_start_time, onset_end_time 145 | ) 146 | else: 147 | raise ValueError("Unknown onset mode: {}".format(onset_mode)) 148 | 149 | # label offset events. 150 | offset_start_time = min( 151 | note.end_time, sequence.total_time - offset_length_ms / 1000.0 152 | ) 153 | offset_end_time = offset_start_time + offset_length_ms / 1000.0 154 | offset_start_frame, offset_end_frame = frames_from_times( 155 | offset_start_time, offset_end_time 156 | ) 157 | offset_end_frame = max(offset_end_frame, offset_start_frame + 1) 158 | 159 | if not onset_overlap: 160 | start_frame = onset_end_frame 161 | end_frame = max(start_frame + 1, end_frame) 162 | 163 | offsets[offset_start_frame:offset_end_frame, note.pitch - min_pitch] = 1.0 164 | onsets[onset_start_frame:onset_end_frame, note.pitch - min_pitch] = 1.0 165 | roll[start_frame:end_frame, note.pitch - min_pitch] = 1.0 166 | 167 | if note.velocity > max_velocity: 168 | raise ValueError( 169 | "Note velocity exceeds max velocity: %d > %d" 170 | % (note.velocity, max_velocity) 171 | ) 172 | 173 | velocities_roll[start_frame:end_frame, note.pitch - min_pitch] = ( 174 | note.velocity / max_velocity 175 | ) 176 | roll_weights[ 177 | onset_start_frame:onset_end_frame, note.pitch - min_pitch 178 | ] = onset_upweight 179 | roll_weights[onset_end_frame:end_frame, note.pitch - min_pitch] = [ 180 | onset_upweight / x for x in range(1, end_frame - onset_end_frame + 1) 181 | ] 182 | if note.part is not None: 183 | fingerings[start_frame:end_frame, note.pitch - min_pitch] = note.part 184 | 185 | if add_blank_frame_before_onset: 186 | if start_frame > 0: 187 | roll[start_frame - 1, note.pitch - min_pitch] = 0.0 188 | roll_weights[start_frame - 1, note.pitch - min_pitch] = 1.0 189 | 190 | for cc in sequence.control_changes: 191 | frame, _ = frames_from_times(cc.time, 0) 192 | if frame < len(control_changes): 193 | control_changes[frame, cc.control_number] = cc.control_value + 1 194 | 195 | return Pianoroll( 196 | active=roll, 197 | weights=roll_weights, 198 | onsets=onsets, 199 | onset_velocities=velocities_roll * onsets, 200 | active_velocities=velocities_roll, 201 | offsets=offsets, 202 | control_changes=control_changes, 203 | fingerings=fingerings, 204 | ) 205 | 206 | 207 | def pianoroll_onsets_to_note_sequence( 208 | onsets, 209 | frames_per_second, 210 | note_duration_seconds=0.05, 211 | velocity=70, 212 | instrument=0, 213 | program=0, 214 | qpm=constants.DEFAULT_QUARTERS_PER_MINUTE, 215 | min_midi_pitch=constants.MIN_MIDI_PITCH, 216 | velocity_values=None, 217 | velocity_scale=80, 218 | velocity_bias=10, 219 | ): 220 | frame_length_seconds = 1 / frames_per_second 221 | 222 | sequence = music_pb2.NoteSequence() 223 | sequence.tempos.add().qpm = qpm 224 | sequence.ticks_per_quarter = constants.STANDARD_PPQ 225 | 226 | if velocity_values is None: 227 | velocity_values = velocity * np.ones_like(onsets, dtype=np.int32) 228 | 229 | for frame, pitch in zip(*np.nonzero(onsets)): 230 | start_time = frame * frame_length_seconds 231 | end_time = start_time + note_duration_seconds 232 | 233 | note = sequence.notes.add() 234 | note.start_time = start_time 235 | note.end_time = end_time 236 | note.pitch = pitch + min_midi_pitch 237 | note.velocity = _unscale_velocity( 238 | velocity_values[frame, pitch], scale=velocity_scale, bias=velocity_bias 239 | ) 240 | note.instrument = instrument 241 | note.program = program 242 | 243 | sequence.total_time = len(onsets) * frame_length_seconds + note_duration_seconds 244 | if sequence.notes: 245 | assert sequence.total_time >= sequence.notes[-1].end_time 246 | 247 | return sequence 248 | 249 | 250 | def sequence_to_valued_intervals( 251 | note_sequence, 252 | min_midi_pitch=constants.MIN_MIDI_PITCH, 253 | max_midi_pitch=constants.MAX_MIDI_PITCH, 254 | restrict_to_pitch=None, 255 | ): 256 | """Convert a NoteSequence to valued intervals. 257 | Value intervals are intended to be used with mir_eval metrics methods. 258 | Args: 259 | note_sequence: sequence to convert. 260 | min_midi_pitch: notes lower than this will be discarded. 261 | max_midi_pitch: notes higher than this will be discarded. 262 | restrict_to_pitch: notes that are not this pitch will be discarded. 263 | Returns: 264 | intervals: start and end times 265 | pitches: pitches in Hz. 266 | velocities: MIDI velocities. 267 | """ 268 | intervals = [] 269 | pitches = [] 270 | velocities = [] 271 | 272 | for note in note_sequence.notes: 273 | if restrict_to_pitch and restrict_to_pitch != note.pitch: 274 | continue 275 | if note.pitch < min_midi_pitch or note.pitch > max_midi_pitch: 276 | continue 277 | # mir_eval does not allow notes that start and end at the same time. 278 | if note.end_time == note.start_time: 279 | continue 280 | intervals.append((note.start_time, note.end_time)) 281 | pitches.append(note.pitch) 282 | velocities.append(note.velocity) 283 | 284 | # Reshape intervals to ensure that the second dim is 2, even if the list is 285 | # of size 0. mir_eval functions will complain if intervals is not shaped 286 | # appropriately. 287 | intervals = np.array(intervals).reshape((-1, 2)) 288 | pitches = np.array(pitches) 289 | pitches = pretty_midi.note_number_to_hz(pitches) 290 | velocities = np.array(velocities) 291 | return intervals, pitches, velocities 292 | --------------------------------------------------------------------------------