├── .github ├── ISSUE_TEMPLATE │ ├── bug_report.md │ └── feature_request.md └── pull_request_template.md ├── .gitignore ├── LICENSE ├── README.md ├── doc ├── 01_quickstart.md ├── 02_config.md ├── 03_phase_and_basis.md ├── 04_movement_primitives.md ├── 05_nn-based_mp.md └── README.md ├── img ├── basis_norm_rbf.png └── basis_prodmp.png ├── mp_pytorch ├── __init__.py ├── basis_gn │ ├── __init__.py │ ├── basis_generator.py │ ├── norm_rbf_basis.py │ ├── prodmp_basis.py │ └── rhytmic_basis.py ├── demo │ ├── __init__.py │ ├── demo_basis_gn.py │ ├── demo_data_type_cast.py │ ├── demo_delay_and_scale.py │ ├── demo_dmp.py │ ├── demo_mp_config.py │ ├── demo_prodmp.py │ ├── demo_prodmp_autoscale.py │ └── demo_promp.py ├── mp │ ├── __init__.py │ ├── dmp.py │ ├── mp_factory.py │ ├── mp_interfaces.py │ ├── prodmp.py │ └── promp.py ├── phase_gn │ ├── __init__.py │ ├── exp_decay_phase.py │ ├── linear_phase.py │ ├── phase_generator.py │ ├── rhythmic_phase_generator.py │ └── smooth_phase_generator.py └── util │ ├── __init__.py │ ├── util_data_structure.py │ ├── util_debug.py │ ├── util_matrix.py │ ├── util_media.py │ └── util_string.py ├── setup.py └── test ├── __init__.py ├── test_dmp_vs_prodmp.py ├── test_main.py ├── test_prodmp_relative_goal.py ├── test_prodmp_speed.py └── test_quantitative.py /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: '' 6 | assignees: Fitz13009 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | 13 | **To Reproduce** 14 | Steps to reproduce the behavior: 15 | 1. Go to '...' 16 | 2. Click on '....' 17 | 3. Scroll down to '....' 18 | 4. See error 19 | 20 | **Expected behavior** 21 | A clear and concise description of what you expected to happen. 22 | 23 | **Screenshots** 24 | If applicable, add screenshots to help explain your problem. 25 | 26 | **Desktop (please complete the following information):** 27 | - OS: [e.g. iOS] 28 | - Browser [e.g. chrome, safari] 29 | - Version [e.g. 22] 30 | 31 | **Smartphone (please complete the following information):** 32 | - Device: [e.g. iPhone6] 33 | - OS: [e.g. iOS8.1] 34 | - Browser [e.g. stock browser, safari] 35 | - Version [e.g. 22] 36 | 37 | **Additional context** 38 | Add any other context about the problem here. 39 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: '' 5 | labels: enhancement 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Is your feature request related to a problem? Please describe.** 11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 12 | 13 | **Describe the solution you'd like** 14 | A clear and concise description of what you want to happen. 15 | 16 | **Describe alternatives you've considered** 17 | A clear and concise description of any alternative solutions or features you've considered. 18 | 19 | **Additional context** 20 | Add any other context or screenshots about the feature request here. 21 | -------------------------------------------------------------------------------- /.github/pull_request_template.md: -------------------------------------------------------------------------------- 1 | ### All Submissions: 2 | 3 | * [ ] ❗ Have you successfully run the demos for **ALL** MPs ?? 4 | * [ ] Note in your description if the branch should **not** be deleted. Branches will be deleted after a successfull merge by default. 5 | 6 | ### New Features: 7 | * [ ] Have you added new demos showcasing how to use your new feature? 8 | * [ ] Have you tried to implement the new feature for **ALL** MPs? 9 | * [ ] Optional: Have you thought about writing a `docs` file explaining your feature? This is not strictly necessary for small, obvious additions. 10 | 11 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | .idea/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MP_PyTorch: The Movement Primitives Package in PyTorch 2 | 3 | MP_PyTorch package focus on **Movement Primitives(MPs) on Imitation Learning(IL) and Reinforcement Learning(RL)** and provides convenient movement primitives interface implemented by PyTorch, including DMPs, ProMPs and [ProDMPs](https://arxiv.org/abs/2210.01531). 4 | Users can also implement custom Movement Primitives according to the basis and phase generator. Further, advanced NN-based Movement Primitives Algorithm can also be realized according to the convenient PyTorch-based Interface. 5 | This package aims to building a movement primitives toolkit which could be combined with modern imitation learning and reinforcement learning algorithm. 6 | 7 | 11 | 12 |   13 | ## Installation 14 | 15 | For the installation we recommend you set up a conda environment or venv beforehand. 16 | 17 | This package will automatically install the following dependencies: addict, numpy, pytorch and matplotlib. 18 | 19 | ### 1. Install from Conda (Recommended) 20 | ```bash 21 | conda install -c conda-forge mp_pytorch 22 | ``` 23 | 24 | ### 2. Install from PyPI 25 | ```bash 26 | pip install mp_pytorch 27 | ``` 28 | 29 | ### 3. Install from source 30 | 31 | ```bash 32 | git clone git@github.com:ALRhub/MP_PyTorch.git 33 | cd mp_pytorch 34 | pip install -e . 35 | ``` 36 | 37 | After installation, you can import the package easily. 38 | ```bash 39 | import mp_pytorch 40 | from mp_pytorch import MPFactory 41 | ``` 42 | 43 |   44 | ## Quickstart 45 | For further information, please refer to the [User Guide](./doc/README.md). 46 | 47 | The main steps to create ProDMPs instance and generate trajectories are as follows: 48 | 49 | ### 1. Edit configuration 50 | Suppose you have edited the required configuration. 51 | You can view the demo and check how to edit the configuration in [Edit Configuration](./doc/02_config.md). 52 | ```python 53 | # config, times, params, params_L, init_time, init_pos, init_vel, demos = get_mp_utils("prodmp", True, True) 54 | ``` 55 | 56 | ### 2. Initial prodmp instance and update inputs 57 | ```python 58 | mp = MPFactory.init_mp(**config) 59 | mp.update_inputs(times=times, params=params, params_L=params_L, 60 | init_time=init_time, init_pos=init_pos, init_vel=init_vel) 61 | 62 | # you can also choose to learn parameters from demonstrations. 63 | params_dict = mp.learn_mp_params_from_trajs(times, demos) 64 | ``` 65 | 66 | ### 3. Generate trajectories 67 | ```python 68 | traj_dict = mp.get_trajs(get_pos=True, get_pos_cov=True, 69 | get_pos_std=True, get_vel=True, 70 | get_vel_cov=True, get_vel_std=True) 71 | 72 | # for probablistic movement primitives, you can also choose to sample trajectories 73 | samples, samples_vel = mp.sample_trajectories(num_smp=10) 74 | ``` 75 | 76 | The structure of this package can be seen as follows: 77 | 78 | | Types | Classes | Description | 79 | |-------------------------|------------------------------------------|------------------------------------------------------------------------------| 80 | | **Phase Generator** | `PhaseGenerator` | Interface for Phase Generators | 81 | | | `RhythmicPhaseGenerator` | Rhythmic phase generator | 82 | | | `SmoothPhaseGenerator` | Smooth phase generator | 83 | | | `LinearPhaseGenerator` | Linear phase generator | 84 | | | `ExpDecayPhaseGenerator` | Exponential decay phase generator | 85 | | **Basis Generator** | `BasisGenerator` | Interface for Basis Generators | 86 | | | `RhythmicBasisGenerator` | Rhythmic basis generator | 87 | | | `NormalizedRBFBasisGenerator` | Normalized RBF basis generator | 88 | | | `ProDMPBasisGenerator` | ProDMP basis generator | 89 | | **Movement Primitives** | `MPFactory` | Create an MP instance given configuration | 90 | | | `MPInterface` | Interface for Deterministic Movement Primitives | 91 | | | `ProbabilisticMPInterface` | Interface for Probablistic Movement Primitives | 92 | | | `DMP` | Dynamic Movement Primitives | 93 | | | `ProMP` | Probablistic Movement Primitives | 94 | | | `ProDMP` | [Probablistic Dynamic Movement Primitives](https://arxiv.org/abs/2210.01531) | 95 | 96 | 97 | 98 |   99 | ## Cite 100 | If you interest this project and use it in a scientific publication, we would appreciate citations to the following information: 101 | ```markdown 102 | @article{li2023prodmp, 103 | title={ProDMP: A Unified Perspective on Dynamic and Probabilistic Movement Primitives}, 104 | author={Li, Ge and Jin, Zeqi and Volpp, Michael and Otto, Fabian and Lioutikov, Rudolf and Neumann, Gerhard}, 105 | journal={IEEE Robotics and Automation Letters}, 106 | year={2023}, 107 | publisher={IEEE} 108 | } 109 | 110 | ``` 111 | 112 |   113 | ## Team 114 | MP_PyTorch is developed and maintained by the [ALR-Lab](https://alr.anthropomatik.kit.edu)(Autonomous Learning Robots Lab), KIT. 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 |

Welcome to our GitHub Pages!

124 | 125 | 126 | 127 | -------------------------------------------------------------------------------- /doc/01_quickstart.md: -------------------------------------------------------------------------------- 1 | # 1. Quick Start 2 | 3 | MP_PyTorch provides convenient interfaces to develop Movement Primitives with modern Imitation Learning and Reinforcement Learning algorithm. 4 | You can create the basic Moment Primitives Instance(DMPs, ProMPs and ProDMPs) in the MPFactory or define your own custom MPs with the MPInterface. 5 | It's also convenient to combine the MPs with modern neural networks based algorithm to realize more complex task. 6 | 7 |   8 | ### 1.1 Quick start for MPFactory 9 | In this quick start section, we will provide a demo showing how to create ProDMPs instance and generate trajectories. 10 | 11 | #### 1.1.1 Edit Configuration 12 | Suppose you have edited the required configuration. 13 | You can view the demo and check how to edit the configuration in [Edit Configuration](./02_config.md). 14 | ```python 15 | # config, times, params, params_L, init_time, init_pos, init_vel, demos = get_mp_utils("prodmp", True, True) 16 | ``` 17 | 18 | #### 1.1.2 Initial ProDMPs instance and update inputs 19 | ```python 20 | mp = MPFactory.init_mp(**config) 21 | mp.update_inputs(times=times, params=params, params_L=params_L, 22 | init_time=init_time, init_pos=init_pos, init_vel=init_vel) 23 | 24 | # you can also choose to learn parameters from demonstrations. 25 | params_dict = mp.learn_mp_params_from_trajs(times, demos) 26 | ``` 27 | 28 | #### 1.1.3 Generate trajectories 29 | ```python 30 | traj_dict = mp.get_trajs(get_pos=True, get_pos_cov=True, 31 | get_pos_std=True, get_vel=True, 32 | get_vel_cov=True, get_vel_std=True) 33 | 34 | # for probablistic movement primitives, you can also choose to sample trajectories 35 | samples, samples_vel = mp.sample_trajectories(num_smp=10) 36 | ``` 37 | 38 |   39 | ### 1.2 Define the custom Movement Primitives 40 | To define the custom Movement Primitives method, you need to understand the following interfaces in corresponding sections: 41 | - [Phase Generator Interface](./03_phase_and_basis.md) 42 | - [Basis Generator Interface](./03_phase_and_basis.md) 43 | - [Movement Primitives Interface](./04_movement_primitives.md) 44 | 45 |   46 | ### 1.3 Combing Movement Primitives with Neural Networks 47 | **The corresponding docs and demos are under construction.** 48 | 49 | 50 | 51 | 52 | [Back to Overview](./) -------------------------------------------------------------------------------- /doc/02_config.md: -------------------------------------------------------------------------------- 1 | # 2. Edit Configuration 2 | 3 | We recommend you using `addict` or `yaml` to edit the configuration files. 4 | 5 |   6 | ### 2.1 Demo 7 | We provide a [demo](../mp_pytorch/demo/demo_mp_config.py) to show how to edit the configuration. 8 | 9 | You can call this demo as follows: 10 | 11 | ```python 12 | from mp_pytorch import demo 13 | config, times, params, params_L, init_time, init_pos, init_vel, demos = \ 14 | demo.get_mp_utils(mp_type="prodmp", learn_tau=True, learn_delay=True) 15 | ``` 16 | 17 |   18 | ### 2.2 Parameters in Configuration 19 | 20 | | Type | Parameters | Description | 21 | |---------------------|--------------------------|-----------------------------| 22 | | General | `num_dof` | Number of DoFs | 23 | | | `tau` | | 24 | | | `learn_tau` | If tau is learnable | 25 | | | `learn_delay` | If delay is learnable | 26 | | Movement Primitives | `num_basis` | Number of Basis functions | 27 | | | `basis_bandwidth_factor` | | 28 | | | `alpha` | | 29 | | | `alpha_phase` | | 30 | | | `dt` | Timestep | 31 | | | `weights_scale` | | 32 | | | `goal_scale` | | 33 | | | `mp_type` | Type of Movement Primitives | 34 | 35 | 36 | [Back to Overview](./) -------------------------------------------------------------------------------- /doc/03_phase_and_basis.md: -------------------------------------------------------------------------------- 1 | # 3. Phase and Basis Generation 2 | You can view the main features of the Phase Generator and Basis Generator in this section. 3 | 4 |   5 | ### 3.1 Demo 6 | 7 | You can create the phase and basis instance by calling the corresponding classes. To create the Basis Generator, you need create a Phase Generator firstly. 8 | 9 | ```python 10 | phase_gn = LinearPhaseGenerator(tau=3, delay=1, 11 | learn_tau=False, learn_delay=False) 12 | basis_gn = NormalizedRBFBasisGenerator(phase_generator=phase_gn, 13 | num_basis=10, 14 | basis_bandwidth_factor=3, 15 | num_basis_outside=0) 16 | basis_gn.show_basis(plot=True) 17 | ``` 18 | 19 | We also provide a [demo](../mp_pytorch/demo/demo_basis_gn.py) to visualize the norm RBF and ProDMPs basis functions. 20 | You can call this demo as follows: 21 | ```python 22 | from mp_pytorch import demo 23 | demo.demo_norm_rbf_basis() 24 | demo.dmmo_prodmp_basis() 25 | ``` 26 | The corresponding basis functions will be visualized as: 27 | 28 | | Norm RBF Basis Functions | ProDMPs Basis Functions | 29 | |--------------------------------------|-----------------------------------| 30 | | ![image](../img/basis_norm_rbf.png) | ![image](../img/basis_prodmp.png) | 31 | 32 |   33 | ### 3.2 Phase Generator 34 | We provide the Phase Generator Interface to help to define the custom Movement Primitives Method. 35 | The main features of the Interface and derived Generator are as follows: 36 | 37 | | Classes | Main Functions | Description | 38 | |------------------------|---------------------------------------------------|---------------------------------------------------------------------------------------------------------------| 39 | | PhaseGenerator | | Abstract Basic Class for Phase Generators. Transfer time duration to [0, 1] range. | 40 | | | `PhaseGenerator.phase` | Abstractmethod for phase interface. | 41 | | | `PhaseGenerator.unbound_phase` | Abstractmethod for unbound phase interface. | 42 | | | `PhaseGenerator.phase_to_time` | Abstractmethod for inverse operation, compute times given phase. | 43 | | | `PhaseGenerator.set_params` | Set parameters of current object and attributes | 44 | | | `PhaseGenerator.get_params` | Return all learnable parameters. | 45 | | | `PhaseGenerator.get_params_bounds` | Return all learnable parameters' bounds. | 46 | | | `PhaseGenerator.finalize` | Mark the phase generator as finalized so that the parameters cannot be updated any more. | 47 | | RhythmicPhaseGenerator | | Rhythmic phase generator. | 48 | | SmoothPhaseGenerator | | Smooth phase generator with five order spline phase | 49 | | LinearPhaseGenerator | | Linear Phase Generator | 50 | | ExpDecayPhaseGenerator | | Exponential decay phase generator | 51 | 52 |   53 | ### 3.3 Basis Generator 54 | We provide the Basis Generator Interface to help to define the custom Movement Primitives Method. 55 | The main features of the Interface and derived Generator are as follows: 56 | 57 | | Classes | Main Functions | Description | 58 | |-----------------------------------------|---------------------------------------------------|--------------------------------------------------------------------------| 59 | | BasisGenerator | | Abstract Basic Class for Basis Generators | 60 | | | `BasisGenerator.basis` | Abstractmethod to generate value of single basis function at given time. | 61 | | | `BasisGenerator.basis_multi_dofs` | Interface to generate basis functions for multi-dof at given time | 62 | | | `BasisGenerator.set_params` | Set parameters of current object and attributes | 63 | | | `BasisGenerator.get_params` | Return all learnable parameters | 64 | | | `BasisGenerator.get_params_bounds` | Return all learnable parameters' bounds | 65 | | | `BasisGenerator.show_basis` | Visualize the basis functions for debug usage | 66 | | RhythmicBasisGenerator | | Rhythmic Basis Generator | 67 | | NormalizedRBFBasisGenerator | | Normalized RBF basis generator | 68 | | ZeroPaddingNormalizedRBFBasisGenerator | | Normalized RBF with zero padding basis generator | 69 | | ProDMPBasisGenerator | | ProDMP basis generator | 70 | | | `ProDMP.pre_compute` | Precompute basis functions and other stuff. | 71 | | | `BasisGenerator.basis_and_phase` | Set basis and phase for the rhythmic basis generator | 72 | | | `BasisGenerator.times_to_indices` | Map time points to pre-compute indices | 73 | 74 | [Back to Overview](./) -------------------------------------------------------------------------------- /doc/04_movement_primitives.md: -------------------------------------------------------------------------------- 1 | # 4. Movement Primitives 2 | Currently, we provide three Movement Primitives, including Dynamic Movement Primitives(DMPs), Probablistic Movement Primitives(ProMPs) and Probablistic Dynamic Movement Primitives(ProDMPs). 3 | We also provide two Movement Primitives Interfaces(Dynamic based and Probablistic based), which can be used to define the custom Movement Primitives. 4 | You can register your custom MPs to the MPFactory and create the MP instance given configuration. 5 | 6 | The main features of the MP Factory and MP Interfaces are as follows: 7 | 8 | | Classes | Main Functions | Description | 9 | |--------------------------|------------------------------------------|-----------------------------------------------------------------------------------------------| 10 | | MPFactory | `MPFactory.init_mp` | Create an MP instance given configuration. | 11 | | MPInterface | | Abstract Basic Class for Deterministic Movement Primitives | 12 | | | `MPInterface.update_inputs` | Update MPs parameters | 13 | | | `MPInterface.get_trajs` | Get movement primitives trajectories given flag | 14 | | | `MPInterface.learn_mp_params_from_trajs` | Abstractmethod for learning parameters from trajectories | 15 | | ProbabilisticMPInterface | | Abstract Basic Class for Probablistic Movement Primitives | 16 | | | `MPInterface.update_inputs` | Update MPs parameters | 17 | | | `MPInterface.get_trajs` | Get movement primitives trajectories given flag, including trajectories mean and distribution | 18 | | | `MPInterface.sample_trajectories` | Sample trajectories from MPs | 19 | | | `MPInterface.learn_mp_params_from_trajs` | Abstractmethod for learning parameters from trajectories | 20 | 21 |   22 | ### 4.1 Dynamic Movement Primitives 23 | We provide a [DMPs demo](../mp_pytorch/demo/demo_dmp.py) to show how to create a DMPs instance and visualize the corresponding result. 24 | 25 | To run the demo, you can run the following code: 26 | ```python 27 | from mp_pytorch import demo 28 | demo.test_dmp() 29 | ``` 30 | 31 |   32 | ### 4.2 Probablistic Movement Primitives 33 | We provide a [ProMPs demo](../mp_pytorch/demo/demo_promp.py) to show how to create a ProMPs instance and visualize the corresponding result. 34 | 35 | To run the demo, you can run the following code: 36 | ```python 37 | from mp_pytorch import demo 38 | demo.test_promp() 39 | demo.test_zero_padding_promp() 40 | ``` 41 | 42 |   43 | ### 4.3 Probablistic Dynamic Movement Primitives 44 | [Probablistic Dynamic Movement Primitives(ProDMPs)](https://arxiv.org/abs/2210.01531) is a recently presented Method, which combing the Dynamic and Probablistic properties of Movement Primitives from a unified perspective. 45 | 46 | We provide a [ProDMPs demo](../mp_pytorch/demo/demo_prodmp.py) to show how to create a ProDMPs instance and visualize the corresponding result. 47 | To run the demo, you can run the following code: 48 | ```python 49 | from mp_pytorch import demo 50 | demo.test_prodmp() 51 | ``` 52 | 53 | 54 | [Back to Overview](./) 55 | -------------------------------------------------------------------------------- /doc/05_nn-based_mp.md: -------------------------------------------------------------------------------- 1 | # 5. Neural Networks based Movement Primitives 2 | **The corresponding docs and demos are under construction.** 3 | 4 | 5 | 6 | 7 | [Back to Overview](./) -------------------------------------------------------------------------------- /doc/README.md: -------------------------------------------------------------------------------- 1 | # MP_PyTorch User Guide 2 | 3 | ### Introduction and Quick Start 4 | - [1. Quick Start](./01_quickstart.md) 5 | 6 | ### Basic Movement Primitives 7 | - [2. Edit Configuration](./02_config.md) 8 | - [3. Phase and Basis Generation](./03_phase_and_basis.md) 9 | - [4. Movement Primitives](./04_movement_primitives.md) 10 | 11 | ### Combing Movement Primitives with Neural Networks 12 | - [5. NN-based Movement Primitives](./05_nn-based_mp.md) 13 | 14 | 15 | 16 | -------------------------------------------------------------------------------- /img/basis_norm_rbf.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ALRhub/MP_PyTorch/8c0413adeea9ff9df560d1ac0b6bdd0945f6c149/img/basis_norm_rbf.png -------------------------------------------------------------------------------- /img/basis_prodmp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ALRhub/MP_PyTorch/8c0413adeea9ff9df560d1ac0b6bdd0945f6c149/img/basis_prodmp.png -------------------------------------------------------------------------------- /mp_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.1.4" 2 | -------------------------------------------------------------------------------- /mp_pytorch/basis_gn/__init__.py: -------------------------------------------------------------------------------- 1 | from .basis_generator import * 2 | from .norm_rbf_basis import * 3 | from .prodmp_basis import * 4 | -------------------------------------------------------------------------------- /mp_pytorch/basis_gn/basis_generator.py: -------------------------------------------------------------------------------- 1 | """ 2 | @brief: Basis generators in PyTorch 3 | """ 4 | from typing import Tuple 5 | 6 | from mp_pytorch.phase_gn.phase_generator import * 7 | 8 | 9 | class BasisGenerator(ABC): 10 | @abstractmethod 11 | def __init__(self, 12 | phase_generator: PhaseGenerator, 13 | num_basis: int = 10, 14 | dtype: torch.dtype = torch.float32, 15 | device: torch.device = 'cpu', 16 | ): 17 | """ 18 | Constructor for basis class 19 | Args: 20 | phase_generator: phase generator 21 | num_basis: number of basis functions 22 | dtype: torch data type 23 | device: torch device to run on 24 | """ 25 | self.dtype = dtype 26 | self.device = device 27 | 28 | # Internal number of basis 29 | self._num_basis = num_basis 30 | self.phase_generator = phase_generator 31 | 32 | # Flag of finalized basis generator 33 | self.is_finalized = False 34 | 35 | @property 36 | def num_basis(self) -> int: 37 | """ 38 | Returns: the number of basis with learnable weights 39 | """ 40 | return self._num_basis 41 | 42 | @property 43 | def _num_local_params(self) -> int: 44 | """ 45 | Returns: number of parameters of current class 46 | """ 47 | return 0 48 | 49 | @property 50 | def num_params(self) -> int: 51 | """ 52 | Returns: number of parameters of current class plus parameters of all 53 | attributes 54 | """ 55 | return self._num_local_params + self.phase_generator.num_params 56 | 57 | def set_params(self, 58 | params: Union[torch.Tensor, np.ndarray]) -> torch.Tensor: 59 | """ 60 | Set parameters of current object and attributes 61 | Args: 62 | params: parameters to be set 63 | 64 | Returns: 65 | None 66 | """ 67 | params = torch.as_tensor(params, dtype=self.dtype, device=self.device) 68 | remaining_params = self.phase_generator.set_params(params) 69 | self.finalize() 70 | return remaining_params 71 | 72 | def get_params(self) -> torch.Tensor: 73 | """ 74 | Return all learnable parameters 75 | Returns: 76 | parameters 77 | """ 78 | # Shape of params 79 | # [*add_dim, num_params] 80 | params = self.phase_generator.get_params() 81 | return params 82 | 83 | def get_params_bounds(self) -> torch.Tensor: 84 | """ 85 | Return all learnable parameters' bounds 86 | Returns: 87 | parameters bounds 88 | """ 89 | # Shape of params_bounds 90 | # [2, num_params] 91 | 92 | params_bounds = self.phase_generator.get_params_bounds() 93 | return params_bounds 94 | 95 | @abstractmethod 96 | def basis(self, times: torch.Tensor) -> torch.Tensor: 97 | """ 98 | Interface to generate value of single basis function at given time 99 | points 100 | Args: 101 | times: times in Tensor 102 | 103 | Returns: 104 | basis functions in Tensor 105 | 106 | """ 107 | pass 108 | 109 | def basis_multi_dofs(self, 110 | times: torch.Tensor, 111 | num_dof: int) -> torch.Tensor: 112 | """ 113 | Interface to generate value of single basis function at given time 114 | points 115 | Args: 116 | times: times in Tensor 117 | num_dof: num of Degree of freedoms 118 | Returns: 119 | basis_multi_dofs: Multiple DoFs basis functions in Tensor 120 | 121 | """ 122 | # Shape of time 123 | # [*add_dim, num_times] 124 | # 125 | # Shape of basis_multi_dofs 126 | # [*add_dim, num_dof * num_times, num_dof * num_basis] 127 | 128 | # Extract additional dimensions 129 | add_dim = list(times.shape[:-1]) 130 | 131 | # Get single basis, shape: [*add_dim, num_times, num_basis] 132 | basis_single_dof = self.basis(times) 133 | num_times = basis_single_dof.shape[-2] 134 | num_basis = basis_single_dof.shape[-1] 135 | 136 | # Multiple Dofs, shape: 137 | # [*add_dim, num_dof * num_times, num_dof * num_basis] 138 | basis_multi_dofs = torch.zeros(*add_dim, num_dof * num_times, 139 | num_dof * num_basis, dtype=self.dtype, 140 | device=self.device) 141 | # Assemble 142 | for i in range(num_dof): 143 | row_indices = slice(i * num_times, (i + 1) * num_times) 144 | col_indices = slice(i * num_basis, (i + 1) * num_basis) 145 | basis_multi_dofs[..., row_indices, col_indices] = basis_single_dof 146 | 147 | # Return 148 | return basis_multi_dofs 149 | 150 | def finalize(self): 151 | """ 152 | Mark the basis generator as finalized so that the parameters cannot be 153 | updated any more 154 | Returns: None 155 | 156 | """ 157 | self.is_finalized = True 158 | 159 | def reset(self): 160 | """ 161 | Unmark the finalization 162 | Returns: None 163 | 164 | """ 165 | self.phase_generator.reset() 166 | self.is_finalized = False 167 | 168 | def show_basis(self, plot=False) -> Tuple[torch.Tensor, torch.Tensor]: 169 | """ 170 | Compute basis function values for debug usage 171 | The times are in the range of [delay - tau, delay + 2 * tau] 172 | 173 | Returns: basis function values 174 | 175 | """ 176 | tau = self.phase_generator.tau 177 | delay = self.phase_generator.delay 178 | assert tau.ndim == 0 and delay.ndim == 0 179 | times = torch.linspace(delay - tau, delay + 2 * tau, steps=1000) 180 | basis_values = self.basis(times) 181 | if plot: 182 | import matplotlib.pyplot as plt 183 | plt.figure() 184 | for i in range(basis_values.shape[-1]): 185 | plt.plot(times, basis_values[:, i], label=f"basis_{i}") 186 | plt.grid() 187 | plt.legend() 188 | plt.axvline(x=delay, linestyle='--', color='k', alpha=0.3) 189 | plt.axvline(x=delay + tau, linestyle='--', color='k', alpha=0.3) 190 | plt.show() 191 | return times, basis_values 192 | -------------------------------------------------------------------------------- /mp_pytorch/basis_gn/norm_rbf_basis.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from mp_pytorch.phase_gn import PhaseGenerator 4 | from .basis_generator import BasisGenerator 5 | from ..phase_gn import ExpDecayPhaseGenerator 6 | 7 | 8 | class NormalizedRBFBasisGenerator(BasisGenerator): 9 | 10 | def __init__(self, 11 | phase_generator: PhaseGenerator, 12 | num_basis: int = 10, 13 | basis_bandwidth_factor: int = 3, 14 | num_basis_outside: int = 0, 15 | dtype: torch.dtype = torch.float32, 16 | device: torch.device = 'cpu'): 17 | """ 18 | Constructor of class RBF 19 | 20 | Args: 21 | phase_generator: phase generator 22 | num_basis: number of basis function 23 | basis_bandwidth_factor: basis bandwidth factor 24 | num_basis_outside: basis function outside the duration 25 | dtype: torch data type 26 | device: torch device to run on 27 | """ 28 | self.basis_bandwidth_factor = basis_bandwidth_factor 29 | self.num_basis_outside = num_basis_outside 30 | 31 | super(NormalizedRBFBasisGenerator, self).__init__(phase_generator, 32 | num_basis, 33 | dtype, device) 34 | 35 | # Compute centers and bandwidth 36 | # Distance between basis centers 37 | assert self.phase_generator.tau.nelement() == 1 38 | 39 | if self._num_basis > 1: 40 | basis_dist = self.phase_generator.tau / (self._num_basis - 2 * 41 | self.num_basis_outside - 1) 42 | 43 | # RBF centers in time scope 44 | centers_t = torch.linspace(-self.num_basis_outside * basis_dist 45 | + self.phase_generator.delay, 46 | self.num_basis_outside * basis_dist 47 | + self.phase_generator.tau 48 | + self.phase_generator.delay, 49 | self._num_basis, dtype=self.dtype, 50 | device=self.device) 51 | delta_center = centers_t[1] - centers_t[0] 52 | centers_t = torch.cat([centers_t, 53 | torch.atleast_1d( 54 | centers_t[-1] + delta_center)], 55 | dim=-1) 56 | centers_p = self.phase_generator.unbound_phase(centers_t) 57 | # RBF centers in phase scope 58 | self.centers_p = centers_p[:-1] 59 | 60 | tmp_bandwidth = centers_p[1:] - centers_p[:-1] 61 | if isinstance(phase_generator, ExpDecayPhaseGenerator) \ 62 | and self._num_basis == 2: 63 | tmp_bandwidth[-1] = tmp_bandwidth[-1] * 2 64 | 65 | elif self._num_basis == 1: 66 | # RBF centers in time scope 67 | centers_t = torch.tensor([self.phase_generator.delay 68 | + 0.5 * self.phase_generator.tau], 69 | dtype=self.dtype, device=self.device) 70 | # RBF centers in phase scope 71 | self.centers_p = self.phase_generator.unbound_phase(centers_t) 72 | tmp_bandwidth = torch.tensor([1], dtype=self.dtype, 73 | device=self.device) 74 | 75 | else: 76 | raise NotImplementedError 77 | 78 | # The Centers should not overlap too much (makes w almost random due 79 | # to aliasing effect).Empirically chosen 80 | self.bandwidth = self.basis_bandwidth_factor / (tmp_bandwidth ** 2) 81 | 82 | def basis(self, times: torch.Tensor) -> torch.Tensor: 83 | """ 84 | Generate values of basis function at given time points 85 | Args: 86 | times: times in Tensor 87 | 88 | Returns: 89 | basis: basis functions in Tensor 90 | """ 91 | # Shape of times: 92 | # [*add_dim, num_times] 93 | # 94 | # Shape of basis: 95 | # [*add_dim, num_times, num_basis] 96 | 97 | # Extract dimension 98 | num_times = times.shape[-1] 99 | 100 | # Time to phase 101 | phase = self.phase_generator.phase(times) 102 | 103 | # Add one axis (basis centers) to phase and get shape: 104 | # [*add_dim, num_times, num_basis] 105 | phase = phase[..., None] 106 | phase = phase.expand([*phase.shape[:-1], self._num_basis]) 107 | 108 | # Add one axis (times) to centers in phase scope and get shape: 109 | # [num_times, num_basis] 110 | centers = self.centers_p[None, :] 111 | centers = centers.expand([num_times, -1]) 112 | 113 | # Basis 114 | tmp = torch.einsum('...ij,...j->...ij', (phase - centers) ** 2, 115 | self.bandwidth) 116 | basis = torch.exp(-tmp / 2) 117 | 118 | # Normalization 119 | if self._num_basis > 1: 120 | sum_basis = torch.sum(basis, dim=-1, keepdim=True) 121 | basis = basis / (sum_basis + 1e-15) 122 | 123 | # Return 124 | return basis 125 | 126 | 127 | class ZeroPaddingNormalizedRBFBasisGenerator(NormalizedRBFBasisGenerator): 128 | def __init__(self, 129 | phase_generator: PhaseGenerator, 130 | num_basis: int = 10, 131 | num_basis_zero_start: int = 2, 132 | num_basis_zero_goal: int = 0, 133 | basis_bandwidth_factor: float = 3, 134 | dtype: torch.dtype = torch.float32, 135 | device: torch.device = 'cpu'): 136 | """ 137 | Constructor of class RBF with zero padding basis functions 138 | Args: 139 | phase_generator: phase generator 140 | num_basis: number of basis function 141 | num_basis_zero_start: number of basis padding in front 142 | num_basis_zero_goal: number of basis padding afterwards 143 | basis_bandwidth_factor: basis bandwidth factor 144 | dtype: data type 145 | device: device of the data 146 | """ 147 | self.num_basis_zero_start = num_basis_zero_start 148 | self.num_basis_zero_goal = num_basis_zero_goal 149 | super().__init__(phase_generator=phase_generator, 150 | num_basis=num_basis + num_basis_zero_start 151 | + num_basis_zero_goal, 152 | basis_bandwidth_factor=basis_bandwidth_factor, 153 | num_basis_outside=0, 154 | dtype=dtype, device=device) 155 | 156 | @property 157 | def num_basis(self): 158 | return super().num_basis - self.num_basis_zero_start \ 159 | - self.num_basis_zero_goal 160 | -------------------------------------------------------------------------------- /mp_pytorch/basis_gn/prodmp_basis.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import torch 4 | 5 | from mp_pytorch import util 6 | from mp_pytorch.phase_gn import ExpDecayPhaseGenerator 7 | from .norm_rbf_basis import NormalizedRBFBasisGenerator 8 | 9 | 10 | class ProDMPBasisGenerator(NormalizedRBFBasisGenerator): 11 | def __init__(self, phase_generator: ExpDecayPhaseGenerator, 12 | num_basis: int = 10, 13 | basis_bandwidth_factor: int = 3, 14 | num_basis_outside: int = 0, 15 | dt: float = 0.01, 16 | alpha: float = 25, 17 | pre_compute_length_factor=6, 18 | dtype: torch.dtype = torch.float32, 19 | device: torch.device = 'cpu'): 20 | """ 21 | 22 | Args: 23 | phase_generator: phase generator 24 | num_basis: number of basis function 25 | basis_bandwidth_factor: basis bandwidth factor 26 | num_basis_outside: basis function outside the duration 27 | dt: time step 28 | alpha: alpha value of DMP 29 | pre_compute_length_factor: (n x tau) time length in pre-computation 30 | dtype: data type 31 | device: device of the data 32 | """ 33 | super(ProDMPBasisGenerator, self).__init__(phase_generator, 34 | num_basis, 35 | basis_bandwidth_factor, 36 | num_basis_outside, 37 | dtype, device) 38 | 39 | self.alpha = alpha 40 | self.scaled_dt = dt / self.phase_generator.tau 41 | 42 | assert pre_compute_length_factor <= 6, \ 43 | "For numerical stability, please use a length factor <= 5." 44 | self.pre_compute_length_factor = pre_compute_length_factor 45 | 46 | self.y_1_value = None 47 | self.y_2_value = None 48 | self.dy_1_value = None 49 | self.dy_2_value = None 50 | self.pc_pos_basis = None 51 | self.pc_vel_basis = None 52 | 53 | self.num_basis_g = self.num_basis + 1 54 | self.auto_basis_scale_factors = None 55 | self.pre_compute() 56 | 57 | def pre_compute(self): 58 | """ 59 | Precompute basis functions and other stuff 60 | 61 | Returns: None 62 | 63 | """ 64 | 65 | # Shape of pc_scaled_time 66 | # [num_pc_times] 67 | 68 | # Shape of y_1_value, y_2_value, dy_1_value, dy_2_value: 69 | # [num_pc_times] 70 | # 71 | # Shape of q_1_value, q_2_value: 72 | # [num_pc_times] 73 | # 74 | # Shape of p_1_value, p_2_value: 75 | # [num_pc_times, num_basis] 76 | # 77 | # Shape of pos_basis, vel_basis: 78 | # [num_pc_times, num_basis_g] 79 | # Note: num_basis_g = num_basis + 1 80 | 81 | # Pre-compute scaled time steps in [0, 1] 82 | num_pre_compute = self.pre_compute_length_factor * \ 83 | torch.round(1 / self.scaled_dt).long().item() + 1 84 | pc_scaled_times = torch.linspace(0, self.pre_compute_length_factor, 85 | num_pre_compute, dtype=self.dtype, 86 | device=self.device) 87 | 88 | # y1 and y2 89 | self.y_1_value = torch.exp(-0.5 * self.alpha * pc_scaled_times) 90 | self.y_2_value = pc_scaled_times * self.y_1_value 91 | 92 | self.dy_1_value = -0.5 * self.alpha * self.y_1_value 93 | self.dy_2_value = -0.5 * self.alpha * self.y_2_value + self.y_1_value 94 | 95 | # q_1 and q_2 96 | q_1_value = \ 97 | (0.5 * self.alpha * pc_scaled_times - 1) \ 98 | * torch.exp(0.5 * self.alpha * pc_scaled_times) + 1 99 | q_2_value = \ 100 | 0.5 * self.alpha \ 101 | * (torch.exp(0.5 * self.alpha * pc_scaled_times) - 1) 102 | 103 | # Get basis of one DOF, shape [num_pc_times, num_basis] 104 | pc_times = self.phase_generator.linear_phase_to_time(pc_scaled_times) 105 | 106 | basis_single_dof = super().basis(pc_times) 107 | assert list(basis_single_dof.shape) == [*pc_times.shape, 108 | self.num_basis] 109 | 110 | # Get canonical phase x, shape [num_pc_times] 111 | canonical_x = self.phase_generator.phase(pc_times) 112 | assert list(canonical_x.shape) == [*pc_times.shape] 113 | 114 | # p_1 and p_2 115 | dp_1_value = \ 116 | torch.einsum('...i,...i,...ij->...ij', 117 | pc_scaled_times 118 | * torch.exp(self.alpha * pc_scaled_times / 2), 119 | canonical_x, 120 | basis_single_dof) 121 | dp_2_value = \ 122 | torch.einsum('...i,...i,...ij->...ij', 123 | torch.exp(self.alpha * pc_scaled_times / 2), 124 | canonical_x, 125 | basis_single_dof) 126 | 127 | p_1_value = torch.zeros(size=dp_1_value.shape, dtype=self.dtype, 128 | device=self.device) 129 | p_2_value = torch.zeros(size=dp_2_value.shape, dtype=self.dtype, 130 | device=self.device) 131 | 132 | for i in range(pc_scaled_times.shape[0]): 133 | p_1_value[i] = torch.trapz(dp_1_value[:i + 1], 134 | pc_scaled_times[:i + 1], dim=0) 135 | p_2_value[i] = torch.trapz(dp_2_value[:i + 1], 136 | pc_scaled_times[:i + 1], dim=0) 137 | 138 | # Compute integral form basis values 139 | pos_basis_w = p_2_value * self.y_2_value[:, None] \ 140 | - p_1_value * self.y_1_value[:, None] 141 | pos_basis_g = q_2_value * self.y_2_value \ 142 | - q_1_value * self.y_1_value 143 | vel_basis_w = p_2_value * self.dy_2_value[:, None] \ 144 | - p_1_value * self.dy_1_value[:, None] 145 | vel_basis_g = q_2_value * self.dy_2_value \ 146 | - q_1_value * self.dy_1_value 147 | 148 | # Pre-computed pos and vel basis 149 | self.pc_pos_basis = \ 150 | torch.cat([pos_basis_w, pos_basis_g[:, None]], dim=-1) 151 | self.pc_vel_basis = \ 152 | torch.cat([vel_basis_w, vel_basis_g[:, None]], dim=-1) 153 | 154 | self.auto_compute_basis_scale_factors() 155 | 156 | def auto_compute_basis_scale_factors(self): 157 | """ 158 | Compute scale factors for each basis function 159 | :return: None 160 | """ 161 | assert self.pc_pos_basis is not None, "Pos basis is not pre-computed." 162 | self.auto_basis_scale_factors = 1. / self.pc_pos_basis.max(axis=0).values 163 | 164 | def times_to_indices(self, times: torch.Tensor, round_int: bool = True): 165 | """ 166 | Map time points to pre-compute indices 167 | Args: 168 | times: time points 169 | round_int: if indices should be rounded to the closest integer 170 | 171 | Returns: 172 | time indices 173 | """ 174 | # times to scaled times 175 | 176 | scaled_times = self.phase_generator.left_bound_linear_phase(times) 177 | if scaled_times.max() > self.pre_compute_length_factor: 178 | raise RuntimeError("Time is beyond the pre-computation range. " 179 | "Set larger pre-computation factor") 180 | indices = scaled_times / self.scaled_dt 181 | if round_int: 182 | indices = torch.round(indices).long() 183 | 184 | return indices 185 | 186 | def basis(self, times: torch.Tensor): 187 | """ 188 | Generate values of basis function at given time points 189 | Args: 190 | times: times in Tensor 191 | 192 | Returns: 193 | basis: basis functions in Tensor 194 | """ 195 | # Shape of times: 196 | # [*add_dim, num_times] 197 | # 198 | # Shape of basis: 199 | # [*add_dim, num_times, num_basis_g] 200 | time_indices = self.times_to_indices(times, False) 201 | basis = util.indexing_interpolate(data=self.pc_pos_basis, 202 | indices=time_indices) 203 | return basis 204 | 205 | def vel_basis(self, times: torch.Tensor): 206 | """ 207 | Generate values of velocity basis function at given time points 208 | Args: 209 | times: times in Tensor 210 | 211 | Returns: 212 | vel_basis: velocity basis functions in Tensor 213 | """ 214 | # Shape of times: 215 | # [*add_dim, num_times] 216 | # 217 | # Shape of vel_basis: 218 | # [*add_dim, num_times, num_basis_g] 219 | 220 | time_indices = self.times_to_indices(times, False) 221 | 222 | vel_basis = util.indexing_interpolate(data=self.pc_vel_basis, 223 | indices=time_indices) 224 | return vel_basis 225 | 226 | def basis_multi_dofs(self, times: torch.Tensor, num_dof: int): 227 | """ 228 | Generate blocked-diagonal multiple dof basis matrix 229 | 230 | Args: 231 | times: time points 232 | num_dof: num of dof 233 | 234 | Returns: 235 | pos_basis_multi_dofs 236 | """ 237 | # Shape of time 238 | # [*add_dim, num_times] 239 | # 240 | # Shape of pos_basis_multi_dofs 241 | # [*add_dim, num_dof * num_times, num_dof * num_basis_g] 242 | 243 | # Here the super class will take the last dimension of a single basis 244 | # matrix as num_basis, so no worries for the extra goal basis term 245 | pos_basis_multi_dofs = super().basis_multi_dofs(times, num_dof) 246 | return pos_basis_multi_dofs 247 | 248 | def vel_basis_multi_dofs(self, times: torch.Tensor, num_dof: int): 249 | """ 250 | Generate blocked-diagonal multiple dof velocity basis matrix 251 | 252 | Args: 253 | times: times in Tensor 254 | num_dof: num of Degree of freedoms 255 | 256 | Returns: 257 | vel_basis_multi_dofs: Multiple DoFs velocity basis functions 258 | 259 | """ 260 | # Shape of time 261 | # [*add_dim, num_times] 262 | # 263 | # Shape of vel_basis_multi_dofs 264 | # [*add_dim, num_dof * num_times, num_dof * num_basis_g] 265 | 266 | # Extract additional dimensions 267 | add_dim = list(times.shape[:-1]) 268 | 269 | # Get single basis, shape: [*add_dim, num_times, num_basis_g] 270 | vel_basis_single_dof = self.vel_basis(times) 271 | num_times = vel_basis_single_dof.shape[-2] 272 | 273 | # Multiple Dofs, shape: 274 | # [*add_dim, num_times, num_dof, num_dof * num_basis] 275 | vel_basis_multi_dofs = torch.zeros(*add_dim, 276 | num_dof * num_times, 277 | num_dof * self.num_basis_g, 278 | dtype=self.dtype, device=self.device) 279 | # Assemble 280 | for i in range(num_dof): 281 | row_indices = slice(i * num_times, 282 | (i + 1) * num_times) 283 | col_indices = slice(i * self.num_basis_g, 284 | (i + 1) * self.num_basis_g) 285 | vel_basis_multi_dofs[..., row_indices, col_indices] = \ 286 | vel_basis_single_dof 287 | 288 | # Return 289 | return vel_basis_multi_dofs 290 | 291 | def general_solution_values(self, times: torch.Tensor): 292 | """ 293 | Get values of general solution functions and their derivatives 294 | 295 | Args: 296 | times: time points 297 | 298 | Returns: 299 | values of y1, y2, dy1, dy2 at given time steps 300 | """ 301 | # Shape of times 302 | # [*add_dim, num_times] 303 | # 304 | # Shape of each return 305 | # [*add_dim, num_times] 306 | 307 | time_indices = self.times_to_indices(times, False) 308 | 309 | y_1_value = util.indexing_interpolate(data=self.y_1_value, 310 | indices=time_indices) 311 | y_2_value = util.indexing_interpolate(data=self.y_2_value, 312 | indices=time_indices) 313 | dy_1_value = util.indexing_interpolate(data=self.dy_1_value, 314 | indices=time_indices) 315 | dy_2_value = util.indexing_interpolate(data=self.dy_2_value, 316 | indices=time_indices) 317 | 318 | return y_1_value, y_2_value, dy_1_value, dy_2_value 319 | 320 | def show_basis(self, plot=False) -> Tuple[torch.Tensor, torch.Tensor]: 321 | """ 322 | Compute basis function values for debug usage 323 | The times are in the range of [delay - tau, delay + 2 * tau] 324 | 325 | Returns: basis function values 326 | 327 | """ 328 | tau = self.phase_generator.tau 329 | delay = self.phase_generator.delay 330 | assert tau.ndim == 0 and delay.ndim == 0 331 | times = torch.linspace(delay - tau, delay + 2 * tau, steps=1000) 332 | basis_values = self.basis(times) 333 | if plot: 334 | import matplotlib.pyplot as plt 335 | fig, axes = plt.subplots(1, 2, sharex=True, squeeze=False) 336 | for i in range(basis_values.shape[-1] - 1): 337 | axes[0, 0].plot(times, basis_values[:, i], label=f"w_basis_{i}") 338 | axes[0, 0].grid() 339 | axes[0, 0].legend() 340 | axes[0, 0].axvline(x=delay, linestyle='--', color='k', alpha=0.3) 341 | axes[0, 0].axvline(x=delay + tau, linestyle='--', color='k', 342 | alpha=0.3) 343 | 344 | axes[0, 1].plot(times, basis_values[:, -1], label=f"goal_basis") 345 | axes[0, 1].grid() 346 | axes[0, 1].legend() 347 | axes[0, 1].axvline(x=delay, linestyle='--', color='k', alpha=0.3) 348 | axes[0, 1].axvline(x=delay + tau, linestyle='--', color='k', 349 | alpha=0.3) 350 | 351 | plt.show() 352 | return times, basis_values 353 | 354 | def get_basis_scale_factors(self): 355 | """ 356 | Compute the scale factors of all basis functions, so that their 357 | magnitudes are all equal to 1 358 | 359 | Returns: 360 | auto_basis_scale_factors: scale factors 361 | """ 362 | assert self.auto_basis_scale_factors is not None, "Basis scale factors is not computed." 363 | return self.auto_basis_scale_factors 364 | -------------------------------------------------------------------------------- /mp_pytorch/basis_gn/rhytmic_basis.py: -------------------------------------------------------------------------------- 1 | # TODO: some things still missing 2 | from typing import Tuple 3 | 4 | import numpy as np 5 | 6 | from mp_pytorch import BasisGenerator 7 | from mp_pytorch import PhaseGenerator 8 | 9 | 10 | class RhythmicBasisGenerator(BasisGenerator): 11 | 12 | def __init__( 13 | self, phase_generator: PhaseGenerator, n_basis: int = 5, 14 | duration: float = 1, 15 | basis_bandwidth_factor: float = 3 16 | ): 17 | BasisGenerator.__init__(self, phase_generator, n_basis) 18 | 19 | self.num_bandwidth_factor = basis_bandwidth_factor 20 | self.centers = np.linspace(0, 1, self.n_basis) 21 | 22 | tmp_bandwidth = np.hstack((self.centers[1:] - self.centers[0:-1], 23 | self.centers[-1] - self.centers[- 2])) 24 | 25 | # The Centers should not overlap too much (makes w almost random due to aliasing effect).Empirically chosen 26 | self.bandwidth = self.num_bandwidth_factor / (tmp_bandwidth ** 2) 27 | 28 | def basis_and_phase(self, t: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: 29 | phase = self.getInputTensorIndex(0) 30 | 31 | diff = phase - self.centers 32 | diff_cos = np.array([np.cos(diff * self.bandwidth * 2 * np.pi)]) 33 | basis = np.exp(diff_cos) 34 | 35 | sum_b = np.sum(basis, axis=1) 36 | basis = [column / sum_b for column in basis.transpose()] 37 | return np.array(basis).transpose(), phase 38 | -------------------------------------------------------------------------------- /mp_pytorch/demo/__init__.py: -------------------------------------------------------------------------------- 1 | from .demo_mp_config import * 2 | from .demo_basis_gn import * 3 | from .demo_data_type_cast import * 4 | from .demo_delay_and_scale import * 5 | from .demo_dmp import * 6 | from .demo_promp import * 7 | from .demo_prodmp import * 8 | from .demo_prodmp_autoscale import * 9 | -------------------------------------------------------------------------------- /mp_pytorch/demo/demo_basis_gn.py: -------------------------------------------------------------------------------- 1 | from matplotlib import pyplot as plt 2 | 3 | from mp_pytorch.basis_gn import NormalizedRBFBasisGenerator 4 | from mp_pytorch.basis_gn import ProDMPBasisGenerator 5 | from mp_pytorch.phase_gn import ExpDecayPhaseGenerator 6 | from mp_pytorch.phase_gn import LinearPhaseGenerator 7 | 8 | 9 | def demo_norm_rbf_basis(): 10 | phase_gn = LinearPhaseGenerator(tau=3, delay=1, 11 | learn_tau=False, learn_delay=False) 12 | basis_gn = NormalizedRBFBasisGenerator(phase_generator=phase_gn, 13 | num_basis=10, 14 | basis_bandwidth_factor=3, 15 | num_basis_outside=0) 16 | basis_gn.show_basis(plot=True) 17 | 18 | 19 | def demo_norm_rbf_basis_with_exp_decay_phase(): 20 | phase_gn = ExpDecayPhaseGenerator(tau=3, delay=1, alpha_phase=3, 21 | learn_tau=False, learn_delay=False, 22 | learn_alpha_phase=False) 23 | basis_gn = NormalizedRBFBasisGenerator(phase_generator=phase_gn, 24 | num_basis=10, 25 | basis_bandwidth_factor=3, 26 | num_basis_outside=0) 27 | basis_gn.show_basis(plot=True) 28 | 29 | 30 | def demo_prodmp_basis(): 31 | phase_gn = ExpDecayPhaseGenerator(tau=3, delay=1, alpha_phase=3, 32 | learn_tau=False, learn_delay=False, 33 | learn_alpha_phase=False) 34 | basis_gn = ProDMPBasisGenerator(phase_generator=phase_gn, 35 | num_basis=10, 36 | basis_bandwidth_factor=3, 37 | pre_compute_length_factor=6, 38 | num_basis_outside=0) 39 | basis_gn.show_basis(plot=True) 40 | 41 | 42 | if __name__ == "__main__": 43 | demo_norm_rbf_basis() 44 | demo_norm_rbf_basis_with_exp_decay_phase() 45 | demo_prodmp_basis() 46 | -------------------------------------------------------------------------------- /mp_pytorch/demo/demo_data_type_cast.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | import mp_pytorch.util as util 5 | 6 | 7 | def test_to_ts(): 8 | util.print_wrap_title("test_to_ts") 9 | 10 | a = torch.Tensor([1, 2, 3]) 11 | b = torch.Tensor([1, 2, 3]).double() 12 | c = 3.14 13 | d = np.array([1, 2, 3]) # This is a float 64 array 14 | e = np.array([1, 2, 3], dtype=float) 15 | 16 | util.print_line_title("Original data") 17 | for data in [a, b, c, d, e]: 18 | print(f"data: {data}") 19 | 20 | for data_type in [torch.float32, torch.float64]: 21 | for device in ["cpu", "cuda"]: 22 | util.print_line_title(f"data_type: {data_type}, device: {device}") 23 | for data in [a, b, c, d, e]: 24 | tensor_data = util.to_ts(data, data_type, device) 25 | print(tensor_data) 26 | print(tensor_data.device) 27 | print(tensor_data.type(), "\n") 28 | 29 | 30 | def test_to_tss(): 31 | util.print_wrap_title("test_to_tss") 32 | a = torch.Tensor([1, 2, 3]) 33 | b = torch.Tensor([1, 2, 3]).double() 34 | c = 3.14 35 | d = np.array([1, 2, 3]) # This is a float 64 array 36 | e = np.array([1, 2, 3], dtype=float) 37 | 38 | util.print_line_title("Original data") 39 | for data in [a, b, c, d, e]: 40 | print(f"data: {data}") 41 | 42 | util.print_line_title("Casted data") 43 | a, b, c, d, e = util.to_tss(a, b, c, d, e, dtype=torch.float64, 44 | device="cuda") 45 | for data in [a, b, c, d, e]: 46 | util.print_line() 47 | print(data) 48 | print(data.device) 49 | print(data.type(), "\n") 50 | 51 | 52 | if __name__ == '__main__': 53 | test_to_ts() 54 | test_to_tss() 55 | -------------------------------------------------------------------------------- /mp_pytorch/demo/demo_delay_and_scale.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import torch 3 | from addict import Dict 4 | 5 | from mp_pytorch.demo import get_mp_utils 6 | from mp_pytorch.mp import MPFactory 7 | from mp_pytorch.mp import ProMP 8 | from mp_pytorch import util 9 | 10 | 11 | def get_mp_scale_and_delay_util(mp_type: str, tau: float, delay: float): 12 | config, _, params, params_L, _, init_pos, _, _ = get_mp_utils(mp_type, False, 13 | False) 14 | config = Dict(config) 15 | config.tau = tau 16 | config.delay = delay 17 | num_traj = params.shape[0] 18 | num_t = int((tau + delay) / config.mp_args.dt) * 2 + 1 19 | times = util.tensor_linspace(0, torch.ones([num_traj, 1]) 20 | * tau + delay, num_t).squeeze(-1) 21 | init_time = times[:, 0] + delay 22 | init_vel = torch.zeros_like(init_pos) 23 | 24 | return config.to_dict(), times, params, params_L, init_time, init_pos, init_vel 25 | 26 | 27 | def test_static_delay_and_scale(): 28 | tau_list = [1.0, 2.0, 3.0] 29 | delay_list = [0.0, 1.0, 2.0] 30 | mp_list = ["promp", "dmp", "prodmp"] 31 | time_max = tau_list[-1] + delay_list[-1] 32 | 33 | for mp_type in mp_list: 34 | 35 | fig, axes = plt.subplots(len(tau_list), len(delay_list), sharex='all', 36 | sharey='all', squeeze=False) 37 | fig.suptitle(f"Static scale and delay of {mp_type}") 38 | for i, tau in enumerate(tau_list): 39 | for j, delay in enumerate(delay_list): 40 | config, _, params, params_L, init_time, init_pos, init_vel = \ 41 | get_mp_scale_and_delay_util(mp_type, tau, delay) 42 | config = Dict(config) 43 | num_traj = params.shape[0] 44 | num_t = int(time_max / config.mp_args.dt) * 2 + 1 45 | 46 | times = util.tensor_linspace(0, torch.ones( 47 | [num_traj, 1]) * time_max, num_t).squeeze(-1) 48 | 49 | mp = MPFactory.init_mp(**config) 50 | 51 | init_time = times[:, 0] 52 | mp.update_inputs(times=times, params=params, 53 | params_L=params_L, 54 | init_time=init_time, init_pos=init_pos, 55 | init_vel=init_vel) 56 | traj_pos = mp.get_traj_pos()[0, :, 0] 57 | traj_pos = util.to_np(traj_pos) 58 | 59 | times = util.to_np(times[0]) 60 | axes[i, j].plot(times, traj_pos) 61 | 62 | if isinstance(mp, ProMP): 63 | traj_std = mp.get_traj_pos_std()[0, :, 0] 64 | traj_std = util.to_np(traj_std) 65 | util.fill_between(times, traj_pos, traj_std, axes[i, j]) 66 | 67 | axes[i, j].axvline(x=delay, linestyle='--', color='r', 68 | alpha=0.3) 69 | axes[i, j].axvline(x=tau + delay, linestyle='--', color='r', 70 | alpha=0.3) 71 | axes[i, j].grid(alpha=0.2) 72 | axes[i, j].title.set_text(f"Scale: {tau}s, Delay: {delay}s") 73 | 74 | plt.show() 75 | 76 | 77 | def test_learnable_delay_and_scale(): 78 | tau_list = [1.0, 2.0, 3.0] 79 | delay_list = [0.0, 1.0, 2.0] 80 | mp_list = ["promp", "dmp", "prodmp"] 81 | for mp_type in mp_list: 82 | config = get_mp_utils(mp_type, learn_tau=True, learn_delay=True)[0] 83 | config = Dict(config) 84 | # Generate parameters 85 | num_param = config.num_dof * config.mp_args.num_basis 86 | params_scale_factor = 100 87 | params_L_scale_factor = 10 88 | 89 | if "dmp" in config.mp_type: 90 | num_param += config.num_dof 91 | params_scale_factor = 1000 92 | params_L_scale_factor = 0.1 93 | 94 | # assume we have 3 trajectories in a batch 95 | num_traj = len(tau_list) * len(delay_list) 96 | time_max = tau_list[-1] + delay_list[-1] 97 | num_t = int(time_max / config.mp_args.dt) * 2 + 1 98 | times = util.tensor_linspace(0, torch.ones([num_traj, 1]) * time_max, 99 | num_t).squeeze(-1) 100 | 101 | torch.manual_seed(0) 102 | params = torch.randn([1, num_param]).expand([num_traj, num_param]) \ 103 | * params_scale_factor 104 | if "dmp" in config.mp_type: 105 | params[:, config.mp_args.num_basis::config.mp_args.num_basis] \ 106 | *= 0.001 107 | 108 | lct = torch.distributions.transforms.LowerCholeskyTransform( 109 | cache_size=0) 110 | params_L = lct(torch.randn([1, num_param, num_param]).expand( 111 | [num_traj, num_param, num_param])) * params_L_scale_factor 112 | 113 | tau_delay = torch.zeros([num_traj, 2]) 114 | for i, tau in enumerate(tau_list): 115 | for j, delay in enumerate(delay_list): 116 | tau_delay[i * len(tau_list) + j] = torch.Tensor([tau, delay]) 117 | params = torch.cat([tau_delay, params], dim=-1) 118 | 119 | init_time = times[:, 0] 120 | init_pos = 5 * torch.ones([num_traj, config.num_dof]) 121 | init_vel = torch.zeros_like(init_pos) 122 | 123 | mp = MPFactory.init_mp(**config) 124 | mp.update_inputs(times=times, params=params, 125 | params_L=params_L, 126 | init_time=init_time, init_pos=init_pos, 127 | init_vel=init_vel) 128 | 129 | traj_pos = mp.get_traj_pos()[..., 0] 130 | traj_pos = util.to_np(traj_pos) 131 | 132 | times = util.to_np(times) 133 | 134 | fig, axes = plt.subplots(len(tau_list), len(delay_list), sharex='all', 135 | sharey='all', squeeze=False) 136 | fig.suptitle(f"Learnable scale and delay of {mp_type}") 137 | for i, tau in enumerate(tau_list): 138 | for j, delay in enumerate(delay_list): 139 | axes[i, j].plot(times[i * len(tau_list) + j], 140 | traj_pos[i * len(tau_list) + j]) 141 | 142 | if isinstance(mp, ProMP): 143 | traj_std = mp.get_traj_pos_std()[..., 0] 144 | traj_std = util.to_np(traj_std) 145 | util.fill_between(times[i * len(tau_list) + j], 146 | traj_pos[i * len(tau_list) + j], 147 | traj_std[i * len(tau_list) + j], 148 | axes[i, j]) 149 | axes[i, j].axvline(x=delay, linestyle='--', color='r', 150 | alpha=0.3) 151 | axes[i, j].axvline(x=tau + delay, linestyle='--', color='r', 152 | alpha=0.3) 153 | axes[i, j].title.set_text(f"Scale: {tau}s, Delay: {delay}s") 154 | axes[i, j].grid(alpha=0.2) 155 | 156 | plt.show() 157 | 158 | 159 | if __name__ == '__main__': 160 | test_static_delay_and_scale() 161 | test_learnable_delay_and_scale() 162 | -------------------------------------------------------------------------------- /mp_pytorch/demo/demo_dmp.py: -------------------------------------------------------------------------------- 1 | """ 2 | @brief: testing MPs 3 | """ 4 | 5 | import mp_pytorch.util as util 6 | from mp_pytorch.demo import get_mp_utils 7 | from mp_pytorch.mp import MPFactory 8 | 9 | 10 | def test_dmp(): 11 | util.print_wrap_title("test_dmp") 12 | config, times, params, params_L, init_time, init_pos, init_vel, demos = \ 13 | get_mp_utils("dmp", True, True) 14 | 15 | mp = MPFactory.init_mp(**config) 16 | 17 | # params_L here is redundant, but it will not fail the update func 18 | 19 | # Uncomment this line below if you want to exclude init_time from prediction 20 | # times = times[..., 1:] 21 | 22 | mp.update_inputs(times=times, params=params, params_L=params_L, 23 | init_time=init_time, init_pos=init_pos, init_vel=init_vel) 24 | 25 | traj_dict = mp.get_trajs(get_pos=True, get_vel=True) 26 | 27 | # Pos 28 | util.print_line_title("pos") 29 | print(traj_dict["pos"].shape) 30 | util.debug_plot(times[0], [traj_dict["pos"][0, :, 0]], title="dmp_pos") 31 | 32 | # Vel 33 | util.print_line_title("vel") 34 | util.debug_plot(times[0], [traj_dict["vel"][0, :, 0]], title="dmp_vel") 35 | 36 | # Parameters demo 37 | util.print_line_title("params_bounds") 38 | low, high = mp.get_params_bounds() 39 | print("Lower bound", low, sep="\n") 40 | print("Upper bound", high, sep="\n") 41 | print(mp.get_params_bounds().shape) 42 | 43 | # Show scaled basis 44 | config, times, params, params_L, init_time, init_pos, init_vel, demos = \ 45 | get_mp_utils("dmp", False, False) 46 | 47 | mp = MPFactory.init_mp(**config) 48 | mp.show_scaled_basis(plot=True) 49 | 50 | 51 | def main(): 52 | test_dmp() 53 | 54 | 55 | if __name__ == "__main__": 56 | main() 57 | -------------------------------------------------------------------------------- /mp_pytorch/demo/demo_mp_config.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from addict import Dict 3 | 4 | import mp_pytorch.util as util 5 | 6 | 7 | def get_mp_utils(mp_type: str, learn_tau=False, learn_delay=False, 8 | relative_goal=False): 9 | torch.manual_seed(0) 10 | config = Dict() 11 | 12 | config.num_dof = 2 13 | config.tau = 3 14 | config.learn_tau = learn_tau 15 | config.learn_delay = learn_delay 16 | 17 | config.mp_args.num_basis = 10 18 | config.mp_args.basis_bandwidth_factor = 2 19 | config.mp_args.num_basis_outside = 0 20 | config.mp_args.alpha = 25 21 | config.mp_args.alpha_phase = 2 22 | config.mp_args.dt = 0.01 23 | config.mp_args.weights_scale = torch.ones([config.mp_args.num_basis]) 24 | # config.mp_args.weights_scale = 10 25 | config.mp_args.goal_scale = 1 26 | config.mp_args.relative_goal = relative_goal 27 | config.mp_type = mp_type 28 | 29 | if mp_type == "zero_padding_promp": 30 | config.mp_args.num_basis_zero_start = int( 31 | 0.4 * config.mp_args.num_basis) 32 | config.mp_args.num_basis_zero_goal = 0 33 | 34 | # Generate parameters 35 | num_param = config.num_dof * config.mp_args.num_basis 36 | params_scale_factor = 100 37 | params_L_scale_factor = 10 38 | 39 | if "dmp" in config.mp_type: 40 | num_param += config.num_dof 41 | params_scale_factor = 1000 42 | params_L_scale_factor = 0.3 43 | 44 | # assume we have 3 trajectories in a batch 45 | num_traj = 3 46 | num_t = int(3 / config.mp_args.dt) * 2 + 1 47 | 48 | # Get parameters 49 | torch.manual_seed(0) 50 | 51 | # initial position 52 | init_pos = 5 * torch.ones([num_traj, config.num_dof]) 53 | 54 | params = torch.randn([num_traj, num_param]) * params_scale_factor 55 | # params = torch.ones([num_traj, num_param]) * params_scale_factor 56 | 57 | if "dmp" in config.mp_type: 58 | params[:, config.mp_args.num_basis::config.mp_args.num_basis+1] *= 0.001 59 | if relative_goal: 60 | params[:, config.mp_args.num_basis::config.mp_args.num_basis+1] -= \ 61 | init_pos 62 | 63 | 64 | if config.learn_delay: 65 | torch.manual_seed(0) 66 | delay = torch.rand([num_traj, 1]) 67 | params = torch.cat([delay, params], dim=-1) 68 | else: 69 | delay = 0 70 | 71 | if config.learn_tau: 72 | torch.manual_seed(0) 73 | tau = torch.rand([num_traj, 1]) + 4 74 | params = torch.cat([tau, params], dim=-1) 75 | times = util.tensor_linspace(0, tau + delay, num_t).squeeze(-1) 76 | else: 77 | times = util.tensor_linspace(0, torch.ones([num_traj, 1]) * config.tau 78 | + delay, num_t).squeeze(-1) 79 | 80 | lct = torch.distributions.transforms.LowerCholeskyTransform(cache_size=0) 81 | torch.manual_seed(0) 82 | params_L = lct(torch.randn([num_traj, num_param, num_param])) \ 83 | * params_L_scale_factor 84 | 85 | init_time = times[:, 0] 86 | 87 | if config.learn_delay: 88 | init_vel = torch.zeros_like(init_pos) 89 | else: 90 | init_vel = -5 * torch.ones([num_traj, config.num_dof]) 91 | 92 | demos = torch.zeros([*times.shape, config.num_dof]) 93 | for i in range(config.num_dof): 94 | demos[..., i] = torch.sin(2 * times + i) + 5 95 | 96 | return config.to_dict(), times, params, params_L, init_time, init_pos, \ 97 | init_vel, demos 98 | -------------------------------------------------------------------------------- /mp_pytorch/demo/demo_prodmp.py: -------------------------------------------------------------------------------- 1 | """ 2 | @brief: testing MPs 3 | """ 4 | 5 | import torch 6 | from matplotlib import pyplot as plt 7 | 8 | import mp_pytorch.util as util 9 | from mp_pytorch.demo import get_mp_utils 10 | from mp_pytorch.mp import MPFactory 11 | from mp_pytorch.mp import ProDMP 12 | 13 | 14 | def test_prodmp(): 15 | util.print_wrap_title("test_prodmp") 16 | config, times, params, params_L, init_time, init_pos, init_vel, demos = \ 17 | get_mp_utils("prodmp", True, True, False) 18 | 19 | mp = MPFactory.init_mp(**config) 20 | mp.update_inputs(times=times, params=params, params_L=params_L, 21 | init_time=init_time, init_pos=init_pos, init_vel=init_vel) 22 | assert isinstance(mp, ProDMP) 23 | traj_dict = mp.get_trajs(get_pos=True, get_pos_cov=True, 24 | get_pos_std=True, get_vel=True, 25 | get_vel_cov=True, get_vel_std=True) 26 | # Pos 27 | util.print_line_title("pos") 28 | print(traj_dict["pos"].shape) 29 | util.debug_plot(times[0], [traj_dict["pos"][0, :, 0]], title="prodmp_pos") 30 | 31 | # Pos_cov 32 | util.print_line_title("pos_cov") 33 | pass 34 | 35 | # Pos_std 36 | util.print_line_title("pos_std") 37 | plt.figure() 38 | util.fill_between(times[0], traj_dict["pos"][0, :, 0], 39 | traj_dict["pos_std"][0, :, 0], draw_mean=True) 40 | plt.title("prodmp pos std") 41 | plt.show() 42 | 43 | # Vel 44 | util.print_line_title("vel") 45 | util.debug_plot(times[0], [traj_dict["vel"][0, :, 0]], title="prodmp_vel") 46 | 47 | # Vel_cov 48 | util.print_line_title("vel_cov") 49 | pass 50 | 51 | # Vel_std 52 | util.print_line_title("vel_std") 53 | plt.figure() 54 | print("traj_dict[vel_std].shape", traj_dict["vel_std"].shape) 55 | util.fill_between(times[0], traj_dict["vel"][0, :, 0], 56 | traj_dict["vel_std"][0, :, 0], draw_mean=True) 57 | plt.title("prodmp vel std") 58 | plt.show() 59 | 60 | # Sample trajectories 61 | util.print_line_title("sample trajectories") 62 | num_smp = 50 63 | samples, samples_vel = mp.sample_trajectories(num_smp=num_smp) 64 | print("samples.shape", samples.shape) 65 | util.debug_plot(times[0], [samples[0, i, :, 0] for i in range(num_smp)], 66 | title="prodmp_samples") 67 | 68 | # Parameters demo 69 | util.print_line_title("params_bounds") 70 | low, high = mp.get_params_bounds() 71 | print("Lower bound", low, sep="\n") 72 | print("Upper bound", high, sep="\n") 73 | print(mp.get_params_bounds().shape) 74 | 75 | # Learn weights 76 | util.print_line_title("learn weights") 77 | config, times, params, params_L, init_time, init_pos, init_vel, demos = \ 78 | get_mp_utils("prodmp", False, False, True) 79 | 80 | mp = MPFactory.init_mp(**config) 81 | params_dict = mp.learn_mp_params_from_trajs(times, demos) 82 | 83 | # Reconstruct demos using learned weights 84 | rec_demo = mp.get_traj_pos(times, **params_dict) 85 | util.debug_plot(times[0], [demos[0, :, 0], rec_demo[0, :, 0]], 86 | labels=["demos", "rec_demos"], 87 | title="ProDMP demos vs. rec_demos") 88 | 89 | des_init_pos = torch.zeros_like(demos[:, 0]) - 0.25 + 5 90 | des_init_vel = torch.zeros_like(demos[:, 0]) 91 | 92 | params_dict = \ 93 | mp.learn_mp_params_from_trajs(times, demos, init_time=times[:, 0], 94 | init_pos=des_init_pos, init_vel=des_init_vel) 95 | 96 | # Reconstruct demos using learned weights 97 | rec_demo = mp.get_traj_pos(times, **params_dict) 98 | util.debug_plot(times[0], [demos[0, :, 0], rec_demo[0, :, 0]], 99 | labels=["demos", "rec_demos"], 100 | title="ProDMP demos vs. rec_demos") 101 | 102 | # Show scaled basis 103 | mp.show_scaled_basis(plot=True) 104 | 105 | 106 | def test_prodmp_disable_weights(): 107 | util.print_wrap_title("test_prodmp_disable_weights") 108 | learn_tau = True 109 | learn_delay = True 110 | 111 | config, times, params, _, init_time, init_pos, init_vel, demos = \ 112 | get_mp_utils("prodmp", learn_tau, learn_delay) 113 | 114 | # Disable weights 115 | config["mp_args"]["disable_weights"] = True 116 | num_dof = config["num_dof"] 117 | add_dim = params.shape[:-1] 118 | goal = 2 119 | params = torch.ones([*add_dim, num_dof]) * goal 120 | if learn_delay: 121 | params = torch.cat([torch.ones([*add_dim, 1]) * 1, params], dim=-1) 122 | if learn_tau: 123 | params = torch.cat([torch.ones([*add_dim, 1]) * 3, params], dim=-1) 124 | 125 | mp = MPFactory.init_mp(**config) 126 | mp.update_inputs(times=times, params=params, params_L=None, 127 | init_time=init_time, init_pos=init_pos, init_vel=init_vel) 128 | traj_dict = mp.get_trajs(get_pos=True, get_pos_cov=False, 129 | get_pos_std=False, get_vel=True, 130 | get_vel_cov=False, get_vel_std=False) 131 | 132 | # Pos 133 | util.print_line_title("pos") 134 | print(traj_dict["pos"].shape) 135 | util.debug_plot(times[0], [traj_dict["pos"][0, :, 0]], 136 | title="prodmp_pos, disable weights") 137 | 138 | # Vel 139 | util.print_line_title("vel") 140 | util.debug_plot(times[0], [traj_dict["vel"][0, :, 0]], 141 | title="prodmp_vel, disable weights") 142 | 143 | 144 | def test_prodmp_disable_goal(): 145 | util.print_wrap_title("test_prodmp_disable_goals") 146 | learn_tau = True 147 | learn_delay = True 148 | relative_goal = True 149 | 150 | config, times, params, _, init_time, init_pos, init_vel, demos = \ 151 | get_mp_utils("prodmp", learn_tau, learn_delay, relative_goal) 152 | 153 | # Disable weights 154 | config["mp_args"]["disable_goal"] = True 155 | num_dof = config["num_dof"] 156 | add_dim = params.shape[:-1] 157 | goal = 2 158 | params = \ 159 | torch.ones([*add_dim, num_dof * config['mp_args']['num_basis']]) * 500 160 | 161 | if learn_delay: 162 | params = torch.cat([torch.ones([*add_dim, 1]) * 1, params], dim=-1) 163 | if learn_tau: 164 | params = torch.cat([torch.ones([*add_dim, 1]) * 3, params], dim=-1) 165 | 166 | mp = MPFactory.init_mp(**config) 167 | mp.update_inputs(times=times, params=params, params_L=None, 168 | init_time=init_time, init_pos=init_pos, init_vel=init_vel) 169 | traj_dict = mp.get_trajs(get_pos=True, get_pos_cov=False, 170 | get_pos_std=False, get_vel=True, 171 | get_vel_cov=False, get_vel_std=False) 172 | 173 | # Pos 174 | util.print_line_title("pos") 175 | print(traj_dict["pos"].shape) 176 | util.debug_plot(times[0], [traj_dict["pos"][0, :, 0]], 177 | title="prodmp_pos, disable goal") 178 | 179 | # Vel 180 | util.print_line_title("vel") 181 | util.debug_plot(times[0], [traj_dict["vel"][0, :, 0]], 182 | title="prodmp_vel, disable goal") 183 | 184 | 185 | def main(): 186 | # To suppress the warning message, uncomment the following lines 187 | # import logging 188 | # logging.basicConfig(level=logging.ERROR) 189 | test_prodmp() 190 | test_prodmp_disable_weights() 191 | test_prodmp_disable_goal() 192 | 193 | 194 | if __name__ == "__main__": 195 | main() 196 | -------------------------------------------------------------------------------- /mp_pytorch/demo/demo_prodmp_autoscale.py: -------------------------------------------------------------------------------- 1 | """ 2 | @breif: Demo of the ProDMPs with autoscaling. 3 | """ 4 | 5 | from matplotlib import pyplot as plt 6 | 7 | from mp_pytorch.demo import get_mp_utils 8 | from mp_pytorch.mp import MPFactory 9 | 10 | 11 | def test_prodmp_scaling(auto_scale=True, manual_w_scale=1., manual_g_scale=1.): 12 | config, time, params, params_L, init_time, init_pos, init_vel, demos = \ 13 | get_mp_utils("prodmp", True, True, relative_goal=True) 14 | config['mp_args']['auto_scale_basis'] = auto_scale 15 | config['mp_args']['weights_scale'] = manual_w_scale 16 | config['mp_args']['goal_scale'] = manual_g_scale 17 | mp = MPFactory.init_mp(**config) 18 | mp.show_scaled_basis(True) 19 | 20 | 21 | if __name__ == "__main__": 22 | test_prodmp_scaling(auto_scale=False, manual_w_scale=1., manual_g_scale=1.) 23 | test_prodmp_scaling(auto_scale=True, manual_w_scale=1., manual_g_scale=1.) 24 | test_prodmp_scaling(auto_scale=True, manual_w_scale=0.3, manual_g_scale=0.3) 25 | -------------------------------------------------------------------------------- /mp_pytorch/demo/demo_promp.py: -------------------------------------------------------------------------------- 1 | """ 2 | @brief: testing MPs 3 | """ 4 | 5 | from matplotlib import pyplot as plt 6 | 7 | import mp_pytorch.util as util 8 | from mp_pytorch.demo import get_mp_utils 9 | from mp_pytorch.mp import MPFactory 10 | from mp_pytorch.mp import ProMP 11 | 12 | 13 | def test_promp(): 14 | util.print_wrap_title("test_promp") 15 | 16 | config, times, params, params_L, init_time, init_pos, init_vel, demos = \ 17 | get_mp_utils("promp", True, True) 18 | 19 | mp = MPFactory.init_mp(**config) 20 | assert isinstance(mp, ProMP) 21 | mp.update_inputs(times=times, params=params, params_L=params_L, 22 | init_time=init_time, init_pos=init_pos, init_vel=init_vel) 23 | traj_dict = mp.get_trajs(get_pos=True, get_pos_cov=True, 24 | get_pos_std=True, get_vel=True, 25 | get_vel_cov=True, get_vel_std=True) 26 | # Pos 27 | util.print_line_title("pos") 28 | print("traj_dict[pos].shape", traj_dict["pos"].shape) 29 | util.debug_plot(times[0], [traj_dict["pos"][0, :, 0]], title="promp_mean") 30 | 31 | # Pos_cov 32 | util.print_line_title("pos_cov") 33 | pass 34 | 35 | # Pos_std 36 | util.print_line_title("pos_std") 37 | plt.figure() 38 | print("traj_dict[pos_std].shape", traj_dict["pos_std"].shape) 39 | util.fill_between(times[0], traj_dict["pos"][0, :, 0], 40 | traj_dict["pos_std"][0, :, 0], draw_mean=True) 41 | plt.title("promp std") 42 | plt.show() 43 | 44 | # Vel 45 | util.print_line_title("vel") 46 | print("traj_dict[vel].shape", traj_dict["vel"].shape) 47 | util.debug_plot(times[0], [traj_dict["vel"][0, :, 0]], 48 | title="promp_vel_mean") 49 | 50 | # Vel_cov 51 | util.print_line_title("vel_cov") 52 | assert traj_dict["vel_cov"] is None 53 | 54 | # Vel_std 55 | util.print_line_title("vel_std") 56 | assert traj_dict["vel_std"] is None 57 | 58 | # Sample trajectories 59 | util.print_line_title("sample trajectories") 60 | num_smp = 50 61 | samples, samples_vel = mp.sample_trajectories(num_smp=num_smp) 62 | print("samples.shape", samples.shape) 63 | util.debug_plot(times[0], [samples[0, i, :, 0] for i in range(num_smp)], 64 | title="promp_samples") 65 | 66 | # Parameters demo 67 | util.print_line_title("params_bounds") 68 | low, high = mp.get_params_bounds() 69 | print("Lower bound", low, sep="\n") 70 | print("Upper bound", high, sep="\n") 71 | print(mp.get_params_bounds().shape) 72 | 73 | # Learn weights 74 | util.print_line_title("learn weights") 75 | config, times, params, params_L, init_time, init_pos, init_vel, demos = \ 76 | get_mp_utils("promp", False, False) 77 | 78 | mp = MPFactory.init_mp(**config) 79 | mp.update_inputs(times=times, params=params, params_L=params_L, 80 | init_time=init_time, init_pos=init_pos, init_vel=init_vel) 81 | params_dict = mp.learn_mp_params_from_trajs(times, demos) 82 | # Reconstruct demos using learned weights 83 | rec_demo = mp.get_traj_pos(times, **params_dict) 84 | util.debug_plot(times[0], [demos[0, :, 0], rec_demo[0, :, 0]], 85 | labels=["demos", "rec_demos"], 86 | title="ProMP demos vs. rec_demos") 87 | 88 | # Show scaled basis 89 | mp.show_scaled_basis(plot=True) 90 | 91 | 92 | def test_zero_padding_promp(): 93 | util.print_wrap_title("test_zero_padding_promp") 94 | 95 | config, times, params, params_L, init_time, init_pos, init_vel, demos = \ 96 | get_mp_utils("zero_padding_promp", False, False) 97 | 98 | mp = MPFactory.init_mp(**config) 99 | assert isinstance(mp, ProMP) 100 | mp.update_inputs(times=times, params=params, params_L=params_L, 101 | init_time=init_time, init_pos=init_pos, init_vel=init_vel) 102 | 103 | # Pos 104 | util.print_line_title("zero padding pos") 105 | pos = mp.get_traj_pos() 106 | print("traj_dict[pos].shape", pos.shape) 107 | util.debug_plot(times[0], [pos[0, :, 0]], title="zero_promp_mean") 108 | 109 | # Vel 110 | util.print_line_title("zero padding vel") 111 | vel = mp.get_traj_vel() 112 | print("traj_dict[vel].shape", vel.shape) 113 | util.debug_plot(times[0], [vel[0, :, 0]], title="zero_promp_vel_mean") 114 | 115 | # Show scaled basis 116 | mp.show_scaled_basis(plot=True) 117 | 118 | 119 | def main(): 120 | test_promp() 121 | test_zero_padding_promp() 122 | 123 | 124 | if __name__ == "__main__": 125 | main() 126 | -------------------------------------------------------------------------------- /mp_pytorch/mp/__init__.py: -------------------------------------------------------------------------------- 1 | from .dmp import * 2 | from .mp_factory import * 3 | from .mp_interfaces import * 4 | from .prodmp import * 5 | from .promp import * 6 | -------------------------------------------------------------------------------- /mp_pytorch/mp/dmp.py: -------------------------------------------------------------------------------- 1 | """ 2 | @brief: Dynamic Movement Primitives in PyTorch 3 | """ 4 | from typing import Iterable 5 | from typing import Union 6 | from typing import Tuple 7 | 8 | import numpy as np 9 | import torch 10 | from mp_pytorch.util import to_nps 11 | from mp_pytorch.basis_gn import BasisGenerator 12 | from .mp_interfaces import MPInterface 13 | 14 | 15 | class DMP(MPInterface): 16 | """DMP in PyTorch""" 17 | 18 | def __init__(self, 19 | basis_gn: BasisGenerator, 20 | num_dof: int, 21 | weights_scale: Union[float, Iterable] = 1., 22 | goal_scale: float = 1., 23 | alpha: float = 25, 24 | dtype: torch.dtype = torch.float32, 25 | device: torch.device = 'cpu', 26 | **kwargs): 27 | """ 28 | Constructor of DMP 29 | Args: 30 | basis_gn: basis function value generator 31 | num_dof: number of Degrees of Freedoms 32 | weights_scale: scaling for the parameters weights 33 | goal_scale: scaling for the goal 34 | dtype: torch data type 35 | device: torch device to run on 36 | kwargs: keyword arguments 37 | """ 38 | 39 | super().__init__(basis_gn, num_dof, weights_scale, dtype, device, 40 | **kwargs) 41 | 42 | # Number of parameters 43 | self.num_basis_g = self.num_basis + 1 44 | 45 | # Control parameters 46 | self.alpha = alpha 47 | self.beta = self.alpha / 4 48 | 49 | # Goal scale 50 | self.goal_scale = goal_scale 51 | self.weights_goal_scale = self.get_weights_goal_scale() 52 | 53 | @property 54 | def _num_local_params(self) -> int: 55 | """ 56 | Returns: number of parameters of current class 57 | """ 58 | return super()._num_local_params + self.num_dof 59 | 60 | def get_weights_goal_scale(self) -> torch.Tensor: 61 | """ 62 | Returns: the weights and goal scaling vector 63 | """ 64 | w_g_scale = torch.zeros(self.num_basis_g, 65 | dtype=self.dtype, device=self.device) 66 | w_g_scale[:-1] = self.weights_scale 67 | w_g_scale[-1] = self.goal_scale 68 | return w_g_scale 69 | 70 | def set_initial_conditions(self, init_time: Union[torch.Tensor, np.ndarray], 71 | init_pos: Union[torch.Tensor, np.ndarray], 72 | init_vel: Union[torch.Tensor, np.ndarray]): 73 | """ 74 | Set initial conditions in a batched manner 75 | 76 | Args: 77 | init_time: initial condition time 78 | init_pos: initial condition position 79 | init_vel: initial condition velocity 80 | 81 | Returns: 82 | None 83 | """ 84 | # Shape of init_time: 85 | # [*add_dim] 86 | # 87 | # Shape of init_pos: 88 | # [*add_dim, num_dof] 89 | # 90 | # Shape of init_vel: 91 | # [*add_dim, num_dof] 92 | 93 | init_time = torch.as_tensor(init_time, dtype=self.dtype, device=self.device) 94 | init_pos = torch.as_tensor(init_pos, dtype=self.dtype, device=self.device) 95 | init_vel = torch.as_tensor(init_vel, dtype=self.dtype, device=self.device) 96 | 97 | assert list(init_time.shape) == [*self.add_dim], \ 98 | f"shape of initial condition time {list(init_time.shape)} " \ 99 | f"does not match batch dimension {[*self.add_dim]}" 100 | assert list(init_pos.shape) == list(init_vel.shape) \ 101 | and list(init_vel.shape) == [*self.add_dim, self.num_dof], \ 102 | f"shape of initial condition position {list(init_pos.shape)} " \ 103 | f"and initial condition velocity do not match {list(init_vel.shape)}" 104 | super().set_initial_conditions(init_time, init_pos, init_vel) 105 | 106 | def get_traj_pos(self, times=None, params=None, 107 | init_time=None, init_pos=None, init_vel=None): 108 | """ 109 | Compute trajectories at desired time points 110 | 111 | Refer setting functions for desired shape of inputs 112 | 113 | Args: 114 | times: time points 115 | params: learnable parameters 116 | init_time: initial condition time 117 | init_pos: initial condition position 118 | init_vel: initial condition velocity 119 | 120 | Returns: 121 | pos 122 | """ 123 | 124 | # Shape of pos 125 | # [*add_dim, num_times, num_dof] 126 | 127 | # Update inputs 128 | self.update_inputs(times, params, init_time, init_pos, init_vel) 129 | 130 | # Reuse result if existing 131 | if self.pos is not None: 132 | return self.pos 133 | 134 | # Check initial condition, the desired times should start from 135 | # initial condition time steps or plus dt 136 | if not torch.allclose(self.init_time, self.times[..., 0]): 137 | assert torch.allclose(self.times[..., 1] + self.init_time, 2 * self.times[..., 0]), \ 138 | f"The start time value {self.times[..., 1]} should be either init_time {self.init_time} or init_time + dt." 139 | times_include_init = torch.cat([self.init_time[..., None], self.times], dim=-1) 140 | 141 | # Recursively call itself 142 | self.get_traj_pos(times_include_init) 143 | 144 | # Remove the init_time from the result 145 | self.pos = self.pos[..., 1:, :] 146 | self.vel = self.vel[..., 1:, :] 147 | self.times = self.times[..., 1:] 148 | return self.pos 149 | 150 | # Scale basis functions 151 | weights_goal_scale = self.weights_goal_scale.repeat(self.num_dof) 152 | 153 | # Split weights and goal 154 | # Shape of w: 155 | # [*add_dim, num_dof, num_basis] 156 | # Shape of g: 157 | # [*add_dim, num_dof, 1] 158 | w, g = self._split_weights_goal(self.params * weights_goal_scale) 159 | 160 | # Get basis, shape [*add_dim, num_times, num_basis] 161 | basis = self.basis_gn.basis(self.times) 162 | 163 | # Get canonical phase x, shape [*add_dim, num_times] 164 | canonical_x = self.phase_gn.phase(self.times) 165 | 166 | # Get forcing function 167 | # Einsum shape: [*add_dim, num_times] 168 | # [*add_dim, num_times, num_basis] 169 | # [*add_dim, num_dof, num_basis] 170 | # -> [*add_dim, num_times, num_dof] 171 | f = torch.einsum('...i,...ik,...jk->...ij', canonical_x, basis, w) 172 | 173 | # Initialize trajectory position, velocity 174 | pos = torch.zeros([*self.add_dim, self.times.shape[-1], self.num_dof], 175 | dtype=self.dtype, device=self.device) 176 | vel = torch.zeros([*self.add_dim, self.times.shape[-1], self.num_dof], 177 | dtype=self.dtype, device=self.device) 178 | 179 | pos[..., 0, :] = self.init_pos 180 | vel[..., 0, :] = self.init_vel * self.phase_gn.tau[..., None] 181 | 182 | # Get scaled time increment steps 183 | scaled_times = self.phase_gn.left_bound_linear_phase(self.times) 184 | scaled_dt = torch.diff(scaled_times, dim=-1) 185 | 186 | # Apply Euler Integral 187 | for i in range(scaled_dt.shape[-1]): 188 | acc = (self.alpha * (self.beta * (g - pos[..., i, :]) 189 | - vel[..., i, :]) + f[..., i, :]) 190 | vel[..., i + 1, :] = \ 191 | vel[..., i, :] + torch.einsum('...,...i->...i', 192 | scaled_dt[..., i], acc) 193 | pos[..., i + 1, :] = \ 194 | pos[..., i, :] + torch.einsum('...,...i->...i', 195 | scaled_dt[..., i], 196 | vel[..., i + 1, :]) 197 | 198 | # Unscale velocity to original time space 199 | vel /= self.phase_gn.tau[..., None, None] 200 | 201 | # Store pos and vel 202 | self.pos = pos 203 | self.vel = vel 204 | 205 | return pos 206 | 207 | def get_traj_vel(self, times=None, params=None, 208 | init_time=None, init_pos=None, init_vel=None): 209 | """ 210 | Get trajectory velocity 211 | 212 | Refer setting functions for desired shape of inputs 213 | 214 | Args: 215 | times: time points, can be None 216 | params: learnable parameters, can be None 217 | init_time: initial condition time 218 | init_pos: initial condition position 219 | init_vel: initial condition velocity 220 | 221 | Returns: 222 | vel 223 | """ 224 | 225 | # Shape of vel 226 | # [*add_dim, num_times, num_dof] 227 | 228 | # Update inputs 229 | self.update_inputs(times, params, init_time, init_pos, init_vel) 230 | 231 | # Reuse result if existing 232 | if self.vel is not None: 233 | return self.vel 234 | 235 | # Recompute otherwise 236 | # Velocity is computed together with position in DMP 237 | self.get_traj_pos() 238 | return self.vel 239 | 240 | def learn_mp_params_from_trajs(self, times: torch.Tensor, 241 | trajs: torch.Tensor, 242 | reg: float = 1e-9): 243 | raise NotImplementedError 244 | 245 | def _split_weights_goal(self, wg): 246 | """ 247 | Helper function to split weights and goal 248 | 249 | Args: 250 | wg: vector storing weights and goal 251 | 252 | Returns: 253 | w: weights 254 | g: goal 255 | 256 | """ 257 | # Shape of wg: 258 | # [*add_dim, num_dof * num_basis_g] 259 | # 260 | # Shape of w: 261 | # [*add_dim, num_dof, num_basis] 262 | # 263 | # Shape of g: 264 | # [*add_dim, num_dof, 1] 265 | 266 | wg = wg.reshape([*wg.shape[:-1], self.num_dof, self.num_basis_g]) 267 | w = wg[..., :-1] 268 | g = wg[..., -1] 269 | 270 | return w, g 271 | 272 | def _show_scaled_basis(self, plot=False) \ 273 | -> Tuple[torch.Tensor, torch.Tensor]: 274 | tau = self.phase_gn.tau 275 | delay = self.phase_gn.delay 276 | assert tau.ndim == 0 and delay.ndim == 0 277 | times = torch.linspace(delay - tau, delay + 2 * tau, steps=1000, 278 | device=self.device, dtype=self.dtype) 279 | self.set_add_dim([]) 280 | self.set_times(times) 281 | 282 | weights_scale = self.weights_scale 283 | canonical_x = self.phase_gn.phase(self.times) 284 | # Get basis 285 | # Shape: [*add_dim, num_times, num_basis] 286 | basis_values = self.basis_gn.basis(times) * canonical_x[..., None] * weights_scale 287 | 288 | # Enforce all variables to numpy 289 | times, basis_values, delay, tau = \ 290 | to_nps(times, basis_values, delay, tau) 291 | 292 | if plot: 293 | import matplotlib.pyplot as plt 294 | plt.figure() 295 | for i in range(basis_values.shape[-1]): 296 | plt.plot(times, basis_values[:, i], label=f"basis_{i}") 297 | plt.grid() 298 | plt.legend() 299 | plt.axvline(x=delay, linestyle='--', color='k', alpha=0.3) 300 | plt.axvline(x=delay + tau, linestyle='--', color='k', alpha=0.3) 301 | plt.show() 302 | return times, basis_values 303 | -------------------------------------------------------------------------------- /mp_pytorch/mp/mp_factory.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from mp_pytorch.basis_gn import NormalizedRBFBasisGenerator 4 | from mp_pytorch.basis_gn import ProDMPBasisGenerator 5 | from mp_pytorch.basis_gn import ZeroPaddingNormalizedRBFBasisGenerator 6 | from mp_pytorch.phase_gn import ExpDecayPhaseGenerator 7 | from mp_pytorch.phase_gn import LinearPhaseGenerator 8 | from .dmp import DMP 9 | from .prodmp import ProDMP 10 | from .promp import ProMP 11 | 12 | 13 | class MPFactory: 14 | @staticmethod 15 | def init_mp(mp_type: str, 16 | mp_args: dict, 17 | num_dof: int = 1, 18 | tau: float = 3, 19 | delay: float = 0, 20 | learn_tau: bool = False, 21 | learn_delay: bool = False, 22 | dtype: torch.dtype = torch.float32, 23 | device: torch.device = "cpu"): 24 | """ 25 | This is a helper class to initialize MPs, 26 | You can also directly initialize the MPs without using this class 27 | 28 | Create an MP instance given configuration 29 | 30 | Args: 31 | mp_type: type of movement primitives 32 | mp_args: arguments to a specific mp, refer each MP class 33 | num_dof: the number of degree of freedoms 34 | tau: default length of the trajectory 35 | delay: default delay before executing the trajectory 36 | learn_tau: if the length is a learnable parameter 37 | learn_delay: if the delay is a learnable parameter 38 | dtype: data type of the torch tensor 39 | device: device of the torch tensor 40 | 41 | 42 | Returns: 43 | MP instance 44 | """ 45 | 46 | # Get phase generator 47 | if mp_type == "promp": 48 | phase_gn = LinearPhaseGenerator(tau=tau, delay=delay, 49 | learn_tau=learn_tau, 50 | learn_delay=learn_delay, 51 | dtype=dtype, device=device) 52 | basis_gn = NormalizedRBFBasisGenerator( 53 | phase_generator=phase_gn, 54 | num_basis=mp_args["num_basis"], 55 | basis_bandwidth_factor=mp_args["basis_bandwidth_factor"], 56 | num_basis_outside=mp_args["num_basis_outside"], 57 | dtype=dtype, device=device) 58 | mp = ProMP(basis_gn=basis_gn, num_dof=num_dof, dtype=dtype, 59 | device=device, **mp_args) 60 | 61 | elif mp_type == 'zero_padding_promp': 62 | phase_gn = LinearPhaseGenerator(tau=tau, 63 | learn_tau=learn_tau, 64 | learn_delay=learn_delay, 65 | dtype=dtype, device=device) 66 | basis_gn = ZeroPaddingNormalizedRBFBasisGenerator( 67 | phase_generator=phase_gn, 68 | num_basis=mp_args["num_basis"], 69 | num_basis_zero_start=mp_args['num_basis_zero_start'], 70 | num_basis_zero_goal=mp_args['num_basis_zero_goal'], 71 | basis_bandwidth_factor=mp_args["basis_bandwidth_factor"], 72 | dtype=dtype, device=device 73 | ) 74 | mp = ProMP(basis_gn=basis_gn, num_dof=num_dof, dtype=dtype, 75 | device=device, **mp_args) 76 | 77 | elif mp_type == "dmp": 78 | phase_gn = ExpDecayPhaseGenerator(tau=tau, delay=delay, 79 | learn_tau=learn_tau, 80 | learn_delay=learn_delay, 81 | alpha_phase=mp_args[ 82 | "alpha_phase"], 83 | dtype=dtype, device=device) 84 | basis_gn = NormalizedRBFBasisGenerator( 85 | phase_generator=phase_gn, 86 | num_basis=mp_args["num_basis"], 87 | basis_bandwidth_factor=mp_args["basis_bandwidth_factor"], 88 | num_basis_outside=mp_args["num_basis_outside"], 89 | dtype=dtype, device=device) 90 | mp = DMP(basis_gn=basis_gn, num_dof=num_dof, dtype=dtype, 91 | device=device, **mp_args) 92 | elif mp_type == "prodmp": 93 | phase_gn = ExpDecayPhaseGenerator(tau=tau, delay=delay, 94 | learn_tau=learn_tau, 95 | learn_delay=learn_delay, 96 | alpha_phase=mp_args[ 97 | "alpha_phase"], 98 | dtype=dtype, device=device) 99 | basis_gn = ProDMPBasisGenerator( 100 | phase_generator=phase_gn, 101 | num_basis=mp_args["num_basis"], 102 | basis_bandwidth_factor=mp_args["basis_bandwidth_factor"], 103 | num_basis_outside=mp_args["num_basis_outside"], 104 | dt=mp_args["dt"], 105 | alpha=mp_args["alpha"], 106 | dtype=dtype, device=device) 107 | mp = ProDMP(basis_gn=basis_gn, num_dof=num_dof, dtype=dtype, 108 | device=device, **mp_args) 109 | else: 110 | raise NotImplementedError 111 | 112 | return mp 113 | -------------------------------------------------------------------------------- /mp_pytorch/mp/mp_interfaces.py: -------------------------------------------------------------------------------- 1 | """ 2 | @brief: Movement Primitives interfaces in PyTorch 3 | """ 4 | import copy 5 | from abc import ABC 6 | from abc import abstractmethod 7 | from typing import Iterable 8 | from typing import Optional 9 | from typing import Union 10 | 11 | import numpy as np 12 | import torch 13 | from torch.distributions import MultivariateNormal 14 | 15 | import mp_pytorch.util as util 16 | from mp_pytorch.basis_gn import BasisGenerator 17 | 18 | 19 | class MPInterface(ABC): 20 | @abstractmethod 21 | def __init__(self, 22 | basis_gn: BasisGenerator, 23 | num_dof: int, 24 | weights_scale: Union[float, Iterable] = 1., 25 | dtype: torch.dtype = torch.float32, 26 | device: torch.device = 'cpu', 27 | **kwargs): 28 | """ 29 | Constructor interface 30 | Args: 31 | basis_gn: basis generator 32 | num_dof: number of dof 33 | weights_scale: scaling for the parameters weights 34 | dtype: torch.dtype = torch.float32, 35 | device: torch.device = 'cpu', 36 | **kwargs: keyword arguments 37 | """ 38 | self.dtype = dtype 39 | self.device = device 40 | 41 | # Additional batch dimension 42 | self.add_dim = list() 43 | 44 | # The basis generators 45 | self.basis_gn = basis_gn 46 | 47 | # Number of DoFs 48 | self.num_dof = num_dof 49 | 50 | # Scaling of weights 51 | self.weights_scale = \ 52 | torch.as_tensor(weights_scale, dtype=self.dtype, device=self.device) 53 | assert self.weights_scale.ndim <= 1, \ 54 | "weights_scale should be float or 1-dim vector" 55 | 56 | # Value caches 57 | # Compute values at these time points 58 | self.times = None 59 | 60 | # Learnable parameters 61 | self.params = None 62 | 63 | # Initial conditions 64 | self.init_time = None 65 | self.init_pos = None 66 | self.init_vel = None 67 | 68 | # Runtime computation results, shall be reset every time when 69 | # inputs are reset 70 | self.pos = None 71 | self.vel = None 72 | 73 | # Flag of if the MP instance is finalized 74 | self.is_finalized = False 75 | 76 | # Local parameters bound 77 | self.local_params_bound = kwargs.get("params_bound", None) 78 | if not self.local_params_bound: 79 | self.local_params_bound = torch.zeros([2, self._num_local_params], 80 | dtype=self.dtype, 81 | device=self.device) 82 | self.local_params_bound[0, :] = -torch.inf 83 | self.local_params_bound[1, :] = torch.inf 84 | else: 85 | self.local_params_bound = torch.as_tensor(self.local_params_bound, 86 | dtype=self.dtype, 87 | device=self.device) 88 | assert list(self.local_params_bound.shape) == [2, 89 | self._num_local_params] 90 | 91 | @property 92 | def learn_tau(self): 93 | return self.phase_gn.learn_tau 94 | 95 | @property 96 | def learn_delay(self): 97 | return self.phase_gn.learn_delay 98 | 99 | @property 100 | def tau(self): 101 | return self.phase_gn.tau 102 | 103 | @property 104 | def num_basis(self): 105 | return self.basis_gn.num_basis 106 | 107 | @property 108 | def phase_gn(self): 109 | return self.basis_gn.phase_generator 110 | 111 | def clear_computation_result(self): 112 | """ 113 | Clear runtime computation result 114 | 115 | Returns: 116 | None 117 | """ 118 | 119 | self.pos = None 120 | self.vel = None 121 | 122 | def set_add_dim(self, add_dim: Union[list, torch.Size]): 123 | """ 124 | Set additional batch dimension 125 | Args: 126 | add_dim: additional batch dimension 127 | 128 | Returns: None 129 | 130 | """ 131 | self.add_dim = add_dim 132 | self.clear_computation_result() 133 | 134 | def set_times(self, times: Union[torch.Tensor, np.ndarray]): 135 | """ 136 | Set MP time points 137 | Args: 138 | times: desired time points 139 | 140 | Returns: 141 | None 142 | """ 143 | 144 | # Shape of times 145 | # [*add_dim, num_times] 146 | 147 | self.times = torch.as_tensor(times, dtype=self.dtype, 148 | device=self.device) 149 | self.clear_computation_result() 150 | 151 | def set_duration(self, duration: Optional[float], dt: float, 152 | include_init_time: bool = False): 153 | """ 154 | Set MP time points of a duration. The times start from init_time or 0 155 | 156 | Args: 157 | duration: desired duration of trajectory 158 | dt: control frequency 159 | include_init_time: if the duration includes the bc time step. 160 | Returns: 161 | None 162 | """ 163 | 164 | # Shape of times 165 | # [*add_dim, num_times] 166 | 167 | if duration is None: 168 | duration = round(self.tau.max().item() / dt) * dt 169 | 170 | # dt = torch.as_tensor(dt, dtype=self.dtype, device=self.device) 171 | times = torch.linspace(0, duration, round(duration / dt) + 1, 172 | dtype=self.dtype, device=self.device) 173 | times = util.add_expand_dim(times, list(range(len(self.add_dim))), 174 | self.add_dim) 175 | 176 | if self.init_time is not None: 177 | times = times + self.init_time[..., None] 178 | if include_init_time: 179 | self.set_times(times) 180 | else: 181 | self.set_times(times[..., 1:]) 182 | 183 | def set_params(self, 184 | params: Union[torch.Tensor, np.ndarray]) -> torch.Tensor: 185 | """ 186 | Set MP params 187 | Args: 188 | params: parameters 189 | 190 | Returns: unused parameters 191 | 192 | """ 193 | 194 | # Shape of params 195 | # [*add_dim, num_params] 196 | 197 | params = torch.as_tensor(params, dtype=self.dtype, device=self.device) 198 | 199 | # Check number of params 200 | assert params.shape[-1] == self.num_params 201 | 202 | # Set additional batch size 203 | self.set_add_dim(list(params.shape[:-1])) 204 | 205 | remaining_params = self.basis_gn.set_params(params) 206 | self.params = remaining_params[..., :self._num_local_params] 207 | self.clear_computation_result() 208 | return remaining_params[..., self._num_local_params:] 209 | 210 | def set_initial_conditions(self, init_time: Union[torch.Tensor, np.ndarray], 211 | init_pos: Union[torch.Tensor, np.ndarray], 212 | init_vel: Union[torch.Tensor, np.ndarray]): 213 | """ 214 | Set initial conditions in a batched manner 215 | 216 | Args: 217 | init_time: initial condition time 218 | init_pos: initial condition position 219 | init_vel: initial condition velocity 220 | 221 | Returns: 222 | None 223 | """ 224 | 225 | # Shape of init_time: 226 | # [*add_dim] 227 | # 228 | # Shape of init_pos: 229 | # [*add_dim, num_dof] 230 | # 231 | # Shape of init_vel: 232 | # [*add_dim, num_dof] 233 | 234 | self.init_time = torch.as_tensor(init_time, dtype=self.dtype, 235 | device=self.device) 236 | self.init_pos = torch.as_tensor(init_pos, dtype=self.dtype, 237 | device=self.device) 238 | init_vel = torch.as_tensor(init_vel, dtype=self.dtype, device=self.device) 239 | 240 | # If velocity is non-zero, then cannot wait 241 | if torch.count_nonzero(init_vel) != 0: 242 | assert torch.all(self.init_time - self.phase_gn.delay >= 0), \ 243 | f"Cannot set non-zero initial velocity {init_vel} if initial condition time" \ 244 | f"value(s) {self.init_time} is (are) smaller than delay value(s) {self.phase_gn.delay}" 245 | self.init_vel = init_vel 246 | self.clear_computation_result() 247 | 248 | def update_inputs(self, times=None, params=None, 249 | init_time=None, init_pos=None, init_vel=None, **kwargs): 250 | """ 251 | Update MP 252 | Args: 253 | times: desired time points 254 | params: parameters 255 | init_time: initial condition time 256 | init_pos: initial condition position 257 | init_vel: initial condition velocity 258 | kwargs: other keyword arguments 259 | 260 | Returns: None 261 | 262 | """ 263 | if params is not None: 264 | self.set_params(params) 265 | if times is not None: 266 | self.set_times(times) 267 | if all([data is not None for data in {init_time, init_pos, init_vel}]): 268 | self.set_initial_conditions(init_time, init_pos, init_vel) 269 | 270 | def get_params(self) -> torch.Tensor: 271 | """ 272 | Return all learnable parameters 273 | Returns: 274 | parameters 275 | """ 276 | # Shape of params 277 | # [*add_dim, num_params] 278 | params = self.basis_gn.get_params() 279 | params = torch.cat([params, self.params], dim=-1) 280 | return params 281 | 282 | def get_params_bounds(self) -> torch.Tensor: 283 | """ 284 | Return all learnable parameters' bounds 285 | Returns: 286 | parameters bounds 287 | """ 288 | # Shape of params_bounds 289 | # [2, num_params] 290 | 291 | params_bounds = self.basis_gn.get_params_bounds() 292 | params_bounds = torch.cat([params_bounds, self.local_params_bound], 293 | dim=1) 294 | return params_bounds 295 | 296 | def get_trajs(self, get_pos: bool = True, get_vel: bool = True) -> dict: 297 | """ 298 | Get movement primitives trajectories given flag 299 | Args: 300 | get_pos: True if pos shall be computed 301 | get_vel: True if vel shall be computed 302 | 303 | Returns: 304 | results in dictionary 305 | """ 306 | 307 | # Initialize result dictionary 308 | result = dict() 309 | 310 | # Position 311 | result["pos"] = self.get_traj_pos() if get_pos else None 312 | 313 | # Velocity 314 | result["vel"] = self.get_traj_vel() if get_vel else None 315 | 316 | # Return 317 | return result 318 | 319 | @property 320 | def _num_local_params(self) -> int: 321 | """ 322 | Returns: number of parameters of current class 323 | """ 324 | return self.num_basis * self.num_dof 325 | 326 | @property 327 | def num_params(self) -> int: 328 | """ 329 | Returns: number of parameters of current class plus parameters of all 330 | attributes 331 | """ 332 | return self._num_local_params + self.basis_gn.num_params 333 | 334 | @abstractmethod 335 | def get_traj_pos(self, times=None, params=None, 336 | init_time=None, init_pos=None, init_vel=None): 337 | """ 338 | Get trajectory position 339 | Args: 340 | times: time points 341 | params: learnable parameters 342 | init_time: initial condition time 343 | init_pos: initial condition position 344 | init_vel: initial condition velocity 345 | 346 | Returns: 347 | pos 348 | """ 349 | pass 350 | 351 | @abstractmethod 352 | def get_traj_vel(self, times=None, params=None, 353 | init_time=None, init_pos=None, init_vel=None): 354 | """ 355 | Get trajectory velocity 356 | 357 | Args: 358 | times: time points 359 | params: learnable parameters 360 | init_time: initial condition time 361 | init_pos: initial condition position 362 | init_vel: initial condition velocity 363 | 364 | Returns: vel 365 | """ 366 | pass 367 | 368 | @abstractmethod 369 | def learn_mp_params_from_trajs(self, times: torch.Tensor, 370 | trajs: torch.Tensor, 371 | reg=1e-9): 372 | """ 373 | Learn params from trajectories 374 | 375 | Args: 376 | times: time points of the trajectories 377 | trajs: demonstration trajectories 378 | reg: regularization term of linear ridge regression 379 | 380 | Returns: 381 | learned parameters 382 | """ 383 | pass 384 | 385 | def finalize(self): 386 | """ 387 | Mark the MP as finalized so that the parameters cannot be 388 | updated any more 389 | Returns: None 390 | 391 | """ 392 | self.is_finalized = True 393 | 394 | def reset(self): 395 | """ 396 | Unmark the finalization 397 | Returns: None 398 | 399 | """ 400 | self.basis_gn.reset() 401 | self.is_finalized = False 402 | 403 | @abstractmethod 404 | def _show_scaled_basis(self, *args, **kwargs): 405 | pass 406 | 407 | def show_scaled_basis(self, plot=False): 408 | """ 409 | External call of show basis, it will make a hard copy of the current mp, 410 | and feed artificial time sequence. 411 | 412 | The current mp will not get influenced. 413 | 414 | Args: 415 | plot: if to plot the basis 416 | 417 | Returns: 418 | 419 | """ 420 | # Make a hard copy to show basis and do not change other settings of the 421 | # original mp instance 422 | try: 423 | copied_mp = copy.deepcopy(self) 424 | except RuntimeError: 425 | print("Please do not use this function during NN training. " 426 | "The deepcopy cannot work when there is a computation graph.") 427 | return 428 | return copied_mp._show_scaled_basis(plot) 429 | 430 | 431 | class ProbabilisticMPInterface(MPInterface): 432 | def __init__(self, 433 | basis_gn: BasisGenerator, 434 | num_dof: int, 435 | weights_scale: float = 1., 436 | dtype: torch.dtype = torch.float32, 437 | device: torch.device = 'cpu', 438 | **kwargs): 439 | """ 440 | Constructor interface 441 | Args: 442 | basis_gn: basis generator 443 | num_dof: number of dof 444 | weights_scale: scaling for the parameters weights 445 | dtype: torch data type 446 | device: torch device to run on 447 | **kwargs: keyword arguments 448 | """ 449 | 450 | super().__init__(basis_gn, num_dof, weights_scale, dtype, device, 451 | **kwargs) 452 | 453 | # Learnable parameters variance 454 | self.params_L = None 455 | 456 | # Runtime computation results, shall be reset every time when 457 | # inputs are reset 458 | self.pos_cov = None 459 | self.pos_std = None 460 | self.vel_cov = None 461 | self.vel_std = None 462 | 463 | def clear_computation_result(self): 464 | """ 465 | Clear runtime computation result 466 | 467 | Returns: 468 | None 469 | """ 470 | super().clear_computation_result() 471 | self.pos_cov = None 472 | self.pos_std = None 473 | self.vel_cov = None 474 | self.vel_std = None 475 | 476 | def set_mp_params_variances(self, 477 | params_L: Union[ 478 | torch.Tensor, None, np.ndarray]): 479 | """ 480 | Set variance of MP params 481 | Args: 482 | params_L: cholesky of covariance matrix of the MP parameters 483 | 484 | Returns: None 485 | 486 | """ 487 | # Shape of params_L 488 | # [*add_dim, num_mp_params, num_mp_params] 489 | 490 | self.params_L = torch.as_tensor( 491 | params_L) if params_L is not None else params_L 492 | self.clear_computation_result() 493 | 494 | def update_inputs(self, times=None, params=None, params_L=None, 495 | init_time=None, init_pos=None, init_vel=None, **kwargs): 496 | """ 497 | Set MP 498 | Args: 499 | times: desired time points 500 | params: parameters 501 | params_L: cholesky of covariance matrix of the MP parameters 502 | init_time: initial condition time 503 | init_pos: initial condition position 504 | init_vel: initial condition velocity 505 | kwargs: other keyword arguments 506 | 507 | Returns: None 508 | 509 | """ 510 | super().update_inputs(times, params, init_time, init_pos, init_vel) 511 | if params_L is not None: 512 | self.set_mp_params_variances(params_L) 513 | 514 | @property 515 | def params_cov(self): 516 | """ 517 | Compute params covariance using params_L 518 | Returns: 519 | covariance matrix of parameters 520 | """ 521 | assert self.params_L is not None 522 | params_cov = torch.einsum('...ij,...kj->...ik', 523 | self.params_L, 524 | self.params_L) 525 | return params_cov 526 | 527 | def get_trajs(self, get_pos=True, get_pos_cov=True, get_pos_std=True, 528 | get_vel=True, get_vel_cov=True, get_vel_std=True, 529 | flat_shape=False, reg: float = 1e-4): 530 | """ 531 | Get movement primitives trajectories given flag 532 | Args: 533 | get_pos: True if pos shall be computed 534 | get_vel: True if vel shall be computed 535 | get_pos_cov: True if pos_cov shall be computed 536 | get_pos_std: True if pos_std shall be computed 537 | get_vel_cov: True if vel_cov shall be computed 538 | get_vel_std: True if vel_std shall be computed 539 | flat_shape: if flatten the dimensions of Dof and time 540 | reg: regularization term 541 | 542 | Returns: 543 | results in dictionary 544 | """ 545 | # Initialize result dictionary 546 | result = dict() 547 | 548 | # pos 549 | result["pos"] = self.get_traj_pos( 550 | flat_shape=flat_shape) if get_pos else None 551 | 552 | # vel 553 | result["vel"] = self.get_traj_vel( 554 | flat_shape=flat_shape) if get_vel else None 555 | 556 | # pos_cov 557 | result["pos_cov"] = self.get_traj_pos_cov( 558 | reg=reg) if get_pos_cov else None 559 | 560 | # pos_std 561 | result["pos_std"] = self.get_traj_pos_std(flat_shape=flat_shape, 562 | reg=reg) if get_pos_std else None 563 | 564 | # vel_cov 565 | result["vel_cov"] = self.get_traj_vel_cov( 566 | reg=reg) if get_vel_cov else None 567 | 568 | # vel_std 569 | result["vel_std"] = self.get_traj_vel_std(flat_shape=flat_shape, 570 | reg=reg) if get_vel_std else None 571 | 572 | return result 573 | 574 | @abstractmethod 575 | def get_traj_pos(self, times=None, params=None, 576 | init_time=None, init_pos=None, init_vel=None, 577 | flat_shape=False): 578 | """ 579 | Get trajectory position 580 | Args: 581 | times: time points 582 | params: learnable parameters 583 | init_time: initial condition time 584 | init_pos: initial condition position 585 | init_vel: initial condition velocity 586 | flat_shape: if flatten the dimensions of Dof and time 587 | 588 | Returns: 589 | pos 590 | """ 591 | pass 592 | 593 | @abstractmethod 594 | def get_traj_pos_cov(self, times=None, params_L=None, 595 | init_time=None, init_pos=None, init_vel=None, 596 | reg: float = 1e-4): 597 | """ 598 | Get trajectory covariance 599 | Returns: cov 600 | 601 | Args: 602 | times: time points 603 | params_L: learnable parameters' variance 604 | init_time: initial condition time 605 | init_pos: initial condition position 606 | init_vel: initial condition velocity 607 | reg: regularization term 608 | 609 | Returns: 610 | pos cov 611 | """ 612 | pass 613 | 614 | @abstractmethod 615 | def get_traj_pos_std(self, times=None, params_L=None, 616 | init_time=None, init_pos=None, init_vel=None, 617 | flat_shape=False, reg: float = 1e-4): 618 | """ 619 | Get trajectory standard deviation 620 | Args: 621 | times: time points 622 | params_L: learnable parameters' variance 623 | init_time: initial condition time 624 | init_pos: initial condition position 625 | init_vel: initial condition velocity 626 | flat_shape: if flatten the dimensions of Dof and time 627 | reg: regularization term 628 | 629 | Returns: 630 | pos std 631 | """ 632 | pass 633 | 634 | @abstractmethod 635 | def get_traj_vel(self, times=None, params=None, 636 | init_time=None, init_pos=None, init_vel=None, 637 | flat_shape=False): 638 | """ 639 | Get trajectory velocity 640 | Returns: vel 641 | 642 | Args: 643 | times: time points 644 | params: learnable parameters 645 | init_time: initial condition time 646 | init_pos: initial condition position 647 | init_vel: initial condition velocity 648 | flat_shape: if flatten the dimensions of Dof and time 649 | 650 | Returns: 651 | vel 652 | """ 653 | pass 654 | 655 | @abstractmethod 656 | def get_traj_vel_cov(self, times=None, params_L=None, 657 | init_time=None, init_pos=None, init_vel=None, 658 | reg: float = 1e-4): 659 | """ 660 | Get trajectory covariance 661 | Args: 662 | times: time points 663 | params_L: learnable parameters' variance 664 | init_time: initial condition time 665 | init_pos: initial condition position 666 | init_vel: initial condition velocity 667 | reg: regularization term 668 | 669 | Returns: 670 | vel cov 671 | """ 672 | pass 673 | 674 | @abstractmethod 675 | def get_traj_vel_std(self, times=None, params_L=None, 676 | init_time=None, init_pos=None, init_vel=None, 677 | flat_shape=False, reg: float = 1e-4): 678 | """ 679 | Get trajectory standard deviation 680 | Args: 681 | times: time points 682 | params_L: learnable parameters' variance 683 | init_time: initial condition time 684 | init_pos: initial condition position 685 | init_vel: initial condition velocity 686 | flat_shape: if flatten the dimensions of Dof and time 687 | reg: regularization term 688 | 689 | Returns: 690 | vel std 691 | """ 692 | pass 693 | 694 | def sample_trajectories(self, times=None, params=None, params_L=None, 695 | init_time=None, init_pos=None, init_vel=None, 696 | num_smp=1, flat_shape=False): 697 | """ 698 | Sample trajectories from MP 699 | 700 | Args: 701 | times: time points 702 | params: learnable parameters 703 | params_L: learnable parameters' variance 704 | init_time: initial condition time 705 | init_pos: initial condition position 706 | init_vel: initial condition velocity 707 | num_smp: num of trajectories to be sampled 708 | flat_shape: if flatten the dimensions of Dof and time 709 | 710 | Returns: 711 | sampled trajectories 712 | """ 713 | 714 | # Shape of pos_smp 715 | # [*add_dim, num_smp, num_times, num_dof] 716 | # or [*add_dim, num_smp, num_dof * num_times] 717 | 718 | if all([data is None for data in {times, params, params_L, init_time, 719 | init_pos, init_vel}]): 720 | times = self.times 721 | params = self.params 722 | params_L = self.params_L 723 | init_time = self.init_time 724 | init_pos = self.init_pos 725 | init_vel = self.init_vel 726 | 727 | num_add_dim = params.ndim - 1 728 | 729 | # Add additional sample axis to time 730 | # Shape [*add_dim, num_smp, num_times] 731 | times_smp = util.add_expand_dim(times, [num_add_dim], [num_smp]) 732 | 733 | # Sample parameters, shape [num_smp, *add_dim, num_mp_params] 734 | params_smp = MultivariateNormal(loc=params, 735 | scale_tril=params_L, 736 | validate_args=False).rsample([num_smp]) 737 | 738 | # Switch axes to [*add_dim, num_smp, num_mp_params] 739 | params_smp = torch.einsum('i...j->...ij', params_smp) 740 | 741 | params_super = self.basis_gn.get_params() 742 | if params_super.nelement() != 0: 743 | params_super_smp = util.add_expand_dim(params_super, [-2], 744 | [num_smp]) 745 | params_smp = torch.cat([params_super_smp, params_smp], dim=-1) 746 | 747 | # Add additional sample axis to initial condition 748 | if init_time is not None: 749 | init_time_smp = util.add_expand_dim(init_time, [num_add_dim], [num_smp]) 750 | init_pos_smp = util.add_expand_dim(init_pos, [num_add_dim], [num_smp]) 751 | init_vel_smp = util.add_expand_dim(init_vel, [num_add_dim], [num_smp]) 752 | else: 753 | init_time_smp = None 754 | init_pos_smp = None 755 | init_vel_smp = None 756 | 757 | # Update inputs 758 | self.reset() 759 | self.update_inputs(times_smp, params_smp, None, 760 | init_time_smp, init_pos_smp, init_vel_smp) 761 | 762 | # Get sample trajectories 763 | pos_smp = self.get_traj_pos(flat_shape=flat_shape) 764 | vel_smp = self.get_traj_vel(flat_shape=flat_shape) 765 | 766 | # Recover old inputs 767 | if params_super.nelement() != 0: 768 | params = torch.cat([params_super, params], dim=-1) 769 | self.reset() 770 | self.update_inputs(times, params, None, init_time, init_pos, init_vel) 771 | 772 | return pos_smp, vel_smp 773 | -------------------------------------------------------------------------------- /mp_pytorch/mp/promp.py: -------------------------------------------------------------------------------- 1 | """ 2 | @brief: Probabilistic Movement Primitives in PyTorch 3 | """ 4 | import logging 5 | from typing import Iterable 6 | from typing import Union 7 | from typing import Tuple 8 | 9 | import numpy as np 10 | import torch 11 | from mp_pytorch.util import to_nps 12 | from mp_pytorch.basis_gn import BasisGenerator 13 | from .mp_interfaces import ProbabilisticMPInterface 14 | 15 | 16 | class ProMP(ProbabilisticMPInterface): 17 | """ProMP in PyTorch""" 18 | 19 | def __init__(self, 20 | basis_gn: BasisGenerator, 21 | num_dof: int, 22 | weights_scale: Union[float, Iterable] = 1., 23 | dtype: torch.dtype = torch.float32, 24 | device: torch.device = 'cpu', 25 | **kwargs): 26 | """ 27 | Constructor of ProMP 28 | Args: 29 | basis_gn: basis function value generator 30 | num_dof: number of Degrees of Freedoms 31 | weights_scale: scaling for the parameters weights 32 | dtype: torch data type 33 | device: torch device to run on 34 | **kwargs: keyword arguments 35 | """ 36 | 37 | super().__init__(basis_gn, num_dof, weights_scale, dtype, device, 38 | **kwargs) 39 | 40 | # Some legacy code for a smooth start from 0 41 | self.has_zero_padding = hasattr(self.basis_gn, 'num_basis_zero_start') 42 | if self.has_zero_padding: 43 | # if no zero start/ zero goal, use weights as it is 44 | self.padding = torch.nn.ConstantPad2d( 45 | (self.basis_gn.num_basis_zero_start, 46 | self.basis_gn.num_basis_zero_goal, 0, 0), 0) 47 | logging.warning( 48 | "Zero Padding ProMP is being used. Only the traj position" 49 | " and velocity can be computed correctly. The other " 50 | "entities are not guaranteed.") 51 | else: 52 | self.padding = lambda x: x 53 | 54 | def set_times(self, times: Union[torch.Tensor, np.ndarray]): 55 | """ 56 | Set MP time points 57 | Args: 58 | times: desired time points 59 | 60 | Returns: 61 | None 62 | """ 63 | # Shape of times 64 | # [*add_dim, num_times] 65 | 66 | times = torch.as_tensor(times, dtype=self.dtype, device=self.device) 67 | super().set_times(times) 68 | 69 | def set_mp_params_variances(self, params_L: Union[torch.Tensor, None]): 70 | """ 71 | Set variance of MP params 72 | Args: 73 | params_L: cholesky of covariance matrix of the MP parameters 74 | 75 | Returns: None 76 | 77 | """ 78 | # Shape of params_L: 79 | # [*add_dim, num_dof * num_basis, num_dof * num_basis] 80 | 81 | if params_L is not None: 82 | assert list(params_L.shape) == [*self.add_dim, 83 | self._num_local_params, 84 | self._num_local_params] 85 | super().set_mp_params_variances(params_L) 86 | 87 | def get_traj_pos(self, times=None, params=None, 88 | init_time=None, init_pos=None, init_vel=None, 89 | flat_shape=False): 90 | """ 91 | Get trajectory position 92 | 93 | Refer setting functions for desired shape of inputs 94 | 95 | Args: 96 | 97 | times: time points 98 | params: learnable parameters 99 | init_time: initial condition time 100 | init_pos: initial condition position 101 | init_vel: initial condition velocity 102 | flat_shape: if flatten the dimensions of Dof and time 103 | 104 | Returns: 105 | pos 106 | """ 107 | 108 | # Shape of pos 109 | # [*add_dim, num_times, num_dof] or [*add_dim, num_dof * num_times] 110 | 111 | # Update inputs 112 | self.update_inputs(times, params, None, init_time, init_pos, init_vel) 113 | 114 | # Reuse result if existing 115 | if self.pos is not None: 116 | pos = self.pos 117 | 118 | else: 119 | assert self.params is not None 120 | 121 | # Reshape params 122 | # [*add_dim, num_dof * num_basis] -> [*add_dim, num_dof, num_basis] 123 | params = self.params.reshape(*self.add_dim, self.num_dof, -1) 124 | 125 | # Padding if necessary, this is a legacy case 126 | # [*add_dim, num_dof, num_basis] 127 | # -> [*add_dim, num_dof, num_basis + num_padding] 128 | params = self.padding(params) 129 | if self.weights_scale.ndim != 0: 130 | weights_scale = self.padding(self.weights_scale[None])[0] 131 | else: 132 | weights_scale = self.padding(torch.ones([1, self.num_basis], 133 | dtype=self.dtype, 134 | device=self.device) * 135 | self.weights_scale)[0] 136 | 137 | # Get basis 138 | # Shape: [*add_dim, num_times, num_basis] 139 | basis_single_dof = \ 140 | self.basis_gn.basis(self.times) * weights_scale 141 | 142 | # Einsum shape: [*add_dim, num_times, num_basis], 143 | # [*add_dim, num_dof, num_basis] 144 | # -> [*add_dim, num_times, num_dof] 145 | pos = torch.einsum('...ik,...jk->...ij', basis_single_dof, params) 146 | 147 | # Padding if necessary, this is a legacy case 148 | pos += self.init_pos[..., None, :] if self.has_zero_padding else 0 149 | 150 | self.pos = pos 151 | 152 | if flat_shape: 153 | # Switch axes to [*add_dim, num_dof, num_times] 154 | pos = torch.einsum('...ji->...ij', pos) 155 | 156 | # Reshape to [*add_dim, num_dof * num_times] 157 | pos = pos.reshape(*self.add_dim, -1) 158 | 159 | return pos 160 | 161 | def get_traj_pos_cov(self, times=None, params_L=None, 162 | init_time=None, init_pos=None, init_vel=None, 163 | reg: float = 1e-4): 164 | """ 165 | Compute position covariance 166 | 167 | Refer setting functions for desired shape of inputs 168 | 169 | Args: 170 | times: time points 171 | params_L: learnable parameters' variance 172 | init_time: initial condition time 173 | init_pos: initial condition position 174 | init_vel: initial condition velocity 175 | reg: regularization term 176 | 177 | Returns: 178 | pos_cov 179 | """ 180 | 181 | # Shape of pos_cov 182 | # [*add_dim, num_dof * num_times, num_dof * num_times] 183 | 184 | # Update inputs 185 | self.update_inputs(times, None, params_L, init_time, init_pos, init_vel) 186 | 187 | # Reuse result if existing 188 | if self.pos_cov is not None: 189 | return self.pos_cov 190 | 191 | # Otherwise recompute result 192 | if self.params_L is None: 193 | return None 194 | 195 | # Get weights scale 196 | if self.weights_scale.ndim == 0: 197 | weights_scale = self.weights_scale 198 | else: 199 | weights_scale = self.weights_scale.repeat(self.num_dof) 200 | 201 | # Get basis of all Dofs 202 | # Shape: [*add_dim, num_dof * num_times, num_dof * num_basis] 203 | basis_multi_dofs = self.basis_gn.basis_multi_dofs( 204 | self.times, self.num_dof) * weights_scale 205 | 206 | # Einsum shape: [*add_dim, num_dof * num_times, num_dof * num_basis] 207 | # [*add_dim, num_dof * num_basis, num_dof * num_basis] 208 | # [*add_dim, num_dof * num_times, num_dof * num_basis] 209 | # -> [*add_dim, num_dof * num_times, num_dof * num_times] 210 | pos_cov = torch.einsum('...ik,...kl,...jl->...ij', 211 | basis_multi_dofs, 212 | self.params_cov, 213 | basis_multi_dofs) 214 | 215 | # Determine regularization term to make traj_cov positive definite 216 | traj_cov_reg = reg 217 | reg_term_pos = torch.max(torch.einsum('...ii->...i', 218 | pos_cov)).item() * traj_cov_reg 219 | 220 | # Add regularization term for numerical stability 221 | self.pos_cov = pos_cov + torch.eye(pos_cov.shape[-1], 222 | dtype=self.dtype, 223 | device=self.device) * reg_term_pos 224 | return self.pos_cov 225 | 226 | def get_traj_pos_std(self, times=None, params_L=None, init_time=None, 227 | init_pos=None, 228 | init_vel=None, flat_shape=False, reg: float = 1e-4): 229 | """ 230 | Compute position standard deviation 231 | 232 | Refer setting functions for desired shape of inputs 233 | 234 | Args: 235 | times: time points 236 | params_L: learnable parameters' variance 237 | init_time: initial condition time 238 | init_pos: initial condition position 239 | init_vel: initial condition velocity 240 | flat_shape: if flatten the dimensions of Dof and time 241 | reg: regularization term 242 | 243 | Returns: 244 | pos_std 245 | """ 246 | 247 | # Shape of pos_std 248 | # [*add_dim, num_times, num_dof] or [*add_dim, num_dof * num_times] 249 | 250 | # Update inputs 251 | self.update_inputs(times, None, params_L, init_time, init_pos, init_vel) 252 | 253 | # Reuse result if existing 254 | if self.pos_std is not None: 255 | pos_std = self.pos_std 256 | 257 | else: 258 | # Otherwise recompute 259 | if self.pos_cov is not None: 260 | pos_cov = self.pos_cov 261 | else: 262 | pos_cov = self.get_traj_pos_cov() 263 | 264 | if pos_cov is None: 265 | pos_std = None 266 | else: 267 | # Shape [*add_dim, num_dof * num_times] 268 | pos_std = torch.sqrt(torch.einsum('...ii->...i', pos_cov)) 269 | 270 | self.pos_std = pos_std 271 | 272 | if pos_std is not None and not flat_shape: 273 | # Reshape to [*add_dim, num_dof, num_times] 274 | pos_std = pos_std.reshape(*self.add_dim, self.num_dof, -1) 275 | 276 | # Switch axes to [*add_dim, num_times, num_dof] 277 | pos_std = torch.einsum('...ji->...ij', pos_std) 278 | 279 | return pos_std 280 | 281 | def get_traj_vel(self, times=None, params=None, 282 | init_time=None, init_pos=None, init_vel=None, 283 | flat_shape=False): 284 | """ 285 | Get trajectory velocity 286 | 287 | Refer setting functions for desired shape of inputs 288 | 289 | Args: 290 | times: time points 291 | params: learnable parameters 292 | init_time: initial condition time 293 | init_pos: initial condition position 294 | init_vel: initial condition velocity 295 | flat_shape: if flatten the dimensions of Dof and time 296 | 297 | Returns: 298 | vel 299 | """ 300 | 301 | # Shape of vel 302 | # [*add_dim, num_times, num_dof] or [*add_dim, num_dof * num_times] 303 | 304 | # Update inputs 305 | self.update_inputs(times, params, None, init_time, init_pos, init_vel) 306 | 307 | # Reuse result if existing 308 | if self.vel is not None: 309 | vel = self.vel 310 | 311 | else: 312 | # Recompute otherwise 313 | pos = self.get_traj_pos() 314 | 315 | vel = torch.zeros_like(pos, dtype=self.dtype, device=self.device) 316 | vel[..., :-1, :] = torch.diff(pos, dim=-2) \ 317 | / torch.diff(self.times)[..., None] 318 | vel[..., -1, :] = vel[..., -2, :] 319 | 320 | self.vel = vel 321 | 322 | if flat_shape: 323 | # Switch axes to [*add_dim, num_dof, num_times] 324 | vel = torch.einsum('...ji->...ij', vel) 325 | 326 | # Reshape to [*add_dim, num_dof * num_times] 327 | vel = vel.reshape(*self.add_dim, -1) 328 | 329 | return vel 330 | 331 | def get_traj_vel_cov(self, times=None, params_L=None, init_time=None, 332 | init_pos=None, 333 | init_vel=None, reg: float = 1e-4): 334 | """ 335 | Get velocity covariance 336 | 337 | Refer setting functions for desired shape of inputs 338 | 339 | Args: 340 | times: time points 341 | params_L: learnable parameters' variance 342 | init_time: initial condition time 343 | init_pos: initial condition position 344 | init_vel: initial condition velocity 345 | reg: regularization term 346 | 347 | Returns: 348 | vel_cov 349 | """ 350 | self.vel_cov = None 351 | return self.vel_cov 352 | 353 | def get_traj_vel_std(self, times=None, params_L=None, init_time=None, 354 | init_pos=None, 355 | init_vel=None, flat_shape=False, reg: float = 1e-4): 356 | """ 357 | Get trajectory standard deviation 358 | 359 | Refer setting functions for desired shape of inputs 360 | 361 | Args: 362 | times: time points 363 | params_L: learnable parameters' variance 364 | init_time: initial condition time 365 | init_pos: initial condition position 366 | init_vel: initial condition velocity 367 | flat_shape: if flatten the dimensions of Dof and time 368 | reg: regularization term 369 | 370 | Returns: 371 | vel_std 372 | """ 373 | self.vel_std = None 374 | return self.vel_std 375 | 376 | def learn_mp_params_from_trajs(self, times: torch.Tensor, 377 | trajs: torch.Tensor, 378 | reg: float = 1e-9, **kwargs) -> dict: 379 | """ 380 | Learn ProMP weights from demonstration 381 | 382 | Args: 383 | times: trajectory time points 384 | trajs: trajectory from which weights should be learned 385 | reg: regularization term 386 | kwargs: keyword arguments 387 | 388 | Returns: 389 | param_dict: dictionary of parameters containing 390 | - weights 391 | """ 392 | # Shape of times 393 | # [*add_dim, num_times] 394 | # 395 | # Shape of trajs: 396 | # [*add_dim, num_times, num_dof] 397 | # 398 | # Shape of params: 399 | # [*add_dim, num_dof * num_basis] 400 | 401 | assert trajs.shape[:-1] == times.shape 402 | assert trajs.shape[-1] == self.num_dof 403 | 404 | self.set_add_dim(list(trajs.shape[:-2])) 405 | self.set_times(times) 406 | 407 | # Get weights scale 408 | if self.weights_scale.ndim == 0: 409 | weights_scale = self.weights_scale 410 | else: 411 | weights_scale = self.weights_scale.repeat(self.num_dof) 412 | 413 | # Get multiple dof basis function values 414 | # Tensor [*add_dim, num_dof * num_times, num_dof * num_basis] 415 | basis_multi_dofs = self.basis_gn.basis_multi_dofs( 416 | times, self.num_dof) * weights_scale 417 | 418 | # Einsum shape: [*add_dim, num_dof * num_times, num_dof * num_basis], 419 | # [*add_dim, num_dof * num_times, num_dof * num_basis], 420 | # -> [*add_dim, num_dof * num_basis, num_dof * num_basis] 421 | A = torch.einsum('...ki,...kj->...ij', basis_multi_dofs, 422 | basis_multi_dofs) 423 | A += torch.eye(self._num_local_params, 424 | dtype=self.dtype, 425 | device=self.device) * reg 426 | 427 | # Reorder axis [*add_dim, num_times, num_dof] 428 | # -> [*add_dim, num_dof, num_times] 429 | trajs = torch.as_tensor(trajs, dtype=self.dtype, device=self.device) 430 | trajs = torch.einsum('...ij->...ji', trajs) 431 | 432 | # Reshape: [*add_dim, num_dof, num_times] 433 | # -> [*add_dim, num_dof * num_times] 434 | add_dim = trajs.shape[:-2] 435 | trajs = trajs.reshape(*add_dim, -1) 436 | 437 | # Einsum shape: [*add_dim, num_dof * num_times, num_dof * num_basis], 438 | # [*add_dim, num_dof * num_times], 439 | # -> [*add_dim, num_dof * num_basis] 440 | B = torch.einsum('...ki,...k->...i', basis_multi_dofs, trajs) 441 | 442 | # Solve for weights, shape [*add_dim, num_dof * num_basis] 443 | params = torch.linalg.solve(A, B) 444 | 445 | # Check if parameters basis or phase generator exist 446 | if self.basis_gn.num_params > 0: 447 | params_super = self.basis_gn.get_params() 448 | params = torch.cat([params_super, params], dim=-1) 449 | 450 | self.set_params(params) 451 | self.set_mp_params_variances(None) 452 | 453 | return {"params": params} 454 | 455 | def _show_scaled_basis(self, plot=False) \ 456 | -> Tuple[torch.Tensor, torch.Tensor]: 457 | tau = self.phase_gn.tau 458 | delay = self.phase_gn.delay 459 | assert tau.ndim == 0 and delay.ndim == 0 460 | times = torch.linspace(delay - tau, delay + 2 * tau, steps=1000, 461 | device=self.device, dtype=self.dtype) 462 | 463 | if self.weights_scale.ndim != 0: 464 | weights_scale = self.padding(self.weights_scale[None])[0] 465 | else: 466 | weights_scale = self.padding(torch.ones([1, self.num_basis], 467 | dtype=self.dtype, 468 | device=self.device) * 469 | self.weights_scale)[0] 470 | 471 | # Get basis 472 | # Shape: [*add_dim, num_times, num_basis] 473 | basis_values = self.basis_gn.basis(times) * weights_scale 474 | 475 | vel_basis_values = torch.zeros_like(basis_values, dtype=self.dtype, 476 | device=self.device) 477 | vel_basis_values[..., :-1, :] = torch.diff(basis_values, dim=-2) \ 478 | / torch.diff(times)[..., None] 479 | vel_basis_values[..., -1, :] = vel_basis_values[..., -2, :] 480 | 481 | # Enforce all variables to numpy 482 | times, basis_values, vel_basis_values, delay, tau = \ 483 | to_nps(times, basis_values, vel_basis_values, delay, tau) 484 | 485 | if plot: 486 | import matplotlib.pyplot as plt 487 | fig, axes = plt.subplots(2, 1, sharex=True, squeeze=False) 488 | for i in range(basis_values.shape[-1]): 489 | axes[0, 0].plot(times, basis_values[:, i], label=f"basis_{i}") 490 | axes[1, 0].plot(times, vel_basis_values[:, i], 491 | label=f"basis_{i}") 492 | axes[0, 0].grid() 493 | axes[0, 0].legend() 494 | axes[0, 0].axvline(x=delay, linestyle='--', color='k', alpha=0.3) 495 | axes[0, 0].axvline(x=delay + tau, linestyle='--', color='k', alpha=0.3) 496 | 497 | axes[1, 0].grid() 498 | axes[1, 0].legend() 499 | axes[1, 0].axvline(x=delay, linestyle='--', color='k', alpha=0.3) 500 | axes[1, 0].axvline(x=delay + tau, linestyle='--', color='k', 501 | alpha=0.3) 502 | 503 | return times, basis_values 504 | -------------------------------------------------------------------------------- /mp_pytorch/phase_gn/__init__.py: -------------------------------------------------------------------------------- 1 | from .exp_decay_phase import * 2 | from .linear_phase import * 3 | from .phase_generator import * 4 | -------------------------------------------------------------------------------- /mp_pytorch/phase_gn/exp_decay_phase.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | import numpy as np 4 | import torch 5 | 6 | from .phase_generator import PhaseGenerator 7 | 8 | 9 | class ExpDecayPhaseGenerator(PhaseGenerator): 10 | def __init__(self, 11 | tau: float = 1.0, 12 | delay: float = 0.0, 13 | alpha_phase: float = 3.0, 14 | learn_tau: bool = False, 15 | learn_delay: bool = False, 16 | learn_alpha_phase: bool = False, 17 | dtype: torch.dtype = torch.float32, 18 | device: torch.device = 'cpu', 19 | *args, **kwargs): 20 | """ 21 | Constructor for exponential decay phase generator 22 | Args: 23 | tau: trajectory length scaling factor 24 | delay: time to wait before execute 25 | alpha_phase: decaying factor: tau * dx/dt = -alpha_phase * x 26 | learn_tau: if tau is learnable parameter 27 | learn_delay: if delay is learnable parameter 28 | learn_alpha_phase: if alpha_phase is a learnable parameter 29 | dtype: torch data type 30 | device: torch device to run on 31 | *args: other arguments list 32 | **kwargs: other keyword arguments 33 | """ 34 | super(ExpDecayPhaseGenerator, self).__init__(tau=tau, delay=delay, 35 | learn_tau=learn_tau, 36 | learn_delay=learn_delay, 37 | dtype=dtype, device=device, 38 | *args, **kwargs) 39 | 40 | self.alpha_phase = torch.tensor(alpha_phase, dtype=self.dtype, 41 | device=self.device) 42 | self.learn_alpha_phase = learn_alpha_phase 43 | 44 | if learn_alpha_phase: 45 | self.alpha_phase_bound = kwargs.get("alpha_phase_bound", 46 | [1e-5, torch.inf]) 47 | assert len(self.alpha_phase_bound) == 2 48 | 49 | @property 50 | def _num_local_params(self) -> int: 51 | """ 52 | Returns: number of parameters of current class 53 | """ 54 | return super()._num_local_params + int(self.learn_alpha_phase) 55 | 56 | def set_params(self, 57 | params: Union[torch.Tensor, np.ndarray]) -> torch.Tensor: 58 | """ 59 | Set parameters of current object and attributes 60 | Args: 61 | params: parameters to be set 62 | 63 | Returns: 64 | Unused parameters 65 | """ 66 | params = torch.as_tensor(params, dtype=self.dtype, device=self.device) 67 | 68 | is_finalized = self.is_finalized 69 | 70 | remaining_params = super().set_params(params) 71 | 72 | iterator = 0 73 | if self.learn_alpha_phase: 74 | alpha_phase = remaining_params[..., iterator] 75 | if is_finalized: 76 | assert not alpha_phase.requires_grad,\ 77 | "Parameters are finalized and won't be updated. " \ 78 | "Requiring gradient of it will cause errors." 79 | else: 80 | self.alpha_phase = alpha_phase 81 | iterator += 1 82 | self.finalize() 83 | return remaining_params[..., iterator:] 84 | 85 | def get_params(self) -> torch.Tensor: 86 | """ 87 | Return all learnable parameters 88 | Returns: 89 | parameters 90 | """ 91 | # Shape of params 92 | # [*add_dim, num_params] 93 | params = super().get_params() 94 | if self.learn_alpha_phase: 95 | params = torch.cat([params, self.alpha_phase[..., None]], dim=-1) 96 | return params 97 | 98 | def get_params_bounds(self) -> torch.Tensor: 99 | """ 100 | Return all learnable parameters' bounds 101 | Returns: 102 | parameters bounds 103 | """ 104 | # Shape of params_bounds 105 | # [2, num_params] 106 | 107 | params_bounds = super().get_params_bounds() 108 | if self.learn_alpha_phase: 109 | alpha_phase_bound = \ 110 | torch.as_tensor(self.alpha_phase_bound, dtype=self.dtype, 111 | device=self.device)[..., None] 112 | params_bounds = torch.cat([params_bounds, alpha_phase_bound], dim=1) 113 | return params_bounds 114 | 115 | def left_bound_linear_phase(self, times): 116 | """ 117 | Compute left bounded linear phase in [0, +inf] 118 | Returns: 119 | linear phase in Tensor 120 | """ 121 | # Shape of time 122 | # [*add_dim, num_times] 123 | 124 | left_bound_Linear_phase = torch.clip( 125 | (times - self.delay[..., None]) / self.tau[..., None], min=0) 126 | return left_bound_Linear_phase 127 | 128 | def phase(self, times: torch.Tensor): 129 | """ 130 | Compute phase 131 | Args: 132 | times: times Tensor 133 | 134 | Returns: 135 | phase in Tensor 136 | 137 | """ 138 | # Shape of time 139 | # [*add_dim, num_times] 140 | 141 | phase = torch.exp( 142 | -self.alpha_phase[..., None] * self.left_bound_linear_phase(times)) 143 | return phase 144 | 145 | def phase_to_time(self, phases: torch.Tensor) -> torch.Tensor: 146 | """ 147 | Inverse operation, compute times given phase 148 | Args: 149 | phases: phases in Tensor 150 | 151 | Returns: 152 | times in Tensor 153 | """ 154 | l_phases = torch.log(phases) / (-self.alpha_phase[..., None]) 155 | times = l_phases * self.tau[..., None] + self.delay[..., None] 156 | 157 | return times 158 | 159 | def linear_phase_to_time(self, phases: torch.Tensor) -> torch.Tensor: 160 | """ 161 | Inverse operation, linearly compute times given phase 162 | Args: 163 | phases: phases in Tensor 164 | 165 | Returns: 166 | times in Tensor 167 | """ 168 | times = phases * self.tau[..., None] + self.delay[..., None] 169 | return times 170 | 171 | def unbound_linear_phase(self, times): 172 | """ 173 | Compute unbounded linear phase [-inf, +inf] 174 | Args: 175 | times: times in Tensor 176 | 177 | Returns: 178 | phase in Tensor 179 | 180 | """ 181 | # Shape of time 182 | # [*add_dim, num_times] 183 | 184 | linear_phase = (times - self.delay[..., None]) / self.tau[..., None] 185 | return linear_phase 186 | 187 | def unbound_phase(self, times: torch.Tensor) -> torch.Tensor: 188 | """ 189 | Compute unbounded phase 190 | Args: 191 | times: times in Tensor 192 | 193 | Returns: 194 | phase in Tensor 195 | 196 | """ 197 | # Shape of time 198 | # [*add_dim, num_times] 199 | phase = torch.exp( 200 | -self.alpha_phase[..., None] * self.unbound_linear_phase(times)) 201 | return phase 202 | -------------------------------------------------------------------------------- /mp_pytorch/phase_gn/linear_phase.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .phase_generator import PhaseGenerator 4 | 5 | 6 | class LinearPhaseGenerator(PhaseGenerator): 7 | """Linear phase generator""" 8 | 9 | def phase(self, times: torch.Tensor) -> torch.Tensor: 10 | """ 11 | Compute bounded phase in [0, 1] 12 | Args: 13 | times: times in Tensor 14 | 15 | Returns: 16 | phase in Tensor 17 | 18 | """ 19 | # Shape of time 20 | # [*add_dim, num_times] 21 | 22 | phase = torch.clip( 23 | (times - self.delay[..., None]) / self.tau[..., None], 0, 1) 24 | return phase 25 | 26 | def phase_to_time(self, phases: torch.Tensor) -> torch.Tensor: 27 | """ 28 | Inverse operation, compute times given phase 29 | Args: 30 | phases: phases in Tensor 31 | 32 | Returns: 33 | times in Tensor 34 | """ 35 | times = phases * self.tau[..., None] + self.delay[..., None] 36 | return times 37 | 38 | def unbound_phase(self, times: torch.Tensor) -> torch.Tensor: 39 | """ 40 | Compute unbounded phase 41 | Args: 42 | times: times in Tensor 43 | 44 | Returns: 45 | phase in Tensor 46 | 47 | """ 48 | # Shape of time 49 | # [*add_dim, num_times] 50 | 51 | phase = (times - self.delay[..., None]) / self.tau[..., None] 52 | return phase 53 | -------------------------------------------------------------------------------- /mp_pytorch/phase_gn/phase_generator.py: -------------------------------------------------------------------------------- 1 | """ 2 | @brief: Phase generators in PyTorch 3 | """ 4 | from abc import ABC 5 | from abc import abstractmethod 6 | from typing import Union 7 | 8 | import numpy as np 9 | import torch 10 | 11 | 12 | # Classes of Phase Generator 13 | 14 | 15 | class PhaseGenerator(ABC): 16 | 17 | def __init__(self, 18 | tau: float = 1.0, 19 | delay: float = 0.0, 20 | learn_tau: bool = False, 21 | learn_delay: bool = False, 22 | dtype: torch.dtype = torch.float32, 23 | device: torch.device = 'cpu', 24 | *args, **kwargs): 25 | """ 26 | Basis class constructor 27 | Args: 28 | tau: trajectory length scaling factor 29 | delay: time to wait before execute 30 | learn_tau: if tau is learnable parameter 31 | learn_delay: if delay is learnable parameter 32 | dtype: torch data type 33 | device: torch device to run on 34 | *args: other arguments list 35 | **kwargs: other keyword arguments 36 | """ 37 | self.dtype = dtype 38 | self.device = device 39 | 40 | self.tau = torch.as_tensor(tau, dtype=self.dtype, device=self.device) 41 | self.delay = torch.as_tensor(delay, dtype=self.dtype, 42 | device=self.device) 43 | self.learn_tau = learn_tau 44 | self.learn_delay = learn_delay 45 | 46 | if learn_tau: 47 | self.tau_bound = kwargs.get("tau_bound", [1e-5, torch.inf]) 48 | assert len(self.tau_bound) == 2 49 | if learn_delay: 50 | self.delay_bound = kwargs.get("delay_bound", [0, torch.inf]) 51 | assert len(self.delay_bound) == 2 52 | 53 | self.is_finalized = False 54 | 55 | @abstractmethod 56 | def phase(self, times: torch.Tensor) -> torch.Tensor: 57 | """ 58 | Basis class phase interface 59 | Args: 60 | times: times in Tensor 61 | 62 | Returns: phases in Tensor 63 | 64 | """ 65 | pass 66 | 67 | @abstractmethod 68 | def unbound_phase(self, times: torch.Tensor) -> torch.Tensor: 69 | """ 70 | Basis class unbound phase interface 71 | Args: 72 | times: times in Tensor 73 | 74 | Returns: phases in Tensor 75 | 76 | """ 77 | pass 78 | 79 | @abstractmethod 80 | def phase_to_time(self, phases: torch.Tensor) -> torch.Tensor: 81 | """ 82 | Inverse operation, compute times given phase 83 | Args: 84 | phases: phases in Tensor 85 | 86 | Returns: 87 | times in Tensor 88 | """ 89 | pass 90 | 91 | @property 92 | def _num_local_params(self) -> int: 93 | """ 94 | Returns: number of parameters of current class 95 | """ 96 | return int(self.learn_tau) + int(self.learn_delay) 97 | 98 | @property 99 | def num_params(self) -> int: 100 | """ 101 | Returns: number of parameters of current class plus parameters of all 102 | attributes 103 | """ 104 | return self._num_local_params 105 | 106 | def set_params(self, 107 | params: Union[torch.Tensor, np.ndarray]) -> torch.Tensor: 108 | """ 109 | Set parameters of current object and attributes 110 | Args: 111 | params: parameters to be set 112 | 113 | Returns: 114 | Unused parameters 115 | """ 116 | params = torch.as_tensor(params, dtype=self.dtype, device=self.device) 117 | 118 | iterator = 0 119 | is_finalized = self.is_finalized 120 | 121 | if self.learn_tau: 122 | tau = params[..., iterator] 123 | assert tau.min() > 0 124 | if is_finalized: 125 | assert not tau.requires_grad, \ 126 | "Parameters are finalized and won't be updated. " \ 127 | "Requiring gradient of it will cause errors." 128 | else: 129 | self.tau = tau 130 | iterator += 1 131 | if self.learn_delay: 132 | delay = params[..., iterator] 133 | assert delay.min() >= 0 134 | if is_finalized: 135 | assert not delay.requires_grad, \ 136 | "Parameters are finalized and won't be updated. " \ 137 | "Requiring gradient of it will cause errors." 138 | else: 139 | self.delay = delay 140 | iterator += 1 141 | remaining_params = params[..., iterator:] 142 | 143 | self.finalize() 144 | return remaining_params 145 | 146 | def get_params(self) -> torch.Tensor: 147 | """ 148 | Return all learnable parameters 149 | Returns: 150 | parameters 151 | """ 152 | # Shape of params 153 | # [*add_dim, num_params] 154 | 155 | params = torch.as_tensor([], dtype=self.dtype, device=self.device) 156 | if self.learn_tau: 157 | params = torch.cat([params, self.tau[..., None]], dim=-1) 158 | if self.learn_delay: 159 | params = torch.cat([params, self.delay[..., None]], dim=-1) 160 | return params 161 | 162 | def get_params_bounds(self) -> torch.Tensor: 163 | """ 164 | Return all learnable parameters' bounds 165 | Returns: 166 | parameters bounds 167 | """ 168 | # Shape of params_bounds 169 | # [2, num_params] 170 | 171 | params_bounds = torch.zeros([2, 0], dtype=self.dtype, 172 | device=self.device) 173 | if self.learn_tau: 174 | tau_bound = torch.as_tensor(self.tau_bound, dtype=self.dtype, 175 | device=self.device)[..., None] 176 | params_bounds = torch.cat([params_bounds, tau_bound], dim=1) 177 | if self.learn_delay: 178 | delay_bound = torch.as_tensor(self.delay_bound, dtype=self.dtype, 179 | device=self.device)[..., None] 180 | params_bounds = torch.cat([params_bounds, delay_bound], dim=1) 181 | return params_bounds 182 | 183 | def finalize(self): 184 | """ 185 | Mark the phase generator as finalized so that the parameters cannot be 186 | updated any more 187 | Returns: None 188 | 189 | """ 190 | self.is_finalized = True 191 | 192 | def reset(self): 193 | """ 194 | Unmark the finalization 195 | Returns: None 196 | 197 | """ 198 | self.is_finalized = False 199 | -------------------------------------------------------------------------------- /mp_pytorch/phase_gn/rhythmic_phase_generator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from mp_pytorch import PhaseGenerator 4 | 5 | 6 | # TODO: Adjust to mp_pytorch 7 | class RhythmicPhaseGenerator(PhaseGenerator): 8 | 9 | def phase(self, t: np.ndarray, duration: float): 10 | linear_phase = t / duration 11 | phase = linear_phase % 1.0 12 | 13 | return phase 14 | -------------------------------------------------------------------------------- /mp_pytorch/phase_gn/smooth_phase_generator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.interpolate import make_interp_spline as spi_make_interp_spline 3 | 4 | from mp_pytorch import PhaseGenerator 5 | 6 | 7 | # TODO: Adjust to mp_pytorch lib 8 | class SmoothPhaseGenerator(PhaseGenerator): 9 | 10 | def __init__(self, duration: float = 1): 11 | self.left = [(1, 0.0), (2, 0.0)] 12 | self.right = [(1, 0.0), (2, 0.0)] 13 | 14 | def phase(self, t: np.ndarray, duration: float): 15 | spline = spi_make_interp_spline([0, duration], [0, 1], 16 | bc_type=(self.left, self.right), k=5) 17 | return spline(t) 18 | -------------------------------------------------------------------------------- /mp_pytorch/util/__init__.py: -------------------------------------------------------------------------------- 1 | from .util_data_structure import * 2 | from .util_debug import * 3 | from .util_matrix import * 4 | from .util_media import * 5 | from .util_string import * 6 | -------------------------------------------------------------------------------- /mp_pytorch/util/util_data_structure.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities of data type and structure 3 | """ 4 | from typing import List 5 | from typing import Tuple 6 | from typing import Union 7 | 8 | import numpy as np 9 | import torch 10 | 11 | 12 | def make_iterable(data: any, default: str = 'tuple') \ 13 | -> Union[Tuple, List]: 14 | """ 15 | Make data a tuple or list, i.e. (data) or [data] 16 | Args: 17 | data: some data 18 | default: default type 19 | Returns: 20 | (data) if it is not a tuple 21 | """ 22 | if isinstance(data, tuple): 23 | return data 24 | elif isinstance(data, list): 25 | return data 26 | else: 27 | if default == 'tuple': 28 | return (data,) # Do not use tuple() 29 | elif default == 'list': 30 | return [data, ] 31 | else: 32 | raise NotImplementedError 33 | 34 | 35 | def to_np(tensor: Union[np.ndarray, torch.Tensor]) -> np.ndarray: 36 | """ 37 | Transfer any type and device of tensor to a numpy ndarray 38 | Args: 39 | tensor: np.ndarray, cpu tensor or gpu tensor 40 | 41 | Returns: 42 | tensor in np.ndarray 43 | """ 44 | if is_np(tensor): 45 | return tensor 46 | elif is_ts(tensor): 47 | return tensor.detach().cpu().numpy() 48 | else: 49 | np.array(tensor) 50 | 51 | 52 | def to_nps(*tensors: [Union[np.ndarray, torch.Tensor]]) -> [np.ndarray]: 53 | """ 54 | transfer a list of any type of tensors to np.ndarray 55 | Args: 56 | tensors: a list of tensors 57 | 58 | Returns: 59 | a list of np.ndarray 60 | """ 61 | return [to_np(tensor) for tensor in tensors] 62 | 63 | 64 | def is_np(data: any) -> bool: 65 | """ 66 | is data a numpy array? 67 | """ 68 | return isinstance(data, np.ndarray) 69 | 70 | 71 | def to_ts(data: Union[int, float, np.ndarray, torch.Tensor], 72 | dtype: torch.dtype = torch.float32, 73 | device: str = "cpu") -> torch.Tensor: 74 | """ 75 | Transfer any numerical input to a torch tensor in default data type + device 76 | 77 | Args: 78 | device: device of the tensor, default: cpu 79 | dtype: data type of tensor, float 32 or float 64 (double) 80 | data: float, np.ndarray, torch.Tensor 81 | 82 | Returns: 83 | tensor in torch.Tensor 84 | """ 85 | 86 | return torch.as_tensor(data, dtype=dtype, device=device) 87 | 88 | 89 | def to_tss(*datas: [Union[int, float, np.ndarray, torch.Tensor]], 90 | dtype: torch.dtype = torch.float32, 91 | device: str = "cpu") \ 92 | -> [torch.Tensor]: 93 | """ 94 | transfer a list of any type of numerical input to a list of tensors in given 95 | data type and device 96 | 97 | Args: 98 | datas: a list of data 99 | dtype: data type of tensor, float 32 or float 64 (double) 100 | device: device of the tensor, default: cpu 101 | 102 | Returns: 103 | a list of np.ndarray 104 | """ 105 | return [to_ts(data, dtype, device) for data in datas] 106 | 107 | 108 | def is_ts(data: any) -> bool: 109 | """ 110 | is data a torch Tensor? 111 | """ 112 | return isinstance(data, torch.Tensor) 113 | -------------------------------------------------------------------------------- /mp_pytorch/util/util_debug.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities for debugging 3 | """ 4 | 5 | import time 6 | from typing import Callable 7 | from typing import Optional 8 | from typing import Union 9 | 10 | import matplotlib.pyplot as plt 11 | import numpy as np 12 | import torch 13 | 14 | from mp_pytorch import util 15 | 16 | 17 | def how_fast(repeat: int, func: Callable, *args, **kwargs): 18 | """ 19 | Test how fast a given function call is 20 | Args: 21 | repeat: number of times to run the function 22 | func: function to be tested 23 | *args: list of arguments used in the function call 24 | **kwargs: dict of arguments used in the function call 25 | 26 | Returns: 27 | avg duration function call 28 | 29 | Raise: 30 | any type of exception when test the function call 31 | """ 32 | run_time_test(lock=True) 33 | try: 34 | for i in range(repeat): 35 | func(*args, **kwargs) 36 | duration = run_time_test(lock=False) 37 | if duration is not None: 38 | print(f"total_time of {repeat} runs: {duration} s") 39 | print(f"avg_time of each run: {duration / repeat} s") 40 | return duration / repeat 41 | except RuntimeError: 42 | raise 43 | except Exception: 44 | raise 45 | 46 | 47 | def run_time_test(lock: bool) -> Optional[float]: 48 | """ 49 | A manual running time computing function. It will print the running time 50 | for every second call 51 | 52 | E.g.: 53 | run_time_test(lock=True) 54 | some_func1() 55 | some_func2() 56 | ... 57 | run_time_test(lock=False) 58 | 59 | Args: 60 | lock: flag indicating if time counter starts 61 | 62 | Returns: 63 | None (every first call) or duration (every second call) 64 | 65 | Raise: 66 | RuntimeError if is used in a wrong way 67 | """ 68 | # Initialize function attribute 69 | if not hasattr(run_time_test, "lock_state"): 70 | run_time_test.lock_state = False 71 | run_time_test.last_run_time = time.time() 72 | run_time_test.duration_list = list() 73 | 74 | # Check correct usage 75 | if run_time_test.lock_state == lock: 76 | run_time_test.lock_state = False 77 | raise RuntimeError("run_time_test is wrongly used.") 78 | 79 | # Setup lock 80 | run_time_test.lock_state = lock 81 | 82 | # Update time 83 | if lock is False: 84 | duration = time.time() - run_time_test.last_run_time 85 | run_time_test.duration_list.append(duration) 86 | run_time_test.last_run_time = time.time() 87 | print("duration", duration) 88 | return duration 89 | else: 90 | run_time_test.last_run_time = time.time() 91 | return None 92 | 93 | 94 | def debug_plot(x: Union[np.ndarray, torch.Tensor], 95 | y: [], labels: [] = None, title="debug_plot", grid=True) -> \ 96 | plt.Figure: 97 | """ 98 | One line to plot some variable for debugging, numpy + torch 99 | Args: 100 | x: data used for x-axis, can be None 101 | y: list of data used for y-axis 102 | labels: labels in plots 103 | title: title of current plot 104 | grid: show grid or not 105 | 106 | Returns: 107 | None 108 | """ 109 | fig = plt.figure() 110 | y = util.make_iterable(y) 111 | if labels is not None: 112 | labels = util.make_iterable(labels) 113 | 114 | for i, yi in enumerate(y): 115 | yi = util.to_np(yi) 116 | label = labels[i] if labels is not None else None 117 | if x is not None: 118 | x = util.to_np(x) 119 | plt.plot(x, yi, label=label) 120 | else: 121 | plt.plot(yi, label=label) 122 | 123 | plt.title(title) 124 | if labels is not None: 125 | plt.legend() 126 | if grid: 127 | plt.grid(alpha=0.5) 128 | plt.show() 129 | return fig 130 | -------------------------------------------------------------------------------- /mp_pytorch/util/util_matrix.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities of matrix operation 3 | """ 4 | from typing import Iterable 5 | from typing import Optional 6 | from typing import Union 7 | 8 | import numpy as np 9 | import torch 10 | 11 | 12 | def build_lower_matrix(param_diag: torch.Tensor, 13 | param_off_diag: Optional[torch.Tensor]) -> torch.Tensor: 14 | """ 15 | Compose the lower triangular matrix L from diag and off-diag elements 16 | It seems like faster than using the cholesky transformation from PyTorch 17 | Args: 18 | param_diag: diagonal parameters 19 | param_off_diag: off-diagonal parameters 20 | 21 | Returns: 22 | Lower triangular matrix L 23 | 24 | """ 25 | dim_pred = param_diag.shape[-1] 26 | # Fill diagonal terms 27 | L = param_diag.diag_embed() 28 | if param_off_diag is not None: 29 | # Fill off-diagonal terms 30 | [row, col] = torch.tril_indices(dim_pred, dim_pred, -1) 31 | L[..., row, col] = param_off_diag[..., :] 32 | 33 | return L 34 | 35 | 36 | def add_expand_dim(data: Union[torch.Tensor, np.ndarray], 37 | add_dim_indices: [int], 38 | add_dim_sizes: [int]) -> Union[torch.Tensor, np.ndarray]: 39 | """ 40 | Add additional dimensions to tensor and expand accordingly 41 | Args: 42 | data: tensor to be operated. Torch.Tensor or numpy.ndarray 43 | add_dim_indices: the indices of added dimensions in the result tensor 44 | add_dim_sizes: the expanding size of the additional dimensions 45 | 46 | Returns: 47 | result: result tensor after adding and expanding 48 | """ 49 | num_data_dim = data.ndim 50 | num_dim_to_add = len(add_dim_indices) 51 | 52 | add_dim_reverse_indices = [num_data_dim + num_dim_to_add + idx for idx in 53 | add_dim_indices] 54 | 55 | str_add_dim = "" 56 | str_expand = "" 57 | add_dim_index = 0 58 | for dim in range(num_data_dim + num_dim_to_add): 59 | if dim in add_dim_indices or dim in add_dim_reverse_indices: 60 | str_add_dim += "None, " 61 | str_expand += str(add_dim_sizes[add_dim_index]) + ", " 62 | add_dim_index += 1 63 | else: 64 | str_add_dim += ":, " 65 | if type(data) == torch.Tensor: 66 | str_expand += "-1, " 67 | elif type(data) == np.ndarray: 68 | str_expand += "1, " 69 | else: 70 | raise NotImplementedError 71 | 72 | str_add_dime_eval = "data[" + str_add_dim + "]" 73 | if type(data) == torch.Tensor: 74 | return eval("eval(str_add_dime_eval).expand(" + str_expand + ")") 75 | else: 76 | return eval("np.tile(eval(str_add_dime_eval),[" + str_expand + "])") 77 | 78 | 79 | def tensor_linspace(start: Union[float, int, torch.Tensor], 80 | end: Union[float, int, torch.Tensor], 81 | steps: int) -> torch.Tensor: 82 | """ 83 | Vectorized version of torch.linspace. 84 | Modified from: 85 | https://github.com/zhaobozb/layout2im/blob/master/models/bilinear.py#L246 86 | 87 | Args: 88 | start: start value, scalar or tensor 89 | end: end value, scalar or tensor 90 | steps: num of steps 91 | 92 | Returns: 93 | linspace tensor 94 | """ 95 | # Shape of start: 96 | # [*add_dim, dim_data] or a scalar 97 | # 98 | # Shape of end: 99 | # [*add_dim, dim_data] or a scalar 100 | # 101 | # Shape of out: 102 | # [*add_dim, steps, dim_data] 103 | 104 | # - out: Tensor of shape start.size() + (steps,), such that 105 | # out.select(-1, 0) == start, out.select(-1, -1) == end, 106 | # and the other elements of out linearly interpolate between 107 | # start and end. 108 | 109 | if isinstance(start, torch.Tensor) and not isinstance(end, torch.Tensor): 110 | end += torch.zeros_like(start) 111 | elif not isinstance(start, torch.Tensor) and isinstance(end, torch.Tensor): 112 | start += torch.zeros_like(end) 113 | elif isinstance(start, torch.Tensor) and isinstance(end, torch.Tensor): 114 | assert start.size() == end.size() 115 | else: 116 | return torch.linspace(start, end, steps) 117 | 118 | view_size = start.size() + (1,) 119 | w_size = (1,) * start.dim() + (steps,) 120 | out_size = start.size() + (steps,) 121 | 122 | start_w = torch.linspace(1, 0, steps=steps).to(start) 123 | start_w = start_w.view(w_size).expand(out_size) 124 | end_w = torch.linspace(0, 1, steps=steps).to(start) 125 | end_w = end_w.view(w_size).expand(out_size) 126 | 127 | start = start.contiguous().view(view_size).expand(out_size) 128 | end = end.contiguous().view(view_size).expand(out_size) 129 | 130 | out = start_w * start + end_w * end 131 | out = torch.einsum('...ji->...ij', out) 132 | return out 133 | 134 | 135 | def indexing_interpolate(data: torch.Tensor, 136 | indices: torch.Tensor) -> torch.Tensor: 137 | """ 138 | Indexing values from a given tensor's data, using non-integer indices and 139 | thus apply interpolation. 140 | 141 | Args: 142 | data: data tensor from where indexing happens 143 | indices: float indices tensor 144 | 145 | Returns: 146 | indexed and interpolated data 147 | """ 148 | # Shape of data: 149 | # [num_data, *dim_data] 150 | # 151 | # Shape of indices: 152 | # [*add_dim, num_indices] 153 | # 154 | # Shape of interpolate_result: 155 | # [*add_dim, num_indices, *dim_data] 156 | 157 | ndim_data = data.ndim - 1 158 | indices_0 = torch.clip(indices.floor().long(), 0, 159 | data.shape[-data.ndim] - 2) 160 | indices_1 = indices_0 + 1 161 | weights = indices - indices_0 162 | if ndim_data > 0: 163 | weights = add_expand_dim(weights, 164 | range(indices.ndim, indices.ndim + ndim_data), 165 | [-1] * ndim_data) 166 | interpolate_result = torch.lerp(data[indices_0], data[indices_1], weights) 167 | return interpolate_result 168 | 169 | 170 | def get_sub_tensor(data: torch.Tensor, dims: Iterable, indices: Iterable): 171 | """ 172 | Get sub tensor from a given tensor's data, and multi-dimensional indices 173 | First form up an expression string and then slice the data tensor 174 | Args: 175 | data: original tensor 176 | dims: dimensions to be sliced 177 | indices: slice of these dimensions 178 | 179 | Returns: 180 | sliced tensor 181 | 182 | """ 183 | exp_str = "" 184 | for i, dim in enumerate(dims): 185 | if dim < 0: 186 | dim += data.ndim 187 | temp_str = ":, " * dim 188 | dim_exp_str = f"[{temp_str}indices[{i}]]" 189 | exp_str += dim_exp_str 190 | 191 | return eval("data" + exp_str) 192 | -------------------------------------------------------------------------------- /mp_pytorch/util/util_media.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities for generating media stuff 3 | """ 4 | 5 | from typing import Union 6 | 7 | import numpy as np 8 | import torch 9 | from matplotlib import pyplot as plt 10 | 11 | from mp_pytorch import util 12 | 13 | 14 | def fill_between(x: Union[np.ndarray, torch.Tensor], 15 | y_mean: Union[np.ndarray, torch.Tensor], 16 | y_std: Union[np.ndarray, torch.Tensor], 17 | axis=None, std_scale: int = 2, draw_mean: bool = False, 18 | alpha=0.2, color='gray'): 19 | """ 20 | Utilities to draw std plot 21 | Args: 22 | x: x value 23 | y_mean: y mean value 24 | y_std: standard deviation of y 25 | axis: figure axis to draw 26 | std_scale: filling range of [-scale * std, scale * std] 27 | draw_mean: plot mean curve as well 28 | alpha: transparency of std plot 29 | color: color to fill 30 | 31 | Returns: 32 | None 33 | """ 34 | x, y_mean, y_std = util.to_nps(x, y_mean, y_std) 35 | if axis is None: 36 | axis = plt.gca() 37 | if draw_mean: 38 | axis.plot(x, y_mean) 39 | axis.fill_between(x=x, 40 | y1=y_mean - std_scale * y_std, 41 | y2=y_mean + std_scale * y_std, 42 | alpha=alpha, color=color) 43 | -------------------------------------------------------------------------------- /mp_pytorch/util/util_string.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities of string operation and printing stuff 3 | """ 4 | 5 | 6 | def print_line(char: str = "=", length: int = 60, 7 | before: int = 0, after: int = 0) -> None: 8 | """ 9 | Print a line with given letter in given length 10 | Args: 11 | char: char for print the line 12 | length: length of line 13 | before: number of new lines before print line 14 | after: number of new lines after print line 15 | 16 | Returns: None 17 | """ 18 | 19 | print("\n" * before, end="") 20 | print(char * length) 21 | print("\n" * after, end="") 22 | # End of function print_line 23 | 24 | 25 | def print_line_title(title: str = "", middle: bool = True, char: str = "=", 26 | length: int = 60, before: int = 1, after: int = 1) -> None: 27 | """ 28 | Print a line with title 29 | Args: 30 | title: title to print 31 | middle: if title should be in the middle, otherwise left 32 | char: char for print the line 33 | length: length of line 34 | before: number of new lines before print line 35 | after: number of new lines after print line 36 | 37 | Returns: None 38 | """ 39 | assert len(title) < length, "Title is longer than line length" 40 | len_before_title = (length - len(title)) // 2 - 1 41 | len_after_title = length - len(title) - (length - len(title)) // 2 - 1 42 | print("\n" * before, end="") 43 | if middle is True: 44 | print(char * len_before_title, "", end="") 45 | print(title, end="") 46 | print("", char * len_after_title) 47 | else: 48 | print(title, end="") 49 | print(" ", char * (length - len(title) - 1)) 50 | print("\n" * after, end="") 51 | # End of function print_line_title 52 | 53 | 54 | def print_wrap_title(title: str = "", char: str = "*", length: int = 60, 55 | wrap: int = 1, before: int = 1, after: int = 1) -> None: 56 | """ 57 | Print title with wrapped box 58 | Args: 59 | title: title to print 60 | char: char for print the line 61 | length: length of line 62 | wrap: number of wrapped layers 63 | before: number of new lines before print line 64 | after: number of new lines after print line 65 | 66 | Returns: None 67 | """ 68 | 69 | assert len(title) < length - 4, "Title is longer than line length - 4" 70 | 71 | len_before_title = (length - len(title)) // 2 - 1 72 | len_after_title = length - len(title) - (length - len(title)) // 2 - 1 73 | 74 | print_line(char=char, length=length, before=before) 75 | for _ in range(wrap - 1): 76 | print(char, " " * (length - 2), char, sep="") 77 | print(char, " " * len_before_title, title, " " * len_after_title, char, 78 | sep="") 79 | 80 | for _ in range(wrap - 1): 81 | print(char, " " * (length - 2), char, sep="") 82 | print_line(char=char, length=length, after=after) 83 | # End of function print_wrap_title 84 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from pathlib import Path 3 | this_directory = Path(__file__).parent 4 | readme = (this_directory/"README.md").read_text() 5 | 6 | setup( 7 | name='mp_pytorch', 8 | version='0.1.4', 9 | packages=['mp_pytorch', 'mp_pytorch.mp', 'mp_pytorch.util', 10 | 'mp_pytorch.basis_gn', 'mp_pytorch.phase_gn', 'mp_pytorch.demo'], 11 | url='https://github.com/ALRhub/MP_PyTorch', 12 | license='MIT', 13 | author='Ge Li @ ALR, KIT', 14 | author_email='ge.li@kit.edu', 15 | install_requires=[], 16 | 17 | # README.md 18 | description='The Movement Primitives Package in PyTorch', 19 | long_description=readme, 20 | long_description_content_type='text/markdown', 21 | classifiers=[ 22 | "Intended Audience :: Science/Research", 23 | "Programming Language :: Python :: 3", 24 | "License :: OSI Approved :: MIT License", 25 | "Operating System :: POSIX :: Linux", 26 | ], 27 | ) 28 | -------------------------------------------------------------------------------- /test/__init__.py: -------------------------------------------------------------------------------- 1 | from .test_dmp_vs_prodmp import * 2 | from .test_quantitative import * 3 | -------------------------------------------------------------------------------- /test/test_dmp_vs_prodmp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from addict import Dict 3 | 4 | from mp_pytorch import util 5 | from mp_pytorch.mp import MPFactory 6 | 7 | 8 | def get_mp_config(): 9 | """ 10 | Get the config of DMPs for testing 11 | 12 | Args: 13 | mp_type: "dmp" or "prodmp" 14 | 15 | Returns: 16 | config in dictionary 17 | """ 18 | 19 | torch.manual_seed(0) 20 | 21 | config = Dict() 22 | config.num_dof = 2 23 | config.tau = 3 24 | config.learn_tau = True 25 | config.learn_delay = True 26 | 27 | config.mp_args.num_basis = 9 28 | config.mp_args.basis_bandwidth_factor = 2 29 | config.mp_args.num_basis_outside = 0 30 | config.mp_args.alpha = 25 31 | config.mp_args.alpha_phase = 2 32 | config.mp_args.dt = 0.001 33 | config.mp_args.weights_scale = torch.ones([9]) * 1 34 | config.mp_args.goal_scale = 1 35 | 36 | # assume we have 3 trajectories in a batch 37 | num_traj = 3 38 | 39 | # Get trajectory scaling 40 | tau, delay = 4, 1 41 | scale_delay = torch.Tensor([tau, delay]) 42 | scale_delay = util.add_expand_dim(scale_delay, [0], [num_traj]) 43 | 44 | # Get params 45 | params = torch.Tensor([100, 200, 300, -100, -200, -300, 46 | 100, 200, 300, -2] * config.num_dof) 47 | params = util.add_expand_dim(params, [0], [num_traj]) 48 | params = torch.cat([scale_delay, params], dim=-1) 49 | 50 | # Get times 51 | num_t = int(config.tau / config.mp_args.dt) * 2 + 1 52 | times = util.tensor_linspace(0, (tau + delay), num_t).squeeze(-1) 53 | times = util.add_expand_dim(times, [0], [num_traj]) 54 | 55 | # Get IC 56 | init_time = times[:, 0] 57 | init_pos = 5 * torch.ones([num_traj, config.num_dof]) 58 | init_vel = torch.zeros_like(init_pos) 59 | 60 | return config, params, times, init_time, init_pos, init_vel 61 | 62 | 63 | def test_dmp_vs_prodmp_identical(plot=False): 64 | # Get config 65 | config, params, times, init_time, init_pos, init_vel = get_mp_config() 66 | 67 | # Initialize the DMP and ProDMP 68 | config.mp_type = "dmp" 69 | dmp = MPFactory.init_mp(**config.to_dict()) 70 | config.mp_type = "prodmp" 71 | prodmp = MPFactory.init_mp(**config.to_dict()) 72 | 73 | # Get trajectory 74 | dmp.update_inputs(times=times, params=params, 75 | init_time=init_time, init_pos=init_pos, init_vel=init_vel) 76 | 77 | prodmp.update_inputs(times=times, params=params, params_L=None, 78 | init_time=init_time, init_pos=init_pos, init_vel=init_vel) 79 | 80 | dmp_pos = dmp.get_traj_pos() 81 | dmp_vel = dmp.get_traj_vel() 82 | prodmp_pos = prodmp.get_traj_pos() 83 | prodmp_vel = prodmp.get_traj_vel() 84 | 85 | if plot: 86 | util.debug_plot(x=None, y=[dmp_pos[0, :, 0], prodmp_pos[0, :, 0]], 87 | labels=["dmp", "prodmp"], title="DMP vs. ProDMP") 88 | 89 | util.debug_plot(x=None, y=[dmp_vel[0, :, 0], prodmp_vel[0, :, 0]], 90 | labels=["dmp", "prodmp"], title="DMP vs. ProDMP") 91 | 92 | # Compute error 93 | error = dmp_pos - prodmp_pos 94 | print(f"Desired_max_error: {0.000406}, " 95 | f"Actual_error: {error.max()}") 96 | assert error.max() < 4.1e-3 97 | 98 | 99 | if __name__ == "__main__": 100 | test_dmp_vs_prodmp_identical(True) 101 | -------------------------------------------------------------------------------- /test/test_main.py: -------------------------------------------------------------------------------- 1 | from test import * 2 | 3 | if __name__ == "__main__": 4 | test_dmp_vs_prodmp_identical(plot=False) 5 | dmp_quantitative_test(plot=False) 6 | promp_quantitative_test(plot=False) 7 | prodmp_quantitative_test(plot=False) 8 | -------------------------------------------------------------------------------- /test/test_prodmp_relative_goal.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from addict import Dict 3 | 4 | from mp_pytorch import util 5 | from mp_pytorch.mp import MPFactory 6 | 7 | 8 | def get_mp_config(relative_goal=False, disable_goal=False): 9 | """ 10 | Get the config of DMPs for testing 11 | 12 | Args: 13 | relative_goal: if True, the goal is relative to the initial position 14 | disable_goal: 15 | 16 | Returns: 17 | config in dictionary 18 | """ 19 | 20 | torch.manual_seed(0) 21 | 22 | config = Dict() 23 | config.mp_type = "prodmp" 24 | config.num_dof = 2 25 | config.tau = 3 26 | config.learn_tau = True 27 | config.learn_delay = True 28 | 29 | config.mp_args.num_basis = 4 30 | config.mp_args.basis_bandwidth_factor = 2 31 | config.mp_args.num_basis_outside = 0 32 | config.mp_args.alpha = 25 33 | config.mp_args.alpha_phase = 2 34 | config.mp_args.dt = 0.001 35 | config.mp_args.relative_goal = relative_goal 36 | config.mp_args.disable_goal = disable_goal 37 | config.mp_args.weights_scale = torch.ones([4]) * 1 38 | config.mp_args.goal_scale = 1 39 | 40 | # assume we have 3 trajectories in a batch 41 | num_traj = 3 42 | 43 | # Get trajectory scaling 44 | tau, delay = 4, 1 45 | scale_delay = torch.Tensor([tau, delay]) 46 | scale_delay = util.add_expand_dim(scale_delay, [0], [num_traj]) 47 | 48 | # Get times 49 | num_t = int(config.tau / config.mp_args.dt) * 2 + 1 50 | times = util.tensor_linspace(0, (tau + delay), num_t).squeeze(-1) 51 | times = util.add_expand_dim(times, [0], [num_traj]) 52 | 53 | # Get IC 54 | init_time = times[:, 0] 55 | init_pos_scalar = 1 56 | init_pos = init_pos_scalar * torch.ones([num_traj, config.num_dof]) 57 | init_vel = torch.zeros_like(init_pos) 58 | 59 | # Get params 60 | goal = init_pos_scalar 61 | if relative_goal: 62 | goal -= init_pos_scalar 63 | if not disable_goal: 64 | params_list = [100, 200, 300, -100, goal] 65 | else: 66 | params_list = [100, 200, 300, -100] 67 | params = torch.Tensor(params_list * config.num_dof) 68 | params = util.add_expand_dim(params, [0], [num_traj]) 69 | params = torch.cat([scale_delay, params], dim=-1) 70 | 71 | return config, params, times, init_time, init_pos, init_vel 72 | 73 | 74 | def get_prodmp_results(relative_goal, disable_goal=False): 75 | config, params, times, init_time, init_pos, init_vel = get_mp_config( 76 | relative_goal, disable_goal) 77 | mp = MPFactory.init_mp(**config) 78 | mp.update_inputs(times, params, None, init_time, init_pos, init_vel) 79 | result_dict = mp.get_trajs() 80 | return result_dict 81 | 82 | 83 | if __name__ == "__main__": 84 | no_relative_goal_results = get_prodmp_results(False) 85 | relative_goal_results = get_prodmp_results(True) 86 | disable_goal_results = get_prodmp_results(True, True) 87 | 88 | for key in no_relative_goal_results.keys(): 89 | print(key) 90 | if no_relative_goal_results[key] is None: 91 | print("None") 92 | elif torch.allclose(no_relative_goal_results[key], 93 | relative_goal_results[key])\ 94 | and torch.allclose(no_relative_goal_results[key], 95 | disable_goal_results[key]): 96 | print("PASS") 97 | else: 98 | print("FAIL") 99 | -------------------------------------------------------------------------------- /test/test_prodmp_speed.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from addict import Dict 3 | 4 | from mp_pytorch import util 5 | from mp_pytorch.mp import MPFactory 6 | 7 | 8 | def get_mp_config(): 9 | """ 10 | Get the config of DMPs for testing 11 | 12 | Args: 13 | mp_type: "dmp" or "prodmp" 14 | 15 | Returns: 16 | config in dictionary 17 | """ 18 | 19 | device = torch.device("cuda") 20 | 21 | torch.manual_seed(0) 22 | 23 | config = Dict() 24 | config.num_dof = 2 25 | config.tau = 3 26 | config.learn_tau = True 27 | config.learn_delay = True 28 | 29 | config.mp_args.num_basis = 5 30 | config.mp_args.basis_bandwidth_factor = 2 31 | config.mp_args.num_basis_outside = 0 32 | config.mp_args.alpha = 25 33 | config.mp_args.alpha_phase = 2 34 | config.mp_args.dt = 0.02 35 | config.mp_args.weights_scale = 1 36 | config.mp_args.goal_scale = 1 37 | 38 | # assume we have 3 trajectories in a batch 39 | num_traj = 3 40 | 41 | # Get trajectory scaling 42 | tau, delay = 4, 1 43 | scale_delay = torch.Tensor([tau, delay]).to(device) 44 | scale_delay = util.add_expand_dim(scale_delay, [0], [num_traj]) 45 | 46 | # Get params 47 | params = torch.Tensor([100, 200, 300, -100, -200, -2] * config.num_dof).to(device) 48 | params.requires_grad = True 49 | params = util.add_expand_dim(params, [0], [num_traj]) 50 | params = torch.cat([scale_delay, params], dim=-1).to(device) 51 | 52 | # Get times 53 | num_t = int(config.tau / config.mp_args.dt) * 2 + 1 54 | times = util.tensor_linspace(0, (tau + delay), num_t).squeeze(-1) 55 | times = util.add_expand_dim(times, [0], [num_traj]) 56 | times = times.to(device) 57 | # Get IC 58 | init_time = times[:, 0] 59 | init_pos = 5 * torch.ones([num_traj, config.num_dof]).to(device) 60 | init_vel = torch.zeros_like(init_pos).to(device) 61 | 62 | return config, params, times, init_time, init_pos, init_vel 63 | 64 | 65 | def speed_test(): 66 | device = 'cuda' 67 | 68 | # Get config 69 | config, params, times, init_time, init_pos, init_vel = get_mp_config() 70 | 71 | # Initialize the DMP and ProDMP 72 | config.mp_type = "dmp" 73 | dmp = MPFactory.init_mp(**config.to_dict(), device=device) 74 | config.mp_type = "prodmp" 75 | prodmp = MPFactory.init_mp(**config.to_dict(), device=device) 76 | 77 | def traj_gen_func_dmp(params): 78 | params += 0.01 79 | dmp.update_inputs(times=times, params=params, 80 | init_time=init_time, init_pos=init_pos + 0.01, 81 | init_vel=init_vel) 82 | 83 | dmp_pos = dmp.get_traj_pos() 84 | dmp_vel = dmp.get_traj_vel() 85 | 86 | def traj_gen_func_prodmp(params): 87 | params += 0.01 88 | prodmp.update_inputs(times=times, params=params, params_L=None, 89 | init_time=init_time, init_pos=init_pos + 0.01, 90 | init_vel=init_vel) 91 | 92 | prodmp_pos = prodmp.get_traj_pos() 93 | prodmp_vel = prodmp.get_traj_vel() 94 | 95 | # Get trajectory 96 | print("dmp: ") 97 | util.how_fast(100, traj_gen_func_dmp, params) 98 | print("prodmp: ") 99 | util.how_fast(100, traj_gen_func_prodmp, params) 100 | 101 | 102 | if __name__ == "__main__": 103 | torch.set_default_tensor_type('torch.cuda.FloatTensor') 104 | speed_test() 105 | -------------------------------------------------------------------------------- /test/test_quantitative.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from addict import Dict 3 | 4 | from mp_pytorch.mp import MPFactory 5 | from mp_pytorch import util 6 | 7 | 8 | def get_mp_config(): 9 | """ 10 | Get the config of MPs for testing 11 | 12 | Args: 13 | mp_type: "dmp" or "prodmp" 14 | 15 | Returns: 16 | config in dictionary 17 | """ 18 | 19 | torch.manual_seed(0) 20 | 21 | config = Dict() 22 | config.num_dof = 2 23 | config.tau = 3 24 | config.learn_tau = True 25 | config.learn_delay = True 26 | 27 | config.mp_args.num_basis = 9 28 | config.mp_args.basis_bandwidth_factor = 2 29 | config.mp_args.num_basis_outside = 0 30 | config.mp_args.alpha = 25 31 | config.mp_args.alpha_phase = 2 32 | config.mp_args.dt = 0.001 33 | 34 | # assume we have 3 trajectories in a batch 35 | num_traj = 3 36 | 37 | # Get trajectory scaling 38 | tau, delay = 4, 1 39 | scale_delay = torch.Tensor([tau, delay]) 40 | scale_delay = util.add_expand_dim(scale_delay, [0], [num_traj]) 41 | 42 | # Get params 43 | params = torch.Tensor([100, 200, 300, -100, -200, -300, 44 | 100, 200, 300, -2] * config.num_dof) 45 | params = util.add_expand_dim(params, [0], [num_traj]) 46 | params = torch.cat([scale_delay, params], dim=-1) 47 | 48 | # Get params_L 49 | diag = torch.Tensor([10, 20, 30, 10, 20, 30, 50 | 10, 20, 30, 4] * config.num_dof) 51 | off_diag = torch.linspace(-9.5, 9.4, 190) 52 | params_L = util.build_lower_matrix(diag, off_diag) 53 | params_L = util.add_expand_dim(params_L, [0], [num_traj]) 54 | 55 | # Get times 56 | num_t = int(config.tau / config.mp_args.dt) * 2 + 1 57 | times = util.tensor_linspace(0, (tau + delay), num_t).squeeze(-1) 58 | times = util.add_expand_dim(times, [0], [num_traj]) 59 | 60 | # Get IC 61 | init_time = times[:, 0] 62 | init_pos = 5 * torch.ones([num_traj, config.num_dof]) 63 | init_vel = torch.zeros_like(init_pos) 64 | 65 | return config, params, params_L, times, init_time, init_pos, init_vel 66 | 67 | 68 | def dmp_quantitative_test(plot=False): 69 | config, params, params_L, times, init_time, init_pos, init_vel = get_mp_config() 70 | config.mp_type = "dmp" 71 | dmp = MPFactory.init_mp(**config.to_dict()) 72 | dmp.update_inputs(times=times, params=params, 73 | init_time=init_time, init_pos=init_pos, init_vel=init_vel) 74 | pos = dmp.get_traj_pos() 75 | vel = dmp.get_traj_vel() 76 | 77 | if plot: 78 | util.debug_plot(x=None, y=[pos[0, :, 0]], title="DMP pos") 79 | util.debug_plot(x=None, y=[vel[0, :, 0]], title="DMP vel") 80 | 81 | # Quantitative testing 82 | assert torch.abs(pos[0, 100, 0] - 5) < 1e-9 83 | assert torch.abs(pos[0, 1000, 0] - 5) < 1e-9 84 | assert torch.abs(pos[0, 2000, 0] - 1.2169) < 3.71e-5 85 | assert torch.abs(pos[0, 3000, 0] + 0.9573) < 3.6e-5 86 | assert torch.abs(pos[0, 4000, 0] + 2.0863) < 3.78e-5 87 | assert torch.abs(pos[0, 5000, 0] + 2.2135) < 3.6e-5 88 | assert torch.abs(pos[0, 6000, 0] + 1.8863) < 1.4e-5 89 | return True 90 | 91 | 92 | def promp_quantitative_test(plot=False): 93 | config, params, params_L, times, init_time, init_pos, init_vel = get_mp_config() 94 | config.mp_type = "promp" 95 | 96 | # Fix the number of basis 97 | config.mp_args.num_basis += 1 98 | 99 | promp = MPFactory.init_mp(**config.to_dict()) 100 | 101 | promp.update_inputs(times=times, params=params, params_L=params_L, 102 | init_time=init_time, init_pos=init_pos, init_vel=init_vel) 103 | pos = promp.get_traj_pos() 104 | vel = promp.get_traj_vel() 105 | pos_flat = promp.get_traj_pos(flat_shape=True) 106 | pos_cov = promp.get_traj_pos_cov() 107 | mvn = torch.distributions.MultivariateNormal(loc=pos_flat, 108 | covariance_matrix=pos_cov, 109 | validate_args=False) 110 | 111 | if plot: 112 | util.debug_plot(x=None, y=[pos[0, :, 0]], title="ProMP pos") 113 | util.debug_plot(x=None, y=[vel[0, :, 0]], title="ProMP vel") 114 | 115 | # Quantitative testing 116 | assert torch.abs(pos[0, 100, 0] - 129.1609) < 4.6e-5 117 | assert torch.abs(pos[0, 1000, 0] - 129.1609) < 4.6e-5 118 | assert torch.abs(pos[0, 2000, 0] - 219.7397) < 4.6e-5 119 | assert torch.abs(pos[0, 3000, 0] + 111.4337) < 3.1e-5 120 | assert torch.abs(pos[0, 4000, 0] + 145.4950) < 3.1e-5 121 | assert torch.abs(pos[0, 5000, 0] - 203.8375) < 4.6e-5 122 | assert torch.abs(pos[0, 6000, 0] - 80.8178) < 3.82 123 | 124 | assert torch.abs(mvn.log_prob(pos_flat)[0] - 801.7334) < 1e-1 125 | return True 126 | 127 | 128 | def prodmp_quantitative_test(plot=True): 129 | config, params, params_L, times, init_time, init_pos, init_vel = get_mp_config() 130 | config.mp_type = "prodmp" 131 | prodmp = MPFactory.init_mp(**config.to_dict()) 132 | prodmp.update_inputs(times=times, params=params, params_L=params_L, 133 | init_time=init_time, init_pos=init_pos, init_vel=init_vel) 134 | pos = prodmp.get_traj_pos() 135 | vel = prodmp.get_traj_vel() 136 | pos_flat = prodmp.get_traj_pos(flat_shape=True) 137 | pos_cov = prodmp.get_traj_pos_cov() 138 | mvn = torch.distributions.MultivariateNormal(loc=pos_flat, 139 | covariance_matrix=pos_cov, 140 | validate_args=False) 141 | 142 | if plot: 143 | util.debug_plot(x=None, y=[pos[0, :, 0]], title="ProDMP pos") 144 | util.debug_plot(x=None, y=[vel[0, :, 0]], title="ProDMP vel") 145 | 146 | # Quantitative testing 147 | assert torch.abs(pos[0, 100, 0] - 5) < 1e-9 148 | assert torch.abs(pos[0, 1000, 0] - 5) < 1e-9 149 | assert torch.abs(pos[0, 2000, 0] - 1.2203) < 4.37e-5 150 | assert torch.abs(pos[0, 3000, 0] + 0.9576) < 3.9e-5 151 | assert torch.abs(pos[0, 4000, 0] + 2.0867) < 3.56e-5 152 | assert torch.abs(pos[0, 5000, 0] + 2.2139) < 2.6e-5 153 | assert torch.abs(pos[0, 6000, 0] + 1.8863) < 4e-5 154 | 155 | assert torch.abs(mvn.log_prob(pos_flat)[0] - 774.3701) < 6.11e-5 156 | return True 157 | 158 | 159 | if __name__ == "__main__": 160 | dmp_quantitative_test(plot=True) 161 | promp_quantitative_test(plot=True) 162 | prodmp_quantitative_test(plot=True) 163 | --------------------------------------------------------------------------------