├── .gitignore ├── LICENSE ├── README.md ├── demos ├── occupancy.gif ├── panda.gif └── sdf_grid.gif ├── examples ├── mpot_occupancy.py ├── mpot_panda.py └── mpot_sdf.py ├── mpot ├── __init__.py ├── costs.py ├── envs │ ├── __init__.py │ ├── map_generator.py │ ├── obst_map.py │ ├── obst_utils.py │ └── occupancy.py ├── gp │ ├── LICENSE │ ├── __init__.py │ ├── field_factor.py │ ├── gp_factor.py │ ├── gp_prior.py │ └── unary_factor.py ├── ot │ ├── __init__.py │ ├── initializer.py │ ├── problem.py │ ├── sinkhorn.py │ └── sinkhorn_step.py ├── planner.py └── utils │ ├── __init__.py │ ├── misc.py │ ├── polytopes.py │ ├── probe.py │ ├── rotation.py │ └── trajectory.py ├── requirements.txt └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | data/ 162 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 An Thai Le 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Accelerating Motion Planning via Optimal Transport 2 | 3 | This repository implements Motion Planning via Optimal Transport `mpot` in PyTorch. 4 | The philosophy of `mpot` follows the Monte Carlo methods' argument, i.e., more samples could discover more better modes with high enough initialization variances. 5 | In other words, within the multi-modal motion planning scope, `mpot` enables better **brute-force** planning with GPU vectorization. This enhances robustness against bad local minima, a common issue in optimization-based motion planning. 6 | 7 |

8 | 9 | 10 | 11 |

12 | 13 | For those interested in standalone Sinkhorn Step as a general-purpose batch gradient-free solver for non-convex optimization problems, please check out [ssax](https://github.com/anindex/ssax). 14 | 15 | ## Paper Preprint 16 | 17 | This work has been accepted to NeurIPS 2023. Please find the pre-print here: 18 | 19 | [](https://www.ias.informatik.tu-darmstadt.de/uploads/Team/AnThaiLe/mpot_preprint.pdf) 20 | 21 | ## Installation 22 | 23 | Simply activate your conda/Python environment, navigate to `mpot` root directory and run 24 | 25 | ```azure 26 | pip install -e . 27 | ``` 28 | 29 | `mpot` algorithm is specifically designed to work with GPU. Please check if you have installed PyTorch with the CUDA option. 30 | 31 | ## Examples 32 | 33 | Please find in `examples/` folder the demo of vectorized planning in planar environments with occupancy map: 34 | 35 | ```azure 36 | python examples/mpot_occupancy.py 37 | ``` 38 | 39 | and with signed-distance-field (SDF): 40 | 41 | ```azure 42 | python examples/mpot_sdf.py 43 | ``` 44 | 45 | We also added a demo with vectorized Panda planning with dense obstacle environments (SDF): 46 | 47 | ```azure 48 | python examples/mpot_panda.py 49 | ``` 50 | 51 | Every run is associated with **a different seed**. The resulting optimization visualizations are stored at your current directory. 52 | Please refer to the example scripts for playing around with options and different goal points. Note that for all cases, we normalize the joint space to the joint limits and velocity limits, then perform Sinkhorn Step on the normalized state-space. Changing any hyperparameters may require tuning again. 53 | 54 | **Tuning Tips**: The most sensitive parameters are: 55 | 56 | - `polytope`: for small state-dimension that is less than 10, `cube` is a good choice. For much higer state-dimension, the sensible choices are `orthoplex` or `simplex`. 57 | - `step_radius`: the step size. 58 | - `probe_radius`: the probing radius, which projects towards polytope vertices to compute cost-to-go. Note, `probe_radius` >= `step_radius`. 59 | - `num_probe`: number of probing points along the probe radius. This is critical for optimizing performance, usually 3-5 is enough. 60 | - `epsilon`: decay rate of the step/probe size, usually 0.01-0.05. 61 | - `ent_epsilon`: Sinkhorn entropy regularization, usually 1e-2 to 5e-2 for balancing between optimal coupling's sharpness and speed. 62 | - Various cost term weightings. This depends on your applications. 63 | 64 | ## Troubleshooting 65 | 66 | If you encounter memory problems, try: 67 | 68 | ```azure 69 | export 'PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:512' 70 | ``` 71 | 72 | to reduce memory fragmentation. 73 | 74 | ## Acknowledgement 75 | 76 | The Gaussian Process prior implementation is adapted from Sasha Lambert's [`mpc_trajopt`](https://github.com/sashalambert/mpc_trajopt/blob/main/mpc_trajopt/factors/gp_factor.py). 77 | 78 | ## Citation 79 | 80 | If you found this repository useful, please consider citing these references: 81 | 82 | ```azure 83 | @inproceedings{le2023accelerating, 84 | title={Accelerating Motion Planning via Optimal Transport}, 85 | author={Le, An T. and Chalvatzaki, Georgia and Biess, Armin and Peters, Jan}, 86 | booktitle={Advances in Neural Information Processing Systems (NeurIPS)}, 87 | year={2023} 88 | } 89 | -------------------------------------------------------------------------------- /demos/occupancy.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anindex/mpot/a5c49d4fa83d7601bcd91246e07e254b813dc322/demos/occupancy.gif -------------------------------------------------------------------------------- /demos/panda.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anindex/mpot/a5c49d4fa83d7601bcd91246e07e254b813dc322/demos/panda.gif -------------------------------------------------------------------------------- /demos/sdf_grid.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anindex/mpot/a5c49d4fa83d7601bcd91246e07e254b813dc322/demos/sdf_grid.gif -------------------------------------------------------------------------------- /examples/mpot_occupancy.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | import time 4 | import matplotlib.pyplot as plt 5 | import torch 6 | from einops._torch_specific import allow_ops_in_compiled_graph # requires einops>=0.6.1 7 | 8 | from mpot.ot.problem import Epsilon 9 | from mpot.ot.sinkhorn import Sinkhorn 10 | from mpot.planner import MPOT 11 | from mpot.costs import CostGPHolonomic, CostField, CostComposite 12 | from mpot.envs.occupancy import EnvOccupancy2D 13 | from mpot.utils.trajectory import interpolate_trajectory 14 | 15 | from torch_robotics.robots.robot_point_mass import RobotPointMass 16 | from torch_robotics.torch_utils.seed import fix_random_seed 17 | from torch_robotics.torch_utils.torch_timer import TimerCUDA 18 | from torch_robotics.torch_utils.torch_utils import get_torch_device 19 | from torch_robotics.tasks.tasks import PlanningTask 20 | from torch_robotics.visualizers.planning_visualizer import PlanningVisualizer 21 | 22 | allow_ops_in_compiled_graph() 23 | 24 | 25 | if __name__ == "__main__": 26 | seed = int(time.time()) 27 | fix_random_seed(seed) 28 | 29 | device = get_torch_device() 30 | tensor_args = {'device': device, 'dtype': torch.float32} 31 | 32 | # ---------------------------- Environment, Robot, PlanningTask --------------------------------- 33 | q_limits = torch.tensor([[-10, -10], [10, 10]], **tensor_args) 34 | env = EnvOccupancy2D( 35 | precompute_sdf_obj_fixed=False, 36 | tensor_args=tensor_args 37 | ) 38 | 39 | robot = RobotPointMass( 40 | q_limits=q_limits, # joint limits 41 | tensor_args=tensor_args 42 | ) 43 | 44 | task = PlanningTask( 45 | env=env, 46 | robot=robot, 47 | ws_limits=q_limits, # workspace limits 48 | obstacle_cutoff_margin=0.05, 49 | tensor_args=tensor_args 50 | ) 51 | 52 | # -------------------------------- Params --------------------------------- 53 | # NOTE: these parameters are tuned for this environment 54 | step_radius = 0.15 55 | probe_radius = 0.15 # probe radius >= step radius 56 | 57 | # NOTE: changing polytope may require tuning again 58 | polytope = 'cube' # 'simplex' | 'orthoplex' | 'cube'; 59 | 60 | epsilon = 0.01 61 | ent_epsilon = Epsilon(1e-2) 62 | num_probe = 5 # number of probes points for each polytope vertices 63 | num_particles_per_goal = 33 # number of plans per goal 64 | pos_limits = [-10, 10] 65 | vel_limits = [-10, 10] 66 | w_coll = 5e-3 # for tuning the obstacle cost 67 | w_smooth = 1e-7 # for tuning the GP cost: error = w_smooth * || Phi x(t) - x(1+1) ||^2 68 | sigma_gp = 0.1 # for tuning the GP cost: Q_c = sigma_gp^2 * I 69 | sigma_gp_init = 1.6 # for controlling the initial GP variance: Q0_c = sigma_gp_init^2 * I 70 | max_inner_iters = 100 # max inner iterations for Sinkhorn-Knopp 71 | max_outer_iters = 100 # max outer iterations for MPOT 72 | 73 | start_state = torch.tensor([-9, -9, 0., 0.], **tensor_args) 74 | 75 | # NOTE: change goal states here (zero vel goals) 76 | multi_goal_states = torch.tensor([ 77 | [0, 9, 0., 0.], 78 | [9, 9, 0., 0.], 79 | [9, 0, 0., 0.] 80 | ], **tensor_args) 81 | 82 | traj_len = 64 83 | dt = 0.1 84 | 85 | #--------------------------------- Cost function --------------------------------- 86 | 87 | cost_coll = CostField( 88 | robot, traj_len, 89 | field=env.occupancy_map, 90 | sigma_coll=1.0, 91 | tensor_args=tensor_args 92 | ) 93 | cost_gp = CostGPHolonomic(robot, traj_len, dt, sigma_gp, [0, 1], weight=w_smooth, tensor_args=tensor_args) 94 | cost_func_list = [cost_coll, cost_gp] 95 | weights_cost_l = [w_coll, w_smooth] 96 | cost = CostComposite( 97 | robot, traj_len, cost_func_list, 98 | weights_cost_l=weights_cost_l, 99 | tensor_args=tensor_args 100 | ) 101 | 102 | #--------------------------------- MPOT Init --------------------------------- 103 | 104 | linear_ot_solver = Sinkhorn( 105 | threshold=1e-6, 106 | inner_iterations=1, 107 | max_iterations=max_inner_iters, 108 | ) 109 | ss_params = dict( 110 | epsilon=epsilon, 111 | ent_epsilon=ent_epsilon, 112 | step_radius=step_radius, 113 | probe_radius=probe_radius, 114 | num_probe=num_probe, 115 | min_iterations=5, 116 | max_iterations=max_outer_iters, 117 | threshold=2e-3, 118 | store_history=True, 119 | tensor_args=tensor_args, 120 | ) 121 | 122 | mpot_params = dict( 123 | objective_fn=cost, 124 | linear_ot_solver=linear_ot_solver, 125 | ss_params=ss_params, 126 | dim=2, 127 | traj_len=traj_len, 128 | num_particles_per_goal=num_particles_per_goal, 129 | dt=dt, 130 | start_state=start_state, 131 | multi_goal_states=multi_goal_states, 132 | pos_limits=pos_limits, 133 | vel_limits=vel_limits, 134 | polytope=polytope, 135 | fixed_goal=True, 136 | sigma_start_init=0.001, 137 | sigma_goal_init=0.001, 138 | sigma_gp_init=sigma_gp_init, 139 | seed=seed, 140 | tensor_args=tensor_args, 141 | ) 142 | planner = MPOT(**mpot_params) 143 | 144 | #--------------------------------- Optimize --------------------------------- 145 | 146 | with TimerCUDA() as t: 147 | trajs, optim_state, opt_iters = planner.optimize() 148 | int_trajs = interpolate_trajectory(trajs, num_interpolation=3) 149 | colls = env.occupancy_map.get_collisions(int_trajs[..., :2]).any(dim=1) 150 | sinkhorn_iters = optim_state.linear_convergence[:opt_iters] 151 | print(f'Optimization finished at {opt_iters}! Parallelization Quality (GOOD [%]): {(1 - colls.float().mean()) * 100:.2f}') 152 | print(f'Time(s) optim: {t.elapsed} sec') 153 | print(f'Average Sinkhorn Iterations: {sinkhorn_iters.mean():.2f}, min: {sinkhorn_iters.min():.2f}, max: {sinkhorn_iters.max():.2f}') 154 | 155 | # -------------------------------- Visualize --------------------------------- 156 | planner_visualizer = PlanningVisualizer( 157 | task=task, 158 | planner=planner 159 | ) 160 | 161 | traj_history = optim_state.X_history[:opt_iters] 162 | traj_history = traj_history.view(opt_iters, -1, traj_len, 4) 163 | base_file_name = Path(os.path.basename(__file__)).stem 164 | pos_trajs_iters = robot.get_position(traj_history) 165 | 166 | planner_visualizer.animate_opt_iters_joint_space_state( 167 | trajs=traj_history, 168 | pos_start_state=start_state, 169 | vel_start_state=torch.zeros_like(start_state), 170 | video_filepath=f'{base_file_name}-joint-space-opt-iters.mp4', 171 | n_frames=max((2, opt_iters // 5)), 172 | anim_time=5 173 | ) 174 | 175 | planner_visualizer.animate_opt_iters_robots( 176 | trajs=pos_trajs_iters, start_state=start_state, 177 | video_filepath=f'{base_file_name}-traj-opt-iters.mp4', 178 | n_frames=max((2, opt_iters // 5)), 179 | anim_time=5 180 | ) 181 | 182 | plt.show() 183 | -------------------------------------------------------------------------------- /examples/mpot_panda.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | from pathlib import Path 4 | import time 5 | import matplotlib.pyplot as plt 6 | import torch 7 | from einops._torch_specific import allow_ops_in_compiled_graph # requires einops>=0.6.1 8 | 9 | from mpot.ot.problem import Epsilon 10 | from mpot.ot.sinkhorn import Sinkhorn 11 | from mpot.planner import MPOT 12 | from mpot.costs import CostGPHolonomic, CostField, CostComposite 13 | 14 | from torch_robotics.environments.env_spheres_3d import EnvSpheres3D 15 | from torch_robotics.robots.robot_panda import RobotPanda 16 | from torch_robotics.tasks.tasks import PlanningTask 17 | from torch_robotics.torch_utils.seed import fix_random_seed 18 | from torch_robotics.torch_utils.torch_timer import TimerCUDA 19 | from torch_robotics.torch_utils.torch_utils import get_torch_device 20 | from torch_robotics.visualizers.planning_visualizer import PlanningVisualizer 21 | 22 | allow_ops_in_compiled_graph() 23 | 24 | 25 | if __name__ == "__main__": 26 | base_file_name = Path(os.path.basename(__file__)).stem 27 | 28 | seed = int(time.time()) 29 | fix_random_seed(seed) 30 | 31 | device = get_torch_device() 32 | tensor_args = {'device': device, 'dtype': torch.float32} 33 | 34 | # ---------------------------- Environment, Robot, PlanningTask --------------------------------- 35 | env = EnvSpheres3D( 36 | precompute_sdf_obj_fixed=False, 37 | sdf_cell_size=0.01, 38 | tensor_args=tensor_args 39 | ) 40 | 41 | robot = RobotPanda( 42 | use_self_collision_storm=False, 43 | tensor_args=tensor_args 44 | ) 45 | 46 | task = PlanningTask( 47 | env=env, 48 | robot=robot, 49 | ws_limits=torch.tensor([[-1.5, -1.5, -1.5], [1.5, 1.5, 1.5]], **tensor_args), # workspace limits 50 | obstacle_cutoff_margin=0.03, 51 | tensor_args=tensor_args 52 | ) 53 | 54 | # -------------------------------- Params --------------------------------- 55 | for _ in range(100): 56 | q_free = task.random_coll_free_q(n_samples=2) 57 | start_state = q_free[0] 58 | goal_state = q_free[1] 59 | 60 | # check if the EE positions are "enough" far apart 61 | start_state_ee_pos = robot.get_EE_position(start_state).squeeze() 62 | goal_state_ee_pos = robot.get_EE_position(goal_state).squeeze() 63 | 64 | if torch.linalg.norm(start_state_ee_pos - goal_state_ee_pos) > 0.5: 65 | break 66 | 67 | # start_state = torch.tensor([1.0403, 0.0493, 0.0251, -1.2673, 1.6676, 3.3611, -1.5428], **tensor_args) 68 | # goal_state = torch.tensor([1.1142, 1.7289, -0.1771, -0.9284, 2.7171, 1.2497, 1.7724], **tensor_args) 69 | 70 | print('Start state: ', start_state) 71 | print('Goal state: ', goal_state) 72 | start_state = torch.concatenate((start_state, torch.zeros_like(start_state))) 73 | goal_state = torch.concatenate((goal_state, torch.zeros_like(goal_state))) 74 | multi_goal_states = goal_state.unsqueeze(0) 75 | 76 | # Construct planner 77 | duration = 5 # sec 78 | traj_len = 64 79 | dt = duration / traj_len 80 | num_particles_per_goal = 10 # number of plans per goal # NOTE: if memory is not enough, reduce this number 81 | 82 | # NOTE: these parameters are tuned for this environment 83 | step_radius = 0.03 84 | probe_radius = 0.08 # probe radius >= step radius 85 | 86 | # NOTE: changing polytope may require tuning again 87 | # NOTE: cube in this case could lead to memory insufficiency, depending how many plans are optimized 88 | polytope = 'cube' # 'simplex' | 'orthoplex' | 'cube'; 89 | 90 | epsilon = 0.02 91 | ent_epsilon = Epsilon(1e-2) 92 | num_probe = 3 # number of probes points for each polytope vertices 93 | # panda joint limits 94 | q_max = torch.tensor([2.7437, 1.7837, 2.9007, -0.1518, 2.8065, 4.5169, 3.0159], **tensor_args) 95 | q_min = torch.tensor([-2.7437, -1.7837, -2.9007, -3.0421, -2.8065, 0.5445, -3.0159], **tensor_args) 96 | pos_limits = torch.stack([q_min, q_max], dim=1) 97 | vel_limits = [-5, 5] 98 | w_coll = 1e-1 # for tuning the obstacle cost 99 | w_smooth = 1e-7 # for tuning the GP cost: error = w_smooth * || Phi x(t) - x(1+1) ||^2 100 | sigma_gp = 0.01 # for tuning the GP cost: Q_c = sigma_gp^2 * I 101 | sigma_gp_init = 0.5 # for controlling the initial GP variance: Q0_c = sigma_gp_init^2 * I 102 | max_inner_iters = 100 # max inner iterations for Sinkhorn-Knopp 103 | max_outer_iters = 40 # max outer iterations for MPOT 104 | 105 | #--------------------------------- Cost function --------------------------------- 106 | 107 | cost_func_list = [] 108 | weights_cost_l = [] 109 | for collision_field in task.get_collision_fields(): 110 | cost_func_list.append( 111 | CostField( 112 | robot, traj_len, 113 | field=collision_field, 114 | sigma_coll=1.0, 115 | tensor_args=tensor_args 116 | ) 117 | ) 118 | weights_cost_l.append(w_coll) 119 | cost_gp = CostGPHolonomic(robot, traj_len, dt, sigma_gp, [0, 1], weight=w_smooth, tensor_args=tensor_args) 120 | cost_func_list.append(cost_gp) 121 | weights_cost_l.append(w_smooth) 122 | cost = CostComposite( 123 | robot, traj_len, cost_func_list, 124 | weights_cost_l=weights_cost_l, 125 | tensor_args=tensor_args 126 | ) 127 | 128 | #--------------------------------- MPOT Init --------------------------------- 129 | 130 | linear_ot_solver = Sinkhorn( 131 | threshold=1e-3, 132 | inner_iterations=1, 133 | max_iterations=max_inner_iters, 134 | ) 135 | ss_params = dict( 136 | epsilon=epsilon, 137 | ent_epsilon=ent_epsilon, 138 | step_radius=step_radius, 139 | probe_radius=probe_radius, 140 | num_probe=num_probe, 141 | min_iterations=5, 142 | max_iterations=max_outer_iters, 143 | threshold=2e-3, 144 | store_history=True, 145 | tensor_args=tensor_args, 146 | ) 147 | 148 | mpot_params = dict( 149 | objective_fn=cost, 150 | linear_ot_solver=linear_ot_solver, 151 | ss_params=ss_params, 152 | dim=7, 153 | traj_len=traj_len, 154 | num_particles_per_goal=num_particles_per_goal, 155 | dt=dt, 156 | start_state=start_state, 157 | multi_goal_states=multi_goal_states, 158 | pos_limits=pos_limits, 159 | vel_limits=vel_limits, 160 | polytope=polytope, 161 | fixed_goal=True, 162 | sigma_start_init=0.001, 163 | sigma_goal_init=0.001, 164 | sigma_gp_init=sigma_gp_init, 165 | seed=seed, 166 | tensor_args=tensor_args, 167 | ) 168 | planner = MPOT(**mpot_params) 169 | 170 | # Optimize 171 | with TimerCUDA() as t: 172 | trajs, optim_state, opt_iters = planner.optimize() 173 | sinkhorn_iters = optim_state.linear_convergence[:opt_iters] 174 | print(f'Optimization finished at {opt_iters}! Optimization time: {t.elapsed:.3f} sec') 175 | print(f'Average Sinkhorn Iterations: {sinkhorn_iters.mean():.2f}, min: {sinkhorn_iters.min():.2f}, max: {sinkhorn_iters.max():.2f}') 176 | 177 | # -------------------------------- Visualize --------------------------------- 178 | planner_visualizer = PlanningVisualizer( 179 | task=task, 180 | planner=planner 181 | ) 182 | 183 | traj_history = optim_state.X_history[:opt_iters] 184 | traj_history = traj_history.view(opt_iters, -1, traj_len, 14) # 7 + 7 185 | pos_trajs_iters = robot.get_position(traj_history) 186 | trajs = trajs.flatten(0, 1) 187 | trajs_coll, trajs_free = task.get_trajs_collision_and_free(trajs) 188 | 189 | planner_visualizer.animate_opt_iters_joint_space_state( 190 | trajs=traj_history, 191 | pos_start_state=start_state, pos_goal_state=goal_state, 192 | vel_start_state=torch.zeros_like(start_state), vel_goal_state=torch.zeros_like(goal_state), 193 | video_filepath=f'{base_file_name}-joint-space-opt-iters.mp4', 194 | n_frames=max((2, opt_iters // 2)), 195 | anim_time=5 196 | ) 197 | 198 | if trajs_free is not None: 199 | planner_visualizer.animate_robot_trajectories( 200 | trajs=trajs_free, start_state=start_state, goal_state=goal_state, 201 | plot_trajs=False, 202 | draw_links_spheres=False, 203 | video_filepath=f'{base_file_name}-robot-traj.mp4', 204 | # n_frames=max((2, pos_trajs_iters[-1].shape[1]//10)), 205 | n_frames=trajs_free.shape[-2], 206 | anim_time=duration 207 | ) 208 | 209 | plt.show() 210 | -------------------------------------------------------------------------------- /examples/mpot_sdf.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | import time 4 | import matplotlib.pyplot as plt 5 | import torch 6 | from einops._torch_specific import allow_ops_in_compiled_graph # requires einops>=0.6.1 7 | 8 | from mpot.ot.problem import Epsilon 9 | from mpot.ot.sinkhorn import Sinkhorn 10 | from mpot.planner import MPOT 11 | from mpot.costs import CostGPHolonomic, CostField, CostComposite 12 | 13 | from torch_robotics.environments.env_grid_circles_2d import EnvGridCircles2D 14 | from torch_robotics.robots.robot_point_mass import RobotPointMass 15 | from torch_robotics.tasks.tasks import PlanningTask 16 | from torch_robotics.torch_utils.seed import fix_random_seed 17 | from torch_robotics.torch_utils.torch_timer import TimerCUDA 18 | from torch_robotics.torch_utils.torch_utils import get_torch_device 19 | from torch_robotics.visualizers.planning_visualizer import PlanningVisualizer 20 | 21 | allow_ops_in_compiled_graph() 22 | 23 | 24 | if __name__ == "__main__": 25 | seed = int(time.time()) 26 | fix_random_seed(seed) 27 | 28 | device = get_torch_device() 29 | tensor_args = {'device': device, 'dtype': torch.float32} 30 | 31 | # ---------------------------- Environment, Robot, PlanningTask --------------------------------- 32 | env = EnvGridCircles2D( 33 | precompute_sdf_obj_fixed=False, 34 | sdf_cell_size=0.001, 35 | tensor_args=tensor_args 36 | ) 37 | 38 | robot = RobotPointMass( 39 | q_limits=torch.tensor([[-1, -1], [1, 1]], **tensor_args), # joint limits 40 | tensor_args=tensor_args 41 | ) 42 | 43 | task = PlanningTask( 44 | env=env, 45 | robot=robot, 46 | ws_limits=torch.tensor([[-0.81, -0.81], [0.95, 0.95]], **tensor_args), # workspace limits 47 | obstacle_cutoff_margin=0.005, 48 | tensor_args=tensor_args 49 | ) 50 | 51 | # -------------------------------- Params --------------------------------- 52 | 53 | # NOTE: these parameters are tuned for this environment 54 | step_radius = 0.15 55 | probe_radius = 0.15 # probe radius >= step radius 56 | 57 | # NOTE: changing polytope may require tuning again 58 | polytope = 'cube' # 'simplex' | 'orthoplex' | 'cube'; 59 | 60 | epsilon = 0.01 61 | ent_epsilon = Epsilon(1e-2) 62 | num_probe = 5 # number of probes points for each polytope vertices 63 | num_particles_per_goal = 50 # number of plans per goal # NOTE: if memory is not enough, reduce this number 64 | pos_limits = [-1, 1] 65 | vel_limits = [-1, 1] 66 | w_coll = 7e-2 # for tuning the obstacle cost 67 | w_smooth = 1e-7 # for tuning the GP cost: error = w_smooth * || Phi x(t) - x(1+1) ||^2 68 | sigma_gp = 0.03 # for tuning the GP cost: Q_c = sigma_gp^2 * I 69 | sigma_gp_init = 0.8 # for controlling the initial GP variance: Q0_c = sigma_gp_init^2 * I 70 | max_inner_iters = 100 # max inner iterations for Sinkhorn-Knopp 71 | max_outer_iters = 70 # max outer iterations for MPOT 72 | start_state = torch.tensor([-0.8, -0.8, 0., 0.], **tensor_args) 73 | multi_goal_states = torch.tensor([ 74 | [0, 0.75, 0., 0.], 75 | [0.75, 0.75, 0., 0.], 76 | [0.75, 0, 0., 0.] 77 | ], **tensor_args) 78 | 79 | traj_len = 64 80 | dt = 0.04 81 | 82 | #--------------------------------- Cost function --------------------------------- 83 | 84 | cost_coll = CostField( 85 | robot, traj_len, 86 | field=task.df_collision_objects, 87 | sigma_coll=1.0, 88 | tensor_args=tensor_args 89 | ) 90 | cost_gp = CostGPHolonomic(robot, traj_len, dt, sigma_gp, [0, 1], weight=w_smooth, tensor_args=tensor_args) 91 | cost_func_list = [cost_coll, cost_gp] 92 | weights_cost_l = [w_coll, w_smooth] 93 | cost = CostComposite( 94 | robot, traj_len, cost_func_list, 95 | weights_cost_l=weights_cost_l, 96 | tensor_args=tensor_args 97 | ) 98 | 99 | #--------------------------------- MPOT Init --------------------------------- 100 | 101 | linear_ot_solver = Sinkhorn( 102 | threshold=1e-4, 103 | inner_iterations=1, 104 | max_iterations=max_inner_iters, 105 | ) 106 | ss_params = dict( 107 | epsilon=epsilon, 108 | ent_epsilon=ent_epsilon, 109 | step_radius=step_radius, 110 | probe_radius=probe_radius, 111 | num_probe=num_probe, 112 | min_iterations=5, 113 | max_iterations=max_outer_iters, 114 | threshold=1e-3, 115 | store_history=True, 116 | tensor_args=tensor_args, 117 | ) 118 | 119 | mpot_params = dict( 120 | objective_fn=cost, 121 | linear_ot_solver=linear_ot_solver, 122 | ss_params=ss_params, 123 | dim=2, 124 | traj_len=traj_len, 125 | num_particles_per_goal=num_particles_per_goal, 126 | dt=dt, 127 | start_state=start_state, 128 | multi_goal_states=multi_goal_states, 129 | pos_limits=pos_limits, 130 | vel_limits=vel_limits, 131 | polytope=polytope, 132 | fixed_goal=True, 133 | sigma_start_init=0.001, 134 | sigma_goal_init=0.001, 135 | sigma_gp_init=sigma_gp_init, 136 | seed=seed, 137 | tensor_args=tensor_args, 138 | ) 139 | planner = MPOT(**mpot_params) 140 | 141 | # Optimize 142 | with TimerCUDA() as t: 143 | trajs, optim_state, opt_iters = planner.optimize() 144 | sinkhorn_iters = optim_state.linear_convergence[:opt_iters] 145 | print(f'Optimization finished at {opt_iters}! Optimization time: {t.elapsed:.3f} sec') 146 | print(f'Average Sinkhorn Iterations: {sinkhorn_iters.mean():.2f}, min: {sinkhorn_iters.min():.2f}, max: {sinkhorn_iters.max():.2f}') 147 | 148 | # -------------------------------- Visualize --------------------------------- 149 | planner_visualizer = PlanningVisualizer( 150 | task=task, 151 | planner=planner 152 | ) 153 | 154 | traj_history = optim_state.X_history[:opt_iters] 155 | traj_history = traj_history.view(opt_iters, -1, traj_len, 4) 156 | base_file_name = Path(os.path.basename(__file__)).stem 157 | pos_trajs_iters = robot.get_position(traj_history) 158 | 159 | planner_visualizer.animate_opt_iters_joint_space_state( 160 | trajs=traj_history, 161 | pos_start_state=start_state, 162 | vel_start_state=torch.zeros_like(start_state), 163 | video_filepath=f'{base_file_name}-joint-space-opt-iters.mp4', 164 | n_frames=max((2, opt_iters // 5)), 165 | anim_time=5 166 | ) 167 | 168 | planner_visualizer.animate_opt_iters_robots( 169 | trajs=pos_trajs_iters, start_state=start_state, 170 | video_filepath=f'{base_file_name}-traj-opt-iters.mp4', 171 | n_frames=max((2, opt_iters// 5)), 172 | anim_time=5 173 | ) 174 | 175 | plt.show() 176 | -------------------------------------------------------------------------------- /mpot/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anindex/mpot/a5c49d4fa83d7601bcd91246e07e254b813dc322/mpot/__init__.py -------------------------------------------------------------------------------- /mpot/costs.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Any, Optional, Tuple, List, Callable 3 | import torch 4 | import einops 5 | 6 | from mpot.gp.field_factor import FieldFactor 7 | 8 | 9 | class Cost(ABC): 10 | def __init__(self, robot, n_support_points, tensor_args=None, **kwargs): 11 | self.robot = robot 12 | self.n_dof = robot.q_dim 13 | self.dim = 2 * self.n_dof # position + velocity 14 | self.n_support_points = n_support_points 15 | 16 | self.tensor_args = tensor_args 17 | 18 | def set_cost_factors(self): 19 | pass 20 | 21 | def __call__(self, trajs, **kwargs): 22 | return self.eval(trajs, **kwargs) 23 | 24 | @abstractmethod 25 | def eval(self, trajs, **kwargs): 26 | pass 27 | 28 | def get_q_pos_vel_and_fk_map(self, trajs, **kwargs): 29 | assert trajs.ndim == 3 or trajs.ndim == 4 30 | N = 1 31 | if trajs.ndim == 4: 32 | N, B, H, D = trajs.shape # n_goals (or steps), batch of trajectories, length, dim 33 | trajs = einops.rearrange(trajs, 'N B H D -> (N B) H D') 34 | else: 35 | B, H, D = trajs.shape 36 | 37 | q_pos = self.robot.get_position(trajs) 38 | q_vel = self.robot.get_velocity(trajs) 39 | H_positions = self.robot.fk_map_collision(q_pos) # I, taskspaces, x_dim+1, x_dim+1 (homogeneous transformation matrices) 40 | return trajs, q_pos, q_vel, H_positions 41 | 42 | 43 | class CostField(Cost): 44 | 45 | def __init__( 46 | self, 47 | robot, 48 | n_support_points: int, 49 | field: Callable = None, 50 | sigma: float = 1.0, 51 | **kwargs 52 | ): 53 | super().__init__(robot, n_support_points, **kwargs) 54 | self.field = field 55 | self.sigma = sigma 56 | 57 | self.set_cost_factors() 58 | 59 | def set_cost_factors(self): 60 | # ========= Cost factors =============== 61 | self.obst_factor = FieldFactor( 62 | self.n_dof, 63 | self.sigma, 64 | [0, None] 65 | ) 66 | 67 | def cost(self, trajs: torch.Tensor, **observation) -> torch.Tensor: 68 | traj_dim = observation.get('traj_dim', None) 69 | if self.field is None: 70 | return 0 71 | trajs = trajs.view(traj_dim) # [..., traj_len, dim] 72 | states = trajs[..., :self.n_dof] 73 | field_cost = self.field.compute_cost(states, **observation) 74 | return field_cost.view(traj_dim[:-1]).mean(-1) # mean the traj_len 75 | 76 | def eval(self, trajs: torch.Tensor, q_pos=None, q_vel=None, H_positions=None, **observation): 77 | optim_dim = observation.get('optim_dim') 78 | costs = 0 79 | if self.field is not None: 80 | # H_pos = link_pos_from_link_tensor(H) # get translation part from transformation matrices 81 | H_pos = H_positions 82 | err_obst = self.obst_factor.get_error( 83 | trajs, 84 | self.field, 85 | q_pos=q_pos, 86 | q_vel=q_vel, 87 | H_pos=H_pos, 88 | calc_jacobian=False 89 | ) 90 | w_mat = self.obst_factor.K 91 | obst_costs = w_mat * err_obst.mean(-1) 92 | costs = obst_costs.reshape(optim_dim[:2]) 93 | 94 | return costs 95 | 96 | 97 | class CostGPHolonomic(Cost): 98 | 99 | def __init__( 100 | self, 101 | robot, 102 | n_support_points: int, 103 | dt: float, 104 | sigma: float, 105 | probe_range: Tuple[int, int], 106 | Q_c_inv: torch.Tensor = None, 107 | **kwargs 108 | ): 109 | super().__init__(robot, n_support_points, **kwargs) 110 | self.dt = dt 111 | self.phi = self.calc_phi() 112 | self.phi_T = self.phi.T 113 | if Q_c_inv is None: 114 | Q_c_inv = torch.eye(self.n_dof, **self.tensor_args) / sigma**2 115 | self.Q_c_inv = torch.zeros(self.n_support_points - 1, self.n_dof, self.n_dof, **self.tensor_args) + Q_c_inv 116 | self.Q_inv = self.calc_Q_inv() 117 | self.single_Q_inv = self.Q_inv[[0]] 118 | self.probe_range = probe_range 119 | 120 | def calc_phi(self) -> torch.Tensor: 121 | I = torch.eye(self.n_dof, **self.tensor_args) 122 | Z = torch.zeros(self.n_dof, self.n_dof, **self.tensor_args) 123 | phi_u = torch.cat((I, self.dt * I), dim=1) 124 | phi_l = torch.cat((Z, I), dim=1) 125 | phi = torch.cat((phi_u, phi_l), dim=0) 126 | return phi # [dim, dim] 127 | 128 | def calc_Q_inv(self) -> torch.Tensor: 129 | m1 = 12. * (self.dt ** -3.) * self.Q_c_inv 130 | m2 = -6. * (self.dt ** -2.) * self.Q_c_inv 131 | m3 = 4. * (self.dt ** -1.) * self.Q_c_inv 132 | 133 | Q_inv_u = torch.cat((m1, m2), dim=-1) 134 | Q_inv_l = torch.cat((m2, m3), dim=-1) 135 | Q_inv = torch.cat((Q_inv_u, Q_inv_l), dim=-2) 136 | return Q_inv 137 | 138 | def cost(self, trajs: torch.Tensor, **observation) -> torch.Tensor: 139 | traj_dim = observation.get('traj_dim', None) 140 | trajs = trajs.view(traj_dim) # [..., n_support_points, dim] 141 | errors = (trajs[..., 1:, :] - trajs[..., :-1, :] @ self.phi_T) # [..., n_support_points-1, dim * 2] 142 | costs = torch.einsum('...ij,...ijk,...ik->...i', errors, self.single_Q_inv, errors) # [..., n_support_points-1] 143 | return costs.mean(dim=-1) # mean the n_support_points 144 | 145 | def eval(self, trajs: torch.Tensor, **observation) -> torch.Tensor: 146 | traj_dim = observation.get('traj_dim') 147 | optim_dim = observation.get('optim_dim') 148 | 149 | current_trajs = observation.get('current_trajs') 150 | current_trajs = current_trajs.view(traj_dim) # [..., n_support_points, dim] 151 | current_trajs = current_trajs.unsqueeze(-2).unsqueeze(-2) # [..., n_support_points, 1, 1, dim] 152 | 153 | cost_dim = traj_dim[:-1] + optim_dim[1:3] # [..., n_support_points] + [nb2, num_probe] 154 | costs = torch.zeros(cost_dim, **self.tensor_args) 155 | states = trajs 156 | 157 | probe_points = states[..., self.probe_range[0]:self.probe_range[1], :] # [..., nb2, num_eval, dim * 2] 158 | len_probe = probe_points.shape[-2] 159 | probe_points = probe_points.view(traj_dim[:-1] + (optim_dim[1], len_probe, self.dim,)) # [..., n_support_points] + [nb2, num_eval, dim * 2] 160 | right_errors = probe_points[..., 1:self.n_support_points, :, :, :] - current_trajs[..., 0:self.n_support_points-1, :, :, :] @ self.phi_T # [..., n_support_points-1, nb2, num_eval, dim * 2] 161 | left_errors = current_trajs[..., 1:self.n_support_points, :, :, :] - probe_points[..., 0:self.n_support_points-1, :, :, :] @ self.phi_T # [..., n_support_points-1, nb2, num_eval, dim * 2] 162 | # mahalanobis distance 163 | left_cost_dist = torch.einsum('...ij,...ijk,...ik->...i', left_errors, self.single_Q_inv, left_errors) # [..., n_support_points-1, nb2, num_eval] 164 | right_cost_dist = torch.einsum('...ij,...ijk,...ik->...i', right_errors, self.single_Q_inv, right_errors) # [..., n_support_points-1, nb2, num_eval] 165 | 166 | costs[..., 0:self.n_support_points-1, :, self.probe_range[0]:self.probe_range[1]] += left_cost_dist 167 | costs[..., 1:self.n_support_points, :, self.probe_range[0]:self.probe_range[1]] += right_cost_dist 168 | costs = costs.view(optim_dim).mean(dim=-1) # mean the probe 169 | 170 | return costs 171 | 172 | 173 | class CostComposite(Cost): 174 | 175 | def __init__( 176 | self, 177 | robot, 178 | n_support_points, 179 | cost_list, 180 | weights_cost_l=None, 181 | **kwargs 182 | ): 183 | super().__init__(robot, n_support_points, **kwargs) 184 | self.cost_l = cost_list 185 | self.weight_cost_l = weights_cost_l if weights_cost_l is not None else [1.0] * len(cost_list) 186 | 187 | def eval(self, trajs, trajs_interpolated=None, return_invidual_costs_and_weights=False, **kwargs): 188 | trajs, q_pos, q_vel, H_positions = self.get_q_pos_vel_and_fk_map(trajs) 189 | 190 | if not return_invidual_costs_and_weights: 191 | cost_total = 0 192 | for cost, weight_cost in zip(self.cost_l, self.weight_cost_l): 193 | if trajs_interpolated is not None: 194 | trajs_tmp = trajs_interpolated 195 | else: 196 | trajs_tmp = trajs 197 | cost_tmp = weight_cost * cost(trajs_tmp, q_pos=q_pos, q_vel=q_vel, H_positions=H_positions, **kwargs) 198 | cost_total += cost_tmp 199 | return cost_total 200 | else: 201 | cost_l = [] 202 | for cost in self.cost_l: 203 | if trajs_interpolated is not None: 204 | # Compute only collision costs with interpolated trajectories. 205 | # Other costs are computed with non-interpolated trajectories, e.g. smoothness 206 | trajs_tmp = trajs_interpolated 207 | else: 208 | trajs_tmp = trajs 209 | 210 | cost_tmp = cost(trajs_tmp, q_pos=q_pos, q_vel=q_vel, H_positions=H_positions, **kwargs) 211 | cost_l.append(cost_tmp) 212 | 213 | if return_invidual_costs_and_weights: 214 | return cost_l, self.weight_cost_l 215 | -------------------------------------------------------------------------------- /mpot/envs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anindex/mpot/a5c49d4fa83d7601bcd91246e07e254b813dc322/mpot/envs/__init__.py -------------------------------------------------------------------------------- /mpot/envs/map_generator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from mpot.envs.obst_map import ObstacleRectangle, ObstacleMap, ObstacleCircle 5 | from mpot.envs.obst_utils import random_rect, random_circle 6 | import copy 7 | 8 | from torch_robotics.environments.primitives import MultiSphereField, ObjectField, MultiBoxField 9 | 10 | 11 | def random_obstacles( 12 | map_dim = (1, 1), 13 | cell_size: float = 1., 14 | num_obst: int = 5, 15 | rand_xy_limits=[[-1, 1], [-1, 1]], 16 | rand_rect_shape=[2, 2], 17 | rand_circle_radius: float = 1, 18 | max_attempts: int = 50, 19 | tensor_args=None, 20 | ): 21 | obst_map = ObstacleMap(map_dim, cell_size, tensor_args=tensor_args) 22 | num_boxes = np.random.randint(0, num_obst) 23 | num_circles = num_obst - num_boxes 24 | # randomize box obstacles 25 | xlim = rand_xy_limits[0] 26 | ylim = rand_xy_limits[1] 27 | width, height = rand_rect_shape 28 | 29 | boxes = [] 30 | for _ in range(num_boxes): 31 | num_attempts = 0 32 | while num_attempts <= max_attempts: 33 | obst = random_rect(xlim, ylim, width, height) 34 | 35 | # Check validity of new obstacle 36 | # Do not overlap obstacles 37 | valid = obst._obstacle_collision_check(obst_map) 38 | if valid: 39 | # Add to Map 40 | obst._add_to_map(obst_map) 41 | # Add to list 42 | boxes.append(obst.to_array()) 43 | break 44 | 45 | if num_attempts == max_attempts: 46 | print("Obstacle generation: Max. number of attempts reached. ") 47 | print(f"Total num. boxes: {len(boxes)}") 48 | num_attempts += 1 49 | boxes = torch.tensor(np.array(boxes), **tensor_args) 50 | cubes = MultiBoxField(boxes[:, :2], boxes[:, 2:], tensor_args=tensor_args) 51 | box_field = ObjectField([cubes], 'random-boxes') 52 | 53 | # randomize circle obstacles 54 | circles = [] 55 | for _ in range(num_circles): 56 | num_attempts = 0 57 | while num_attempts <= max_attempts: 58 | obst = random_circle(xlim, ylim, rand_circle_radius) 59 | # Check validity of new obstacle 60 | # Do not overlap obstacles 61 | valid = obst._obstacle_collision_check(obst_map) 62 | 63 | if valid: 64 | # Add to Map 65 | obst._add_to_map(obst_map) 66 | # Add to list 67 | circles.append(obst.to_array()) 68 | break 69 | 70 | if num_attempts == max_attempts: 71 | print("Obstacle generation: Max. number of attempts reached. ") 72 | print(f"Total num. boxes: {len(circles)}") 73 | 74 | num_attempts += 1 75 | circles = torch.tensor(np.array(circles), **tensor_args) 76 | spheres = MultiSphereField(circles[:, :2], circles[:, 2], tensor_args=tensor_args) 77 | sphere_field = ObjectField([spheres], 'random-spheres') 78 | obj_list = [box_field, sphere_field] 79 | obst_map.convert_map() 80 | return obst_map, obj_list 81 | 82 | 83 | if __name__ == "__main__": 84 | cell_size = 0.1 85 | map_dim = [20, 20] 86 | seed = 2 87 | tensor_args = {'device': torch.device('cpu'), 'dtype': torch.float32} 88 | obst_map, obst_list = random_obstacles( 89 | map_dim, cell_size, 90 | num_obst=5, 91 | rand_xy_limits=[[-5, 5], [-5, 5]], 92 | rand_rect_shape=[2,2], 93 | rand_circle_radius=1, 94 | tensor_args=tensor_args 95 | ) 96 | fig = obst_map.plot() 97 | -------------------------------------------------------------------------------- /mpot/envs/obst_map.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | import numpy as np 3 | import torch 4 | import matplotlib.pyplot as plt 5 | from math import ceil 6 | from abc import ABC, abstractmethod 7 | import os.path as osp 8 | from copy import deepcopy 9 | 10 | 11 | class Obstacle(ABC): 12 | """ 13 | Base 2D Obstacle class 14 | """ 15 | 16 | def __init__(self,center_x,center_y): 17 | self.center_x = center_x 18 | self.center_y = center_y 19 | self.origin = np.array([self.center_x, self.center_y]) 20 | 21 | def _obstacle_collision_check(self, obst_map): 22 | valid = True 23 | obst_map_test = self._add_to_map(deepcopy(obst_map)) 24 | if np.any(obst_map_test.map > 1): 25 | valid = False 26 | return valid 27 | 28 | def _point_collision_check(self, obst_map, pts): 29 | valid = True 30 | if pts is not None: 31 | obst_map_test = self._add_to_map(np.copy(obst_map)) 32 | for pt in pts: 33 | if obst_map_test[ceil(pt[0]), ceil(pt[1])] >= 1: 34 | valid = False 35 | break 36 | return valid 37 | 38 | @abstractmethod 39 | def _add_to_map(self, obst_map): 40 | pass 41 | 42 | 43 | class ObstacleRectangle(Obstacle): 44 | """ 45 | Derived 2D rectangular Obstacle class 46 | """ 47 | 48 | def __init__( 49 | self, 50 | center_x=0, 51 | center_y=0, 52 | width=None, 53 | height=None, 54 | ): 55 | super().__init__(center_x, center_y) 56 | self.width = width 57 | self.height = height 58 | 59 | def _add_to_map(self, obst_map): 60 | # Convert dims to cell indices 61 | w = ceil(self.width / obst_map.cell_size) 62 | h = ceil(self.height / obst_map.cell_size) 63 | c_x = ceil(self.center_x / obst_map.cell_size) 64 | c_y = ceil(self.center_y / obst_map.cell_size) 65 | 66 | obst_map.map[ 67 | c_y - ceil(h/2.) + obst_map.origin_yi: 68 | c_y + ceil(h/2.) + obst_map.origin_yi, 69 | c_x - ceil(w/2.) + obst_map.origin_xi: 70 | c_x + ceil(w/2.) + obst_map.origin_xi, 71 | ] += 1 72 | return obst_map 73 | 74 | def to_array(self): 75 | return np.array([self.center_x, self.center_y, self.width, self.height]) 76 | 77 | 78 | class ObstacleCircle(Obstacle): 79 | """ 80 | Derived 2D circle Obstacle class 81 | """ 82 | 83 | def __init__( 84 | self, 85 | center_x=0, 86 | center_y=0, 87 | radius=1. 88 | ): 89 | super().__init__(center_x, center_y) 90 | self.radius = radius 91 | 92 | def is_inside(self, p): 93 | # Check if point p is inside of the discretized circle 94 | return np.linalg.norm(p - self.origin) <= self.radius 95 | 96 | def _add_to_map(self, obst_map): 97 | # Convert dims to cell indices 98 | c_r = ceil(self.radius / obst_map.cell_size) 99 | c_x = ceil(self.center_x / obst_map.cell_size) 100 | c_y = ceil(self.center_y / obst_map.cell_size) 101 | 102 | for i in range(c_y - 2 * c_r + obst_map.origin_yi, c_y + 2 * c_r + obst_map.origin_yi): 103 | for j in range(c_x - 2 * c_r + obst_map.origin_xi, c_x + 2 * c_r + obst_map.origin_xi): 104 | p = np.array([(j - obst_map.origin_xi) * obst_map.cell_size, 105 | (i - obst_map.origin_yi) * obst_map.cell_size]) 106 | if self.is_inside(p): 107 | obst_map.map[i, j] += 1 108 | return obst_map 109 | 110 | def to_array(self): 111 | return np.array([self.center_x, self.center_y, self.radius]) 112 | 113 | 114 | class ObstacleMap: 115 | """ 116 | Generates an occupancy grid. 117 | """ 118 | def __init__(self, map_dim, cell_size, tensor_args=None): 119 | 120 | assert map_dim[0] % 2 == 0 121 | assert map_dim[1] % 2 == 0 122 | 123 | if tensor_args is None: 124 | tensor_args = {'device': torch.device('cpu'), 'dtype': torch.float32} 125 | self.tensor_args = tensor_args 126 | 127 | cmap_dim = [0, 0] 128 | cmap_dim[0] = ceil(map_dim[0]/cell_size) 129 | cmap_dim[1] = ceil(map_dim[1]/cell_size) 130 | 131 | self.map = np.zeros(cmap_dim) 132 | self.cell_size = cell_size 133 | 134 | # Map center (in cells) 135 | self.origin_xi = int(cmap_dim[0]/2) 136 | self.origin_yi = int(cmap_dim[1]/2) 137 | 138 | # self.xlim = map_dim[0] 139 | 140 | self.x_dim, self.y_dim = self.map.shape 141 | x_range = self.cell_size * self.x_dim 142 | y_range = self.cell_size * self.y_dim 143 | self.xlim = [-x_range/2, x_range/2] 144 | self.ylim = [-y_range/2, y_range/2] 145 | 146 | self.c_offset = torch.tensor([self.origin_xi, self.origin_yi], **self.tensor_args) 147 | 148 | def __call__(self, X, **kwargs): 149 | return self.compute_cost(X, **kwargs) 150 | 151 | def convert_map(self): 152 | self.map_torch = torch.Tensor(self.map).to(**self.tensor_args) 153 | return self.map_torch 154 | 155 | def plot(self, save_dir=None, filename="obst_map.png"): 156 | fig = plt.figure() 157 | plt.imshow(self.map) 158 | plt.gca().invert_yaxis() 159 | plt.show() 160 | if save_dir is not None: 161 | plt.savefig(osp.join(save_dir, filename)) 162 | return fig 163 | 164 | def get_xy_grid(self, device): 165 | xv, yv = torch.meshgrid([torch.linspace(self.xlim[0], self.xlim[1], self.x_dim), 166 | torch.linspace(self.ylim[0], self.ylim[1], self.y_dim)]) 167 | xy_grid = torch.stack((xv, yv), dim=2) 168 | return xy_grid.to(device) 169 | 170 | def get_collisions(self, X, *args, **kwargs): 171 | """ 172 | Checks for collision in a batch of trajectories using the generated occupancy grid (i.e. obstacle map), and 173 | returns sum of collision costs for the entire batch. 174 | 175 | :param weight: weight on obstacle cost, float tensor. 176 | :param X: Tensor of trajectories, of shape (batch_size, traj_length, position_dim) 177 | :return: collision cost on the trajectories 178 | """ 179 | X_occ = X * (1/self.cell_size) + self.c_offset 180 | X_occ = X_occ.floor().int() 181 | 182 | # Project out-of-bounds locations to axis 183 | X_occ[...,0] = X_occ[..., 0].clamp(0, self.map.shape[0]-1) 184 | X_occ[...,1] = X_occ[..., 1].clamp(0, self.map.shape[1]-1) 185 | 186 | # Collisions 187 | collision_vals = self.map_torch[X_occ[..., 1], X_occ[..., 0]] 188 | return collision_vals 189 | 190 | def compute_cost(self, X, *args, **kwargs): 191 | return self.get_collisions(X, *args, **kwargs) 192 | 193 | def zero_grad(self): 194 | pass 195 | -------------------------------------------------------------------------------- /mpot/envs/obst_utils.py: -------------------------------------------------------------------------------- 1 | from math import ceil 2 | import random 3 | from mpot.envs.obst_map import ObstacleRectangle, ObstacleCircle 4 | 5 | 6 | def round_up(n, decimals=0): 7 | multiplier = 10 ** decimals 8 | return ceil(n * multiplier) / multiplier 9 | 10 | 11 | def random_rect(xlim=(0, 0), ylim=(0, 0), width=2, height=2): 12 | """ 13 | Generates an rectangular obstacle object, with random location and dimensions. 14 | """ 15 | cx = random.uniform(xlim[0], xlim[1]) 16 | cy = random.uniform(ylim[0], ylim[1]) 17 | return ObstacleRectangle(cx, cy, width, height) 18 | 19 | 20 | def random_circle(xlim=(0,0), ylim=(0,0), radius=2): 21 | """ 22 | Generates a circle obstacle object, with random location and dimensions. 23 | """ 24 | cx = random.uniform(xlim[0], xlim[1]) 25 | cy = random.uniform(ylim[0], ylim[1]) 26 | return ObstacleCircle(cx, cy, radius) 27 | -------------------------------------------------------------------------------- /mpot/envs/occupancy.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from matplotlib import pyplot as plt 3 | 4 | from torch_robotics.environments.env_base import EnvBase 5 | from torch_robotics.robots import RobotPointMass 6 | from torch_robotics.torch_utils.torch_utils import DEFAULT_TENSOR_ARGS 7 | from torch_robotics.visualizers.planning_visualizer import create_fig_and_axes 8 | from mpot.envs.map_generator import random_obstacles 9 | 10 | 11 | class EnvOccupancy2D(EnvBase): 12 | 13 | def __init__(self, tensor_args=None, **kwargs): 14 | if tensor_args is None: 15 | tensor_args = DEFAULT_TENSOR_ARGS 16 | obst_map, obj_list = random_obstacles( 17 | map_dim=[20, 20], 18 | cell_size=0.1, 19 | num_obst=15, 20 | rand_xy_limits=[[-7.5, 7.5], [-7.5, 7.5]], 21 | rand_rect_shape=[2, 2], 22 | rand_circle_radius=1., 23 | tensor_args=tensor_args 24 | ) 25 | 26 | super().__init__( 27 | name=self.__class__.__name__, 28 | limits=torch.tensor([[-10, -10], [10, 10]], **tensor_args), # environments limits 29 | obj_fixed_list=obj_list, 30 | tensor_args=tensor_args, 31 | **kwargs 32 | ) 33 | self.occupancy_map = obst_map 34 | 35 | 36 | if __name__ == '__main__': 37 | env = EnvOccupancy2D(precompute_sdf_obj_fixed=False, tensor_args=DEFAULT_TENSOR_ARGS) 38 | fig, ax = create_fig_and_axes(env.dim) 39 | env.render(ax) 40 | plt.show() 41 | 42 | # Render sdf 43 | fig, ax = create_fig_and_axes(env.dim) 44 | env.render_sdf(ax, fig) 45 | 46 | # Render gradient of sdf 47 | env.render_grad_sdf(ax, fig) 48 | plt.show() 49 | -------------------------------------------------------------------------------- /mpot/gp/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 An Thai Le 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | 23 | ********************************************************************** 24 | The first version of some files in this gp/ folder were licensed as 25 | "Original Source License" (see below). Somes enhancements and developments 26 | were done by An Thai Le since obtaining the first version. 27 | ********************************************************************** 28 | 29 | Original Source License: 30 | 31 | MIT License 32 | 33 | Copyright (c) 2022 Sasha Lambert 34 | 35 | Permission is hereby granted, free of charge, to any person obtaining a copy 36 | of this software and associated documentation files (the "Software"), to deal 37 | in the Software without restriction, including without limitation the rights 38 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 39 | copies of the Software, and to permit persons to whom the Software is 40 | furnished to do so, subject to the following conditions: 41 | 42 | The above copyright notice and this permission notice shall be included in all 43 | copies or substantial portions of the Software. 44 | 45 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 46 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 47 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 48 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 49 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 50 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 51 | SOFTWARE. 52 | -------------------------------------------------------------------------------- /mpot/gp/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anindex/mpot/a5c49d4fa83d7601bcd91246e07e254b813dc322/mpot/gp/__init__.py -------------------------------------------------------------------------------- /mpot/gp/field_factor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class FieldFactor: 5 | 6 | def __init__( 7 | self, 8 | n_dof, 9 | sigma, 10 | traj_range, 11 | ): 12 | self.sigma = sigma 13 | self.n_dof = n_dof 14 | self.traj_range = traj_range 15 | self.K = 1. / (sigma**2) 16 | 17 | def get_error( 18 | self, 19 | q_trajs, 20 | field, 21 | q_pos=None, 22 | q_vel=None, 23 | H_pos=None, 24 | q_trajs_interp=None, 25 | q_pos_interp=None, 26 | q_vel_interp=None, 27 | H_pos_interp=None, 28 | calc_jacobian=True, 29 | **kwargs 30 | ): 31 | batch = q_trajs.shape[0] 32 | 33 | if H_pos is not None: 34 | states = H_pos[:, self.traj_range[0]:self.traj_range[1]] 35 | else: 36 | states = q_trajs[:, self.traj_range[0]:self.traj_range[1], :self.n_dof].reshape(-1, self.n_dof) 37 | q_pos_new = q_pos[:, self.traj_range[0]:self.traj_range[1], :] 38 | length = q_pos_new.shape[-2] 39 | error = field.compute_cost(q_pos_new, states, **kwargs).reshape(batch, length) 40 | 41 | if calc_jacobian: 42 | # compute jacobian wrt to the error of the interpolated trajectory 43 | error_interp = error 44 | if H_pos_interp is not None or q_trajs_interp is not None: 45 | # interpolated trajectory 46 | if H_pos_interp is not None: 47 | states = H_pos_interp[:, self.traj_range[0]:self.traj_range[1]] 48 | else: 49 | states = q_trajs_interp[:, self.traj_range[0]:self.traj_range[1], :self.n_dof].reshape(-1, self.n_dof) 50 | q_pos_new = q_pos_interp[:, self.traj_range[0]:self.traj_range[1], :] 51 | length = q_pos_new.shape[-2] 52 | error_interp = field.compute_cost(q_pos_new, states, **kwargs).reshape(batch, length) 53 | 54 | H = -torch.autograd.grad(error_interp.sum(), q_trajs, retain_graph=True)[0][:, self.traj_range[0]:self.traj_range[1], :self.n_dof] 55 | error = error.detach() 56 | field.zero_grad() 57 | return error, H 58 | else: 59 | return error 60 | -------------------------------------------------------------------------------- /mpot/gp/gp_factor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class GPFactor(): 5 | 6 | def __init__( 7 | self, 8 | dim: int, 9 | sigma: float, 10 | dt: float, 11 | num_factors: int, 12 | Q_c_inv: torch.Tensor = None, 13 | tensor_args=None, 14 | ): 15 | self.dim = dim 16 | self.dt = dt 17 | self.tensor_args = tensor_args 18 | self.state_dim = self.dim * 2 # position and velocity 19 | self.num_factors = num_factors 20 | self.idx1 = torch.arange(0, self.num_factors, device=tensor_args['device']) 21 | self.idx2 = torch.arange(1, self.num_factors+1, device=tensor_args['device']) 22 | self.phi = self.calc_phi() 23 | if Q_c_inv is None: 24 | Q_c_inv = torch.eye(dim, **tensor_args) / sigma**2 25 | self.Q_c_inv = torch.zeros(num_factors, dim, dim, **tensor_args) + Q_c_inv 26 | self.Q_inv = self.calc_Q_inv() # shape: [num_factors, state_dim, state_dim] 27 | 28 | ## Pre-compute constant Jacobians 29 | self.H1 = self.phi.unsqueeze(0).repeat(self.num_factors, 1, 1) 30 | self.H2 = -1. * torch.eye(self.state_dim).unsqueeze(0).repeat( 31 | self.num_factors, 1, 1, 32 | ) 33 | 34 | def calc_phi(self) -> torch.Tensor: 35 | I = torch.eye(self.dim, **self.tensor_args) 36 | Z = torch.zeros(self.dim, self.dim, **self.tensor_args) 37 | phi_u = torch.cat((I, self.dt * I), dim=1) 38 | phi_l = torch.cat((Z, I), dim=1) 39 | phi = torch.cat((phi_u, phi_l), dim=0) 40 | return phi 41 | 42 | def calc_Q_inv(self) -> torch.Tensor: 43 | m1 = 12. * (self.dt ** -3.) * self.Q_c_inv 44 | m2 = -6. * (self.dt ** -2.) * self.Q_c_inv 45 | m3 = 4. * (self.dt ** -1.) * self.Q_c_inv 46 | 47 | Q_inv_u = torch.cat((m1, m2), dim=-1) 48 | Q_inv_l = torch.cat((m2, m3), dim=-1) 49 | Q_inv = torch.cat((Q_inv_u, Q_inv_l), dim=-2) 50 | return Q_inv 51 | 52 | def get_error(self, x_traj: torch.Tensor, calc_jacobian: bool = True) -> torch.Tensor: 53 | state_1 = torch.index_select(x_traj, 1, self.idx1).unsqueeze(-1) 54 | state_2 = torch.index_select(x_traj, 1, self.idx2).unsqueeze(-1) 55 | error = state_2 - self.phi @ state_1 56 | 57 | if calc_jacobian: 58 | H1 = self.H1 59 | H2 = self.H2 60 | # H1 = self.H1.unsqueeze(0).repeat(batch, 1, 1, 1) 61 | # H2 = self.H2.unsqueeze(0).repeat(batch, 1, 1, 1) 62 | return error, H1, H2 63 | else: 64 | return error -------------------------------------------------------------------------------- /mpot/gp/gp_prior.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributions as dist 3 | 4 | 5 | class BatchGPPrior: 6 | 7 | def __init__( 8 | self, 9 | traj_len: int, 10 | dt: float, 11 | dim: int, 12 | K_s_inv: torch.Tensor, 13 | K_gp_inv: torch.Tensor, 14 | start_state: torch.Tensor, 15 | means: torch.Tensor = None, 16 | K_g_inv: torch.Tensor = None, 17 | goal_states: torch.Tensor = None, 18 | tensor_args=None, 19 | ): 20 | """ 21 | Motion-Planning prior. 22 | 23 | reference: "Continuous-time Gaussian process motion planning via 24 | probabilistic inference", Mukadam et al. (IJRR 2018) 25 | 26 | Parameters 27 | ---------- 28 | traj_len : int 29 | Planning horizon length (not including start state). 30 | dt : float 31 | Time-step size. 32 | state_dim : int 33 | State state_dimension. 34 | K_s_inv : Tensor 35 | Start-state inverse covariance. Shape: [state_dim, state_dim] 36 | K_gp_inv : Tensor 37 | Gaussian-process single-step inverse covariance i.e. 'Q_inv'. 38 | Assumed constant, meaning homoscedastic noise with constant step-size. 39 | Shape: [2 * state_dim, 2 * state_dim] 40 | start_state : Tensor 41 | Shape: [state_dim] 42 | (Optional) K_g_inv : Tensor 43 | Goal-state inverse covariance. Shape: [state_dim, state_dim] 44 | (Optional) goal_state : Tensor 45 | Shape: [state_dim] 46 | (Optional) dim : int 47 | Degrees of freedom. 48 | """ 49 | self.dim = dim 50 | self.state_dim = dim * 2 51 | self.traj_len = traj_len 52 | self.M = self.state_dim * (traj_len + 1) 53 | self.tensor_args = tensor_args 54 | 55 | self.goal_directed = (goal_states is not None) 56 | 57 | if means is None: 58 | self.num_modes = goal_states.shape[0] if self.goal_directed else 1 59 | means = self.get_const_vel_mean( 60 | start_state, 61 | goal_states, 62 | dt, 63 | traj_len, 64 | dim) 65 | else: 66 | self.num_modes = means.shape[0] 67 | 68 | # Flatten mean trajectories 69 | self.means = means.reshape(self.num_modes, -1) 70 | 71 | # TODO: Add different goal Covariances 72 | # Assume same goal Cov. for now 73 | Sigma_inv = self.get_const_vel_covariance( 74 | dt, 75 | K_s_inv, 76 | K_gp_inv, 77 | K_g_inv, 78 | ) 79 | 80 | # self.Sigma_inv = Sigma_inv 81 | self.Sigma_inv = Sigma_inv # + torch.eye(Sigma_inv.shape[0], **tensor_args) * 1.e-3 82 | self.Sigma_invs = self.Sigma_inv.repeat(self.num_modes, 1, 1) 83 | self.update_dist(self.means, self.Sigma_invs) 84 | 85 | def update_dist( 86 | self, 87 | means: torch.Tensor, 88 | Sigma_invs: torch.Tensor, 89 | ) -> torch.Tensor: 90 | # Create Multi-variate Normal Distribution 91 | self.dist = dist.MultivariateNormal( 92 | means, 93 | precision_matrix=Sigma_invs, 94 | ) 95 | 96 | def get_mean(self, reshape: bool = True) -> torch.Tensor: 97 | if reshape: 98 | return self.means.clone().detach().reshape( 99 | self.num_modes, self.traj_len + 1, self.state_dim, 100 | ) 101 | else: 102 | self.means.clone().detach() 103 | 104 | def set_mean(self, means: torch.Tensor) -> torch.Tensor: 105 | assert means.shape == self.means.shape 106 | self.means = means.clone().detach() 107 | self.update_dist(self.means, self.Sigma_invs) 108 | 109 | def set_Sigma_invs(self, Sigma_invs: torch.Tensor) -> torch.Tensor: 110 | assert Sigma_invs.shape == self.Sigma_invs.shape 111 | self.Sigma_invs = Sigma_invs.clone().detach() 112 | self.update_dist(self.means, self.Sigma_invs) 113 | 114 | def const_vel_trajectory( 115 | self, 116 | start_state: torch.Tensor, 117 | goal_state: torch.Tensor, 118 | dt: float, 119 | traj_len: int, 120 | dim: int, 121 | ) -> torch.Tensor: 122 | state_traj = torch.zeros(traj_len + 1, 2 * dim, **self.tensor_args) 123 | mean_vel = (goal_state[:dim] - start_state[:dim]) / (traj_len * dt) 124 | for i in range(traj_len + 1): 125 | state_traj[i, :dim] = start_state[:dim] * (traj_len - i) * 1. / traj_len \ 126 | + goal_state[:dim] * i * 1./traj_len 127 | state_traj[:, dim:] = mean_vel.unsqueeze(0) 128 | return state_traj 129 | 130 | def get_const_vel_mean( 131 | self, 132 | start_state: torch.Tensor, 133 | goal_states: torch.Tensor, 134 | dt: float, 135 | traj_len: int, 136 | dim: int, 137 | ) -> torch.Tensor: 138 | 139 | # Make mean goal-directed if goal_state is provided. 140 | if self.goal_directed: 141 | means = [] 142 | for i in range(self.num_modes): 143 | means.append(self.const_vel_trajectory( 144 | start_state, 145 | goal_states[i], 146 | dt, 147 | traj_len, 148 | dim, 149 | )) 150 | return torch.stack(means, dim=0) 151 | else: 152 | return start_state.repeat(traj_len + 1, 1) 153 | 154 | def get_const_vel_covariance( 155 | self, 156 | dt: float, 157 | K_s_inv: torch.Tensor, 158 | K_gp_inv: torch.Tensor, 159 | K_g_inv: torch.Tensor, 160 | precision_matrix: bool = True, 161 | ) -> torch.Tensor: 162 | # Transition matrix 163 | Phi = torch.eye(self.state_dim, **self.tensor_args) 164 | Phi[:self.dim, self.dim:] = torch.eye(self.dim, **self.tensor_args) * dt 165 | diag_Phis = Phi 166 | for _ in range(self.traj_len - 1): 167 | diag_Phis = torch.block_diag(diag_Phis, Phi) 168 | 169 | A = torch.eye(self.M, **self.tensor_args) 170 | A[self.state_dim:, :-self.state_dim] += -1. * diag_Phis 171 | if self.goal_directed: 172 | b = torch.zeros(self.state_dim, self.M, **self.tensor_args) 173 | b[:, -self.state_dim:] = torch.eye(self.state_dim, **self.tensor_args) 174 | A = torch.cat((A, b)) 175 | 176 | Q_inv = K_s_inv 177 | for _ in range(self.traj_len): 178 | Q_inv = torch.block_diag(Q_inv, K_gp_inv).to(**self.tensor_args) 179 | if self.goal_directed: 180 | Q_inv = torch.block_diag(Q_inv, K_g_inv).to(**self.tensor_args) 181 | 182 | K_inv = A.t() @ Q_inv @ A 183 | if precision_matrix: 184 | return K_inv 185 | else: 186 | return torch.inverse(K_inv) 187 | 188 | def sample(self, num_samples: int) -> torch.Tensor: 189 | return self.dist.sample((num_samples,)).view( 190 | num_samples, self.num_modes, self.traj_len + 1, self.state_dim, 191 | ).transpose(1, 0) 192 | 193 | def log_prob(self, X: torch.Tensor) -> torch.Tensor: 194 | return self.dist.log_prob(X) -------------------------------------------------------------------------------- /mpot/gp/unary_factor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class UnaryFactor: 5 | 6 | def __init__( 7 | self, 8 | dim: int, 9 | sigma: float, 10 | mean: torch.Tensor = None, 11 | tensor_args=None, 12 | ): 13 | self.sigma = sigma 14 | if mean is None: 15 | self.mean = torch.zeros(dim, **tensor_args) 16 | else: 17 | self.mean = mean 18 | self.tensor_args = tensor_args 19 | self.K = torch.eye(dim, **tensor_args) / sigma**2 # weight matrix 20 | self.dim = dim 21 | 22 | def get_error(self, X: torch.Tensor, calc_jacobian: bool = True) -> torch.Tensor: 23 | error = self.mean - X 24 | 25 | if calc_jacobian: 26 | H = torch.eye(self.dim, **self.tensor_args).unsqueeze(0).repeat(X.shape[0], 1, 1) 27 | return error.view(X.shape[0], self.dim, 1), H 28 | else: 29 | return error 30 | 31 | def set_mean(self, X: torch.Tensor) -> torch.Tensor: 32 | self.mean = X.clone().detach() 33 | -------------------------------------------------------------------------------- /mpot/ot/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anindex/mpot/a5c49d4fa83d7601bcd91246e07e254b813dc322/mpot/ot/__init__.py -------------------------------------------------------------------------------- /mpot/ot/initializer.py: -------------------------------------------------------------------------------- 1 | import abc 2 | from typing import Any, Dict, Optional, Sequence, Tuple 3 | 4 | import torch 5 | 6 | from mpot.ot.problem import LinearProblem 7 | 8 | 9 | class SinkhornInitializer(abc.ABC): 10 | """Base class for Sinkhorn initializers.""" 11 | 12 | @abc.abstractmethod 13 | def init_dual_a( 14 | self, 15 | ot_prob: LinearProblem, 16 | ) -> torch.Tensor: 17 | """Initialize Sinkhorn potential f_u. 18 | 19 | Returns: 20 | potential size ``[n,]``. 21 | """ 22 | 23 | @abc.abstractmethod 24 | def init_dual_b( 25 | self, 26 | ot_prob: LinearProblem, 27 | ) -> torch.Tensor: 28 | """Initialize Sinkhorn potential g_v. 29 | 30 | Returns: 31 | potential size ``[m,]``. 32 | """ 33 | 34 | def __call__( 35 | self, 36 | ot_prob: LinearProblem, 37 | a: Optional[torch.Tensor] = None, 38 | b: Optional[torch.Tensor] = None, 39 | ) -> Tuple[torch.Tensor, torch.Tensor]: 40 | 41 | n, m = ot_prob.C.shape 42 | if a is None: 43 | a = self.init_dual_a(ot_prob) 44 | if b is None: 45 | b = self.init_dual_b(ot_prob) 46 | 47 | assert a.shape == ( 48 | n, 49 | ), f"Expected `f_u` to have shape `{n,}`, found `{a.shape}`." 50 | assert b.shape == ( 51 | m, 52 | ), f"Expected `g_v` to have shape `{m,}`, found `{b.shape}`." 53 | 54 | # cancel dual variables for zero weights 55 | a = torch.where(ot_prob.a > 0., a, -torch.inf) 56 | b = torch.where(ot_prob.b > 0., b, -torch.inf) 57 | 58 | return a, b 59 | 60 | 61 | class DefaultInitializer(SinkhornInitializer): 62 | """Default initialization of Sinkhorn dual potentials scalings.""" 63 | 64 | def init_dual_a( 65 | self, 66 | ot_prob: LinearProblem, 67 | ) -> torch.Tensor: 68 | return torch.zeros_like(ot_prob.a) 69 | 70 | def init_dual_b( 71 | self, 72 | ot_prob: LinearProblem, 73 | ) -> torch.Tensor: 74 | return torch.zeros_like(ot_prob.b) 75 | 76 | 77 | class RandomInitializer(SinkhornInitializer): 78 | """Random initialization of Sinkhorn dual potentials scalings.""" 79 | 80 | def __init__( 81 | self, 82 | seed: Optional[int] = None, 83 | ): 84 | self.seed = seed 85 | 86 | def init_dual_a( 87 | self, 88 | ot_prob: LinearProblem, 89 | ) -> torch.Tensor: 90 | return torch.randn_like(ot_prob.a) 91 | 92 | def init_dual_b( 93 | self, 94 | ot_prob: LinearProblem, 95 | ) -> torch.Tensor: 96 | return torch.randn_like(ot_prob.b) 97 | -------------------------------------------------------------------------------- /mpot/ot/problem.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | import torch 3 | 4 | 5 | class Epsilon: 6 | """Epsilon scheduler for Sinkhorn and Sinkhorn Step.""" 7 | 8 | def __init__( 9 | self, 10 | target: float = 0.1, 11 | scale_epsilon: float = 1.0, 12 | init: float = 1.0, 13 | decay: float = 1.0 14 | ): 15 | self._target_init = target 16 | self._scale_epsilon = scale_epsilon 17 | self._init = init 18 | self._decay = decay 19 | 20 | @property 21 | def target(self) -> float: 22 | """Return the final regularizer value of scheduler.""" 23 | target = 5e-2 if self._target_init is None else self._target_init 24 | scale = 1.0 if self._scale_epsilon is None else self._scale_epsilon 25 | return scale * target 26 | 27 | def at(self, iteration: int = 1) -> float: 28 | """Return (intermediate) regularizer value at a given iteration.""" 29 | if iteration is None: 30 | return self.target 31 | # check the decay is smaller than 1.0. 32 | decay = min(self._decay, 1.0) 33 | # the multiple is either 1.0 or a larger init value that is decayed. 34 | multiple = max(self._init * (decay ** iteration), 1.0) 35 | return multiple * self.target 36 | 37 | def done(self, eps: float) -> bool: 38 | """Return whether the scheduler is done at a given value.""" 39 | return eps == self.target 40 | 41 | def done_at(self, iteration: int) -> bool: 42 | """Return whether the scheduler is done at a given iteration.""" 43 | return self.done(self.at(iteration)) 44 | 45 | 46 | class LinearEpsilon(Epsilon): 47 | 48 | def __init__(self, target: float = 0.1, 49 | scale_epsilon: float = 1, 50 | init: float = 1, 51 | decay: float = 1): 52 | super().__init__(target, scale_epsilon, init, decay) 53 | 54 | def at(self, iteration: int = 1) -> float: 55 | if iteration is None: 56 | return self.target 57 | 58 | eps = max(self._init - self._decay * iteration, self.target) 59 | return eps * self._scale_epsilon 60 | 61 | 62 | class LinearProblem(): 63 | 64 | def __init__( 65 | self, 66 | C: torch.Tensor, 67 | epsilon: Union[Epsilon, float] = 0.01, 68 | a: torch.Tensor = None, 69 | b: torch.Tensor = None, 70 | tau_a: float = 1.0, 71 | tau_b: float = 1.0, 72 | scaling_cost: bool = True, 73 | ) -> None: 74 | if scaling_cost: 75 | C = scale_cost_matrix(C) 76 | self.C = C 77 | self.epsilon = epsilon 78 | self.a = a if a is not None else (torch.ones(C.shape[0]).type_as(C) / C.shape[0]) 79 | self.b = b if b is not None else (torch.ones(C.shape[1]).type_as(C) / C.shape[1]) 80 | self.tau_a = tau_a 81 | self.tau_b = tau_b 82 | 83 | def potential_from_scaling(self, scaling: torch.Tensor) -> torch.Tensor: 84 | """Compute dual potential vector from scaling vector. 85 | 86 | Args: 87 | scaling: vector. 88 | 89 | Returns: 90 | a vector of the same size. 91 | """ 92 | eps = self.epsilon.target if isinstance(self.epsilon, Epsilon) else self.epsilon 93 | return eps * torch.log(scaling) 94 | 95 | def marginal_from_potentials( 96 | self, f: torch.Tensor, g: torch.Tensor, dim: int 97 | ) -> torch.Tensor: 98 | eps = self.epsilon.target if isinstance(self.epsilon, Epsilon) else self.epsilon 99 | h = (f if dim == 1 else g) 100 | z = self.apply_lse_kernel(f, g, eps, dim=dim) 101 | return torch.exp((z + h) / eps) 102 | 103 | def update_potential( 104 | self, f: torch.Tensor, g: torch.Tensor, log_marginal: torch.Tensor, 105 | iteration: int = None, dim: int = 0, 106 | ) -> torch.Tensor: 107 | eps = self.epsilon.at(iteration) if isinstance(self.epsilon, Epsilon) else self.epsilon 108 | app_lse = self.apply_lse_kernel(f, g, eps, dim=dim) 109 | return eps * log_marginal - torch.where(torch.isfinite(app_lse), app_lse, 0) 110 | 111 | def transport_from_potentials( 112 | self, f: torch.Tensor, g: torch.Tensor 113 | ) -> torch.Tensor: 114 | """Output transport matrix from potentials.""" 115 | eps = self.epsilon.target if isinstance(self.epsilon, Epsilon) else self.epsilon 116 | return torch.exp(self._center(f, g) / eps) 117 | 118 | def apply_lse_kernel( 119 | self, f: torch.Tensor, g: torch.Tensor, eps: float, dim: int 120 | ) -> torch.Tensor: 121 | w_res = self._softmax(f, g, eps, dim=dim) 122 | remove = f if dim == 1 else g 123 | return w_res - torch.where(torch.isfinite(remove), remove, 0) 124 | 125 | def _center(self, f: torch.Tensor, g: torch.Tensor) -> torch.Tensor: 126 | return f.unsqueeze(1) + g.unsqueeze(0) - self.C 127 | 128 | def _softmax( 129 | self, f: torch.Tensor, g: torch.Tensor, eps: float, dim: int 130 | ) -> torch.Tensor: 131 | """Apply softmax row or column wise""" 132 | 133 | lse_output = torch.logsumexp( 134 | self._center(f, g) / eps, dim=dim 135 | ) 136 | return eps * lse_output 137 | 138 | 139 | class SinkhornState(): 140 | """Holds the state variables used to solve OT with Sinkhorn.""" 141 | 142 | def __init__( 143 | self, 144 | errors: torch.Tensor = None, 145 | fu: torch.Tensor = None, 146 | gv: torch.Tensor = None, 147 | ): 148 | self.errors = errors 149 | self.fu = fu 150 | self.gv = gv 151 | self.converged_at = -1 152 | 153 | def solution_error( 154 | self, 155 | ot_prob: LinearProblem, 156 | parallel_dual_updates: bool, 157 | ) -> torch.Tensor: 158 | """State dependent function to return error.""" 159 | fu, gv = self.fu, self.gv 160 | 161 | return solution_error( 162 | fu, 163 | gv, 164 | ot_prob, 165 | parallel_dual_updates=parallel_dual_updates 166 | ) 167 | 168 | def ent_reg_cost( 169 | self, ot_prob: LinearProblem, 170 | ) -> float: 171 | return ent_reg_cost(self.fu, self.gv, ot_prob) 172 | 173 | 174 | class SinkhornStepState(): 175 | """Holds the state of the barycenter solver. 176 | 177 | Args: 178 | costs: Holds the sequence of regularized GW costs seen through the outer 179 | loop of the solver. 180 | linear_convergence: Holds the sequence of bool convergence flags of the 181 | inner Sinkhorn iterations. 182 | X: optimizing points. 183 | a: weights of the barycenter. (not using) 184 | """ 185 | 186 | def __init__(self, 187 | X_init: torch.Tensor, 188 | costs: torch.Tensor = None, 189 | linear_convergence: torch.Tensor = None, 190 | objective_vals: torch.Tensor = None, 191 | X_history: torch.Tensor = None, 192 | displacement_sqnorms: torch.Tensor = None, 193 | a: torch.Tensor = None) -> None: 194 | self.X = X_init 195 | self.costs = costs 196 | self.linear_convergence = linear_convergence 197 | self.objective_vals = objective_vals 198 | self.X_history = X_history 199 | self.displacement_sqnorms = displacement_sqnorms 200 | self.a = a 201 | 202 | 203 | def scale_cost_matrix(M: torch.Tensor) -> torch.Tensor: 204 | min_M = M.min() 205 | if min_M < 0: 206 | M -= min_M 207 | max_M = M.max() 208 | if max_M > 1.: 209 | M /= max_M # for stability 210 | return M 211 | 212 | 213 | def phi_star(h: torch.Tensor, rho: float) -> torch.Tensor: 214 | """Legendre transform of KL, :cite:`sejourne:19`, p. 9.""" 215 | return rho * (torch.exp(h / rho) - 1) 216 | 217 | 218 | def rho(epsilon: float, tau: float) -> float: 219 | return (epsilon * tau) / (1. - tau) 220 | 221 | 222 | def derivative_phi_star(f: torch.Tensor, rho: float) -> torch.Tensor: 223 | return torch.exp(f / rho) 224 | 225 | 226 | def grad_of_marginal_fit( 227 | c: torch.Tensor, h: torch.Tensor, tau: float, epsilon: float 228 | ) -> torch.Tensor: 229 | if tau == 1.0: 230 | return c 231 | r = rho(epsilon, tau) 232 | return torch.where(c > 0, c * derivative_phi_star(-h, r), 0.0) 233 | 234 | 235 | def solution_error( 236 | f_u: torch.Tensor, 237 | g_v: torch.Tensor, 238 | ot_prob: LinearProblem, 239 | parallel_dual_updates: bool, 240 | ) -> torch.Tensor: 241 | """Compute error between Sinkhorn solution and target marginals.""" 242 | if not parallel_dual_updates: 243 | return marginal_error( 244 | f_u, g_v, ot_prob.b, ot_prob, dim=0 245 | ) 246 | 247 | grad_a = grad_of_marginal_fit( 248 | ot_prob.a, f_u, ot_prob.tau_a, ot_prob.epsilon 249 | ) 250 | grad_b = grad_of_marginal_fit( 251 | ot_prob.b, g_v, ot_prob.tau_b, ot_prob.epsilon 252 | ) 253 | 254 | err = marginal_error(f_u, g_v, grad_a, ot_prob, dim=1) 255 | err += marginal_error(f_u, g_v, grad_b, ot_prob, dim=0) 256 | return err 257 | 258 | 259 | def marginal_error( 260 | f_u: torch.Tensor, 261 | g_v: torch.Tensor, 262 | target: torch.Tensor, 263 | ot_prob: LinearProblem, 264 | dim: int = 0 265 | ) -> torch.Tensor: 266 | """Output how far Sinkhorn solution is w.r.t target. 267 | 268 | Args: 269 | f_u: a vector of potentials or scalings for the first marginal. 270 | g_v: a vector of potentials or scalings for the second marginal. 271 | target: target marginal. 272 | dim: dim (0 or 1) along which to compute marginal. 273 | 274 | Returns: 275 | Array of floats, quantifying difference between target / marginal. 276 | """ 277 | marginal = ot_prob.marginal_from_potentials(f_u, g_v, dim=dim) 278 | # L1 distance between target and marginal 279 | return torch.sum( 280 | torch.abs(marginal - target) 281 | ) 282 | 283 | 284 | def ent_reg_cost( 285 | f: torch.Tensor, g: torch.Tensor, ot_prob: LinearProblem, 286 | ) -> float: 287 | 288 | supp_a = ot_prob.a > 0 289 | supp_b = ot_prob.b > 0 290 | fa = ot_prob.potential_from_scaling(ot_prob.a) 291 | if ot_prob.tau_a == 1.0: 292 | div_a = torch.sum(torch.where(supp_a, ot_prob.a * (f - fa), 0.0)) 293 | else: 294 | rho_a = rho(ot_prob.epsilon, ot_prob.tau_a) 295 | div_a = -torch.sum( 296 | torch.where(supp_a, ot_prob.a * phi_star(-(f - fa), rho_a), 0.0) 297 | ) 298 | 299 | gb = ot_prob.potential_from_scaling(ot_prob.b) 300 | if ot_prob.tau_b == 1.0: 301 | div_b = torch.sum(torch.where(supp_b, ot_prob.b * (g - gb), 0.0)) 302 | else: 303 | rho_b = rho(ot_prob.epsilon, ot_prob.tau_b) 304 | div_b = -torch.sum( 305 | torch.where(supp_b, ot_prob.b * phi_star(-(g - gb), rho_b), 0.0) 306 | ) 307 | 308 | # Using https://arxiv.org/pdf/1910.12958.pdf (24) 309 | total_sum = torch.sum(ot_prob.marginal_from_potentials(f, g)) 310 | return div_a + div_b + ot_prob.epsilon * ( 311 | torch.sum(ot_prob.a) * torch.sum(ot_prob.b) - total_sum 312 | ) -------------------------------------------------------------------------------- /mpot/ot/sinkhorn.py: -------------------------------------------------------------------------------- 1 | from typing import ( 2 | TYPE_CHECKING, 3 | Any, 4 | Callable, 5 | Literal, 6 | Mapping, 7 | NamedTuple, 8 | Optional, 9 | Sequence, 10 | Tuple, 11 | Union, 12 | ) 13 | import numpy as np 14 | import torch 15 | from mpot.ot.problem import LinearProblem, SinkhornState 16 | from mpot.ot.initializer import DefaultInitializer, RandomInitializer, SinkhornInitializer 17 | 18 | 19 | class Momentum: 20 | """Momentum for Sinkhorn updates. Adapted from OTT-JAX 21 | """ 22 | 23 | def __init__( 24 | self, 25 | start: int = 0, 26 | error_threshold: float = torch.inf, 27 | value: float = 1.0, 28 | inner_iterations: int = 1, 29 | ) -> None: 30 | self.start = start 31 | self.error_threshold = error_threshold 32 | self.value = value 33 | self.inner_iterations = inner_iterations 34 | 35 | def weight(self, state: SinkhornState, iteration: int) -> float: 36 | if self.start == 0: 37 | return self.value 38 | idx = self.start // self.inner_iterations 39 | 40 | return self.lehmann(state) if iteration >= self.start and state.errors[idx - 1, -1] < self.error_threshold \ 41 | else self.value 42 | 43 | def lehmann(self, state: SinkhornState) -> float: 44 | """See Lehmann, T., Von Renesse, M.-K., Sambale, A., and 45 | Uschmajew, A. (2021). A note on overrelaxation in the 46 | sinkhorn algorithm. Optimization Letters, pages 1–12. eq. 5.""" 47 | idx = self.start // self.inner_iterations 48 | error_ratio = torch.minimum( 49 | state.errors[idx - 1, -1] / state.errors[idx - 2, -1], 0.99 50 | ) 51 | power = 1.0 / self.inner_iterations 52 | return 2.0 / (1.0 + torch.sqrt(1.0 - error_ratio ** power)) 53 | 54 | def __call__( 55 | self, 56 | weight: float, 57 | value: torch.Tensor, 58 | new_value: torch.Tensor 59 | ) -> torch.Tensor: 60 | value = torch.where(torch.isfinite(value), value, 0.0) 61 | return (1.0 - weight) * value + weight * new_value 62 | 63 | 64 | class Sinkhorn: 65 | 66 | def __init__( 67 | self, 68 | threshold: float = 1e-3, 69 | inner_iterations: int = 1, 70 | min_iterations: int = 1, 71 | max_iterations: int = 100, 72 | parallel_dual_updates: bool = False, 73 | initializer: Literal["default", "random"] = "default", 74 | **kwargs: Any, 75 | ): 76 | self.threshold = threshold 77 | self.inner_iterations = inner_iterations 78 | self.min_iterations = min_iterations 79 | self.max_iterations = max_iterations 80 | self.momentum = Momentum(inner_iterations=inner_iterations) 81 | 82 | self.parallel_dual_updates = parallel_dual_updates 83 | self.initializer = initializer 84 | 85 | def __call__( 86 | self, 87 | ot_prob: LinearProblem, 88 | init: Tuple[Optional[torch.Tensor], Optional[torch.Tensor]] = (None, None), 89 | compute_error: bool = True, 90 | ) -> torch.Tensor: 91 | 92 | initializer = self.create_initializer() 93 | init_dual_a, init_dual_b = initializer( 94 | ot_prob, *init 95 | ) 96 | final_state = self.iterations(ot_prob, (init_dual_a, init_dual_b), compute_error=compute_error) 97 | return self.output_from_state(ot_prob, final_state), final_state 98 | 99 | def create_initializer(self) -> SinkhornInitializer: 100 | if isinstance(self.initializer, SinkhornInitializer): 101 | return self.initializer 102 | if self.initializer == "default": 103 | return DefaultInitializer() 104 | if self.initializer == "random": 105 | return RandomInitializer() 106 | raise NotImplementedError( 107 | f"Initializer `{self.initializer}` is not yet implemented." 108 | ) 109 | 110 | def lse_step( 111 | self, ot_prob: LinearProblem, state: SinkhornState, 112 | iteration: int 113 | ) -> SinkhornState: 114 | """Sinkhorn LSE update.""" 115 | 116 | w = self.momentum.weight(state, iteration) 117 | tau_a, tau_b = ot_prob.tau_a, ot_prob.tau_b 118 | old_fu, old_gv = state.fu, state.gv 119 | 120 | # update g potential 121 | new_gv = tau_b * ot_prob.update_potential( 122 | old_fu, old_gv, torch.log(ot_prob.b), iteration, dim=0 123 | ) 124 | gv = self.momentum(w, old_gv, new_gv) 125 | 126 | if not self.parallel_dual_updates: 127 | old_gv = gv 128 | 129 | # update f potential 130 | new_fu = tau_a * ot_prob.update_potential( 131 | old_fu, old_gv, torch.log(ot_prob.a), iteration, dim=1 132 | ) 133 | fu = self.momentum(w, old_fu, new_fu) 134 | 135 | state.fu = fu 136 | state.gv = gv 137 | return state 138 | 139 | def one_iteration( 140 | self, ot_prob: LinearProblem, state: SinkhornState, 141 | iteration: int, compute_error: bool = True 142 | ) -> SinkhornState: 143 | 144 | state = self.lse_step(ot_prob, state, iteration) 145 | 146 | # re-computes error if compute_error is True, else set it to -1. 147 | if compute_error: 148 | err = state.solution_error( 149 | ot_prob, 150 | parallel_dual_updates=self.parallel_dual_updates, 151 | ) 152 | else: 153 | err = -1 154 | state.errors[iteration // self.inner_iterations] = err 155 | return state 156 | 157 | def _converged(self, state: SinkhornState, iteration: int) -> bool: 158 | err = state.errors[iteration // self.inner_iterations - 1] 159 | return iteration > self.min_iterations and err < self.threshold 160 | 161 | def _diverged(self, state: SinkhornState, iteration: int) -> bool: 162 | err = state.errors[iteration // self.inner_iterations - 1] 163 | return not torch.isfinite(err) 164 | 165 | def _continue(self, state: SinkhornState, iteration: int) -> bool: 166 | """Continue while not(converged) and not(diverged).""" 167 | return iteration < self.outer_iterations and not self._converged(state, iteration) and not self._diverged(state, iteration) 168 | 169 | @property 170 | def outer_iterations(self) -> int: 171 | """Upper bound on number of times inner_iterations are carried out. 172 | """ 173 | return np.ceil(self.max_iterations / self.inner_iterations).astype(int) 174 | 175 | def init_state( 176 | self, init: Tuple[torch.Tensor, torch.Tensor] 177 | ) -> SinkhornState: 178 | """Return the initial state of the loop.""" 179 | fu, gv = init 180 | errors = -torch.ones(self.outer_iterations).type_as(fu) 181 | state = SinkhornState(errors=errors, fu=fu, gv=gv) 182 | return state 183 | 184 | def output_from_state( 185 | self, ot_prob: LinearProblem, state: SinkhornState 186 | ) -> torch.Tensor: 187 | """Return the output of the Sinkhorn loop.""" 188 | return ot_prob.transport_from_potentials(state.fu, state.gv) 189 | 190 | def iterations( 191 | self, ot_prob: LinearProblem, init: Tuple[torch.Tensor, torch.Tensor], compute_error: bool = True 192 | ) -> SinkhornState: 193 | state = self.init_state(init) 194 | iteration = 0 195 | while self._continue(state, iteration): 196 | state = self.one_iteration(ot_prob, state, iteration, compute_error=compute_error) 197 | iteration += self.inner_iterations 198 | if self._converged(state, iteration): 199 | state.converged_at = iteration 200 | return state 201 | 202 | 203 | if __name__ == "__main__": 204 | from mpot.ot.problem import Epsilon 205 | from torch_robotics.torch_utils.torch_timer import TimerCUDA 206 | epsilon = Epsilon(target=0.05, init=1., decay=0.8) 207 | ot_prob = LinearProblem( 208 | torch.rand((1000, 1000)), epsilon 209 | ) 210 | sinkhorn = Sinkhorn() 211 | with TimerCUDA() as t: 212 | W, _ = sinkhorn(ot_prob) 213 | print(t.elapsed) 214 | print(W.shape) 215 | -------------------------------------------------------------------------------- /mpot/ot/sinkhorn_step.py: -------------------------------------------------------------------------------- 1 | from typing import ( 2 | TYPE_CHECKING, 3 | Any, 4 | List, 5 | Callable, 6 | Literal, 7 | Mapping, 8 | NamedTuple, 9 | Optional, 10 | Sequence, 11 | Tuple, 12 | Union, 13 | ) 14 | 15 | import torch 16 | from mpot.ot.problem import LinearProblem, Epsilon, SinkhornStepState 17 | from mpot.ot.sinkhorn import Sinkhorn 18 | from mpot.utils.polytopes import POLYTOPE_MAP, get_sampled_polytope_vertices 19 | from mpot.utils.misc import MinMaxCenterScaler 20 | 21 | 22 | class SinkhornStep(): 23 | """Sinkhorn Step solver.""" 24 | 25 | def __init__( 26 | self, 27 | dim: int, 28 | objective_fn: Callable, 29 | linear_ot_solver: Sinkhorn, 30 | epsilon: Union[Epsilon, float] , 31 | ent_epsilon: Union[Epsilon, float] = 0.01, 32 | state_scalers: Optional[List[MinMaxCenterScaler]] = None, 33 | polytope_type: str = 'orthoplex', 34 | scale_cost: float = 1.0, 35 | step_radius: float = 1., 36 | probe_radius: float = 2., 37 | num_probe: int = 5, 38 | min_iterations: int = 5, 39 | max_iterations: int = 50, 40 | threshold: float = 1e-3, 41 | store_inner_errors: bool = False, 42 | store_outer_evals: bool = False, 43 | store_history: bool = False, 44 | tensor_args: Optional[Mapping[str, Any]] = None, 45 | **kwargs: Any, 46 | ) -> None: 47 | if tensor_args is None: 48 | tensor_args = {'device': 'cpu', 'dtype': torch.float32} 49 | self.tensor_args = tensor_args 50 | self.dim = dim 51 | self.objective_fn = objective_fn 52 | 53 | # Sinkhorn Step params 54 | self.linear_ot_solver = linear_ot_solver 55 | self.polytope_vertices = POLYTOPE_MAP[polytope_type](torch.zeros((self.dim,), **tensor_args)) 56 | self.epsilon = epsilon 57 | self.ent_epsilon = ent_epsilon 58 | self.state_scalers = state_scalers 59 | self.step_radius = step_radius 60 | self.probe_radius = probe_radius 61 | self.scale_cost = scale_cost 62 | self.num_probe = num_probe 63 | self.min_iterations = min_iterations 64 | self.max_iterations = max_iterations 65 | self.threshold = threshold 66 | self.store_inner_errors = store_inner_errors 67 | self.store_outer_evals = store_outer_evals 68 | self.store_history = store_history 69 | 70 | # TODO: support non-uniform weights for conditional sinkhorn step 71 | 72 | def init_state( 73 | self, 74 | X_init: torch.Tensor, 75 | ) -> SinkhornStepState: 76 | num_points, dim = X_init.shape 77 | num_iters = self.max_iterations 78 | if self.store_history: 79 | X_history = torch.zeros((num_iters, num_points, self.dim)).type_as(X_init) 80 | else: 81 | X_history = None 82 | 83 | if self.store_outer_evals: 84 | costs = -torch.ones(num_iters).type_as(X_init) 85 | else: 86 | costs = None 87 | 88 | displacement_sqnorms = -torch.ones(num_iters).type_as(X_init) 89 | linear_convergence = -torch.ones(num_iters).type_as(X_init) 90 | 91 | a = torch.ones((num_points,)).type_as(X_init) / num_points # always uniform weights for now 92 | 93 | return SinkhornStepState( 94 | X_init=X_init, 95 | costs=costs, 96 | linear_convergence=linear_convergence, 97 | X_history=X_history, 98 | displacement_sqnorms=displacement_sqnorms, 99 | a=a, 100 | ) 101 | 102 | def step(self, state: SinkhornStepState, iteration: int, **kwargs) -> SinkhornStepState: 103 | """Run Sinkhorn Step.""" 104 | X = state.X.clone() 105 | 106 | # scale state features into same range 107 | if self.state_scalers is not None: 108 | for scaler in self.state_scalers: 109 | scaler(X) 110 | 111 | eps = self.epsilon.at(iteration) if isinstance(self.epsilon, Epsilon) else self.epsilon 112 | self.step_radius = self.step_radius * (1 - eps) 113 | self.probe_radius = self.probe_radius * (1 - eps) 114 | 115 | # compute sampled polytope vertices 116 | X_vertices, X_probe, vertices = get_sampled_polytope_vertices(X, 117 | polytope_vertices=self.polytope_vertices, 118 | step_radius=self.step_radius, 119 | probe_radius=self.probe_radius, 120 | num_probe=self.num_probe) 121 | 122 | # unscale for cost evaluation 123 | if self.state_scalers is not None: 124 | for scaler in self.state_scalers: 125 | scaler.inverse(X_vertices) 126 | scaler.inverse(X_probe) 127 | 128 | # solve Sinkhorn 129 | optim_dim = X_probe.shape[:-1] 130 | C = self.objective_fn(X_probe, current_trajs=state.X, optim_dim=optim_dim, **kwargs) 131 | ot_prob = LinearProblem(C, epsilon=self.ent_epsilon, a=state.a, scaling_cost=True) 132 | W, res = self.linear_ot_solver(ot_prob) 133 | 134 | # barycentric projection 135 | X_new = torch.einsum('bik,bi->bk', X_vertices, W / state.a.unsqueeze(-1)) 136 | 137 | if self.store_outer_evals: 138 | state.costs[iteration] = self.objective_fn.cost(X_new, **kwargs).mean() 139 | 140 | if self.store_history: 141 | state.X_history[iteration] = X_new 142 | 143 | state.linear_convergence[iteration] = res.converged_at 144 | state.displacement_sqnorms[iteration] = torch.square(X_new - state.X).sum() 145 | state.X = X_new 146 | return state 147 | 148 | def _converged(self, state: SinkhornStepState, iteration: int) -> bool: 149 | dqsnorm, i, tol = state.displacement_sqnorms, iteration, self.threshold 150 | return torch.isclose(dqsnorm[i - 2], dqsnorm[i - 1], rtol=tol) 151 | 152 | def _continue(self, state: SinkhornStepState, iteration: int) -> bool: 153 | """Continue while not(converged)""" 154 | return iteration <= self.min_iterations or (not self._converged(state, iteration) and iteration < self.max_iterations) 155 | -------------------------------------------------------------------------------- /mpot/planner.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Optional, Tuple, List, Callable 2 | import torch 3 | import numpy as np 4 | from mpot.ot.sinkhorn_step import SinkhornStep, SinkhornStepState 5 | from mpot.ot.sinkhorn import Sinkhorn 6 | from mpot.utils.misc import MinMaxCenterScaler 7 | from mpot.gp.gp_prior import BatchGPPrior 8 | from mpot.gp.gp_factor import GPFactor 9 | from mpot.gp.unary_factor import UnaryFactor 10 | 11 | 12 | class MPOT(object): 13 | """Batch First-order Trajectory Optimization with Sinkhorn Step.""" 14 | 15 | def __init__( 16 | self, 17 | dim: int, 18 | objective_fn: Callable, 19 | linear_ot_solver: Sinkhorn, 20 | ss_params: dict, 21 | traj_len: int = 64, 22 | num_particles_per_goal: int = 16, 23 | dt: float = 0.02, 24 | start_state: torch.Tensor = None, 25 | multi_goal_states: torch.Tensor = None, 26 | initial_particle_means: torch.Tensor = None, 27 | pos_limits=[-10, 10], 28 | vel_limits=[-10, 10], 29 | polytope: str = 'orthoplex', 30 | fixed_start: bool = True, 31 | fixed_goal: bool = False, 32 | sigma_start_init: float = 0.001, 33 | sigma_goal_init: float = 1., 34 | sigma_gp_init: float = 0.001, 35 | seed: int = 0, 36 | tensor_args=None, 37 | **kwargs 38 | ): 39 | self.dim = dim 40 | self.state_dim = dim * 2 41 | self.traj_len = traj_len 42 | self.dt = dt 43 | self.seed = seed 44 | if tensor_args is None: 45 | tensor_args = {'device': torch.device('cpu'), 'dtype': torch.float32} 46 | self.tensor_args = tensor_args 47 | np.random.seed(seed) 48 | torch.manual_seed(seed) 49 | 50 | self.start_state = start_state 51 | self.multi_goal_states = multi_goal_states 52 | if multi_goal_states is None: # NOTE: if there is no goal, we assume here is at least one goal 53 | self.num_goals = 1 54 | else: 55 | assert multi_goal_states.ndim == 2 56 | self.num_goals = multi_goal_states.shape[0] 57 | self.num_particles_per_goal = num_particles_per_goal 58 | self.num_particles = num_particles_per_goal * self.num_goals 59 | self.polytope = polytope 60 | self.fixed_start = fixed_start 61 | self.fixed_goal = fixed_goal 62 | self.sigma_start_init = sigma_start_init 63 | self.sigma_goal_init = sigma_goal_init 64 | self.sigma_gp_init = sigma_gp_init 65 | self._traj_dist = None 66 | 67 | self.reset(initial_particle_means=initial_particle_means) 68 | # scaling operations 69 | if isinstance(pos_limits, torch.Tensor): 70 | self.pos_limits = pos_limits.clone().to(**self.tensor_args) 71 | else: 72 | self.pos_limits = torch.tensor(pos_limits, **self.tensor_args) 73 | if self.pos_limits.ndim == 1: 74 | self.pos_limits = self.pos_limits.unsqueeze(0).repeat(self.dim, 1) 75 | self.pos_scaler = MinMaxCenterScaler(dim_range=[0, self.dim], min=self.pos_limits[:, 0], max=self.pos_limits[:, 1]) 76 | if isinstance(vel_limits, torch.Tensor): 77 | self.vel_limits = vel_limits.clone().to(**self.tensor_args) 78 | else: 79 | self.vel_limits = torch.tensor(vel_limits, **self.tensor_args) 80 | if self.vel_limits.ndim == 1: 81 | self.vel_limits = self.vel_limits.unsqueeze(0).repeat(self.dim, 1) 82 | self.vel_scaler = MinMaxCenterScaler(dim_range=[self.dim, self.state_dim], min=self.vel_limits[:, 0], max=self.vel_limits[:, 1]) 83 | 84 | # init solver 85 | self.ss_params = ss_params 86 | self.sinkhorn_step = SinkhornStep( 87 | self.state_dim, 88 | objective_fn=objective_fn, 89 | linear_ot_solver=linear_ot_solver, 90 | state_scalers=[self.pos_scaler, self.vel_scaler], 91 | **self.ss_params, 92 | ) 93 | 94 | def reset( 95 | self, 96 | start_state: torch.Tensor = None, 97 | multi_goal_states: torch.Tensor = None, 98 | initial_particle_means: torch.Tensor = None, 99 | ): 100 | if start_state is not None: 101 | self.start_state = start_state.detach().clone() 102 | assert self.start_state.shape[-1] == self.state_dim, "start_state dimension should be dim * 2" 103 | 104 | if multi_goal_states is not None: 105 | self.multi_goal_states = multi_goal_states.detach().clone() 106 | assert self.multi_goal_states.shape[-1] == self.state_dim, "multi_goal_states dimension should be dim * 2" 107 | 108 | self.get_prior_dists(initial_particle_means=initial_particle_means) 109 | 110 | def get_prior_dists(self, initial_particle_means: torch.Tensor = None): 111 | if initial_particle_means is None: 112 | self.init_trajs = self.get_random_trajs() 113 | else: 114 | self.init_trajs = initial_particle_means.clone() 115 | self.flatten_trajs = self.init_trajs.flatten(0, 1) 116 | 117 | def get_GP_prior( 118 | self, 119 | start_K: torch.Tensor, 120 | gp_K: torch.Tensor, 121 | goal_K: torch.Tensor, 122 | state_init: torch.Tensor, 123 | particle_means: torch.Tensor = None, 124 | goal_states: torch.Tensor = None, 125 | tensor_args=None, 126 | ) -> BatchGPPrior: 127 | if tensor_args is None: 128 | tensor_args = self.tensor_args 129 | return BatchGPPrior( 130 | self.traj_len - 1, 131 | self.dt, 132 | self.dim, 133 | start_K, 134 | gp_K, 135 | state_init, 136 | means=particle_means, 137 | K_g_inv=goal_K, 138 | goal_states=goal_states, 139 | tensor_args=tensor_args, 140 | ) 141 | 142 | def get_random_trajs(self) -> torch.Tensor: 143 | # force torch.float64 144 | tensor_args = dict(dtype=torch.float64, device=self.tensor_args['device']) 145 | # set zero velocity for GP prior 146 | start_state = self.start_state.to(**tensor_args) 147 | if self.multi_goal_states is not None: 148 | multi_goal_states = self.multi_goal_states.to(**tensor_args) 149 | else: 150 | multi_goal_states = None 151 | #========= Initialization factors =============== 152 | self.start_prior_init = UnaryFactor( 153 | self.dim * 2, 154 | self.sigma_start_init, 155 | start_state, 156 | tensor_args=tensor_args, 157 | ) 158 | 159 | self.gp_prior_init = GPFactor( 160 | self.dim, 161 | self.sigma_gp_init, 162 | self.dt, 163 | self.traj_len - 1, 164 | tensor_args=tensor_args, 165 | ) 166 | 167 | self.multi_goal_prior_init = [] 168 | if multi_goal_states is not None: 169 | for i in range(self.num_goals): 170 | self.multi_goal_prior_init.append( 171 | UnaryFactor( 172 | self.dim * 2, 173 | self.sigma_goal_init, 174 | multi_goal_states[i], 175 | tensor_args=tensor_args, 176 | ) 177 | ) 178 | self._traj_dist = self.get_GP_prior( 179 | self.start_prior_init.K, 180 | self.gp_prior_init.Q_inv[0], 181 | self.multi_goal_prior_init[0].K if multi_goal_states is not None else None, 182 | start_state[..., :self.dim], 183 | goal_states=multi_goal_states[..., :self.dim], 184 | tensor_args=tensor_args, 185 | ) 186 | particles = self._traj_dist.sample(self.num_particles_per_goal).to(**tensor_args) 187 | self.traj_dim = particles.shape 188 | del self._traj_dist # free memory 189 | return particles.flatten(0, 1).to(**self.tensor_args) 190 | 191 | def optimize(self) -> Tuple[torch.Tensor, SinkhornStepState, int]: 192 | state = self.sinkhorn_step.init_state(self.flatten_trajs) 193 | iteration = 0 194 | while self.sinkhorn_step._continue(state, iteration): 195 | state = self.sinkhorn_step.step(state, iteration, traj_dim=self.traj_dim) 196 | trajs = state.X.view(self.traj_dim) 197 | # option to hard fixing start and goal states 198 | if self.fixed_start: 199 | trajs[:, :, 0, :] = self.start_state 200 | if self.fixed_goal: 201 | trajs[:, :, -1, :] = self.multi_goal_states.unsqueeze(1) 202 | state.X = trajs.view(-1, self.state_dim) 203 | iteration += 1 204 | 205 | trajs = state.X.view(self.traj_dim) 206 | return trajs, state, iteration 207 | -------------------------------------------------------------------------------- /mpot/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anindex/mpot/a5c49d4fa83d7601bcd91246e07e254b813dc322/mpot/utils/__init__.py -------------------------------------------------------------------------------- /mpot/utils/misc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Any, Optional, List, Tuple 3 | 4 | 5 | # min max scaler 6 | class MinMaxScaler(): 7 | def __init__(self, min: float = None, max: float = None): 8 | self.min = min 9 | self.max = max 10 | 11 | def __call__(self, X: torch.Tensor) -> torch.Tensor: 12 | if self.min is None: 13 | self.min = X.min() 14 | if self.max is None: 15 | self.max = X.max() 16 | return (X - self.min) / (self.max - self.min) 17 | 18 | def inverse(self, X: torch.Tensor) -> torch.Tensor: 19 | return X * (self.max - self.min) + self.min 20 | 21 | 22 | # scale to [-1, 1] 23 | class MinMaxCenterScaler(): 24 | def __init__(self, dim_range: List[float], min: float = -1, max: float = 1): 25 | self.dim_range = dim_range 26 | self.dim = dim_range[1] - dim_range[0] 27 | self.min = min 28 | self.max = max 29 | 30 | def __call__(self, X: torch.Tensor): 31 | X[..., self.dim_range[0]:self.dim_range[1]] = 2 * (X[..., self.dim_range[0]:self.dim_range[1]] - self.min) / (self.max - self.min) - 1 32 | 33 | def inverse(self, X: torch.Tensor): 34 | X[..., self.dim_range[0]:self.dim_range[1]] = (X[..., self.dim_range[0]:self.dim_range[1]] + 1) * (self.max - self.min) / 2 + self.min 35 | 36 | 37 | # min max mean scaler 38 | class MinMaxMeanScaler(): 39 | '''For torch tensors, NOTE: all are in-place operations''' 40 | def __init__(self, dim_range: List[float], min: float = -1, max: float = 1, mean: torch.Tensor = None): 41 | self.min = min 42 | self.max = max 43 | self.mean = mean 44 | self.dim_range = dim_range 45 | self.dim = dim_range[1] - dim_range[0] 46 | 47 | def __call__(self, X: torch.Tensor): 48 | if self.mean is None: 49 | self.mean = X[..., self.dim_range[0]:self.dim_range[1]].view((-1, self.dim)).mean(0) 50 | X[..., self.dim_range[0]:self.dim_range[1]] = (X[..., self.dim_range[0]:self.dim_range[1]] - self.mean) / (self.max - self.min) 51 | 52 | def inverse(self, X: torch.Tensor): 53 | # clamp min max of input 54 | # X[..., self.dim_range[0]:self.dim_range[1]] = torch.clamp(X[..., self.dim_range[0]:self.dim_range[1]], -1., 1.) 55 | X[..., self.dim_range[0]:self.dim_range[1]] = X[..., self.dim_range[0]:self.dim_range[1]] * (self.max - self.min) + self.mean 56 | 57 | 58 | # STANDARD SCALER 59 | class StandardScaler(): 60 | def __init__(self, mean, std): 61 | self.mean = mean 62 | self.std = std 63 | 64 | def __call__(self, X: torch.Tensor) -> torch.Tensor: 65 | return (X - self.mean) / self.std 66 | 67 | def inverse(self, X: torch.Tensor) -> torch.Tensor: 68 | return X * self.std + self.mean 69 | -------------------------------------------------------------------------------- /mpot/utils/polytopes.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Any, Optional, Tuple 3 | import numpy as np 4 | from itertools import product 5 | import math 6 | from mpot.utils.probe import get_random_probe_points, get_probe_points 7 | from mpot.utils.rotation import get_random_maximal_torus_matrix 8 | 9 | 10 | def get_cube_vertices(origin: torch.Tensor, radius: float = 1., **kwargs) -> torch.Tensor: 11 | dim = origin.shape[-1] 12 | points = torch.tensor(list(product([1, -1], repeat=dim))) / np.sqrt(dim) 13 | points = points.type_as(origin) * radius + origin 14 | return points 15 | 16 | 17 | def get_orthoplex_vertices(origin: torch.Tensor, radius: float = 1., **kwargs) -> torch.Tensor: 18 | dim = origin.shape[-1] 19 | points = torch.zeros((2 * dim, dim)).type_as(origin) 20 | first = torch.arange(0, dim) 21 | second = torch.arange(dim, 2 * dim) 22 | points[first, first] = radius 23 | points[second, first] = -radius 24 | points = points + origin 25 | return points 26 | 27 | 28 | def get_simplex_vertices(origin: torch.Tensor, radius: float = 1., **kwargs) -> torch.Tensor: 29 | ''' 30 | Simplex coordinates: https://en.wikipedia.org/wiki/Simplex#Cartesian_coordinates_for_a_regular_n-dimensional_simplex_in_Rn 31 | ''' 32 | dim = origin.shape[-1] 33 | points = math.sqrt(1 + 1/dim) * torch.eye(dim) - ((math.sqrt(dim + 1) + 1) / math.sqrt(dim ** 3)) * torch.ones((dim, dim)) 34 | points = torch.concatenate([points, (1 / math.sqrt(dim)) * torch.ones((1, dim))], dim=0) 35 | points = points.type_as(origin) * radius + origin 36 | return points 37 | 38 | 39 | def get_sampled_polytope_vertices(origin: torch.Tensor, 40 | polytope_vertices: torch.Tensor, 41 | step_radius: float = 1., 42 | probe_radius: float = 2., 43 | num_probe: int = 5, 44 | **kwargs) -> Tuple[torch.Tensor]: 45 | if origin.ndim == 1: 46 | origin = origin.unsqueeze(0) 47 | batch, dim = origin.shape 48 | polytope_vertices = polytope_vertices.repeat(batch, 1, 1) # [batch, num_vertices, dim] 49 | 50 | # batch random polytope 51 | maximal_torus_mat = get_random_maximal_torus_matrix(origin) 52 | polytope_vertices = polytope_vertices @ maximal_torus_mat 53 | step_points = polytope_vertices * step_radius + origin.unsqueeze(1) # [batch, num_vertices, dim] 54 | probe_points = get_probe_points(origin, polytope_vertices, probe_radius, num_probe) # [batch, num_vertices, num_probe, dim] 55 | return step_points, probe_points, polytope_vertices 56 | 57 | 58 | def get_sampled_points_on_sphere(origin: torch.Tensor, 59 | step_radius: float = 1., 60 | probe_radius: float = 2., 61 | num_probe: int = 5, 62 | num_sphere_point: int = 50, 63 | random_probe: bool = False, **kwargs) -> Tuple[torch.Tensor]: 64 | if origin.ndim == 1: 65 | origin = origin.unsqueeze(0) 66 | batch, dim = origin.shape 67 | # marsaglia method 68 | points = torch.randn(batch, num_sphere_point, dim).type_as(origin) # [batch, num_points, dim] 69 | points = points / points.norm(dim=-1, keepdim=True) 70 | step_points = points * step_radius + origin.unsqueeze(1) # [batch, num_points, dim] 71 | if random_probe: 72 | probe_points = get_random_probe_points(origin, points, probe_radius, num_probe) 73 | else: 74 | probe_points = get_probe_points(origin, points, probe_radius, num_probe) # [batch, 2 * dim, num_probe, dim] 75 | return step_points, probe_points, points 76 | 77 | 78 | POLYTOPE_MAP = { 79 | 'cube': get_cube_vertices, 80 | 'orthoplex': get_orthoplex_vertices, 81 | 'simplex': get_simplex_vertices, 82 | } 83 | 84 | POLYTOPE_NUM_VERTICES_MAP = { 85 | 'cube': lambda dim: 2 ** dim, 86 | 'orthoplex': lambda dim: 2 * dim, 87 | 'simplex': lambda dim: dim + 1, 88 | } 89 | 90 | SAMPLE_POLYTOPE_MAP = { 91 | 'polytope': get_sampled_polytope_vertices, 92 | 'random': get_sampled_points_on_sphere, 93 | } 94 | -------------------------------------------------------------------------------- /mpot/utils/probe.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def get_random_probe_points(origin: torch.Tensor, 5 | points: torch.Tensor, 6 | probe_radius: float = 2., 7 | num_probe: int = 5) -> torch.Tensor: 8 | batch, num_points, dim = points.shape 9 | alpha = torch.rand(batch, num_points, num_probe, 1).type_as(points) 10 | probe_points = points * probe_radius 11 | probe_points = probe_points.unsqueeze(-2) * alpha + origin.unsqueeze(1).unsqueeze(1) # [batch, num_points, num_probe, dim] 12 | return probe_points 13 | 14 | 15 | def get_probe_points(origin: torch.Tensor, 16 | points: torch.Tensor, 17 | probe_radius: float = 2., 18 | num_probe: int = 5) -> torch.Tensor: 19 | alpha = torch.linspace(0, 1, num_probe + 2).type_as(points)[1:num_probe + 1].view(1, 1, -1, 1) 20 | probe_points = points * probe_radius 21 | probe_points = probe_points.unsqueeze(-2) * alpha + origin.unsqueeze(1).unsqueeze(1) # [batch, num_points, num_probe, dim] 22 | return probe_points 23 | 24 | 25 | def get_shifted_points(new_origins: torch.Tensor, points: torch.Tensor) -> torch.Tensor: 26 | ''' 27 | Args: 28 | new_origins: [no, dim] 29 | points: [nb, dim] 30 | Returns: 31 | shifted_points: [no, nb, dim] 32 | ''' 33 | # asumming points has centroid at origin 34 | shifted_points = points + new_origins.unsqueeze(1) 35 | return shifted_points 36 | 37 | 38 | def get_projecting_points(X1: torch.Tensor, X2: torch.Tensor, probe_step_size: float, num_probe: int = 5) -> torch.Tensor: 39 | ''' 40 | X1: [nb1 x dim] 41 | X2: [nb2 x dim] or [nb1 x nb2 x dim] 42 | return [nb1 x nb2 x num_probe x dim] 43 | ''' 44 | if X2.ndim == 2: 45 | X1 = X1.unsqueeze(1).unsqueeze(-2) 46 | X2 = X2.unsqueeze(0).unsqueeze(-2) 47 | elif X2.ndim == 3: 48 | assert X2.shape[0] == X1.shape[0] 49 | X1 = X1.unsqueeze(1).unsqueeze(-2) 50 | X2 = X2.unsqueeze(-2) 51 | alpha = torch.arange(1, num_probe + 1).type_as(X1) * probe_step_size 52 | alpha = alpha.view(1, 1, -1, 1) 53 | points = X1 + (X2 - X1) * alpha 54 | return points 55 | 56 | 57 | if __name__ == '__main__': 58 | X1 = torch.tensor([ 59 | [0, 0], 60 | [2, 2] 61 | ], dtype=torch.float32) 62 | X2 = torch.tensor([ 63 | [0, 2], 64 | [2, 4] 65 | ], dtype=torch.float32) 66 | print(get_projecting_points(X1, X2, 0.5, 1)) 67 | -------------------------------------------------------------------------------- /mpot/utils/rotation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def rotation_matrix(theta: torch.Tensor) -> torch.Tensor: 5 | theta = theta.unsqueeze(1).unsqueeze(1) 6 | dim1 = torch.cat([torch.cos(theta), -torch.sin(theta)], dim=2) 7 | dim2 = torch.cat([torch.sin(theta), torch.cos(theta)], dim=2) 8 | mat = torch.cat([dim1, dim2], dim=1) 9 | return mat 10 | 11 | 12 | def get_random_maximal_torus_matrix(origin: torch.Tensor, 13 | angle_range=[0, 2 * torch.pi], **kwargs) -> torch.Tensor: 14 | batch, dim = origin.shape 15 | assert dim % 2 == 0, 'Only work with even dim for random rotation for now.' 16 | theta = torch.rand(dim // 2, batch).type_as(origin) * (angle_range[1] - angle_range[0]) + angle_range[0] # [batch, dim // 2] 17 | rot_mat = torch.vmap(rotation_matrix)(theta).transpose(0, 1) 18 | # make batch block diag 19 | max_torus_mat = torch.diag_embed(rot_mat[:, :, [0, 1], [0, 1]].flatten(-2, -1), offset=0) 20 | even, odd = torch.arange(0, dim, 2), torch.arange(1, dim, 2) 21 | max_torus_mat[:, even, odd] = rot_mat[:, :, 0, 1] 22 | max_torus_mat[:, odd, even] = rot_mat[:, :, 1, 0] 23 | return max_torus_mat 24 | 25 | 26 | def get_random_uniform_rot_matrix(origin: torch.Tensor, 27 | **kwargs) -> torch.Tensor: 28 | """Compute a uniformly random rotation matrix drawn from the Haar distribution 29 | (the only uniform distribution on SO(n)). This is less efficient than maximal torus. 30 | See: Stewart, G.W., "The efficient generation of random orthogonal 31 | matrices with an application to condition estimators", SIAM Journal 32 | on Numerical Analysis, 17(3), pp. 403-409, 1980. 33 | For more information see 34 | http://en.wikipedia.org/wiki/Orthogonal_matrix#Randomization""" 35 | 36 | batch, dim = origin.shape 37 | H = torch.eye(dim).repeat(batch, 1, 1).type_as(origin) 38 | D = torch.ones((batch, dim)).type_as(origin) 39 | for i in range(1, dim): 40 | v = torch.normal(0, 1., size=(batch, dim - i + 1)).type_as(origin) 41 | D[:, i - 1] = torch.sign(v[:, 0]) 42 | v[:, 0] -= D[:, i - 1] * torch.norm(v, dim=-1) 43 | # Householder transformation 44 | outer = v.unsqueeze(-2) * v.unsqueeze(-1) 45 | Hx = torch.eye(dim - i + 1).repeat(batch, 1, 1).type_as(origin) - 2 * outer / torch.square(v).sum(dim=-1).unsqueeze(-1).unsqueeze(-1) 46 | T = torch.eye(dim).repeat(batch, 1, 1).type_as(origin) 47 | T[:, i - 1:, i - 1:] = Hx 48 | H = torch.matmul(H, T) 49 | # Fix the last sign such that the determinant is 1 50 | D[:, -1] = (-1)**(1 - dim % 2) * D.prod(dim=-1) 51 | R = (D.unsqueeze(-1) * H.mT).mT 52 | return R 53 | 54 | 55 | if __name__ == '__main__': 56 | from torch_robotics.torch_utils.torch_timer import TimerCUDA 57 | origin = torch.zeros(100, 8).cuda() 58 | with TimerCUDA() as t: 59 | rot_mat = get_random_maximal_torus_matrix(origin) 60 | print(t.elapsed) 61 | print(rot_mat.shape) 62 | with TimerCUDA() as t: 63 | rot_mat = get_random_uniform_rot_matrix(origin) 64 | print(t.elapsed) 65 | print(rot_mat.shape) 66 | print(torch.det(rot_mat)) 67 | -------------------------------------------------------------------------------- /mpot/utils/trajectory.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def interpolate_trajectory(trajs: torch.Tensor, num_interpolation: int = 3) -> torch.Tensor: 5 | # Interpolates a trajectory linearly between waypoints 6 | dim = trajs.shape[-1] 7 | if num_interpolation > 0: 8 | assert trajs.ndim > 1 9 | traj_dim = trajs.shape 10 | alpha = torch.linspace(0, 1, num_interpolation + 2).type_as(trajs)[1:num_interpolation + 1] 11 | alpha = alpha.view((1,) * len(traj_dim[:-1]) + (-1, 1)) 12 | interpolated_trajs = trajs[..., 0:traj_dim[-2] - 1, None, :] * alpha + \ 13 | trajs[..., 1:traj_dim[-2], None, :] * (1 - alpha) 14 | interpolated_trajs = interpolated_trajs.view(traj_dim[:-2] + (-1, dim)) 15 | else: 16 | interpolated_trajs = trajs 17 | return interpolated_trajs 18 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | scipy 2 | numpy 3 | matplotlib 4 | labmaze 5 | torch 6 | einops 7 | torch_robotics @ git+https://github.com/anindex/torch_robotics.git@68d28399ff67f34ca3be0130acdc7cba2fabd61b 8 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | from codecs import open 3 | from os import path 4 | 5 | 6 | ext_modules = [] 7 | 8 | here = path.abspath(path.dirname(__file__)) 9 | requires_list = [] 10 | with open(path.join(here, 'requirements.txt'), encoding='utf-8') as f: 11 | for line in f: 12 | requires_list.append(str(line)) 13 | 14 | setuptools.setup( 15 | name='mpot', 16 | description="Implementation of MPOT in PyTorch", 17 | author="An T. Le", 18 | author_email="an@robot-learning.de", 19 | packages=setuptools.find_namespace_packages(), 20 | install_requires=requires_list, 21 | ) 22 | --------------------------------------------------------------------------------