├── .gitignore
├── LICENSE
├── README.md
├── demos
├── occupancy.gif
├── panda.gif
└── sdf_grid.gif
├── examples
├── mpot_occupancy.py
├── mpot_panda.py
└── mpot_sdf.py
├── mpot
├── __init__.py
├── costs.py
├── envs
│ ├── __init__.py
│ ├── map_generator.py
│ ├── obst_map.py
│ ├── obst_utils.py
│ └── occupancy.py
├── gp
│ ├── LICENSE
│ ├── __init__.py
│ ├── field_factor.py
│ ├── gp_factor.py
│ ├── gp_prior.py
│ └── unary_factor.py
├── ot
│ ├── __init__.py
│ ├── initializer.py
│ ├── problem.py
│ ├── sinkhorn.py
│ └── sinkhorn_step.py
├── planner.py
└── utils
│ ├── __init__.py
│ ├── misc.py
│ ├── polytopes.py
│ ├── probe.py
│ ├── rotation.py
│ └── trajectory.py
├── requirements.txt
└── setup.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | #.idea/
161 | data/
162 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 An Thai Le
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Accelerating Motion Planning via Optimal Transport
2 |
3 | This repository implements Motion Planning via Optimal Transport `mpot` in PyTorch.
4 | The philosophy of `mpot` follows the Monte Carlo methods' argument, i.e., more samples could discover more better modes with high enough initialization variances.
5 | In other words, within the multi-modal motion planning scope, `mpot` enables better **brute-force** planning with GPU vectorization. This enhances robustness against bad local minima, a common issue in optimization-based motion planning.
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 | For those interested in standalone Sinkhorn Step as a general-purpose batch gradient-free solver for non-convex optimization problems, please check out [ssax](https://github.com/anindex/ssax).
14 |
15 | ## Paper Preprint
16 |
17 | This work has been accepted to NeurIPS 2023. Please find the pre-print here:
18 |
19 | [
](https://www.ias.informatik.tu-darmstadt.de/uploads/Team/AnThaiLe/mpot_preprint.pdf)
20 |
21 | ## Installation
22 |
23 | Simply activate your conda/Python environment, navigate to `mpot` root directory and run
24 |
25 | ```azure
26 | pip install -e .
27 | ```
28 |
29 | `mpot` algorithm is specifically designed to work with GPU. Please check if you have installed PyTorch with the CUDA option.
30 |
31 | ## Examples
32 |
33 | Please find in `examples/` folder the demo of vectorized planning in planar environments with occupancy map:
34 |
35 | ```azure
36 | python examples/mpot_occupancy.py
37 | ```
38 |
39 | and with signed-distance-field (SDF):
40 |
41 | ```azure
42 | python examples/mpot_sdf.py
43 | ```
44 |
45 | We also added a demo with vectorized Panda planning with dense obstacle environments (SDF):
46 |
47 | ```azure
48 | python examples/mpot_panda.py
49 | ```
50 |
51 | Every run is associated with **a different seed**. The resulting optimization visualizations are stored at your current directory.
52 | Please refer to the example scripts for playing around with options and different goal points. Note that for all cases, we normalize the joint space to the joint limits and velocity limits, then perform Sinkhorn Step on the normalized state-space. Changing any hyperparameters may require tuning again.
53 |
54 | **Tuning Tips**: The most sensitive parameters are:
55 |
56 | - `polytope`: for small state-dimension that is less than 10, `cube` is a good choice. For much higer state-dimension, the sensible choices are `orthoplex` or `simplex`.
57 | - `step_radius`: the step size.
58 | - `probe_radius`: the probing radius, which projects towards polytope vertices to compute cost-to-go. Note, `probe_radius` >= `step_radius`.
59 | - `num_probe`: number of probing points along the probe radius. This is critical for optimizing performance, usually 3-5 is enough.
60 | - `epsilon`: decay rate of the step/probe size, usually 0.01-0.05.
61 | - `ent_epsilon`: Sinkhorn entropy regularization, usually 1e-2 to 5e-2 for balancing between optimal coupling's sharpness and speed.
62 | - Various cost term weightings. This depends on your applications.
63 |
64 | ## Troubleshooting
65 |
66 | If you encounter memory problems, try:
67 |
68 | ```azure
69 | export 'PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:512'
70 | ```
71 |
72 | to reduce memory fragmentation.
73 |
74 | ## Acknowledgement
75 |
76 | The Gaussian Process prior implementation is adapted from Sasha Lambert's [`mpc_trajopt`](https://github.com/sashalambert/mpc_trajopt/blob/main/mpc_trajopt/factors/gp_factor.py).
77 |
78 | ## Citation
79 |
80 | If you found this repository useful, please consider citing these references:
81 |
82 | ```azure
83 | @inproceedings{le2023accelerating,
84 | title={Accelerating Motion Planning via Optimal Transport},
85 | author={Le, An T. and Chalvatzaki, Georgia and Biess, Armin and Peters, Jan},
86 | booktitle={Advances in Neural Information Processing Systems (NeurIPS)},
87 | year={2023}
88 | }
89 |
--------------------------------------------------------------------------------
/demos/occupancy.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/anindex/mpot/a5c49d4fa83d7601bcd91246e07e254b813dc322/demos/occupancy.gif
--------------------------------------------------------------------------------
/demos/panda.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/anindex/mpot/a5c49d4fa83d7601bcd91246e07e254b813dc322/demos/panda.gif
--------------------------------------------------------------------------------
/demos/sdf_grid.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/anindex/mpot/a5c49d4fa83d7601bcd91246e07e254b813dc322/demos/sdf_grid.gif
--------------------------------------------------------------------------------
/examples/mpot_occupancy.py:
--------------------------------------------------------------------------------
1 | import os
2 | from pathlib import Path
3 | import time
4 | import matplotlib.pyplot as plt
5 | import torch
6 | from einops._torch_specific import allow_ops_in_compiled_graph # requires einops>=0.6.1
7 |
8 | from mpot.ot.problem import Epsilon
9 | from mpot.ot.sinkhorn import Sinkhorn
10 | from mpot.planner import MPOT
11 | from mpot.costs import CostGPHolonomic, CostField, CostComposite
12 | from mpot.envs.occupancy import EnvOccupancy2D
13 | from mpot.utils.trajectory import interpolate_trajectory
14 |
15 | from torch_robotics.robots.robot_point_mass import RobotPointMass
16 | from torch_robotics.torch_utils.seed import fix_random_seed
17 | from torch_robotics.torch_utils.torch_timer import TimerCUDA
18 | from torch_robotics.torch_utils.torch_utils import get_torch_device
19 | from torch_robotics.tasks.tasks import PlanningTask
20 | from torch_robotics.visualizers.planning_visualizer import PlanningVisualizer
21 |
22 | allow_ops_in_compiled_graph()
23 |
24 |
25 | if __name__ == "__main__":
26 | seed = int(time.time())
27 | fix_random_seed(seed)
28 |
29 | device = get_torch_device()
30 | tensor_args = {'device': device, 'dtype': torch.float32}
31 |
32 | # ---------------------------- Environment, Robot, PlanningTask ---------------------------------
33 | q_limits = torch.tensor([[-10, -10], [10, 10]], **tensor_args)
34 | env = EnvOccupancy2D(
35 | precompute_sdf_obj_fixed=False,
36 | tensor_args=tensor_args
37 | )
38 |
39 | robot = RobotPointMass(
40 | q_limits=q_limits, # joint limits
41 | tensor_args=tensor_args
42 | )
43 |
44 | task = PlanningTask(
45 | env=env,
46 | robot=robot,
47 | ws_limits=q_limits, # workspace limits
48 | obstacle_cutoff_margin=0.05,
49 | tensor_args=tensor_args
50 | )
51 |
52 | # -------------------------------- Params ---------------------------------
53 | # NOTE: these parameters are tuned for this environment
54 | step_radius = 0.15
55 | probe_radius = 0.15 # probe radius >= step radius
56 |
57 | # NOTE: changing polytope may require tuning again
58 | polytope = 'cube' # 'simplex' | 'orthoplex' | 'cube';
59 |
60 | epsilon = 0.01
61 | ent_epsilon = Epsilon(1e-2)
62 | num_probe = 5 # number of probes points for each polytope vertices
63 | num_particles_per_goal = 33 # number of plans per goal
64 | pos_limits = [-10, 10]
65 | vel_limits = [-10, 10]
66 | w_coll = 5e-3 # for tuning the obstacle cost
67 | w_smooth = 1e-7 # for tuning the GP cost: error = w_smooth * || Phi x(t) - x(1+1) ||^2
68 | sigma_gp = 0.1 # for tuning the GP cost: Q_c = sigma_gp^2 * I
69 | sigma_gp_init = 1.6 # for controlling the initial GP variance: Q0_c = sigma_gp_init^2 * I
70 | max_inner_iters = 100 # max inner iterations for Sinkhorn-Knopp
71 | max_outer_iters = 100 # max outer iterations for MPOT
72 |
73 | start_state = torch.tensor([-9, -9, 0., 0.], **tensor_args)
74 |
75 | # NOTE: change goal states here (zero vel goals)
76 | multi_goal_states = torch.tensor([
77 | [0, 9, 0., 0.],
78 | [9, 9, 0., 0.],
79 | [9, 0, 0., 0.]
80 | ], **tensor_args)
81 |
82 | traj_len = 64
83 | dt = 0.1
84 |
85 | #--------------------------------- Cost function ---------------------------------
86 |
87 | cost_coll = CostField(
88 | robot, traj_len,
89 | field=env.occupancy_map,
90 | sigma_coll=1.0,
91 | tensor_args=tensor_args
92 | )
93 | cost_gp = CostGPHolonomic(robot, traj_len, dt, sigma_gp, [0, 1], weight=w_smooth, tensor_args=tensor_args)
94 | cost_func_list = [cost_coll, cost_gp]
95 | weights_cost_l = [w_coll, w_smooth]
96 | cost = CostComposite(
97 | robot, traj_len, cost_func_list,
98 | weights_cost_l=weights_cost_l,
99 | tensor_args=tensor_args
100 | )
101 |
102 | #--------------------------------- MPOT Init ---------------------------------
103 |
104 | linear_ot_solver = Sinkhorn(
105 | threshold=1e-6,
106 | inner_iterations=1,
107 | max_iterations=max_inner_iters,
108 | )
109 | ss_params = dict(
110 | epsilon=epsilon,
111 | ent_epsilon=ent_epsilon,
112 | step_radius=step_radius,
113 | probe_radius=probe_radius,
114 | num_probe=num_probe,
115 | min_iterations=5,
116 | max_iterations=max_outer_iters,
117 | threshold=2e-3,
118 | store_history=True,
119 | tensor_args=tensor_args,
120 | )
121 |
122 | mpot_params = dict(
123 | objective_fn=cost,
124 | linear_ot_solver=linear_ot_solver,
125 | ss_params=ss_params,
126 | dim=2,
127 | traj_len=traj_len,
128 | num_particles_per_goal=num_particles_per_goal,
129 | dt=dt,
130 | start_state=start_state,
131 | multi_goal_states=multi_goal_states,
132 | pos_limits=pos_limits,
133 | vel_limits=vel_limits,
134 | polytope=polytope,
135 | fixed_goal=True,
136 | sigma_start_init=0.001,
137 | sigma_goal_init=0.001,
138 | sigma_gp_init=sigma_gp_init,
139 | seed=seed,
140 | tensor_args=tensor_args,
141 | )
142 | planner = MPOT(**mpot_params)
143 |
144 | #--------------------------------- Optimize ---------------------------------
145 |
146 | with TimerCUDA() as t:
147 | trajs, optim_state, opt_iters = planner.optimize()
148 | int_trajs = interpolate_trajectory(trajs, num_interpolation=3)
149 | colls = env.occupancy_map.get_collisions(int_trajs[..., :2]).any(dim=1)
150 | sinkhorn_iters = optim_state.linear_convergence[:opt_iters]
151 | print(f'Optimization finished at {opt_iters}! Parallelization Quality (GOOD [%]): {(1 - colls.float().mean()) * 100:.2f}')
152 | print(f'Time(s) optim: {t.elapsed} sec')
153 | print(f'Average Sinkhorn Iterations: {sinkhorn_iters.mean():.2f}, min: {sinkhorn_iters.min():.2f}, max: {sinkhorn_iters.max():.2f}')
154 |
155 | # -------------------------------- Visualize ---------------------------------
156 | planner_visualizer = PlanningVisualizer(
157 | task=task,
158 | planner=planner
159 | )
160 |
161 | traj_history = optim_state.X_history[:opt_iters]
162 | traj_history = traj_history.view(opt_iters, -1, traj_len, 4)
163 | base_file_name = Path(os.path.basename(__file__)).stem
164 | pos_trajs_iters = robot.get_position(traj_history)
165 |
166 | planner_visualizer.animate_opt_iters_joint_space_state(
167 | trajs=traj_history,
168 | pos_start_state=start_state,
169 | vel_start_state=torch.zeros_like(start_state),
170 | video_filepath=f'{base_file_name}-joint-space-opt-iters.mp4',
171 | n_frames=max((2, opt_iters // 5)),
172 | anim_time=5
173 | )
174 |
175 | planner_visualizer.animate_opt_iters_robots(
176 | trajs=pos_trajs_iters, start_state=start_state,
177 | video_filepath=f'{base_file_name}-traj-opt-iters.mp4',
178 | n_frames=max((2, opt_iters // 5)),
179 | anim_time=5
180 | )
181 |
182 | plt.show()
183 |
--------------------------------------------------------------------------------
/examples/mpot_panda.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | from pathlib import Path
4 | import time
5 | import matplotlib.pyplot as plt
6 | import torch
7 | from einops._torch_specific import allow_ops_in_compiled_graph # requires einops>=0.6.1
8 |
9 | from mpot.ot.problem import Epsilon
10 | from mpot.ot.sinkhorn import Sinkhorn
11 | from mpot.planner import MPOT
12 | from mpot.costs import CostGPHolonomic, CostField, CostComposite
13 |
14 | from torch_robotics.environments.env_spheres_3d import EnvSpheres3D
15 | from torch_robotics.robots.robot_panda import RobotPanda
16 | from torch_robotics.tasks.tasks import PlanningTask
17 | from torch_robotics.torch_utils.seed import fix_random_seed
18 | from torch_robotics.torch_utils.torch_timer import TimerCUDA
19 | from torch_robotics.torch_utils.torch_utils import get_torch_device
20 | from torch_robotics.visualizers.planning_visualizer import PlanningVisualizer
21 |
22 | allow_ops_in_compiled_graph()
23 |
24 |
25 | if __name__ == "__main__":
26 | base_file_name = Path(os.path.basename(__file__)).stem
27 |
28 | seed = int(time.time())
29 | fix_random_seed(seed)
30 |
31 | device = get_torch_device()
32 | tensor_args = {'device': device, 'dtype': torch.float32}
33 |
34 | # ---------------------------- Environment, Robot, PlanningTask ---------------------------------
35 | env = EnvSpheres3D(
36 | precompute_sdf_obj_fixed=False,
37 | sdf_cell_size=0.01,
38 | tensor_args=tensor_args
39 | )
40 |
41 | robot = RobotPanda(
42 | use_self_collision_storm=False,
43 | tensor_args=tensor_args
44 | )
45 |
46 | task = PlanningTask(
47 | env=env,
48 | robot=robot,
49 | ws_limits=torch.tensor([[-1.5, -1.5, -1.5], [1.5, 1.5, 1.5]], **tensor_args), # workspace limits
50 | obstacle_cutoff_margin=0.03,
51 | tensor_args=tensor_args
52 | )
53 |
54 | # -------------------------------- Params ---------------------------------
55 | for _ in range(100):
56 | q_free = task.random_coll_free_q(n_samples=2)
57 | start_state = q_free[0]
58 | goal_state = q_free[1]
59 |
60 | # check if the EE positions are "enough" far apart
61 | start_state_ee_pos = robot.get_EE_position(start_state).squeeze()
62 | goal_state_ee_pos = robot.get_EE_position(goal_state).squeeze()
63 |
64 | if torch.linalg.norm(start_state_ee_pos - goal_state_ee_pos) > 0.5:
65 | break
66 |
67 | # start_state = torch.tensor([1.0403, 0.0493, 0.0251, -1.2673, 1.6676, 3.3611, -1.5428], **tensor_args)
68 | # goal_state = torch.tensor([1.1142, 1.7289, -0.1771, -0.9284, 2.7171, 1.2497, 1.7724], **tensor_args)
69 |
70 | print('Start state: ', start_state)
71 | print('Goal state: ', goal_state)
72 | start_state = torch.concatenate((start_state, torch.zeros_like(start_state)))
73 | goal_state = torch.concatenate((goal_state, torch.zeros_like(goal_state)))
74 | multi_goal_states = goal_state.unsqueeze(0)
75 |
76 | # Construct planner
77 | duration = 5 # sec
78 | traj_len = 64
79 | dt = duration / traj_len
80 | num_particles_per_goal = 10 # number of plans per goal # NOTE: if memory is not enough, reduce this number
81 |
82 | # NOTE: these parameters are tuned for this environment
83 | step_radius = 0.03
84 | probe_radius = 0.08 # probe radius >= step radius
85 |
86 | # NOTE: changing polytope may require tuning again
87 | # NOTE: cube in this case could lead to memory insufficiency, depending how many plans are optimized
88 | polytope = 'cube' # 'simplex' | 'orthoplex' | 'cube';
89 |
90 | epsilon = 0.02
91 | ent_epsilon = Epsilon(1e-2)
92 | num_probe = 3 # number of probes points for each polytope vertices
93 | # panda joint limits
94 | q_max = torch.tensor([2.7437, 1.7837, 2.9007, -0.1518, 2.8065, 4.5169, 3.0159], **tensor_args)
95 | q_min = torch.tensor([-2.7437, -1.7837, -2.9007, -3.0421, -2.8065, 0.5445, -3.0159], **tensor_args)
96 | pos_limits = torch.stack([q_min, q_max], dim=1)
97 | vel_limits = [-5, 5]
98 | w_coll = 1e-1 # for tuning the obstacle cost
99 | w_smooth = 1e-7 # for tuning the GP cost: error = w_smooth * || Phi x(t) - x(1+1) ||^2
100 | sigma_gp = 0.01 # for tuning the GP cost: Q_c = sigma_gp^2 * I
101 | sigma_gp_init = 0.5 # for controlling the initial GP variance: Q0_c = sigma_gp_init^2 * I
102 | max_inner_iters = 100 # max inner iterations for Sinkhorn-Knopp
103 | max_outer_iters = 40 # max outer iterations for MPOT
104 |
105 | #--------------------------------- Cost function ---------------------------------
106 |
107 | cost_func_list = []
108 | weights_cost_l = []
109 | for collision_field in task.get_collision_fields():
110 | cost_func_list.append(
111 | CostField(
112 | robot, traj_len,
113 | field=collision_field,
114 | sigma_coll=1.0,
115 | tensor_args=tensor_args
116 | )
117 | )
118 | weights_cost_l.append(w_coll)
119 | cost_gp = CostGPHolonomic(robot, traj_len, dt, sigma_gp, [0, 1], weight=w_smooth, tensor_args=tensor_args)
120 | cost_func_list.append(cost_gp)
121 | weights_cost_l.append(w_smooth)
122 | cost = CostComposite(
123 | robot, traj_len, cost_func_list,
124 | weights_cost_l=weights_cost_l,
125 | tensor_args=tensor_args
126 | )
127 |
128 | #--------------------------------- MPOT Init ---------------------------------
129 |
130 | linear_ot_solver = Sinkhorn(
131 | threshold=1e-3,
132 | inner_iterations=1,
133 | max_iterations=max_inner_iters,
134 | )
135 | ss_params = dict(
136 | epsilon=epsilon,
137 | ent_epsilon=ent_epsilon,
138 | step_radius=step_radius,
139 | probe_radius=probe_radius,
140 | num_probe=num_probe,
141 | min_iterations=5,
142 | max_iterations=max_outer_iters,
143 | threshold=2e-3,
144 | store_history=True,
145 | tensor_args=tensor_args,
146 | )
147 |
148 | mpot_params = dict(
149 | objective_fn=cost,
150 | linear_ot_solver=linear_ot_solver,
151 | ss_params=ss_params,
152 | dim=7,
153 | traj_len=traj_len,
154 | num_particles_per_goal=num_particles_per_goal,
155 | dt=dt,
156 | start_state=start_state,
157 | multi_goal_states=multi_goal_states,
158 | pos_limits=pos_limits,
159 | vel_limits=vel_limits,
160 | polytope=polytope,
161 | fixed_goal=True,
162 | sigma_start_init=0.001,
163 | sigma_goal_init=0.001,
164 | sigma_gp_init=sigma_gp_init,
165 | seed=seed,
166 | tensor_args=tensor_args,
167 | )
168 | planner = MPOT(**mpot_params)
169 |
170 | # Optimize
171 | with TimerCUDA() as t:
172 | trajs, optim_state, opt_iters = planner.optimize()
173 | sinkhorn_iters = optim_state.linear_convergence[:opt_iters]
174 | print(f'Optimization finished at {opt_iters}! Optimization time: {t.elapsed:.3f} sec')
175 | print(f'Average Sinkhorn Iterations: {sinkhorn_iters.mean():.2f}, min: {sinkhorn_iters.min():.2f}, max: {sinkhorn_iters.max():.2f}')
176 |
177 | # -------------------------------- Visualize ---------------------------------
178 | planner_visualizer = PlanningVisualizer(
179 | task=task,
180 | planner=planner
181 | )
182 |
183 | traj_history = optim_state.X_history[:opt_iters]
184 | traj_history = traj_history.view(opt_iters, -1, traj_len, 14) # 7 + 7
185 | pos_trajs_iters = robot.get_position(traj_history)
186 | trajs = trajs.flatten(0, 1)
187 | trajs_coll, trajs_free = task.get_trajs_collision_and_free(trajs)
188 |
189 | planner_visualizer.animate_opt_iters_joint_space_state(
190 | trajs=traj_history,
191 | pos_start_state=start_state, pos_goal_state=goal_state,
192 | vel_start_state=torch.zeros_like(start_state), vel_goal_state=torch.zeros_like(goal_state),
193 | video_filepath=f'{base_file_name}-joint-space-opt-iters.mp4',
194 | n_frames=max((2, opt_iters // 2)),
195 | anim_time=5
196 | )
197 |
198 | if trajs_free is not None:
199 | planner_visualizer.animate_robot_trajectories(
200 | trajs=trajs_free, start_state=start_state, goal_state=goal_state,
201 | plot_trajs=False,
202 | draw_links_spheres=False,
203 | video_filepath=f'{base_file_name}-robot-traj.mp4',
204 | # n_frames=max((2, pos_trajs_iters[-1].shape[1]//10)),
205 | n_frames=trajs_free.shape[-2],
206 | anim_time=duration
207 | )
208 |
209 | plt.show()
210 |
--------------------------------------------------------------------------------
/examples/mpot_sdf.py:
--------------------------------------------------------------------------------
1 | import os
2 | from pathlib import Path
3 | import time
4 | import matplotlib.pyplot as plt
5 | import torch
6 | from einops._torch_specific import allow_ops_in_compiled_graph # requires einops>=0.6.1
7 |
8 | from mpot.ot.problem import Epsilon
9 | from mpot.ot.sinkhorn import Sinkhorn
10 | from mpot.planner import MPOT
11 | from mpot.costs import CostGPHolonomic, CostField, CostComposite
12 |
13 | from torch_robotics.environments.env_grid_circles_2d import EnvGridCircles2D
14 | from torch_robotics.robots.robot_point_mass import RobotPointMass
15 | from torch_robotics.tasks.tasks import PlanningTask
16 | from torch_robotics.torch_utils.seed import fix_random_seed
17 | from torch_robotics.torch_utils.torch_timer import TimerCUDA
18 | from torch_robotics.torch_utils.torch_utils import get_torch_device
19 | from torch_robotics.visualizers.planning_visualizer import PlanningVisualizer
20 |
21 | allow_ops_in_compiled_graph()
22 |
23 |
24 | if __name__ == "__main__":
25 | seed = int(time.time())
26 | fix_random_seed(seed)
27 |
28 | device = get_torch_device()
29 | tensor_args = {'device': device, 'dtype': torch.float32}
30 |
31 | # ---------------------------- Environment, Robot, PlanningTask ---------------------------------
32 | env = EnvGridCircles2D(
33 | precompute_sdf_obj_fixed=False,
34 | sdf_cell_size=0.001,
35 | tensor_args=tensor_args
36 | )
37 |
38 | robot = RobotPointMass(
39 | q_limits=torch.tensor([[-1, -1], [1, 1]], **tensor_args), # joint limits
40 | tensor_args=tensor_args
41 | )
42 |
43 | task = PlanningTask(
44 | env=env,
45 | robot=robot,
46 | ws_limits=torch.tensor([[-0.81, -0.81], [0.95, 0.95]], **tensor_args), # workspace limits
47 | obstacle_cutoff_margin=0.005,
48 | tensor_args=tensor_args
49 | )
50 |
51 | # -------------------------------- Params ---------------------------------
52 |
53 | # NOTE: these parameters are tuned for this environment
54 | step_radius = 0.15
55 | probe_radius = 0.15 # probe radius >= step radius
56 |
57 | # NOTE: changing polytope may require tuning again
58 | polytope = 'cube' # 'simplex' | 'orthoplex' | 'cube';
59 |
60 | epsilon = 0.01
61 | ent_epsilon = Epsilon(1e-2)
62 | num_probe = 5 # number of probes points for each polytope vertices
63 | num_particles_per_goal = 50 # number of plans per goal # NOTE: if memory is not enough, reduce this number
64 | pos_limits = [-1, 1]
65 | vel_limits = [-1, 1]
66 | w_coll = 7e-2 # for tuning the obstacle cost
67 | w_smooth = 1e-7 # for tuning the GP cost: error = w_smooth * || Phi x(t) - x(1+1) ||^2
68 | sigma_gp = 0.03 # for tuning the GP cost: Q_c = sigma_gp^2 * I
69 | sigma_gp_init = 0.8 # for controlling the initial GP variance: Q0_c = sigma_gp_init^2 * I
70 | max_inner_iters = 100 # max inner iterations for Sinkhorn-Knopp
71 | max_outer_iters = 70 # max outer iterations for MPOT
72 | start_state = torch.tensor([-0.8, -0.8, 0., 0.], **tensor_args)
73 | multi_goal_states = torch.tensor([
74 | [0, 0.75, 0., 0.],
75 | [0.75, 0.75, 0., 0.],
76 | [0.75, 0, 0., 0.]
77 | ], **tensor_args)
78 |
79 | traj_len = 64
80 | dt = 0.04
81 |
82 | #--------------------------------- Cost function ---------------------------------
83 |
84 | cost_coll = CostField(
85 | robot, traj_len,
86 | field=task.df_collision_objects,
87 | sigma_coll=1.0,
88 | tensor_args=tensor_args
89 | )
90 | cost_gp = CostGPHolonomic(robot, traj_len, dt, sigma_gp, [0, 1], weight=w_smooth, tensor_args=tensor_args)
91 | cost_func_list = [cost_coll, cost_gp]
92 | weights_cost_l = [w_coll, w_smooth]
93 | cost = CostComposite(
94 | robot, traj_len, cost_func_list,
95 | weights_cost_l=weights_cost_l,
96 | tensor_args=tensor_args
97 | )
98 |
99 | #--------------------------------- MPOT Init ---------------------------------
100 |
101 | linear_ot_solver = Sinkhorn(
102 | threshold=1e-4,
103 | inner_iterations=1,
104 | max_iterations=max_inner_iters,
105 | )
106 | ss_params = dict(
107 | epsilon=epsilon,
108 | ent_epsilon=ent_epsilon,
109 | step_radius=step_radius,
110 | probe_radius=probe_radius,
111 | num_probe=num_probe,
112 | min_iterations=5,
113 | max_iterations=max_outer_iters,
114 | threshold=1e-3,
115 | store_history=True,
116 | tensor_args=tensor_args,
117 | )
118 |
119 | mpot_params = dict(
120 | objective_fn=cost,
121 | linear_ot_solver=linear_ot_solver,
122 | ss_params=ss_params,
123 | dim=2,
124 | traj_len=traj_len,
125 | num_particles_per_goal=num_particles_per_goal,
126 | dt=dt,
127 | start_state=start_state,
128 | multi_goal_states=multi_goal_states,
129 | pos_limits=pos_limits,
130 | vel_limits=vel_limits,
131 | polytope=polytope,
132 | fixed_goal=True,
133 | sigma_start_init=0.001,
134 | sigma_goal_init=0.001,
135 | sigma_gp_init=sigma_gp_init,
136 | seed=seed,
137 | tensor_args=tensor_args,
138 | )
139 | planner = MPOT(**mpot_params)
140 |
141 | # Optimize
142 | with TimerCUDA() as t:
143 | trajs, optim_state, opt_iters = planner.optimize()
144 | sinkhorn_iters = optim_state.linear_convergence[:opt_iters]
145 | print(f'Optimization finished at {opt_iters}! Optimization time: {t.elapsed:.3f} sec')
146 | print(f'Average Sinkhorn Iterations: {sinkhorn_iters.mean():.2f}, min: {sinkhorn_iters.min():.2f}, max: {sinkhorn_iters.max():.2f}')
147 |
148 | # -------------------------------- Visualize ---------------------------------
149 | planner_visualizer = PlanningVisualizer(
150 | task=task,
151 | planner=planner
152 | )
153 |
154 | traj_history = optim_state.X_history[:opt_iters]
155 | traj_history = traj_history.view(opt_iters, -1, traj_len, 4)
156 | base_file_name = Path(os.path.basename(__file__)).stem
157 | pos_trajs_iters = robot.get_position(traj_history)
158 |
159 | planner_visualizer.animate_opt_iters_joint_space_state(
160 | trajs=traj_history,
161 | pos_start_state=start_state,
162 | vel_start_state=torch.zeros_like(start_state),
163 | video_filepath=f'{base_file_name}-joint-space-opt-iters.mp4',
164 | n_frames=max((2, opt_iters // 5)),
165 | anim_time=5
166 | )
167 |
168 | planner_visualizer.animate_opt_iters_robots(
169 | trajs=pos_trajs_iters, start_state=start_state,
170 | video_filepath=f'{base_file_name}-traj-opt-iters.mp4',
171 | n_frames=max((2, opt_iters// 5)),
172 | anim_time=5
173 | )
174 |
175 | plt.show()
176 |
--------------------------------------------------------------------------------
/mpot/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/anindex/mpot/a5c49d4fa83d7601bcd91246e07e254b813dc322/mpot/__init__.py
--------------------------------------------------------------------------------
/mpot/costs.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from typing import Any, Optional, Tuple, List, Callable
3 | import torch
4 | import einops
5 |
6 | from mpot.gp.field_factor import FieldFactor
7 |
8 |
9 | class Cost(ABC):
10 | def __init__(self, robot, n_support_points, tensor_args=None, **kwargs):
11 | self.robot = robot
12 | self.n_dof = robot.q_dim
13 | self.dim = 2 * self.n_dof # position + velocity
14 | self.n_support_points = n_support_points
15 |
16 | self.tensor_args = tensor_args
17 |
18 | def set_cost_factors(self):
19 | pass
20 |
21 | def __call__(self, trajs, **kwargs):
22 | return self.eval(trajs, **kwargs)
23 |
24 | @abstractmethod
25 | def eval(self, trajs, **kwargs):
26 | pass
27 |
28 | def get_q_pos_vel_and_fk_map(self, trajs, **kwargs):
29 | assert trajs.ndim == 3 or trajs.ndim == 4
30 | N = 1
31 | if trajs.ndim == 4:
32 | N, B, H, D = trajs.shape # n_goals (or steps), batch of trajectories, length, dim
33 | trajs = einops.rearrange(trajs, 'N B H D -> (N B) H D')
34 | else:
35 | B, H, D = trajs.shape
36 |
37 | q_pos = self.robot.get_position(trajs)
38 | q_vel = self.robot.get_velocity(trajs)
39 | H_positions = self.robot.fk_map_collision(q_pos) # I, taskspaces, x_dim+1, x_dim+1 (homogeneous transformation matrices)
40 | return trajs, q_pos, q_vel, H_positions
41 |
42 |
43 | class CostField(Cost):
44 |
45 | def __init__(
46 | self,
47 | robot,
48 | n_support_points: int,
49 | field: Callable = None,
50 | sigma: float = 1.0,
51 | **kwargs
52 | ):
53 | super().__init__(robot, n_support_points, **kwargs)
54 | self.field = field
55 | self.sigma = sigma
56 |
57 | self.set_cost_factors()
58 |
59 | def set_cost_factors(self):
60 | # ========= Cost factors ===============
61 | self.obst_factor = FieldFactor(
62 | self.n_dof,
63 | self.sigma,
64 | [0, None]
65 | )
66 |
67 | def cost(self, trajs: torch.Tensor, **observation) -> torch.Tensor:
68 | traj_dim = observation.get('traj_dim', None)
69 | if self.field is None:
70 | return 0
71 | trajs = trajs.view(traj_dim) # [..., traj_len, dim]
72 | states = trajs[..., :self.n_dof]
73 | field_cost = self.field.compute_cost(states, **observation)
74 | return field_cost.view(traj_dim[:-1]).mean(-1) # mean the traj_len
75 |
76 | def eval(self, trajs: torch.Tensor, q_pos=None, q_vel=None, H_positions=None, **observation):
77 | optim_dim = observation.get('optim_dim')
78 | costs = 0
79 | if self.field is not None:
80 | # H_pos = link_pos_from_link_tensor(H) # get translation part from transformation matrices
81 | H_pos = H_positions
82 | err_obst = self.obst_factor.get_error(
83 | trajs,
84 | self.field,
85 | q_pos=q_pos,
86 | q_vel=q_vel,
87 | H_pos=H_pos,
88 | calc_jacobian=False
89 | )
90 | w_mat = self.obst_factor.K
91 | obst_costs = w_mat * err_obst.mean(-1)
92 | costs = obst_costs.reshape(optim_dim[:2])
93 |
94 | return costs
95 |
96 |
97 | class CostGPHolonomic(Cost):
98 |
99 | def __init__(
100 | self,
101 | robot,
102 | n_support_points: int,
103 | dt: float,
104 | sigma: float,
105 | probe_range: Tuple[int, int],
106 | Q_c_inv: torch.Tensor = None,
107 | **kwargs
108 | ):
109 | super().__init__(robot, n_support_points, **kwargs)
110 | self.dt = dt
111 | self.phi = self.calc_phi()
112 | self.phi_T = self.phi.T
113 | if Q_c_inv is None:
114 | Q_c_inv = torch.eye(self.n_dof, **self.tensor_args) / sigma**2
115 | self.Q_c_inv = torch.zeros(self.n_support_points - 1, self.n_dof, self.n_dof, **self.tensor_args) + Q_c_inv
116 | self.Q_inv = self.calc_Q_inv()
117 | self.single_Q_inv = self.Q_inv[[0]]
118 | self.probe_range = probe_range
119 |
120 | def calc_phi(self) -> torch.Tensor:
121 | I = torch.eye(self.n_dof, **self.tensor_args)
122 | Z = torch.zeros(self.n_dof, self.n_dof, **self.tensor_args)
123 | phi_u = torch.cat((I, self.dt * I), dim=1)
124 | phi_l = torch.cat((Z, I), dim=1)
125 | phi = torch.cat((phi_u, phi_l), dim=0)
126 | return phi # [dim, dim]
127 |
128 | def calc_Q_inv(self) -> torch.Tensor:
129 | m1 = 12. * (self.dt ** -3.) * self.Q_c_inv
130 | m2 = -6. * (self.dt ** -2.) * self.Q_c_inv
131 | m3 = 4. * (self.dt ** -1.) * self.Q_c_inv
132 |
133 | Q_inv_u = torch.cat((m1, m2), dim=-1)
134 | Q_inv_l = torch.cat((m2, m3), dim=-1)
135 | Q_inv = torch.cat((Q_inv_u, Q_inv_l), dim=-2)
136 | return Q_inv
137 |
138 | def cost(self, trajs: torch.Tensor, **observation) -> torch.Tensor:
139 | traj_dim = observation.get('traj_dim', None)
140 | trajs = trajs.view(traj_dim) # [..., n_support_points, dim]
141 | errors = (trajs[..., 1:, :] - trajs[..., :-1, :] @ self.phi_T) # [..., n_support_points-1, dim * 2]
142 | costs = torch.einsum('...ij,...ijk,...ik->...i', errors, self.single_Q_inv, errors) # [..., n_support_points-1]
143 | return costs.mean(dim=-1) # mean the n_support_points
144 |
145 | def eval(self, trajs: torch.Tensor, **observation) -> torch.Tensor:
146 | traj_dim = observation.get('traj_dim')
147 | optim_dim = observation.get('optim_dim')
148 |
149 | current_trajs = observation.get('current_trajs')
150 | current_trajs = current_trajs.view(traj_dim) # [..., n_support_points, dim]
151 | current_trajs = current_trajs.unsqueeze(-2).unsqueeze(-2) # [..., n_support_points, 1, 1, dim]
152 |
153 | cost_dim = traj_dim[:-1] + optim_dim[1:3] # [..., n_support_points] + [nb2, num_probe]
154 | costs = torch.zeros(cost_dim, **self.tensor_args)
155 | states = trajs
156 |
157 | probe_points = states[..., self.probe_range[0]:self.probe_range[1], :] # [..., nb2, num_eval, dim * 2]
158 | len_probe = probe_points.shape[-2]
159 | probe_points = probe_points.view(traj_dim[:-1] + (optim_dim[1], len_probe, self.dim,)) # [..., n_support_points] + [nb2, num_eval, dim * 2]
160 | right_errors = probe_points[..., 1:self.n_support_points, :, :, :] - current_trajs[..., 0:self.n_support_points-1, :, :, :] @ self.phi_T # [..., n_support_points-1, nb2, num_eval, dim * 2]
161 | left_errors = current_trajs[..., 1:self.n_support_points, :, :, :] - probe_points[..., 0:self.n_support_points-1, :, :, :] @ self.phi_T # [..., n_support_points-1, nb2, num_eval, dim * 2]
162 | # mahalanobis distance
163 | left_cost_dist = torch.einsum('...ij,...ijk,...ik->...i', left_errors, self.single_Q_inv, left_errors) # [..., n_support_points-1, nb2, num_eval]
164 | right_cost_dist = torch.einsum('...ij,...ijk,...ik->...i', right_errors, self.single_Q_inv, right_errors) # [..., n_support_points-1, nb2, num_eval]
165 |
166 | costs[..., 0:self.n_support_points-1, :, self.probe_range[0]:self.probe_range[1]] += left_cost_dist
167 | costs[..., 1:self.n_support_points, :, self.probe_range[0]:self.probe_range[1]] += right_cost_dist
168 | costs = costs.view(optim_dim).mean(dim=-1) # mean the probe
169 |
170 | return costs
171 |
172 |
173 | class CostComposite(Cost):
174 |
175 | def __init__(
176 | self,
177 | robot,
178 | n_support_points,
179 | cost_list,
180 | weights_cost_l=None,
181 | **kwargs
182 | ):
183 | super().__init__(robot, n_support_points, **kwargs)
184 | self.cost_l = cost_list
185 | self.weight_cost_l = weights_cost_l if weights_cost_l is not None else [1.0] * len(cost_list)
186 |
187 | def eval(self, trajs, trajs_interpolated=None, return_invidual_costs_and_weights=False, **kwargs):
188 | trajs, q_pos, q_vel, H_positions = self.get_q_pos_vel_and_fk_map(trajs)
189 |
190 | if not return_invidual_costs_and_weights:
191 | cost_total = 0
192 | for cost, weight_cost in zip(self.cost_l, self.weight_cost_l):
193 | if trajs_interpolated is not None:
194 | trajs_tmp = trajs_interpolated
195 | else:
196 | trajs_tmp = trajs
197 | cost_tmp = weight_cost * cost(trajs_tmp, q_pos=q_pos, q_vel=q_vel, H_positions=H_positions, **kwargs)
198 | cost_total += cost_tmp
199 | return cost_total
200 | else:
201 | cost_l = []
202 | for cost in self.cost_l:
203 | if trajs_interpolated is not None:
204 | # Compute only collision costs with interpolated trajectories.
205 | # Other costs are computed with non-interpolated trajectories, e.g. smoothness
206 | trajs_tmp = trajs_interpolated
207 | else:
208 | trajs_tmp = trajs
209 |
210 | cost_tmp = cost(trajs_tmp, q_pos=q_pos, q_vel=q_vel, H_positions=H_positions, **kwargs)
211 | cost_l.append(cost_tmp)
212 |
213 | if return_invidual_costs_and_weights:
214 | return cost_l, self.weight_cost_l
215 |
--------------------------------------------------------------------------------
/mpot/envs/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/anindex/mpot/a5c49d4fa83d7601bcd91246e07e254b813dc322/mpot/envs/__init__.py
--------------------------------------------------------------------------------
/mpot/envs/map_generator.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 | from mpot.envs.obst_map import ObstacleRectangle, ObstacleMap, ObstacleCircle
5 | from mpot.envs.obst_utils import random_rect, random_circle
6 | import copy
7 |
8 | from torch_robotics.environments.primitives import MultiSphereField, ObjectField, MultiBoxField
9 |
10 |
11 | def random_obstacles(
12 | map_dim = (1, 1),
13 | cell_size: float = 1.,
14 | num_obst: int = 5,
15 | rand_xy_limits=[[-1, 1], [-1, 1]],
16 | rand_rect_shape=[2, 2],
17 | rand_circle_radius: float = 1,
18 | max_attempts: int = 50,
19 | tensor_args=None,
20 | ):
21 | obst_map = ObstacleMap(map_dim, cell_size, tensor_args=tensor_args)
22 | num_boxes = np.random.randint(0, num_obst)
23 | num_circles = num_obst - num_boxes
24 | # randomize box obstacles
25 | xlim = rand_xy_limits[0]
26 | ylim = rand_xy_limits[1]
27 | width, height = rand_rect_shape
28 |
29 | boxes = []
30 | for _ in range(num_boxes):
31 | num_attempts = 0
32 | while num_attempts <= max_attempts:
33 | obst = random_rect(xlim, ylim, width, height)
34 |
35 | # Check validity of new obstacle
36 | # Do not overlap obstacles
37 | valid = obst._obstacle_collision_check(obst_map)
38 | if valid:
39 | # Add to Map
40 | obst._add_to_map(obst_map)
41 | # Add to list
42 | boxes.append(obst.to_array())
43 | break
44 |
45 | if num_attempts == max_attempts:
46 | print("Obstacle generation: Max. number of attempts reached. ")
47 | print(f"Total num. boxes: {len(boxes)}")
48 | num_attempts += 1
49 | boxes = torch.tensor(np.array(boxes), **tensor_args)
50 | cubes = MultiBoxField(boxes[:, :2], boxes[:, 2:], tensor_args=tensor_args)
51 | box_field = ObjectField([cubes], 'random-boxes')
52 |
53 | # randomize circle obstacles
54 | circles = []
55 | for _ in range(num_circles):
56 | num_attempts = 0
57 | while num_attempts <= max_attempts:
58 | obst = random_circle(xlim, ylim, rand_circle_radius)
59 | # Check validity of new obstacle
60 | # Do not overlap obstacles
61 | valid = obst._obstacle_collision_check(obst_map)
62 |
63 | if valid:
64 | # Add to Map
65 | obst._add_to_map(obst_map)
66 | # Add to list
67 | circles.append(obst.to_array())
68 | break
69 |
70 | if num_attempts == max_attempts:
71 | print("Obstacle generation: Max. number of attempts reached. ")
72 | print(f"Total num. boxes: {len(circles)}")
73 |
74 | num_attempts += 1
75 | circles = torch.tensor(np.array(circles), **tensor_args)
76 | spheres = MultiSphereField(circles[:, :2], circles[:, 2], tensor_args=tensor_args)
77 | sphere_field = ObjectField([spheres], 'random-spheres')
78 | obj_list = [box_field, sphere_field]
79 | obst_map.convert_map()
80 | return obst_map, obj_list
81 |
82 |
83 | if __name__ == "__main__":
84 | cell_size = 0.1
85 | map_dim = [20, 20]
86 | seed = 2
87 | tensor_args = {'device': torch.device('cpu'), 'dtype': torch.float32}
88 | obst_map, obst_list = random_obstacles(
89 | map_dim, cell_size,
90 | num_obst=5,
91 | rand_xy_limits=[[-5, 5], [-5, 5]],
92 | rand_rect_shape=[2,2],
93 | rand_circle_radius=1,
94 | tensor_args=tensor_args
95 | )
96 | fig = obst_map.plot()
97 |
--------------------------------------------------------------------------------
/mpot/envs/obst_map.py:
--------------------------------------------------------------------------------
1 | from typing import Any
2 | import numpy as np
3 | import torch
4 | import matplotlib.pyplot as plt
5 | from math import ceil
6 | from abc import ABC, abstractmethod
7 | import os.path as osp
8 | from copy import deepcopy
9 |
10 |
11 | class Obstacle(ABC):
12 | """
13 | Base 2D Obstacle class
14 | """
15 |
16 | def __init__(self,center_x,center_y):
17 | self.center_x = center_x
18 | self.center_y = center_y
19 | self.origin = np.array([self.center_x, self.center_y])
20 |
21 | def _obstacle_collision_check(self, obst_map):
22 | valid = True
23 | obst_map_test = self._add_to_map(deepcopy(obst_map))
24 | if np.any(obst_map_test.map > 1):
25 | valid = False
26 | return valid
27 |
28 | def _point_collision_check(self, obst_map, pts):
29 | valid = True
30 | if pts is not None:
31 | obst_map_test = self._add_to_map(np.copy(obst_map))
32 | for pt in pts:
33 | if obst_map_test[ceil(pt[0]), ceil(pt[1])] >= 1:
34 | valid = False
35 | break
36 | return valid
37 |
38 | @abstractmethod
39 | def _add_to_map(self, obst_map):
40 | pass
41 |
42 |
43 | class ObstacleRectangle(Obstacle):
44 | """
45 | Derived 2D rectangular Obstacle class
46 | """
47 |
48 | def __init__(
49 | self,
50 | center_x=0,
51 | center_y=0,
52 | width=None,
53 | height=None,
54 | ):
55 | super().__init__(center_x, center_y)
56 | self.width = width
57 | self.height = height
58 |
59 | def _add_to_map(self, obst_map):
60 | # Convert dims to cell indices
61 | w = ceil(self.width / obst_map.cell_size)
62 | h = ceil(self.height / obst_map.cell_size)
63 | c_x = ceil(self.center_x / obst_map.cell_size)
64 | c_y = ceil(self.center_y / obst_map.cell_size)
65 |
66 | obst_map.map[
67 | c_y - ceil(h/2.) + obst_map.origin_yi:
68 | c_y + ceil(h/2.) + obst_map.origin_yi,
69 | c_x - ceil(w/2.) + obst_map.origin_xi:
70 | c_x + ceil(w/2.) + obst_map.origin_xi,
71 | ] += 1
72 | return obst_map
73 |
74 | def to_array(self):
75 | return np.array([self.center_x, self.center_y, self.width, self.height])
76 |
77 |
78 | class ObstacleCircle(Obstacle):
79 | """
80 | Derived 2D circle Obstacle class
81 | """
82 |
83 | def __init__(
84 | self,
85 | center_x=0,
86 | center_y=0,
87 | radius=1.
88 | ):
89 | super().__init__(center_x, center_y)
90 | self.radius = radius
91 |
92 | def is_inside(self, p):
93 | # Check if point p is inside of the discretized circle
94 | return np.linalg.norm(p - self.origin) <= self.radius
95 |
96 | def _add_to_map(self, obst_map):
97 | # Convert dims to cell indices
98 | c_r = ceil(self.radius / obst_map.cell_size)
99 | c_x = ceil(self.center_x / obst_map.cell_size)
100 | c_y = ceil(self.center_y / obst_map.cell_size)
101 |
102 | for i in range(c_y - 2 * c_r + obst_map.origin_yi, c_y + 2 * c_r + obst_map.origin_yi):
103 | for j in range(c_x - 2 * c_r + obst_map.origin_xi, c_x + 2 * c_r + obst_map.origin_xi):
104 | p = np.array([(j - obst_map.origin_xi) * obst_map.cell_size,
105 | (i - obst_map.origin_yi) * obst_map.cell_size])
106 | if self.is_inside(p):
107 | obst_map.map[i, j] += 1
108 | return obst_map
109 |
110 | def to_array(self):
111 | return np.array([self.center_x, self.center_y, self.radius])
112 |
113 |
114 | class ObstacleMap:
115 | """
116 | Generates an occupancy grid.
117 | """
118 | def __init__(self, map_dim, cell_size, tensor_args=None):
119 |
120 | assert map_dim[0] % 2 == 0
121 | assert map_dim[1] % 2 == 0
122 |
123 | if tensor_args is None:
124 | tensor_args = {'device': torch.device('cpu'), 'dtype': torch.float32}
125 | self.tensor_args = tensor_args
126 |
127 | cmap_dim = [0, 0]
128 | cmap_dim[0] = ceil(map_dim[0]/cell_size)
129 | cmap_dim[1] = ceil(map_dim[1]/cell_size)
130 |
131 | self.map = np.zeros(cmap_dim)
132 | self.cell_size = cell_size
133 |
134 | # Map center (in cells)
135 | self.origin_xi = int(cmap_dim[0]/2)
136 | self.origin_yi = int(cmap_dim[1]/2)
137 |
138 | # self.xlim = map_dim[0]
139 |
140 | self.x_dim, self.y_dim = self.map.shape
141 | x_range = self.cell_size * self.x_dim
142 | y_range = self.cell_size * self.y_dim
143 | self.xlim = [-x_range/2, x_range/2]
144 | self.ylim = [-y_range/2, y_range/2]
145 |
146 | self.c_offset = torch.tensor([self.origin_xi, self.origin_yi], **self.tensor_args)
147 |
148 | def __call__(self, X, **kwargs):
149 | return self.compute_cost(X, **kwargs)
150 |
151 | def convert_map(self):
152 | self.map_torch = torch.Tensor(self.map).to(**self.tensor_args)
153 | return self.map_torch
154 |
155 | def plot(self, save_dir=None, filename="obst_map.png"):
156 | fig = plt.figure()
157 | plt.imshow(self.map)
158 | plt.gca().invert_yaxis()
159 | plt.show()
160 | if save_dir is not None:
161 | plt.savefig(osp.join(save_dir, filename))
162 | return fig
163 |
164 | def get_xy_grid(self, device):
165 | xv, yv = torch.meshgrid([torch.linspace(self.xlim[0], self.xlim[1], self.x_dim),
166 | torch.linspace(self.ylim[0], self.ylim[1], self.y_dim)])
167 | xy_grid = torch.stack((xv, yv), dim=2)
168 | return xy_grid.to(device)
169 |
170 | def get_collisions(self, X, *args, **kwargs):
171 | """
172 | Checks for collision in a batch of trajectories using the generated occupancy grid (i.e. obstacle map), and
173 | returns sum of collision costs for the entire batch.
174 |
175 | :param weight: weight on obstacle cost, float tensor.
176 | :param X: Tensor of trajectories, of shape (batch_size, traj_length, position_dim)
177 | :return: collision cost on the trajectories
178 | """
179 | X_occ = X * (1/self.cell_size) + self.c_offset
180 | X_occ = X_occ.floor().int()
181 |
182 | # Project out-of-bounds locations to axis
183 | X_occ[...,0] = X_occ[..., 0].clamp(0, self.map.shape[0]-1)
184 | X_occ[...,1] = X_occ[..., 1].clamp(0, self.map.shape[1]-1)
185 |
186 | # Collisions
187 | collision_vals = self.map_torch[X_occ[..., 1], X_occ[..., 0]]
188 | return collision_vals
189 |
190 | def compute_cost(self, X, *args, **kwargs):
191 | return self.get_collisions(X, *args, **kwargs)
192 |
193 | def zero_grad(self):
194 | pass
195 |
--------------------------------------------------------------------------------
/mpot/envs/obst_utils.py:
--------------------------------------------------------------------------------
1 | from math import ceil
2 | import random
3 | from mpot.envs.obst_map import ObstacleRectangle, ObstacleCircle
4 |
5 |
6 | def round_up(n, decimals=0):
7 | multiplier = 10 ** decimals
8 | return ceil(n * multiplier) / multiplier
9 |
10 |
11 | def random_rect(xlim=(0, 0), ylim=(0, 0), width=2, height=2):
12 | """
13 | Generates an rectangular obstacle object, with random location and dimensions.
14 | """
15 | cx = random.uniform(xlim[0], xlim[1])
16 | cy = random.uniform(ylim[0], ylim[1])
17 | return ObstacleRectangle(cx, cy, width, height)
18 |
19 |
20 | def random_circle(xlim=(0,0), ylim=(0,0), radius=2):
21 | """
22 | Generates a circle obstacle object, with random location and dimensions.
23 | """
24 | cx = random.uniform(xlim[0], xlim[1])
25 | cy = random.uniform(ylim[0], ylim[1])
26 | return ObstacleCircle(cx, cy, radius)
27 |
--------------------------------------------------------------------------------
/mpot/envs/occupancy.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from matplotlib import pyplot as plt
3 |
4 | from torch_robotics.environments.env_base import EnvBase
5 | from torch_robotics.robots import RobotPointMass
6 | from torch_robotics.torch_utils.torch_utils import DEFAULT_TENSOR_ARGS
7 | from torch_robotics.visualizers.planning_visualizer import create_fig_and_axes
8 | from mpot.envs.map_generator import random_obstacles
9 |
10 |
11 | class EnvOccupancy2D(EnvBase):
12 |
13 | def __init__(self, tensor_args=None, **kwargs):
14 | if tensor_args is None:
15 | tensor_args = DEFAULT_TENSOR_ARGS
16 | obst_map, obj_list = random_obstacles(
17 | map_dim=[20, 20],
18 | cell_size=0.1,
19 | num_obst=15,
20 | rand_xy_limits=[[-7.5, 7.5], [-7.5, 7.5]],
21 | rand_rect_shape=[2, 2],
22 | rand_circle_radius=1.,
23 | tensor_args=tensor_args
24 | )
25 |
26 | super().__init__(
27 | name=self.__class__.__name__,
28 | limits=torch.tensor([[-10, -10], [10, 10]], **tensor_args), # environments limits
29 | obj_fixed_list=obj_list,
30 | tensor_args=tensor_args,
31 | **kwargs
32 | )
33 | self.occupancy_map = obst_map
34 |
35 |
36 | if __name__ == '__main__':
37 | env = EnvOccupancy2D(precompute_sdf_obj_fixed=False, tensor_args=DEFAULT_TENSOR_ARGS)
38 | fig, ax = create_fig_and_axes(env.dim)
39 | env.render(ax)
40 | plt.show()
41 |
42 | # Render sdf
43 | fig, ax = create_fig_and_axes(env.dim)
44 | env.render_sdf(ax, fig)
45 |
46 | # Render gradient of sdf
47 | env.render_grad_sdf(ax, fig)
48 | plt.show()
49 |
--------------------------------------------------------------------------------
/mpot/gp/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 An Thai Le
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
23 | **********************************************************************
24 | The first version of some files in this gp/ folder were licensed as
25 | "Original Source License" (see below). Somes enhancements and developments
26 | were done by An Thai Le since obtaining the first version.
27 | **********************************************************************
28 |
29 | Original Source License:
30 |
31 | MIT License
32 |
33 | Copyright (c) 2022 Sasha Lambert
34 |
35 | Permission is hereby granted, free of charge, to any person obtaining a copy
36 | of this software and associated documentation files (the "Software"), to deal
37 | in the Software without restriction, including without limitation the rights
38 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
39 | copies of the Software, and to permit persons to whom the Software is
40 | furnished to do so, subject to the following conditions:
41 |
42 | The above copyright notice and this permission notice shall be included in all
43 | copies or substantial portions of the Software.
44 |
45 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
46 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
47 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
48 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
49 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
50 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
51 | SOFTWARE.
52 |
--------------------------------------------------------------------------------
/mpot/gp/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/anindex/mpot/a5c49d4fa83d7601bcd91246e07e254b813dc322/mpot/gp/__init__.py
--------------------------------------------------------------------------------
/mpot/gp/field_factor.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class FieldFactor:
5 |
6 | def __init__(
7 | self,
8 | n_dof,
9 | sigma,
10 | traj_range,
11 | ):
12 | self.sigma = sigma
13 | self.n_dof = n_dof
14 | self.traj_range = traj_range
15 | self.K = 1. / (sigma**2)
16 |
17 | def get_error(
18 | self,
19 | q_trajs,
20 | field,
21 | q_pos=None,
22 | q_vel=None,
23 | H_pos=None,
24 | q_trajs_interp=None,
25 | q_pos_interp=None,
26 | q_vel_interp=None,
27 | H_pos_interp=None,
28 | calc_jacobian=True,
29 | **kwargs
30 | ):
31 | batch = q_trajs.shape[0]
32 |
33 | if H_pos is not None:
34 | states = H_pos[:, self.traj_range[0]:self.traj_range[1]]
35 | else:
36 | states = q_trajs[:, self.traj_range[0]:self.traj_range[1], :self.n_dof].reshape(-1, self.n_dof)
37 | q_pos_new = q_pos[:, self.traj_range[0]:self.traj_range[1], :]
38 | length = q_pos_new.shape[-2]
39 | error = field.compute_cost(q_pos_new, states, **kwargs).reshape(batch, length)
40 |
41 | if calc_jacobian:
42 | # compute jacobian wrt to the error of the interpolated trajectory
43 | error_interp = error
44 | if H_pos_interp is not None or q_trajs_interp is not None:
45 | # interpolated trajectory
46 | if H_pos_interp is not None:
47 | states = H_pos_interp[:, self.traj_range[0]:self.traj_range[1]]
48 | else:
49 | states = q_trajs_interp[:, self.traj_range[0]:self.traj_range[1], :self.n_dof].reshape(-1, self.n_dof)
50 | q_pos_new = q_pos_interp[:, self.traj_range[0]:self.traj_range[1], :]
51 | length = q_pos_new.shape[-2]
52 | error_interp = field.compute_cost(q_pos_new, states, **kwargs).reshape(batch, length)
53 |
54 | H = -torch.autograd.grad(error_interp.sum(), q_trajs, retain_graph=True)[0][:, self.traj_range[0]:self.traj_range[1], :self.n_dof]
55 | error = error.detach()
56 | field.zero_grad()
57 | return error, H
58 | else:
59 | return error
60 |
--------------------------------------------------------------------------------
/mpot/gp/gp_factor.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class GPFactor():
5 |
6 | def __init__(
7 | self,
8 | dim: int,
9 | sigma: float,
10 | dt: float,
11 | num_factors: int,
12 | Q_c_inv: torch.Tensor = None,
13 | tensor_args=None,
14 | ):
15 | self.dim = dim
16 | self.dt = dt
17 | self.tensor_args = tensor_args
18 | self.state_dim = self.dim * 2 # position and velocity
19 | self.num_factors = num_factors
20 | self.idx1 = torch.arange(0, self.num_factors, device=tensor_args['device'])
21 | self.idx2 = torch.arange(1, self.num_factors+1, device=tensor_args['device'])
22 | self.phi = self.calc_phi()
23 | if Q_c_inv is None:
24 | Q_c_inv = torch.eye(dim, **tensor_args) / sigma**2
25 | self.Q_c_inv = torch.zeros(num_factors, dim, dim, **tensor_args) + Q_c_inv
26 | self.Q_inv = self.calc_Q_inv() # shape: [num_factors, state_dim, state_dim]
27 |
28 | ## Pre-compute constant Jacobians
29 | self.H1 = self.phi.unsqueeze(0).repeat(self.num_factors, 1, 1)
30 | self.H2 = -1. * torch.eye(self.state_dim).unsqueeze(0).repeat(
31 | self.num_factors, 1, 1,
32 | )
33 |
34 | def calc_phi(self) -> torch.Tensor:
35 | I = torch.eye(self.dim, **self.tensor_args)
36 | Z = torch.zeros(self.dim, self.dim, **self.tensor_args)
37 | phi_u = torch.cat((I, self.dt * I), dim=1)
38 | phi_l = torch.cat((Z, I), dim=1)
39 | phi = torch.cat((phi_u, phi_l), dim=0)
40 | return phi
41 |
42 | def calc_Q_inv(self) -> torch.Tensor:
43 | m1 = 12. * (self.dt ** -3.) * self.Q_c_inv
44 | m2 = -6. * (self.dt ** -2.) * self.Q_c_inv
45 | m3 = 4. * (self.dt ** -1.) * self.Q_c_inv
46 |
47 | Q_inv_u = torch.cat((m1, m2), dim=-1)
48 | Q_inv_l = torch.cat((m2, m3), dim=-1)
49 | Q_inv = torch.cat((Q_inv_u, Q_inv_l), dim=-2)
50 | return Q_inv
51 |
52 | def get_error(self, x_traj: torch.Tensor, calc_jacobian: bool = True) -> torch.Tensor:
53 | state_1 = torch.index_select(x_traj, 1, self.idx1).unsqueeze(-1)
54 | state_2 = torch.index_select(x_traj, 1, self.idx2).unsqueeze(-1)
55 | error = state_2 - self.phi @ state_1
56 |
57 | if calc_jacobian:
58 | H1 = self.H1
59 | H2 = self.H2
60 | # H1 = self.H1.unsqueeze(0).repeat(batch, 1, 1, 1)
61 | # H2 = self.H2.unsqueeze(0).repeat(batch, 1, 1, 1)
62 | return error, H1, H2
63 | else:
64 | return error
--------------------------------------------------------------------------------
/mpot/gp/gp_prior.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.distributions as dist
3 |
4 |
5 | class BatchGPPrior:
6 |
7 | def __init__(
8 | self,
9 | traj_len: int,
10 | dt: float,
11 | dim: int,
12 | K_s_inv: torch.Tensor,
13 | K_gp_inv: torch.Tensor,
14 | start_state: torch.Tensor,
15 | means: torch.Tensor = None,
16 | K_g_inv: torch.Tensor = None,
17 | goal_states: torch.Tensor = None,
18 | tensor_args=None,
19 | ):
20 | """
21 | Motion-Planning prior.
22 |
23 | reference: "Continuous-time Gaussian process motion planning via
24 | probabilistic inference", Mukadam et al. (IJRR 2018)
25 |
26 | Parameters
27 | ----------
28 | traj_len : int
29 | Planning horizon length (not including start state).
30 | dt : float
31 | Time-step size.
32 | state_dim : int
33 | State state_dimension.
34 | K_s_inv : Tensor
35 | Start-state inverse covariance. Shape: [state_dim, state_dim]
36 | K_gp_inv : Tensor
37 | Gaussian-process single-step inverse covariance i.e. 'Q_inv'.
38 | Assumed constant, meaning homoscedastic noise with constant step-size.
39 | Shape: [2 * state_dim, 2 * state_dim]
40 | start_state : Tensor
41 | Shape: [state_dim]
42 | (Optional) K_g_inv : Tensor
43 | Goal-state inverse covariance. Shape: [state_dim, state_dim]
44 | (Optional) goal_state : Tensor
45 | Shape: [state_dim]
46 | (Optional) dim : int
47 | Degrees of freedom.
48 | """
49 | self.dim = dim
50 | self.state_dim = dim * 2
51 | self.traj_len = traj_len
52 | self.M = self.state_dim * (traj_len + 1)
53 | self.tensor_args = tensor_args
54 |
55 | self.goal_directed = (goal_states is not None)
56 |
57 | if means is None:
58 | self.num_modes = goal_states.shape[0] if self.goal_directed else 1
59 | means = self.get_const_vel_mean(
60 | start_state,
61 | goal_states,
62 | dt,
63 | traj_len,
64 | dim)
65 | else:
66 | self.num_modes = means.shape[0]
67 |
68 | # Flatten mean trajectories
69 | self.means = means.reshape(self.num_modes, -1)
70 |
71 | # TODO: Add different goal Covariances
72 | # Assume same goal Cov. for now
73 | Sigma_inv = self.get_const_vel_covariance(
74 | dt,
75 | K_s_inv,
76 | K_gp_inv,
77 | K_g_inv,
78 | )
79 |
80 | # self.Sigma_inv = Sigma_inv
81 | self.Sigma_inv = Sigma_inv # + torch.eye(Sigma_inv.shape[0], **tensor_args) * 1.e-3
82 | self.Sigma_invs = self.Sigma_inv.repeat(self.num_modes, 1, 1)
83 | self.update_dist(self.means, self.Sigma_invs)
84 |
85 | def update_dist(
86 | self,
87 | means: torch.Tensor,
88 | Sigma_invs: torch.Tensor,
89 | ) -> torch.Tensor:
90 | # Create Multi-variate Normal Distribution
91 | self.dist = dist.MultivariateNormal(
92 | means,
93 | precision_matrix=Sigma_invs,
94 | )
95 |
96 | def get_mean(self, reshape: bool = True) -> torch.Tensor:
97 | if reshape:
98 | return self.means.clone().detach().reshape(
99 | self.num_modes, self.traj_len + 1, self.state_dim,
100 | )
101 | else:
102 | self.means.clone().detach()
103 |
104 | def set_mean(self, means: torch.Tensor) -> torch.Tensor:
105 | assert means.shape == self.means.shape
106 | self.means = means.clone().detach()
107 | self.update_dist(self.means, self.Sigma_invs)
108 |
109 | def set_Sigma_invs(self, Sigma_invs: torch.Tensor) -> torch.Tensor:
110 | assert Sigma_invs.shape == self.Sigma_invs.shape
111 | self.Sigma_invs = Sigma_invs.clone().detach()
112 | self.update_dist(self.means, self.Sigma_invs)
113 |
114 | def const_vel_trajectory(
115 | self,
116 | start_state: torch.Tensor,
117 | goal_state: torch.Tensor,
118 | dt: float,
119 | traj_len: int,
120 | dim: int,
121 | ) -> torch.Tensor:
122 | state_traj = torch.zeros(traj_len + 1, 2 * dim, **self.tensor_args)
123 | mean_vel = (goal_state[:dim] - start_state[:dim]) / (traj_len * dt)
124 | for i in range(traj_len + 1):
125 | state_traj[i, :dim] = start_state[:dim] * (traj_len - i) * 1. / traj_len \
126 | + goal_state[:dim] * i * 1./traj_len
127 | state_traj[:, dim:] = mean_vel.unsqueeze(0)
128 | return state_traj
129 |
130 | def get_const_vel_mean(
131 | self,
132 | start_state: torch.Tensor,
133 | goal_states: torch.Tensor,
134 | dt: float,
135 | traj_len: int,
136 | dim: int,
137 | ) -> torch.Tensor:
138 |
139 | # Make mean goal-directed if goal_state is provided.
140 | if self.goal_directed:
141 | means = []
142 | for i in range(self.num_modes):
143 | means.append(self.const_vel_trajectory(
144 | start_state,
145 | goal_states[i],
146 | dt,
147 | traj_len,
148 | dim,
149 | ))
150 | return torch.stack(means, dim=0)
151 | else:
152 | return start_state.repeat(traj_len + 1, 1)
153 |
154 | def get_const_vel_covariance(
155 | self,
156 | dt: float,
157 | K_s_inv: torch.Tensor,
158 | K_gp_inv: torch.Tensor,
159 | K_g_inv: torch.Tensor,
160 | precision_matrix: bool = True,
161 | ) -> torch.Tensor:
162 | # Transition matrix
163 | Phi = torch.eye(self.state_dim, **self.tensor_args)
164 | Phi[:self.dim, self.dim:] = torch.eye(self.dim, **self.tensor_args) * dt
165 | diag_Phis = Phi
166 | for _ in range(self.traj_len - 1):
167 | diag_Phis = torch.block_diag(diag_Phis, Phi)
168 |
169 | A = torch.eye(self.M, **self.tensor_args)
170 | A[self.state_dim:, :-self.state_dim] += -1. * diag_Phis
171 | if self.goal_directed:
172 | b = torch.zeros(self.state_dim, self.M, **self.tensor_args)
173 | b[:, -self.state_dim:] = torch.eye(self.state_dim, **self.tensor_args)
174 | A = torch.cat((A, b))
175 |
176 | Q_inv = K_s_inv
177 | for _ in range(self.traj_len):
178 | Q_inv = torch.block_diag(Q_inv, K_gp_inv).to(**self.tensor_args)
179 | if self.goal_directed:
180 | Q_inv = torch.block_diag(Q_inv, K_g_inv).to(**self.tensor_args)
181 |
182 | K_inv = A.t() @ Q_inv @ A
183 | if precision_matrix:
184 | return K_inv
185 | else:
186 | return torch.inverse(K_inv)
187 |
188 | def sample(self, num_samples: int) -> torch.Tensor:
189 | return self.dist.sample((num_samples,)).view(
190 | num_samples, self.num_modes, self.traj_len + 1, self.state_dim,
191 | ).transpose(1, 0)
192 |
193 | def log_prob(self, X: torch.Tensor) -> torch.Tensor:
194 | return self.dist.log_prob(X)
--------------------------------------------------------------------------------
/mpot/gp/unary_factor.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class UnaryFactor:
5 |
6 | def __init__(
7 | self,
8 | dim: int,
9 | sigma: float,
10 | mean: torch.Tensor = None,
11 | tensor_args=None,
12 | ):
13 | self.sigma = sigma
14 | if mean is None:
15 | self.mean = torch.zeros(dim, **tensor_args)
16 | else:
17 | self.mean = mean
18 | self.tensor_args = tensor_args
19 | self.K = torch.eye(dim, **tensor_args) / sigma**2 # weight matrix
20 | self.dim = dim
21 |
22 | def get_error(self, X: torch.Tensor, calc_jacobian: bool = True) -> torch.Tensor:
23 | error = self.mean - X
24 |
25 | if calc_jacobian:
26 | H = torch.eye(self.dim, **self.tensor_args).unsqueeze(0).repeat(X.shape[0], 1, 1)
27 | return error.view(X.shape[0], self.dim, 1), H
28 | else:
29 | return error
30 |
31 | def set_mean(self, X: torch.Tensor) -> torch.Tensor:
32 | self.mean = X.clone().detach()
33 |
--------------------------------------------------------------------------------
/mpot/ot/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/anindex/mpot/a5c49d4fa83d7601bcd91246e07e254b813dc322/mpot/ot/__init__.py
--------------------------------------------------------------------------------
/mpot/ot/initializer.py:
--------------------------------------------------------------------------------
1 | import abc
2 | from typing import Any, Dict, Optional, Sequence, Tuple
3 |
4 | import torch
5 |
6 | from mpot.ot.problem import LinearProblem
7 |
8 |
9 | class SinkhornInitializer(abc.ABC):
10 | """Base class for Sinkhorn initializers."""
11 |
12 | @abc.abstractmethod
13 | def init_dual_a(
14 | self,
15 | ot_prob: LinearProblem,
16 | ) -> torch.Tensor:
17 | """Initialize Sinkhorn potential f_u.
18 |
19 | Returns:
20 | potential size ``[n,]``.
21 | """
22 |
23 | @abc.abstractmethod
24 | def init_dual_b(
25 | self,
26 | ot_prob: LinearProblem,
27 | ) -> torch.Tensor:
28 | """Initialize Sinkhorn potential g_v.
29 |
30 | Returns:
31 | potential size ``[m,]``.
32 | """
33 |
34 | def __call__(
35 | self,
36 | ot_prob: LinearProblem,
37 | a: Optional[torch.Tensor] = None,
38 | b: Optional[torch.Tensor] = None,
39 | ) -> Tuple[torch.Tensor, torch.Tensor]:
40 |
41 | n, m = ot_prob.C.shape
42 | if a is None:
43 | a = self.init_dual_a(ot_prob)
44 | if b is None:
45 | b = self.init_dual_b(ot_prob)
46 |
47 | assert a.shape == (
48 | n,
49 | ), f"Expected `f_u` to have shape `{n,}`, found `{a.shape}`."
50 | assert b.shape == (
51 | m,
52 | ), f"Expected `g_v` to have shape `{m,}`, found `{b.shape}`."
53 |
54 | # cancel dual variables for zero weights
55 | a = torch.where(ot_prob.a > 0., a, -torch.inf)
56 | b = torch.where(ot_prob.b > 0., b, -torch.inf)
57 |
58 | return a, b
59 |
60 |
61 | class DefaultInitializer(SinkhornInitializer):
62 | """Default initialization of Sinkhorn dual potentials scalings."""
63 |
64 | def init_dual_a(
65 | self,
66 | ot_prob: LinearProblem,
67 | ) -> torch.Tensor:
68 | return torch.zeros_like(ot_prob.a)
69 |
70 | def init_dual_b(
71 | self,
72 | ot_prob: LinearProblem,
73 | ) -> torch.Tensor:
74 | return torch.zeros_like(ot_prob.b)
75 |
76 |
77 | class RandomInitializer(SinkhornInitializer):
78 | """Random initialization of Sinkhorn dual potentials scalings."""
79 |
80 | def __init__(
81 | self,
82 | seed: Optional[int] = None,
83 | ):
84 | self.seed = seed
85 |
86 | def init_dual_a(
87 | self,
88 | ot_prob: LinearProblem,
89 | ) -> torch.Tensor:
90 | return torch.randn_like(ot_prob.a)
91 |
92 | def init_dual_b(
93 | self,
94 | ot_prob: LinearProblem,
95 | ) -> torch.Tensor:
96 | return torch.randn_like(ot_prob.b)
97 |
--------------------------------------------------------------------------------
/mpot/ot/problem.py:
--------------------------------------------------------------------------------
1 | from typing import Union
2 | import torch
3 |
4 |
5 | class Epsilon:
6 | """Epsilon scheduler for Sinkhorn and Sinkhorn Step."""
7 |
8 | def __init__(
9 | self,
10 | target: float = 0.1,
11 | scale_epsilon: float = 1.0,
12 | init: float = 1.0,
13 | decay: float = 1.0
14 | ):
15 | self._target_init = target
16 | self._scale_epsilon = scale_epsilon
17 | self._init = init
18 | self._decay = decay
19 |
20 | @property
21 | def target(self) -> float:
22 | """Return the final regularizer value of scheduler."""
23 | target = 5e-2 if self._target_init is None else self._target_init
24 | scale = 1.0 if self._scale_epsilon is None else self._scale_epsilon
25 | return scale * target
26 |
27 | def at(self, iteration: int = 1) -> float:
28 | """Return (intermediate) regularizer value at a given iteration."""
29 | if iteration is None:
30 | return self.target
31 | # check the decay is smaller than 1.0.
32 | decay = min(self._decay, 1.0)
33 | # the multiple is either 1.0 or a larger init value that is decayed.
34 | multiple = max(self._init * (decay ** iteration), 1.0)
35 | return multiple * self.target
36 |
37 | def done(self, eps: float) -> bool:
38 | """Return whether the scheduler is done at a given value."""
39 | return eps == self.target
40 |
41 | def done_at(self, iteration: int) -> bool:
42 | """Return whether the scheduler is done at a given iteration."""
43 | return self.done(self.at(iteration))
44 |
45 |
46 | class LinearEpsilon(Epsilon):
47 |
48 | def __init__(self, target: float = 0.1,
49 | scale_epsilon: float = 1,
50 | init: float = 1,
51 | decay: float = 1):
52 | super().__init__(target, scale_epsilon, init, decay)
53 |
54 | def at(self, iteration: int = 1) -> float:
55 | if iteration is None:
56 | return self.target
57 |
58 | eps = max(self._init - self._decay * iteration, self.target)
59 | return eps * self._scale_epsilon
60 |
61 |
62 | class LinearProblem():
63 |
64 | def __init__(
65 | self,
66 | C: torch.Tensor,
67 | epsilon: Union[Epsilon, float] = 0.01,
68 | a: torch.Tensor = None,
69 | b: torch.Tensor = None,
70 | tau_a: float = 1.0,
71 | tau_b: float = 1.0,
72 | scaling_cost: bool = True,
73 | ) -> None:
74 | if scaling_cost:
75 | C = scale_cost_matrix(C)
76 | self.C = C
77 | self.epsilon = epsilon
78 | self.a = a if a is not None else (torch.ones(C.shape[0]).type_as(C) / C.shape[0])
79 | self.b = b if b is not None else (torch.ones(C.shape[1]).type_as(C) / C.shape[1])
80 | self.tau_a = tau_a
81 | self.tau_b = tau_b
82 |
83 | def potential_from_scaling(self, scaling: torch.Tensor) -> torch.Tensor:
84 | """Compute dual potential vector from scaling vector.
85 |
86 | Args:
87 | scaling: vector.
88 |
89 | Returns:
90 | a vector of the same size.
91 | """
92 | eps = self.epsilon.target if isinstance(self.epsilon, Epsilon) else self.epsilon
93 | return eps * torch.log(scaling)
94 |
95 | def marginal_from_potentials(
96 | self, f: torch.Tensor, g: torch.Tensor, dim: int
97 | ) -> torch.Tensor:
98 | eps = self.epsilon.target if isinstance(self.epsilon, Epsilon) else self.epsilon
99 | h = (f if dim == 1 else g)
100 | z = self.apply_lse_kernel(f, g, eps, dim=dim)
101 | return torch.exp((z + h) / eps)
102 |
103 | def update_potential(
104 | self, f: torch.Tensor, g: torch.Tensor, log_marginal: torch.Tensor,
105 | iteration: int = None, dim: int = 0,
106 | ) -> torch.Tensor:
107 | eps = self.epsilon.at(iteration) if isinstance(self.epsilon, Epsilon) else self.epsilon
108 | app_lse = self.apply_lse_kernel(f, g, eps, dim=dim)
109 | return eps * log_marginal - torch.where(torch.isfinite(app_lse), app_lse, 0)
110 |
111 | def transport_from_potentials(
112 | self, f: torch.Tensor, g: torch.Tensor
113 | ) -> torch.Tensor:
114 | """Output transport matrix from potentials."""
115 | eps = self.epsilon.target if isinstance(self.epsilon, Epsilon) else self.epsilon
116 | return torch.exp(self._center(f, g) / eps)
117 |
118 | def apply_lse_kernel(
119 | self, f: torch.Tensor, g: torch.Tensor, eps: float, dim: int
120 | ) -> torch.Tensor:
121 | w_res = self._softmax(f, g, eps, dim=dim)
122 | remove = f if dim == 1 else g
123 | return w_res - torch.where(torch.isfinite(remove), remove, 0)
124 |
125 | def _center(self, f: torch.Tensor, g: torch.Tensor) -> torch.Tensor:
126 | return f.unsqueeze(1) + g.unsqueeze(0) - self.C
127 |
128 | def _softmax(
129 | self, f: torch.Tensor, g: torch.Tensor, eps: float, dim: int
130 | ) -> torch.Tensor:
131 | """Apply softmax row or column wise"""
132 |
133 | lse_output = torch.logsumexp(
134 | self._center(f, g) / eps, dim=dim
135 | )
136 | return eps * lse_output
137 |
138 |
139 | class SinkhornState():
140 | """Holds the state variables used to solve OT with Sinkhorn."""
141 |
142 | def __init__(
143 | self,
144 | errors: torch.Tensor = None,
145 | fu: torch.Tensor = None,
146 | gv: torch.Tensor = None,
147 | ):
148 | self.errors = errors
149 | self.fu = fu
150 | self.gv = gv
151 | self.converged_at = -1
152 |
153 | def solution_error(
154 | self,
155 | ot_prob: LinearProblem,
156 | parallel_dual_updates: bool,
157 | ) -> torch.Tensor:
158 | """State dependent function to return error."""
159 | fu, gv = self.fu, self.gv
160 |
161 | return solution_error(
162 | fu,
163 | gv,
164 | ot_prob,
165 | parallel_dual_updates=parallel_dual_updates
166 | )
167 |
168 | def ent_reg_cost(
169 | self, ot_prob: LinearProblem,
170 | ) -> float:
171 | return ent_reg_cost(self.fu, self.gv, ot_prob)
172 |
173 |
174 | class SinkhornStepState():
175 | """Holds the state of the barycenter solver.
176 |
177 | Args:
178 | costs: Holds the sequence of regularized GW costs seen through the outer
179 | loop of the solver.
180 | linear_convergence: Holds the sequence of bool convergence flags of the
181 | inner Sinkhorn iterations.
182 | X: optimizing points.
183 | a: weights of the barycenter. (not using)
184 | """
185 |
186 | def __init__(self,
187 | X_init: torch.Tensor,
188 | costs: torch.Tensor = None,
189 | linear_convergence: torch.Tensor = None,
190 | objective_vals: torch.Tensor = None,
191 | X_history: torch.Tensor = None,
192 | displacement_sqnorms: torch.Tensor = None,
193 | a: torch.Tensor = None) -> None:
194 | self.X = X_init
195 | self.costs = costs
196 | self.linear_convergence = linear_convergence
197 | self.objective_vals = objective_vals
198 | self.X_history = X_history
199 | self.displacement_sqnorms = displacement_sqnorms
200 | self.a = a
201 |
202 |
203 | def scale_cost_matrix(M: torch.Tensor) -> torch.Tensor:
204 | min_M = M.min()
205 | if min_M < 0:
206 | M -= min_M
207 | max_M = M.max()
208 | if max_M > 1.:
209 | M /= max_M # for stability
210 | return M
211 |
212 |
213 | def phi_star(h: torch.Tensor, rho: float) -> torch.Tensor:
214 | """Legendre transform of KL, :cite:`sejourne:19`, p. 9."""
215 | return rho * (torch.exp(h / rho) - 1)
216 |
217 |
218 | def rho(epsilon: float, tau: float) -> float:
219 | return (epsilon * tau) / (1. - tau)
220 |
221 |
222 | def derivative_phi_star(f: torch.Tensor, rho: float) -> torch.Tensor:
223 | return torch.exp(f / rho)
224 |
225 |
226 | def grad_of_marginal_fit(
227 | c: torch.Tensor, h: torch.Tensor, tau: float, epsilon: float
228 | ) -> torch.Tensor:
229 | if tau == 1.0:
230 | return c
231 | r = rho(epsilon, tau)
232 | return torch.where(c > 0, c * derivative_phi_star(-h, r), 0.0)
233 |
234 |
235 | def solution_error(
236 | f_u: torch.Tensor,
237 | g_v: torch.Tensor,
238 | ot_prob: LinearProblem,
239 | parallel_dual_updates: bool,
240 | ) -> torch.Tensor:
241 | """Compute error between Sinkhorn solution and target marginals."""
242 | if not parallel_dual_updates:
243 | return marginal_error(
244 | f_u, g_v, ot_prob.b, ot_prob, dim=0
245 | )
246 |
247 | grad_a = grad_of_marginal_fit(
248 | ot_prob.a, f_u, ot_prob.tau_a, ot_prob.epsilon
249 | )
250 | grad_b = grad_of_marginal_fit(
251 | ot_prob.b, g_v, ot_prob.tau_b, ot_prob.epsilon
252 | )
253 |
254 | err = marginal_error(f_u, g_v, grad_a, ot_prob, dim=1)
255 | err += marginal_error(f_u, g_v, grad_b, ot_prob, dim=0)
256 | return err
257 |
258 |
259 | def marginal_error(
260 | f_u: torch.Tensor,
261 | g_v: torch.Tensor,
262 | target: torch.Tensor,
263 | ot_prob: LinearProblem,
264 | dim: int = 0
265 | ) -> torch.Tensor:
266 | """Output how far Sinkhorn solution is w.r.t target.
267 |
268 | Args:
269 | f_u: a vector of potentials or scalings for the first marginal.
270 | g_v: a vector of potentials or scalings for the second marginal.
271 | target: target marginal.
272 | dim: dim (0 or 1) along which to compute marginal.
273 |
274 | Returns:
275 | Array of floats, quantifying difference between target / marginal.
276 | """
277 | marginal = ot_prob.marginal_from_potentials(f_u, g_v, dim=dim)
278 | # L1 distance between target and marginal
279 | return torch.sum(
280 | torch.abs(marginal - target)
281 | )
282 |
283 |
284 | def ent_reg_cost(
285 | f: torch.Tensor, g: torch.Tensor, ot_prob: LinearProblem,
286 | ) -> float:
287 |
288 | supp_a = ot_prob.a > 0
289 | supp_b = ot_prob.b > 0
290 | fa = ot_prob.potential_from_scaling(ot_prob.a)
291 | if ot_prob.tau_a == 1.0:
292 | div_a = torch.sum(torch.where(supp_a, ot_prob.a * (f - fa), 0.0))
293 | else:
294 | rho_a = rho(ot_prob.epsilon, ot_prob.tau_a)
295 | div_a = -torch.sum(
296 | torch.where(supp_a, ot_prob.a * phi_star(-(f - fa), rho_a), 0.0)
297 | )
298 |
299 | gb = ot_prob.potential_from_scaling(ot_prob.b)
300 | if ot_prob.tau_b == 1.0:
301 | div_b = torch.sum(torch.where(supp_b, ot_prob.b * (g - gb), 0.0))
302 | else:
303 | rho_b = rho(ot_prob.epsilon, ot_prob.tau_b)
304 | div_b = -torch.sum(
305 | torch.where(supp_b, ot_prob.b * phi_star(-(g - gb), rho_b), 0.0)
306 | )
307 |
308 | # Using https://arxiv.org/pdf/1910.12958.pdf (24)
309 | total_sum = torch.sum(ot_prob.marginal_from_potentials(f, g))
310 | return div_a + div_b + ot_prob.epsilon * (
311 | torch.sum(ot_prob.a) * torch.sum(ot_prob.b) - total_sum
312 | )
--------------------------------------------------------------------------------
/mpot/ot/sinkhorn.py:
--------------------------------------------------------------------------------
1 | from typing import (
2 | TYPE_CHECKING,
3 | Any,
4 | Callable,
5 | Literal,
6 | Mapping,
7 | NamedTuple,
8 | Optional,
9 | Sequence,
10 | Tuple,
11 | Union,
12 | )
13 | import numpy as np
14 | import torch
15 | from mpot.ot.problem import LinearProblem, SinkhornState
16 | from mpot.ot.initializer import DefaultInitializer, RandomInitializer, SinkhornInitializer
17 |
18 |
19 | class Momentum:
20 | """Momentum for Sinkhorn updates. Adapted from OTT-JAX
21 | """
22 |
23 | def __init__(
24 | self,
25 | start: int = 0,
26 | error_threshold: float = torch.inf,
27 | value: float = 1.0,
28 | inner_iterations: int = 1,
29 | ) -> None:
30 | self.start = start
31 | self.error_threshold = error_threshold
32 | self.value = value
33 | self.inner_iterations = inner_iterations
34 |
35 | def weight(self, state: SinkhornState, iteration: int) -> float:
36 | if self.start == 0:
37 | return self.value
38 | idx = self.start // self.inner_iterations
39 |
40 | return self.lehmann(state) if iteration >= self.start and state.errors[idx - 1, -1] < self.error_threshold \
41 | else self.value
42 |
43 | def lehmann(self, state: SinkhornState) -> float:
44 | """See Lehmann, T., Von Renesse, M.-K., Sambale, A., and
45 | Uschmajew, A. (2021). A note on overrelaxation in the
46 | sinkhorn algorithm. Optimization Letters, pages 1–12. eq. 5."""
47 | idx = self.start // self.inner_iterations
48 | error_ratio = torch.minimum(
49 | state.errors[idx - 1, -1] / state.errors[idx - 2, -1], 0.99
50 | )
51 | power = 1.0 / self.inner_iterations
52 | return 2.0 / (1.0 + torch.sqrt(1.0 - error_ratio ** power))
53 |
54 | def __call__(
55 | self,
56 | weight: float,
57 | value: torch.Tensor,
58 | new_value: torch.Tensor
59 | ) -> torch.Tensor:
60 | value = torch.where(torch.isfinite(value), value, 0.0)
61 | return (1.0 - weight) * value + weight * new_value
62 |
63 |
64 | class Sinkhorn:
65 |
66 | def __init__(
67 | self,
68 | threshold: float = 1e-3,
69 | inner_iterations: int = 1,
70 | min_iterations: int = 1,
71 | max_iterations: int = 100,
72 | parallel_dual_updates: bool = False,
73 | initializer: Literal["default", "random"] = "default",
74 | **kwargs: Any,
75 | ):
76 | self.threshold = threshold
77 | self.inner_iterations = inner_iterations
78 | self.min_iterations = min_iterations
79 | self.max_iterations = max_iterations
80 | self.momentum = Momentum(inner_iterations=inner_iterations)
81 |
82 | self.parallel_dual_updates = parallel_dual_updates
83 | self.initializer = initializer
84 |
85 | def __call__(
86 | self,
87 | ot_prob: LinearProblem,
88 | init: Tuple[Optional[torch.Tensor], Optional[torch.Tensor]] = (None, None),
89 | compute_error: bool = True,
90 | ) -> torch.Tensor:
91 |
92 | initializer = self.create_initializer()
93 | init_dual_a, init_dual_b = initializer(
94 | ot_prob, *init
95 | )
96 | final_state = self.iterations(ot_prob, (init_dual_a, init_dual_b), compute_error=compute_error)
97 | return self.output_from_state(ot_prob, final_state), final_state
98 |
99 | def create_initializer(self) -> SinkhornInitializer:
100 | if isinstance(self.initializer, SinkhornInitializer):
101 | return self.initializer
102 | if self.initializer == "default":
103 | return DefaultInitializer()
104 | if self.initializer == "random":
105 | return RandomInitializer()
106 | raise NotImplementedError(
107 | f"Initializer `{self.initializer}` is not yet implemented."
108 | )
109 |
110 | def lse_step(
111 | self, ot_prob: LinearProblem, state: SinkhornState,
112 | iteration: int
113 | ) -> SinkhornState:
114 | """Sinkhorn LSE update."""
115 |
116 | w = self.momentum.weight(state, iteration)
117 | tau_a, tau_b = ot_prob.tau_a, ot_prob.tau_b
118 | old_fu, old_gv = state.fu, state.gv
119 |
120 | # update g potential
121 | new_gv = tau_b * ot_prob.update_potential(
122 | old_fu, old_gv, torch.log(ot_prob.b), iteration, dim=0
123 | )
124 | gv = self.momentum(w, old_gv, new_gv)
125 |
126 | if not self.parallel_dual_updates:
127 | old_gv = gv
128 |
129 | # update f potential
130 | new_fu = tau_a * ot_prob.update_potential(
131 | old_fu, old_gv, torch.log(ot_prob.a), iteration, dim=1
132 | )
133 | fu = self.momentum(w, old_fu, new_fu)
134 |
135 | state.fu = fu
136 | state.gv = gv
137 | return state
138 |
139 | def one_iteration(
140 | self, ot_prob: LinearProblem, state: SinkhornState,
141 | iteration: int, compute_error: bool = True
142 | ) -> SinkhornState:
143 |
144 | state = self.lse_step(ot_prob, state, iteration)
145 |
146 | # re-computes error if compute_error is True, else set it to -1.
147 | if compute_error:
148 | err = state.solution_error(
149 | ot_prob,
150 | parallel_dual_updates=self.parallel_dual_updates,
151 | )
152 | else:
153 | err = -1
154 | state.errors[iteration // self.inner_iterations] = err
155 | return state
156 |
157 | def _converged(self, state: SinkhornState, iteration: int) -> bool:
158 | err = state.errors[iteration // self.inner_iterations - 1]
159 | return iteration > self.min_iterations and err < self.threshold
160 |
161 | def _diverged(self, state: SinkhornState, iteration: int) -> bool:
162 | err = state.errors[iteration // self.inner_iterations - 1]
163 | return not torch.isfinite(err)
164 |
165 | def _continue(self, state: SinkhornState, iteration: int) -> bool:
166 | """Continue while not(converged) and not(diverged)."""
167 | return iteration < self.outer_iterations and not self._converged(state, iteration) and not self._diverged(state, iteration)
168 |
169 | @property
170 | def outer_iterations(self) -> int:
171 | """Upper bound on number of times inner_iterations are carried out.
172 | """
173 | return np.ceil(self.max_iterations / self.inner_iterations).astype(int)
174 |
175 | def init_state(
176 | self, init: Tuple[torch.Tensor, torch.Tensor]
177 | ) -> SinkhornState:
178 | """Return the initial state of the loop."""
179 | fu, gv = init
180 | errors = -torch.ones(self.outer_iterations).type_as(fu)
181 | state = SinkhornState(errors=errors, fu=fu, gv=gv)
182 | return state
183 |
184 | def output_from_state(
185 | self, ot_prob: LinearProblem, state: SinkhornState
186 | ) -> torch.Tensor:
187 | """Return the output of the Sinkhorn loop."""
188 | return ot_prob.transport_from_potentials(state.fu, state.gv)
189 |
190 | def iterations(
191 | self, ot_prob: LinearProblem, init: Tuple[torch.Tensor, torch.Tensor], compute_error: bool = True
192 | ) -> SinkhornState:
193 | state = self.init_state(init)
194 | iteration = 0
195 | while self._continue(state, iteration):
196 | state = self.one_iteration(ot_prob, state, iteration, compute_error=compute_error)
197 | iteration += self.inner_iterations
198 | if self._converged(state, iteration):
199 | state.converged_at = iteration
200 | return state
201 |
202 |
203 | if __name__ == "__main__":
204 | from mpot.ot.problem import Epsilon
205 | from torch_robotics.torch_utils.torch_timer import TimerCUDA
206 | epsilon = Epsilon(target=0.05, init=1., decay=0.8)
207 | ot_prob = LinearProblem(
208 | torch.rand((1000, 1000)), epsilon
209 | )
210 | sinkhorn = Sinkhorn()
211 | with TimerCUDA() as t:
212 | W, _ = sinkhorn(ot_prob)
213 | print(t.elapsed)
214 | print(W.shape)
215 |
--------------------------------------------------------------------------------
/mpot/ot/sinkhorn_step.py:
--------------------------------------------------------------------------------
1 | from typing import (
2 | TYPE_CHECKING,
3 | Any,
4 | List,
5 | Callable,
6 | Literal,
7 | Mapping,
8 | NamedTuple,
9 | Optional,
10 | Sequence,
11 | Tuple,
12 | Union,
13 | )
14 |
15 | import torch
16 | from mpot.ot.problem import LinearProblem, Epsilon, SinkhornStepState
17 | from mpot.ot.sinkhorn import Sinkhorn
18 | from mpot.utils.polytopes import POLYTOPE_MAP, get_sampled_polytope_vertices
19 | from mpot.utils.misc import MinMaxCenterScaler
20 |
21 |
22 | class SinkhornStep():
23 | """Sinkhorn Step solver."""
24 |
25 | def __init__(
26 | self,
27 | dim: int,
28 | objective_fn: Callable,
29 | linear_ot_solver: Sinkhorn,
30 | epsilon: Union[Epsilon, float] ,
31 | ent_epsilon: Union[Epsilon, float] = 0.01,
32 | state_scalers: Optional[List[MinMaxCenterScaler]] = None,
33 | polytope_type: str = 'orthoplex',
34 | scale_cost: float = 1.0,
35 | step_radius: float = 1.,
36 | probe_radius: float = 2.,
37 | num_probe: int = 5,
38 | min_iterations: int = 5,
39 | max_iterations: int = 50,
40 | threshold: float = 1e-3,
41 | store_inner_errors: bool = False,
42 | store_outer_evals: bool = False,
43 | store_history: bool = False,
44 | tensor_args: Optional[Mapping[str, Any]] = None,
45 | **kwargs: Any,
46 | ) -> None:
47 | if tensor_args is None:
48 | tensor_args = {'device': 'cpu', 'dtype': torch.float32}
49 | self.tensor_args = tensor_args
50 | self.dim = dim
51 | self.objective_fn = objective_fn
52 |
53 | # Sinkhorn Step params
54 | self.linear_ot_solver = linear_ot_solver
55 | self.polytope_vertices = POLYTOPE_MAP[polytope_type](torch.zeros((self.dim,), **tensor_args))
56 | self.epsilon = epsilon
57 | self.ent_epsilon = ent_epsilon
58 | self.state_scalers = state_scalers
59 | self.step_radius = step_radius
60 | self.probe_radius = probe_radius
61 | self.scale_cost = scale_cost
62 | self.num_probe = num_probe
63 | self.min_iterations = min_iterations
64 | self.max_iterations = max_iterations
65 | self.threshold = threshold
66 | self.store_inner_errors = store_inner_errors
67 | self.store_outer_evals = store_outer_evals
68 | self.store_history = store_history
69 |
70 | # TODO: support non-uniform weights for conditional sinkhorn step
71 |
72 | def init_state(
73 | self,
74 | X_init: torch.Tensor,
75 | ) -> SinkhornStepState:
76 | num_points, dim = X_init.shape
77 | num_iters = self.max_iterations
78 | if self.store_history:
79 | X_history = torch.zeros((num_iters, num_points, self.dim)).type_as(X_init)
80 | else:
81 | X_history = None
82 |
83 | if self.store_outer_evals:
84 | costs = -torch.ones(num_iters).type_as(X_init)
85 | else:
86 | costs = None
87 |
88 | displacement_sqnorms = -torch.ones(num_iters).type_as(X_init)
89 | linear_convergence = -torch.ones(num_iters).type_as(X_init)
90 |
91 | a = torch.ones((num_points,)).type_as(X_init) / num_points # always uniform weights for now
92 |
93 | return SinkhornStepState(
94 | X_init=X_init,
95 | costs=costs,
96 | linear_convergence=linear_convergence,
97 | X_history=X_history,
98 | displacement_sqnorms=displacement_sqnorms,
99 | a=a,
100 | )
101 |
102 | def step(self, state: SinkhornStepState, iteration: int, **kwargs) -> SinkhornStepState:
103 | """Run Sinkhorn Step."""
104 | X = state.X.clone()
105 |
106 | # scale state features into same range
107 | if self.state_scalers is not None:
108 | for scaler in self.state_scalers:
109 | scaler(X)
110 |
111 | eps = self.epsilon.at(iteration) if isinstance(self.epsilon, Epsilon) else self.epsilon
112 | self.step_radius = self.step_radius * (1 - eps)
113 | self.probe_radius = self.probe_radius * (1 - eps)
114 |
115 | # compute sampled polytope vertices
116 | X_vertices, X_probe, vertices = get_sampled_polytope_vertices(X,
117 | polytope_vertices=self.polytope_vertices,
118 | step_radius=self.step_radius,
119 | probe_radius=self.probe_radius,
120 | num_probe=self.num_probe)
121 |
122 | # unscale for cost evaluation
123 | if self.state_scalers is not None:
124 | for scaler in self.state_scalers:
125 | scaler.inverse(X_vertices)
126 | scaler.inverse(X_probe)
127 |
128 | # solve Sinkhorn
129 | optim_dim = X_probe.shape[:-1]
130 | C = self.objective_fn(X_probe, current_trajs=state.X, optim_dim=optim_dim, **kwargs)
131 | ot_prob = LinearProblem(C, epsilon=self.ent_epsilon, a=state.a, scaling_cost=True)
132 | W, res = self.linear_ot_solver(ot_prob)
133 |
134 | # barycentric projection
135 | X_new = torch.einsum('bik,bi->bk', X_vertices, W / state.a.unsqueeze(-1))
136 |
137 | if self.store_outer_evals:
138 | state.costs[iteration] = self.objective_fn.cost(X_new, **kwargs).mean()
139 |
140 | if self.store_history:
141 | state.X_history[iteration] = X_new
142 |
143 | state.linear_convergence[iteration] = res.converged_at
144 | state.displacement_sqnorms[iteration] = torch.square(X_new - state.X).sum()
145 | state.X = X_new
146 | return state
147 |
148 | def _converged(self, state: SinkhornStepState, iteration: int) -> bool:
149 | dqsnorm, i, tol = state.displacement_sqnorms, iteration, self.threshold
150 | return torch.isclose(dqsnorm[i - 2], dqsnorm[i - 1], rtol=tol)
151 |
152 | def _continue(self, state: SinkhornStepState, iteration: int) -> bool:
153 | """Continue while not(converged)"""
154 | return iteration <= self.min_iterations or (not self._converged(state, iteration) and iteration < self.max_iterations)
155 |
--------------------------------------------------------------------------------
/mpot/planner.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Optional, Tuple, List, Callable
2 | import torch
3 | import numpy as np
4 | from mpot.ot.sinkhorn_step import SinkhornStep, SinkhornStepState
5 | from mpot.ot.sinkhorn import Sinkhorn
6 | from mpot.utils.misc import MinMaxCenterScaler
7 | from mpot.gp.gp_prior import BatchGPPrior
8 | from mpot.gp.gp_factor import GPFactor
9 | from mpot.gp.unary_factor import UnaryFactor
10 |
11 |
12 | class MPOT(object):
13 | """Batch First-order Trajectory Optimization with Sinkhorn Step."""
14 |
15 | def __init__(
16 | self,
17 | dim: int,
18 | objective_fn: Callable,
19 | linear_ot_solver: Sinkhorn,
20 | ss_params: dict,
21 | traj_len: int = 64,
22 | num_particles_per_goal: int = 16,
23 | dt: float = 0.02,
24 | start_state: torch.Tensor = None,
25 | multi_goal_states: torch.Tensor = None,
26 | initial_particle_means: torch.Tensor = None,
27 | pos_limits=[-10, 10],
28 | vel_limits=[-10, 10],
29 | polytope: str = 'orthoplex',
30 | fixed_start: bool = True,
31 | fixed_goal: bool = False,
32 | sigma_start_init: float = 0.001,
33 | sigma_goal_init: float = 1.,
34 | sigma_gp_init: float = 0.001,
35 | seed: int = 0,
36 | tensor_args=None,
37 | **kwargs
38 | ):
39 | self.dim = dim
40 | self.state_dim = dim * 2
41 | self.traj_len = traj_len
42 | self.dt = dt
43 | self.seed = seed
44 | if tensor_args is None:
45 | tensor_args = {'device': torch.device('cpu'), 'dtype': torch.float32}
46 | self.tensor_args = tensor_args
47 | np.random.seed(seed)
48 | torch.manual_seed(seed)
49 |
50 | self.start_state = start_state
51 | self.multi_goal_states = multi_goal_states
52 | if multi_goal_states is None: # NOTE: if there is no goal, we assume here is at least one goal
53 | self.num_goals = 1
54 | else:
55 | assert multi_goal_states.ndim == 2
56 | self.num_goals = multi_goal_states.shape[0]
57 | self.num_particles_per_goal = num_particles_per_goal
58 | self.num_particles = num_particles_per_goal * self.num_goals
59 | self.polytope = polytope
60 | self.fixed_start = fixed_start
61 | self.fixed_goal = fixed_goal
62 | self.sigma_start_init = sigma_start_init
63 | self.sigma_goal_init = sigma_goal_init
64 | self.sigma_gp_init = sigma_gp_init
65 | self._traj_dist = None
66 |
67 | self.reset(initial_particle_means=initial_particle_means)
68 | # scaling operations
69 | if isinstance(pos_limits, torch.Tensor):
70 | self.pos_limits = pos_limits.clone().to(**self.tensor_args)
71 | else:
72 | self.pos_limits = torch.tensor(pos_limits, **self.tensor_args)
73 | if self.pos_limits.ndim == 1:
74 | self.pos_limits = self.pos_limits.unsqueeze(0).repeat(self.dim, 1)
75 | self.pos_scaler = MinMaxCenterScaler(dim_range=[0, self.dim], min=self.pos_limits[:, 0], max=self.pos_limits[:, 1])
76 | if isinstance(vel_limits, torch.Tensor):
77 | self.vel_limits = vel_limits.clone().to(**self.tensor_args)
78 | else:
79 | self.vel_limits = torch.tensor(vel_limits, **self.tensor_args)
80 | if self.vel_limits.ndim == 1:
81 | self.vel_limits = self.vel_limits.unsqueeze(0).repeat(self.dim, 1)
82 | self.vel_scaler = MinMaxCenterScaler(dim_range=[self.dim, self.state_dim], min=self.vel_limits[:, 0], max=self.vel_limits[:, 1])
83 |
84 | # init solver
85 | self.ss_params = ss_params
86 | self.sinkhorn_step = SinkhornStep(
87 | self.state_dim,
88 | objective_fn=objective_fn,
89 | linear_ot_solver=linear_ot_solver,
90 | state_scalers=[self.pos_scaler, self.vel_scaler],
91 | **self.ss_params,
92 | )
93 |
94 | def reset(
95 | self,
96 | start_state: torch.Tensor = None,
97 | multi_goal_states: torch.Tensor = None,
98 | initial_particle_means: torch.Tensor = None,
99 | ):
100 | if start_state is not None:
101 | self.start_state = start_state.detach().clone()
102 | assert self.start_state.shape[-1] == self.state_dim, "start_state dimension should be dim * 2"
103 |
104 | if multi_goal_states is not None:
105 | self.multi_goal_states = multi_goal_states.detach().clone()
106 | assert self.multi_goal_states.shape[-1] == self.state_dim, "multi_goal_states dimension should be dim * 2"
107 |
108 | self.get_prior_dists(initial_particle_means=initial_particle_means)
109 |
110 | def get_prior_dists(self, initial_particle_means: torch.Tensor = None):
111 | if initial_particle_means is None:
112 | self.init_trajs = self.get_random_trajs()
113 | else:
114 | self.init_trajs = initial_particle_means.clone()
115 | self.flatten_trajs = self.init_trajs.flatten(0, 1)
116 |
117 | def get_GP_prior(
118 | self,
119 | start_K: torch.Tensor,
120 | gp_K: torch.Tensor,
121 | goal_K: torch.Tensor,
122 | state_init: torch.Tensor,
123 | particle_means: torch.Tensor = None,
124 | goal_states: torch.Tensor = None,
125 | tensor_args=None,
126 | ) -> BatchGPPrior:
127 | if tensor_args is None:
128 | tensor_args = self.tensor_args
129 | return BatchGPPrior(
130 | self.traj_len - 1,
131 | self.dt,
132 | self.dim,
133 | start_K,
134 | gp_K,
135 | state_init,
136 | means=particle_means,
137 | K_g_inv=goal_K,
138 | goal_states=goal_states,
139 | tensor_args=tensor_args,
140 | )
141 |
142 | def get_random_trajs(self) -> torch.Tensor:
143 | # force torch.float64
144 | tensor_args = dict(dtype=torch.float64, device=self.tensor_args['device'])
145 | # set zero velocity for GP prior
146 | start_state = self.start_state.to(**tensor_args)
147 | if self.multi_goal_states is not None:
148 | multi_goal_states = self.multi_goal_states.to(**tensor_args)
149 | else:
150 | multi_goal_states = None
151 | #========= Initialization factors ===============
152 | self.start_prior_init = UnaryFactor(
153 | self.dim * 2,
154 | self.sigma_start_init,
155 | start_state,
156 | tensor_args=tensor_args,
157 | )
158 |
159 | self.gp_prior_init = GPFactor(
160 | self.dim,
161 | self.sigma_gp_init,
162 | self.dt,
163 | self.traj_len - 1,
164 | tensor_args=tensor_args,
165 | )
166 |
167 | self.multi_goal_prior_init = []
168 | if multi_goal_states is not None:
169 | for i in range(self.num_goals):
170 | self.multi_goal_prior_init.append(
171 | UnaryFactor(
172 | self.dim * 2,
173 | self.sigma_goal_init,
174 | multi_goal_states[i],
175 | tensor_args=tensor_args,
176 | )
177 | )
178 | self._traj_dist = self.get_GP_prior(
179 | self.start_prior_init.K,
180 | self.gp_prior_init.Q_inv[0],
181 | self.multi_goal_prior_init[0].K if multi_goal_states is not None else None,
182 | start_state[..., :self.dim],
183 | goal_states=multi_goal_states[..., :self.dim],
184 | tensor_args=tensor_args,
185 | )
186 | particles = self._traj_dist.sample(self.num_particles_per_goal).to(**tensor_args)
187 | self.traj_dim = particles.shape
188 | del self._traj_dist # free memory
189 | return particles.flatten(0, 1).to(**self.tensor_args)
190 |
191 | def optimize(self) -> Tuple[torch.Tensor, SinkhornStepState, int]:
192 | state = self.sinkhorn_step.init_state(self.flatten_trajs)
193 | iteration = 0
194 | while self.sinkhorn_step._continue(state, iteration):
195 | state = self.sinkhorn_step.step(state, iteration, traj_dim=self.traj_dim)
196 | trajs = state.X.view(self.traj_dim)
197 | # option to hard fixing start and goal states
198 | if self.fixed_start:
199 | trajs[:, :, 0, :] = self.start_state
200 | if self.fixed_goal:
201 | trajs[:, :, -1, :] = self.multi_goal_states.unsqueeze(1)
202 | state.X = trajs.view(-1, self.state_dim)
203 | iteration += 1
204 |
205 | trajs = state.X.view(self.traj_dim)
206 | return trajs, state, iteration
207 |
--------------------------------------------------------------------------------
/mpot/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/anindex/mpot/a5c49d4fa83d7601bcd91246e07e254b813dc322/mpot/utils/__init__.py
--------------------------------------------------------------------------------
/mpot/utils/misc.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from typing import Any, Optional, List, Tuple
3 |
4 |
5 | # min max scaler
6 | class MinMaxScaler():
7 | def __init__(self, min: float = None, max: float = None):
8 | self.min = min
9 | self.max = max
10 |
11 | def __call__(self, X: torch.Tensor) -> torch.Tensor:
12 | if self.min is None:
13 | self.min = X.min()
14 | if self.max is None:
15 | self.max = X.max()
16 | return (X - self.min) / (self.max - self.min)
17 |
18 | def inverse(self, X: torch.Tensor) -> torch.Tensor:
19 | return X * (self.max - self.min) + self.min
20 |
21 |
22 | # scale to [-1, 1]
23 | class MinMaxCenterScaler():
24 | def __init__(self, dim_range: List[float], min: float = -1, max: float = 1):
25 | self.dim_range = dim_range
26 | self.dim = dim_range[1] - dim_range[0]
27 | self.min = min
28 | self.max = max
29 |
30 | def __call__(self, X: torch.Tensor):
31 | X[..., self.dim_range[0]:self.dim_range[1]] = 2 * (X[..., self.dim_range[0]:self.dim_range[1]] - self.min) / (self.max - self.min) - 1
32 |
33 | def inverse(self, X: torch.Tensor):
34 | X[..., self.dim_range[0]:self.dim_range[1]] = (X[..., self.dim_range[0]:self.dim_range[1]] + 1) * (self.max - self.min) / 2 + self.min
35 |
36 |
37 | # min max mean scaler
38 | class MinMaxMeanScaler():
39 | '''For torch tensors, NOTE: all are in-place operations'''
40 | def __init__(self, dim_range: List[float], min: float = -1, max: float = 1, mean: torch.Tensor = None):
41 | self.min = min
42 | self.max = max
43 | self.mean = mean
44 | self.dim_range = dim_range
45 | self.dim = dim_range[1] - dim_range[0]
46 |
47 | def __call__(self, X: torch.Tensor):
48 | if self.mean is None:
49 | self.mean = X[..., self.dim_range[0]:self.dim_range[1]].view((-1, self.dim)).mean(0)
50 | X[..., self.dim_range[0]:self.dim_range[1]] = (X[..., self.dim_range[0]:self.dim_range[1]] - self.mean) / (self.max - self.min)
51 |
52 | def inverse(self, X: torch.Tensor):
53 | # clamp min max of input
54 | # X[..., self.dim_range[0]:self.dim_range[1]] = torch.clamp(X[..., self.dim_range[0]:self.dim_range[1]], -1., 1.)
55 | X[..., self.dim_range[0]:self.dim_range[1]] = X[..., self.dim_range[0]:self.dim_range[1]] * (self.max - self.min) + self.mean
56 |
57 |
58 | # STANDARD SCALER
59 | class StandardScaler():
60 | def __init__(self, mean, std):
61 | self.mean = mean
62 | self.std = std
63 |
64 | def __call__(self, X: torch.Tensor) -> torch.Tensor:
65 | return (X - self.mean) / self.std
66 |
67 | def inverse(self, X: torch.Tensor) -> torch.Tensor:
68 | return X * self.std + self.mean
69 |
--------------------------------------------------------------------------------
/mpot/utils/polytopes.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from typing import Any, Optional, Tuple
3 | import numpy as np
4 | from itertools import product
5 | import math
6 | from mpot.utils.probe import get_random_probe_points, get_probe_points
7 | from mpot.utils.rotation import get_random_maximal_torus_matrix
8 |
9 |
10 | def get_cube_vertices(origin: torch.Tensor, radius: float = 1., **kwargs) -> torch.Tensor:
11 | dim = origin.shape[-1]
12 | points = torch.tensor(list(product([1, -1], repeat=dim))) / np.sqrt(dim)
13 | points = points.type_as(origin) * radius + origin
14 | return points
15 |
16 |
17 | def get_orthoplex_vertices(origin: torch.Tensor, radius: float = 1., **kwargs) -> torch.Tensor:
18 | dim = origin.shape[-1]
19 | points = torch.zeros((2 * dim, dim)).type_as(origin)
20 | first = torch.arange(0, dim)
21 | second = torch.arange(dim, 2 * dim)
22 | points[first, first] = radius
23 | points[second, first] = -radius
24 | points = points + origin
25 | return points
26 |
27 |
28 | def get_simplex_vertices(origin: torch.Tensor, radius: float = 1., **kwargs) -> torch.Tensor:
29 | '''
30 | Simplex coordinates: https://en.wikipedia.org/wiki/Simplex#Cartesian_coordinates_for_a_regular_n-dimensional_simplex_in_Rn
31 | '''
32 | dim = origin.shape[-1]
33 | points = math.sqrt(1 + 1/dim) * torch.eye(dim) - ((math.sqrt(dim + 1) + 1) / math.sqrt(dim ** 3)) * torch.ones((dim, dim))
34 | points = torch.concatenate([points, (1 / math.sqrt(dim)) * torch.ones((1, dim))], dim=0)
35 | points = points.type_as(origin) * radius + origin
36 | return points
37 |
38 |
39 | def get_sampled_polytope_vertices(origin: torch.Tensor,
40 | polytope_vertices: torch.Tensor,
41 | step_radius: float = 1.,
42 | probe_radius: float = 2.,
43 | num_probe: int = 5,
44 | **kwargs) -> Tuple[torch.Tensor]:
45 | if origin.ndim == 1:
46 | origin = origin.unsqueeze(0)
47 | batch, dim = origin.shape
48 | polytope_vertices = polytope_vertices.repeat(batch, 1, 1) # [batch, num_vertices, dim]
49 |
50 | # batch random polytope
51 | maximal_torus_mat = get_random_maximal_torus_matrix(origin)
52 | polytope_vertices = polytope_vertices @ maximal_torus_mat
53 | step_points = polytope_vertices * step_radius + origin.unsqueeze(1) # [batch, num_vertices, dim]
54 | probe_points = get_probe_points(origin, polytope_vertices, probe_radius, num_probe) # [batch, num_vertices, num_probe, dim]
55 | return step_points, probe_points, polytope_vertices
56 |
57 |
58 | def get_sampled_points_on_sphere(origin: torch.Tensor,
59 | step_radius: float = 1.,
60 | probe_radius: float = 2.,
61 | num_probe: int = 5,
62 | num_sphere_point: int = 50,
63 | random_probe: bool = False, **kwargs) -> Tuple[torch.Tensor]:
64 | if origin.ndim == 1:
65 | origin = origin.unsqueeze(0)
66 | batch, dim = origin.shape
67 | # marsaglia method
68 | points = torch.randn(batch, num_sphere_point, dim).type_as(origin) # [batch, num_points, dim]
69 | points = points / points.norm(dim=-1, keepdim=True)
70 | step_points = points * step_radius + origin.unsqueeze(1) # [batch, num_points, dim]
71 | if random_probe:
72 | probe_points = get_random_probe_points(origin, points, probe_radius, num_probe)
73 | else:
74 | probe_points = get_probe_points(origin, points, probe_radius, num_probe) # [batch, 2 * dim, num_probe, dim]
75 | return step_points, probe_points, points
76 |
77 |
78 | POLYTOPE_MAP = {
79 | 'cube': get_cube_vertices,
80 | 'orthoplex': get_orthoplex_vertices,
81 | 'simplex': get_simplex_vertices,
82 | }
83 |
84 | POLYTOPE_NUM_VERTICES_MAP = {
85 | 'cube': lambda dim: 2 ** dim,
86 | 'orthoplex': lambda dim: 2 * dim,
87 | 'simplex': lambda dim: dim + 1,
88 | }
89 |
90 | SAMPLE_POLYTOPE_MAP = {
91 | 'polytope': get_sampled_polytope_vertices,
92 | 'random': get_sampled_points_on_sphere,
93 | }
94 |
--------------------------------------------------------------------------------
/mpot/utils/probe.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def get_random_probe_points(origin: torch.Tensor,
5 | points: torch.Tensor,
6 | probe_radius: float = 2.,
7 | num_probe: int = 5) -> torch.Tensor:
8 | batch, num_points, dim = points.shape
9 | alpha = torch.rand(batch, num_points, num_probe, 1).type_as(points)
10 | probe_points = points * probe_radius
11 | probe_points = probe_points.unsqueeze(-2) * alpha + origin.unsqueeze(1).unsqueeze(1) # [batch, num_points, num_probe, dim]
12 | return probe_points
13 |
14 |
15 | def get_probe_points(origin: torch.Tensor,
16 | points: torch.Tensor,
17 | probe_radius: float = 2.,
18 | num_probe: int = 5) -> torch.Tensor:
19 | alpha = torch.linspace(0, 1, num_probe + 2).type_as(points)[1:num_probe + 1].view(1, 1, -1, 1)
20 | probe_points = points * probe_radius
21 | probe_points = probe_points.unsqueeze(-2) * alpha + origin.unsqueeze(1).unsqueeze(1) # [batch, num_points, num_probe, dim]
22 | return probe_points
23 |
24 |
25 | def get_shifted_points(new_origins: torch.Tensor, points: torch.Tensor) -> torch.Tensor:
26 | '''
27 | Args:
28 | new_origins: [no, dim]
29 | points: [nb, dim]
30 | Returns:
31 | shifted_points: [no, nb, dim]
32 | '''
33 | # asumming points has centroid at origin
34 | shifted_points = points + new_origins.unsqueeze(1)
35 | return shifted_points
36 |
37 |
38 | def get_projecting_points(X1: torch.Tensor, X2: torch.Tensor, probe_step_size: float, num_probe: int = 5) -> torch.Tensor:
39 | '''
40 | X1: [nb1 x dim]
41 | X2: [nb2 x dim] or [nb1 x nb2 x dim]
42 | return [nb1 x nb2 x num_probe x dim]
43 | '''
44 | if X2.ndim == 2:
45 | X1 = X1.unsqueeze(1).unsqueeze(-2)
46 | X2 = X2.unsqueeze(0).unsqueeze(-2)
47 | elif X2.ndim == 3:
48 | assert X2.shape[0] == X1.shape[0]
49 | X1 = X1.unsqueeze(1).unsqueeze(-2)
50 | X2 = X2.unsqueeze(-2)
51 | alpha = torch.arange(1, num_probe + 1).type_as(X1) * probe_step_size
52 | alpha = alpha.view(1, 1, -1, 1)
53 | points = X1 + (X2 - X1) * alpha
54 | return points
55 |
56 |
57 | if __name__ == '__main__':
58 | X1 = torch.tensor([
59 | [0, 0],
60 | [2, 2]
61 | ], dtype=torch.float32)
62 | X2 = torch.tensor([
63 | [0, 2],
64 | [2, 4]
65 | ], dtype=torch.float32)
66 | print(get_projecting_points(X1, X2, 0.5, 1))
67 |
--------------------------------------------------------------------------------
/mpot/utils/rotation.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def rotation_matrix(theta: torch.Tensor) -> torch.Tensor:
5 | theta = theta.unsqueeze(1).unsqueeze(1)
6 | dim1 = torch.cat([torch.cos(theta), -torch.sin(theta)], dim=2)
7 | dim2 = torch.cat([torch.sin(theta), torch.cos(theta)], dim=2)
8 | mat = torch.cat([dim1, dim2], dim=1)
9 | return mat
10 |
11 |
12 | def get_random_maximal_torus_matrix(origin: torch.Tensor,
13 | angle_range=[0, 2 * torch.pi], **kwargs) -> torch.Tensor:
14 | batch, dim = origin.shape
15 | assert dim % 2 == 0, 'Only work with even dim for random rotation for now.'
16 | theta = torch.rand(dim // 2, batch).type_as(origin) * (angle_range[1] - angle_range[0]) + angle_range[0] # [batch, dim // 2]
17 | rot_mat = torch.vmap(rotation_matrix)(theta).transpose(0, 1)
18 | # make batch block diag
19 | max_torus_mat = torch.diag_embed(rot_mat[:, :, [0, 1], [0, 1]].flatten(-2, -1), offset=0)
20 | even, odd = torch.arange(0, dim, 2), torch.arange(1, dim, 2)
21 | max_torus_mat[:, even, odd] = rot_mat[:, :, 0, 1]
22 | max_torus_mat[:, odd, even] = rot_mat[:, :, 1, 0]
23 | return max_torus_mat
24 |
25 |
26 | def get_random_uniform_rot_matrix(origin: torch.Tensor,
27 | **kwargs) -> torch.Tensor:
28 | """Compute a uniformly random rotation matrix drawn from the Haar distribution
29 | (the only uniform distribution on SO(n)). This is less efficient than maximal torus.
30 | See: Stewart, G.W., "The efficient generation of random orthogonal
31 | matrices with an application to condition estimators", SIAM Journal
32 | on Numerical Analysis, 17(3), pp. 403-409, 1980.
33 | For more information see
34 | http://en.wikipedia.org/wiki/Orthogonal_matrix#Randomization"""
35 |
36 | batch, dim = origin.shape
37 | H = torch.eye(dim).repeat(batch, 1, 1).type_as(origin)
38 | D = torch.ones((batch, dim)).type_as(origin)
39 | for i in range(1, dim):
40 | v = torch.normal(0, 1., size=(batch, dim - i + 1)).type_as(origin)
41 | D[:, i - 1] = torch.sign(v[:, 0])
42 | v[:, 0] -= D[:, i - 1] * torch.norm(v, dim=-1)
43 | # Householder transformation
44 | outer = v.unsqueeze(-2) * v.unsqueeze(-1)
45 | Hx = torch.eye(dim - i + 1).repeat(batch, 1, 1).type_as(origin) - 2 * outer / torch.square(v).sum(dim=-1).unsqueeze(-1).unsqueeze(-1)
46 | T = torch.eye(dim).repeat(batch, 1, 1).type_as(origin)
47 | T[:, i - 1:, i - 1:] = Hx
48 | H = torch.matmul(H, T)
49 | # Fix the last sign such that the determinant is 1
50 | D[:, -1] = (-1)**(1 - dim % 2) * D.prod(dim=-1)
51 | R = (D.unsqueeze(-1) * H.mT).mT
52 | return R
53 |
54 |
55 | if __name__ == '__main__':
56 | from torch_robotics.torch_utils.torch_timer import TimerCUDA
57 | origin = torch.zeros(100, 8).cuda()
58 | with TimerCUDA() as t:
59 | rot_mat = get_random_maximal_torus_matrix(origin)
60 | print(t.elapsed)
61 | print(rot_mat.shape)
62 | with TimerCUDA() as t:
63 | rot_mat = get_random_uniform_rot_matrix(origin)
64 | print(t.elapsed)
65 | print(rot_mat.shape)
66 | print(torch.det(rot_mat))
67 |
--------------------------------------------------------------------------------
/mpot/utils/trajectory.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def interpolate_trajectory(trajs: torch.Tensor, num_interpolation: int = 3) -> torch.Tensor:
5 | # Interpolates a trajectory linearly between waypoints
6 | dim = trajs.shape[-1]
7 | if num_interpolation > 0:
8 | assert trajs.ndim > 1
9 | traj_dim = trajs.shape
10 | alpha = torch.linspace(0, 1, num_interpolation + 2).type_as(trajs)[1:num_interpolation + 1]
11 | alpha = alpha.view((1,) * len(traj_dim[:-1]) + (-1, 1))
12 | interpolated_trajs = trajs[..., 0:traj_dim[-2] - 1, None, :] * alpha + \
13 | trajs[..., 1:traj_dim[-2], None, :] * (1 - alpha)
14 | interpolated_trajs = interpolated_trajs.view(traj_dim[:-2] + (-1, dim))
15 | else:
16 | interpolated_trajs = trajs
17 | return interpolated_trajs
18 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | scipy
2 | numpy
3 | matplotlib
4 | labmaze
5 | torch
6 | einops
7 | torch_robotics @ git+https://github.com/anindex/torch_robotics.git@68d28399ff67f34ca3be0130acdc7cba2fabd61b
8 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | import setuptools
2 | from codecs import open
3 | from os import path
4 |
5 |
6 | ext_modules = []
7 |
8 | here = path.abspath(path.dirname(__file__))
9 | requires_list = []
10 | with open(path.join(here, 'requirements.txt'), encoding='utf-8') as f:
11 | for line in f:
12 | requires_list.append(str(line))
13 |
14 | setuptools.setup(
15 | name='mpot',
16 | description="Implementation of MPOT in PyTorch",
17 | author="An T. Le",
18 | author_email="an@robot-learning.de",
19 | packages=setuptools.find_namespace_packages(),
20 | install_requires=requires_list,
21 | )
22 |
--------------------------------------------------------------------------------