├── .devcontainer
└── devcontainer.json
├── .gitignore
├── .vscode
└── settings.json
├── Dockerfile
├── LICENSE
├── Makefile
├── README.md
├── app
├── catpole.py
├── mountaincar.py
├── mujoco_cartpole.py
├── navigation2d.py
└── pendulum.py
├── media
├── cartpole.gif
├── mountaincar.gif
├── navigation_2d.gif
└── pendulum.gif
├── poetry.lock
├── pyproject.toml
├── src
├── __init__.py
├── controller
│ ├── __init__.py
│ └── mppi.py
└── envs
│ ├── __init__.py
│ ├── navigation_2d.py
│ └── obstacle_map_2d.py
└── tests
├── __init__.py
├── test_brax.py
├── test_gui.py
├── test_mujoco.py
└── test_torch.py
/.devcontainer/devcontainer.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "dev",
3 | "dockerFile": "../Dockerfile",
4 | "settings": {
5 | "terminal.integrated.shell.linux": "/bin/bash"
6 | },
7 | "extensions": [
8 | "ms-python.python",
9 | "ms-vscode-remote.remote-containers",
10 | "ms-python.vscode-pylance",
11 | "GitHub.copilot",
12 | "ms-python.black-formatter",
13 | "ms-python.flake8",
14 | ],
15 | "runArgs": [
16 | "--gpus", "all",
17 | "--shm-size", "10G",
18 | ]
19 | }
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | #.idea/
161 |
162 |
163 | # project specific
164 | video/
--------------------------------------------------------------------------------
/.vscode/settings.json:
--------------------------------------------------------------------------------
1 | {
2 | "C_Cpp.default.configurationProvider": "ms-vscode.makefile-tools",
3 | "python.analysis.include": [
4 | "src/**",
5 | "tests/**"
6 | ],
7 | }
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM nvidia/cuda:11.8.0-devel-ubuntu20.04
2 | ENV DEBIAN_FRONTEND=noninteractive
3 |
4 | RUN apt-get -y update && apt-get -y install --no-install-recommends\
5 | software-properties-common\
6 | libgl1-mesa-dev\
7 | libgl1-mesa-glx \
8 | libglew-dev \
9 | libosmesa6-dev \
10 | wget\
11 | libssl-dev\
12 | curl\
13 | git\
14 | x11-apps \
15 | swig \
16 | patchelf
17 |
18 | # Python (version 3.10)
19 | RUN add-apt-repository ppa:deadsnakes/ppa && \
20 | apt-get update && apt-get install -y \
21 | python3.10 \
22 | python3.10-dev \
23 | python3.10-venv \
24 | python3.10-distutils \
25 | python3.10-tk
26 |
27 | RUN update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.10 1
28 | RUN curl -sS https://bootstrap.pypa.io/get-pip.py | python3.10
29 | RUN pip3 install --upgrade pip
30 | RUN pip3 install -U pip distlib setuptools wheel
31 |
32 | # vnc
33 | RUN apt-get install -y xvfb x11vnc icewm lsof net-tools
34 | RUN echo "alias vnc='PASSWORD=\$(openssl rand -hex 24); for i in {99..0}; do export DISPLAY=:\$i; if ! xdpyinfo &>/dev/null; then break; fi; done; for i in {5999..5900}; do if ! netstat -tuln | grep -q \":\$i \"; then PORT=\$i; break; fi; done; Xvfb \$DISPLAY -screen 0 1400x900x24 & until xdpyinfo > /dev/null 2>&1; do sleep 0.1; done; x11vnc -forever -noxdamage -display \$DISPLAY -rfbport \$PORT -passwd \$PASSWORD > /dev/null 2>&1 & until lsof -i :\$PORT > /dev/null; do sleep 0.1; done; icewm-session & echo DISPLAY=\$DISPLAY, PORT=\$PORT, PASSWORD=\$PASSWORD'" >> ~/.bashrc
35 |
36 | # utils
37 | RUN apt-get update && apt-get install -y htop vim ffmpeg
38 | # RUN pip3 install jupyterlab ipywidgets && \
39 | # echo 'alias jup="jupyter lab --ip 0.0.0.0 --port 8888 --allow-root &"' >> /root/.bashrc
40 |
41 | # clear cache
42 | RUN rm -rf /var/lib/apt/lists/*
43 |
44 | # pytorch 2.0
45 | RUN pip3 install torch==2.0.0+cu118 torchvision==0.15.1+cu118 torchaudio==2.0.1 --extra-index-url https://download.pytorch.org/whl/cu118
46 |
47 | # mujoco 210
48 | RUN curl -o /usr/local/bin/patchelf https://s3-us-west-2.amazonaws.com/openai-sci-artifacts/manual-builds/patchelf_0.9_amd64.elf \
49 | && chmod +x /usr/local/bin/patchelf
50 | RUN mkdir -p /root/.mujoco \
51 | && wget https://mujoco.org/download/mujoco210-linux-x86_64.tar.gz -O mujoco.tar.gz \
52 | && tar -xf mujoco.tar.gz -C /root/.mujoco \
53 | && rm mujoco.tar.gz
54 | ENV LD_LIBRARY_PATH /root/.mujoco/mujoco210/bin:${LD_LIBRARY_PATH}
55 |
56 | WORKDIR /workspace
57 | COPY src/ src/
58 | COPY pyproject.toml .
59 |
60 | RUN pip3 install -e .[dev]
61 |
62 | CMD ["bash"]
63 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 kohonda
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | NAME=mppi_playground
2 | VERSION=0.0.1
3 | DOCKER_IMAGE_NAME=$(NAME):$(VERSION)
4 | CONTAINER_NAME=$(NAME)
5 | GPU_ID=all
6 |
7 | build:
8 | docker build -t $(DOCKER_IMAGE_NAME) .
9 |
10 | bash:
11 | xhost +local:docker && \
12 | docker run -it \
13 | --gpus '"device=${GPU_ID}"' \
14 | -v ${PWD}/workspace \
15 | -v ${PWD}:/workspace/$(NAME) \
16 | --rm \
17 | --shm-size 10G \
18 | -v /tmp/.X11-unix:/tmp/.X11-unix \
19 | -e DISPLAY \
20 | -p 5900:5900 \
21 | --name $(CONTAINER_NAME)-bash \
22 | $(DOCKER_IMAGE_NAME) \
23 | bash
24 |
25 | bash-wo-gpu:
26 | docker run -it \
27 | -v ${PWD}/workspace \
28 | -v ${PWD}:/workspace/$(NAME) \
29 | --rm \
30 | --shm-size 10G \
31 | --name $(CONTAINER_NAME)-bash \
32 | $(DOCKER_IMAGE_NAME) \
33 | bash
34 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # MPPI Playground
2 | This repository contains an implementation of [Model Predictive Path Integral Control (MPPI)](https://arxiv.org/abs/1707.02342) with PyTorch to accelerate computations on the GPU.
3 |
4 | ## Tested Native Environment
5 | - Ubuntu Focal 20.04 (LTS)
6 | - NVIDIA Driver 510 or later due to PyTorch 2.x
7 |
8 | ## Dependencies
9 | - cuda 11.8
10 | - Python 3.10
11 | - PyTorch 2.0
12 |
13 |
14 | Docker Setup
15 |
16 | ### Install Docker
17 |
18 | [Installation guide](https://docs.docker.com/engine/install/ubuntu/#install-using-the-repository)
19 |
20 | ```bash
21 | # Install from get.docker.com
22 | curl -fsSL https://get.docker.com -o get-docker.sh
23 | sudo sh get-docker.sh
24 | sudo groupadd docker
25 | sudo usermod -aG docker $USER
26 | ```
27 |
28 |
29 | ### Setup GPU for Docker
30 | [Installation guide](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html)
31 | ```bash
32 | curl -fsSL https://nvidia.github.io/libnvidia-container/gpgkey | sudo gpg --dearmor -o /usr/share/keyrings/nvidia-container-toolkit-keyring.gpg \
33 | && curl -s -L https://nvidia.github.io/libnvidia-container/stable/deb/nvidia-container-toolkit.list | \
34 | sed 's#deb https://#deb [signed-by=/usr/share/keyrings/nvidia-container-toolkit-keyring.gpg] https://#g' | \
35 | sudo tee /etc/apt/sources.list.d/nvidia-container-toolkit.list
36 |
37 | sudo apt-get update
38 |
39 | sudo apt-get install -y nvidia-container-toolkit nvidia-container-runtime
40 |
41 | sudo nvidia-ctk runtime configure --runtime=docker
42 |
43 | sudo systemctl restart docker
44 | ```
45 |
46 |
47 | ## Installation
48 |
49 | ### with Docker (Recommend)
50 |
51 | ```bash
52 | # build container
53 | make build
54 |
55 | # Open remote container via Vscode (Recommend)
56 | # 1. Open the folder using vscode
57 | # 2. Ctrl+P and select 'devcontainer rebuild and reopen in container'
58 | # Then, you can skip the following commands
59 |
60 | # Or Run container via terminal
61 | make bash
62 | ```
63 |
64 | ### with venv
65 |
66 | ```bash
67 | python3 -m venv .venv
68 | source .venv/bin/activate
69 | pip3 install -e .[dev]
70 | ```
71 |
72 | ## Examples
73 |
74 | ### Navigation 2D
75 | ```bash
76 | python3 app/navigation2d.py
77 | ```
78 |
79 |
80 |
81 |
82 | ### Pendulum
83 | ```bash
84 | python3 app/pendulum.py
85 | ```
86 |
87 |
88 |
89 |
90 | ### Cartpole
91 | ```bash
92 | python3 app/cartpole.py
93 | ```
94 |
95 |
96 |
97 |
98 | ### Mountain car
99 | ```bash
100 | python3 app/mountaincar.py
101 | ```
102 |
103 |
104 |
105 |
106 | ## Reference
107 | - [pytorch_mppi](https://github.com/UM-ARM-Lab/pytorch_mppi)
--------------------------------------------------------------------------------
/app/catpole.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | import time
4 | import gymnasium
5 | import fire
6 | import numpy as np
7 |
8 | from controller.mppi import MPPI
9 |
10 |
11 | @torch.jit.script
12 | def angle_normalize(x):
13 | return ((x + torch.pi) % (2 * torch.pi)) - torch.pi
14 |
15 |
16 | def main(save_mode: bool = False):
17 | # dynamics and cost
18 | @torch.jit.script
19 | def dynamics(state: torch.Tensor, action: torch.Tensor) -> torch.Tensor:
20 | """
21 | Args:
22 | state (torch.Tensor): [x, x_dt, theta, theta_dt]
23 | action (torch.Tensor): [-1, 1]
24 | """
25 | # dynamics from gymnasium
26 | x = state[:, 0].view(-1, 1)
27 | x_dt = state[:, 1].view(-1, 1)
28 | theta = state[:, 2].view(-1, 1)
29 | theta_dt = state[:, 3].view(-1, 1)
30 |
31 | gravity = 9.8
32 | masscart = 1.0
33 | masspole = 0.1
34 | total_mass = masspole + masscart
35 | length = 0.5 # actually half the pole's length
36 | polemass_length = masspole * length
37 | force_mag = 10.0
38 | tau = 0.02 # seconds between state updates
39 |
40 | # convert continuous action to discrete action
41 | # because MPPI only can handle continuous action
42 | continuous_action = action[:, 0].view(-1, 1)
43 | force = torch.zeros_like(continuous_action)
44 | force[continuous_action >= 0] = force_mag
45 | force[continuous_action < 0] = -force_mag
46 |
47 | costheta = torch.cos(theta)
48 | sintheta = torch.sin(theta)
49 |
50 | temp = (force + polemass_length * theta_dt**2 * sintheta) / total_mass
51 | thetaacc = (gravity * sintheta - costheta * temp) / (
52 | length * (4.0 / 3.0 - masspole * costheta**2 / total_mass)
53 | )
54 | xacc = temp - polemass_length * thetaacc * costheta / total_mass
55 |
56 | newx = x + tau * x_dt
57 | newx_dt = x_dt + tau * xacc
58 | newtheta = theta + tau * theta_dt
59 | newtheta_dt = theta_dt + tau * thetaacc
60 |
61 | x_threshold = 2.4
62 | theta_threshold_radians = 12 * 2 * torch.pi / 360
63 | newx = torch.clamp(newx, -x_threshold, x_threshold)
64 | newtheta = torch.clamp(
65 | newtheta, -theta_threshold_radians, theta_threshold_radians
66 | )
67 |
68 | new_state = torch.cat((newx, newx_dt, newtheta, newtheta_dt), dim=1)
69 |
70 | return new_state
71 |
72 | @torch.jit.script
73 | def stage_cost(state: torch.Tensor, action: torch.Tensor) -> torch.Tensor:
74 | x = state[:, 0]
75 | x_dt = state[:, 1]
76 | theta = state[:, 2]
77 | theta_dt = state[:, 3]
78 |
79 | normlized_theta = angle_normalize(theta)
80 |
81 | cost = normlized_theta**2 + 0.1 * theta_dt**2 + 0.1 * x**2
82 |
83 | return cost
84 |
85 | @torch.jit.script
86 | def terminal_cost(state: torch.Tensor) -> torch.Tensor:
87 | x = state[:, 0]
88 | x_dt = state[:, 1]
89 | theta = state[:, 2]
90 | theta_dt = state[:, 3]
91 |
92 | normlized_theta = angle_normalize(theta)
93 |
94 | cost = normlized_theta**2 + 0.1 * theta_dt**2 + 0.1 * x**2
95 |
96 | return cost
97 |
98 | # simulator
99 | if save_mode:
100 | env = gymnasium.make("CartPole-v1", render_mode="rgb_array")
101 | env = gymnasium.wrappers.RecordVideo(env=env, video_folder="video")
102 | else:
103 | env = gymnasium.make("CartPole-v1", render_mode="human")
104 | observation, _ = env.reset(seed=42)
105 |
106 | # start from the inverted position
107 | env.unwrapped.state = np.array([0.0, 0.0, np.pi, 0.0])
108 | observation, _, _, _, _ = env.step(0)
109 |
110 | # solver
111 | solver = MPPI(
112 | horizon=1000,
113 | num_samples=5000,
114 | dim_state=4,
115 | dim_control=1,
116 | dynamics=dynamics,
117 | stage_cost=stage_cost,
118 | terminal_cost=terminal_cost,
119 | u_min=torch.tensor([-1.0]),
120 | u_max=torch.tensor([1.0]),
121 | sigmas=torch.tensor([1.0]),
122 | lambda_=0.001,
123 | )
124 |
125 | average_time = 0
126 | for i in range(500):
127 | # solve
128 | start = time.time()
129 | action_seq, state_seq = solver.forward(state=observation)
130 | elipsed_time = time.time() - start
131 | average_time = i / (i + 1) * average_time + elipsed_time / (i + 1)
132 |
133 | action_seq_np = action_seq.cpu().numpy()
134 | state_seq_np = state_seq.cpu().numpy()
135 |
136 | # convert continuous action to discrete action
137 | discrete_action = 0 if action_seq_np[0, 0] < 0 else 1
138 |
139 | # update simulator
140 | observation, reward, terminated, truncated, info = env.step(discrete_action)
141 | env.render()
142 |
143 | print("average solve time: {}".format(average_time * 1000), " [ms]")
144 | env.close()
145 |
146 |
147 | if __name__ == "__main__":
148 | fire.Fire(main)
149 |
--------------------------------------------------------------------------------
/app/mountaincar.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | import time
4 | import gymnasium
5 | import fire
6 |
7 | from controller.mppi import MPPI
8 |
9 |
10 | @torch.jit.script
11 | def angle_normalize(x):
12 | return ((x + torch.pi) % (2 * torch.pi)) - torch.pi
13 |
14 |
15 | def main(save_mode: bool = False):
16 | # dynamics and cost
17 | @torch.jit.script
18 | def dynamics(state: torch.Tensor, action: torch.Tensor) -> torch.Tensor:
19 | # dynamics from gymnasium
20 | min_action = -1.0
21 | max_action = 1.0
22 | min_position = -1.2
23 | max_position = 0.6
24 | max_speed = 0.07
25 | goal_position = 0.45
26 | goal_velocity = 0.0
27 | power = 0.0015
28 |
29 | position = state[:, 0].view(-1, 1)
30 | velocity = state[:, 1].view(-1, 1)
31 |
32 | force = torch.clamp(action[:, 0].view(-1, 1), min_action, max_action)
33 |
34 | velocity += force * power - 0.0025 * torch.cos(3 * position)
35 | velocity = torch.clamp(velocity, -max_speed, max_speed)
36 | position += velocity
37 | position = torch.clamp(position, min_position, max_position)
38 | # if (position == min_position and velocity < 0):
39 | # velocity = torch.zeros_like(velocity)
40 |
41 | new_state = torch.cat((position, velocity), dim=1)
42 |
43 | return new_state
44 |
45 | @torch.jit.script
46 | def stage_cost(state: torch.Tensor, action: torch.Tensor) -> torch.Tensor:
47 | goal_position = 0.45
48 | goal_velocity = 0.0
49 |
50 | position = state[:, 0]
51 | velocity = state[:, 1]
52 |
53 | cost = (goal_position - position) ** 2
54 | # + 0.01 * (velocity-goal_velocity)**2
55 |
56 | return cost
57 |
58 | @torch.jit.script
59 | def terminal_cost(state: torch.Tensor) -> torch.Tensor:
60 | goal_position = 0.45
61 | goal_velocity = 0.0
62 |
63 | position = state[:, 0]
64 | velocity = state[:, 1]
65 |
66 | cost = (goal_position - position) ** 2
67 | # + (velocity-goal_velocity)**2
68 | return cost
69 |
70 | # simulator
71 | if save_mode:
72 | env = gymnasium.make("MountainCarContinuous-v0", render_mode="rgb_array")
73 | env = gymnasium.wrappers.RecordVideo(env=env, video_folder="video")
74 | else:
75 | env = gymnasium.make("MountainCarContinuous-v0", render_mode="human")
76 | observation, _ = env.reset(seed=42)
77 |
78 | # solver
79 | solver = MPPI(
80 | horizon=1000,
81 | num_samples=1000,
82 | dim_state=2,
83 | dim_control=1,
84 | dynamics=dynamics,
85 | stage_cost=stage_cost,
86 | terminal_cost=terminal_cost,
87 | u_min=torch.tensor([-1.0]),
88 | u_max=torch.tensor([1.0]),
89 | sigmas=torch.tensor([1.0]),
90 | lambda_=0.1,
91 | )
92 |
93 | average_time = 0
94 | for i in range(300):
95 | state = env.unwrapped.state.copy()
96 |
97 | # solve
98 | start = time.time()
99 | action_seq, state_seq = solver.forward(state=state)
100 | elipsed_time = time.time() - start
101 | average_time = i / (i + 1) * average_time + elipsed_time / (i + 1)
102 |
103 | action_seq_np = action_seq.cpu().numpy()
104 | state_seq_np = state_seq.cpu().numpy()
105 |
106 | # update simulator
107 | observation, reward, terminated, truncated, info = env.step(action_seq_np[0, :])
108 | env.render()
109 |
110 | print("average solve time: {}".format(average_time * 1000), " [ms]")
111 | env.close()
112 |
113 |
114 | if __name__ == "__main__":
115 | fire.Fire(main)
116 |
--------------------------------------------------------------------------------
/app/mujoco_cartpole.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | import time
4 | import gymnasium as gym
5 | import fire
6 | import numpy as np
7 |
8 | from controller.mppi import MPPI
9 |
10 |
11 | @torch.jit.script
12 | def angle_normalize(x):
13 | return ((x + torch.pi) % (2 * torch.pi)) - torch.pi
14 |
15 |
16 | # Not work well because of the difference of dynamics
17 | # I should use the true dynamics from mujoco like:
18 | # https://github.com/mohakbhardwaj/mjmpc/blob/master/examples/example_mpc.py#L112
19 | def main(save_mode: bool = False):
20 | # dynamics and cost
21 | @torch.jit.script
22 | def dynamics(state: torch.Tensor, action: torch.Tensor) -> torch.Tensor:
23 | """
24 | Args:
25 | state (torch.Tensor): [x, x_dt, theta, theta_dt]
26 | action (torch.Tensor): [-1, 1]
27 | """
28 | # dynamics not from mujoco
29 | # https://github.com/openai/gym/blob/master/gym/envs/mujoco/assets/inverted_pendulum.xml
30 | x = state[:, 0].view(-1, 1)
31 | x_dt = state[:, 1].view(-1, 1)
32 | theta = state[:, 2].view(-1, 1)
33 | theta_dt = state[:, 3].view(-1, 1)
34 |
35 | force = action[:, 0].view(-1, 1)
36 |
37 | gravity = 9.8
38 | masscart = 1.0
39 | masspole = 1.0
40 | total_mass = masspole + masscart
41 | length = 0.5 # actually half the pole's length
42 | polemass_length = masspole * length
43 | tau = 0.02 # seconds between state updates
44 |
45 | costheta = torch.cos(theta)
46 | sintheta = torch.sin(theta)
47 |
48 | temp = (force + polemass_length * theta_dt**2 * sintheta) / total_mass
49 | thetaacc = (gravity * sintheta - costheta * temp) / (
50 | length * (4.0 / 3.0 - masspole * costheta**2 / total_mass)
51 | )
52 | xacc = temp - polemass_length * thetaacc * costheta / total_mass
53 |
54 | newx = x + tau * x_dt
55 | newx_dt = x_dt + tau * xacc
56 | newtheta = theta + tau * theta_dt
57 | newtheta_dt = theta_dt + tau * thetaacc
58 |
59 | x_threshold = 1.0
60 | theta_threshold_radians = 12 * 2 * torch.pi / 360
61 | newx = torch.clamp(newx, -x_threshold, x_threshold)
62 | newtheta = torch.clamp(
63 | newtheta, -theta_threshold_radians, theta_threshold_radians
64 | )
65 |
66 | new_state = torch.cat((newx, newx_dt, newtheta, newtheta_dt), dim=1)
67 |
68 | return new_state
69 |
70 | @torch.jit.script
71 | def stage_cost(state: torch.Tensor, action: torch.Tensor) -> torch.Tensor:
72 | x = state[:, 0]
73 | x_dt = state[:, 1]
74 | theta = state[:, 2]
75 | theta_dt = state[:, 3]
76 |
77 | normlized_theta = angle_normalize(theta)
78 |
79 | cost = normlized_theta**2 + 0.1 * theta_dt**2 + 0.1 * x**2
80 |
81 | return cost
82 |
83 | @torch.jit.script
84 | def terminal_cost(state: torch.Tensor) -> torch.Tensor:
85 | x = state[:, 0]
86 | x_dt = state[:, 1]
87 | theta = state[:, 2]
88 | theta_dt = state[:, 3]
89 |
90 | normlized_theta = angle_normalize(theta)
91 |
92 | cost = normlized_theta**2 + 0.1 * theta_dt**2 + 0.1 * x**2
93 |
94 | return cost
95 |
96 | # simulator
97 | if save_mode:
98 | env = gym.make("InvertedPendulum-v4", render_mode="rgb_array")
99 | env = gym.wrappers.RecordVideo(env=env, video_folder="video")
100 | else:
101 | env = gym.make("InvertedPendulum-v4", render_mode="human")
102 |
103 | observation, _ = env.reset(seed=42)
104 |
105 | # start from the inverted position
106 | # env.unwrapped.state = np.array([0.0, 0.0, np.pi / 8, 0.0])
107 | # observation, _, _, _, _ = env.step(0)
108 |
109 | # solver
110 | solver = MPPI(
111 | horizon=50,
112 | num_samples=1000,
113 | dim_state=4,
114 | dim_control=1,
115 | dynamics=dynamics,
116 | stage_cost=stage_cost,
117 | terminal_cost=terminal_cost,
118 | u_min=torch.tensor([-3.0]),
119 | u_max=torch.tensor([3.0]),
120 | sigmas=torch.tensor([1.0]),
121 | lambda_=1.0,
122 | )
123 |
124 | average_time = 0
125 | for i in range(500):
126 | # solve
127 | start = time.time()
128 | action_seq, state_seq = solver.forward(state=observation)
129 |
130 | elipsed_time = time.time() - start
131 | average_time = i / (i + 1) * average_time + elipsed_time / (i + 1)
132 |
133 | action_seq_np = action_seq.cpu().numpy()
134 | state_seq_np = state_seq.cpu().numpy()
135 |
136 | # update simulator
137 | observation, reward, terminated, truncated, info = env.step(action_seq_np[0, :])
138 | env.render()
139 |
140 | print("average solve time: {}".format(average_time * 1000), " [ms]")
141 | env.close()
142 |
143 |
144 | if __name__ == "__main__":
145 | fire.Fire(main)
146 |
--------------------------------------------------------------------------------
/app/navigation2d.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | import time
4 |
5 | # import gymnasium
6 | import fire
7 | import tqdm
8 |
9 | from controller.mppi import MPPI
10 | from envs.navigation_2d import Navigation2DEnv
11 |
12 |
13 | def main(save_mode: bool = False):
14 | env = Navigation2DEnv()
15 |
16 | # solver
17 | solver = MPPI(
18 | horizon=50,
19 | num_samples=10000,
20 | dim_state=3,
21 | dim_control=2,
22 | dynamics=env.dynamics,
23 | stage_cost=env.stage_cost,
24 | terminal_cost=env.terminal_cost,
25 | u_min=env.u_min,
26 | u_max=env.u_max,
27 | sigmas=torch.tensor([0.5, 0.5]),
28 | lambda_=1.0,
29 | )
30 |
31 | state = env.reset()
32 | max_steps = 500
33 | average_time = 0
34 | for i in range(max_steps):
35 | start = time.time()
36 | with torch.no_grad():
37 | action_seq, state_seq = solver.forward(state=state)
38 | end = time.time()
39 | average_time += (end - start) / max_steps
40 |
41 | state, is_goal_reached = env.step(action_seq[0, :])
42 |
43 | is_collisions = env.collision_check(state=state_seq)
44 |
45 | top_samples, top_weights = solver.get_top_samples(num_samples=300)
46 |
47 | if save_mode:
48 | env.render(
49 | predicted_trajectory=state_seq,
50 | is_collisions=is_collisions,
51 | top_samples=(top_samples, top_weights),
52 | mode="rgb_array",
53 | )
54 | # progress bar
55 | if i == 0:
56 | pbar = tqdm.tqdm(total=max_steps, desc="recording video")
57 | pbar.update(1)
58 |
59 | else:
60 | env.render(
61 | predicted_trajectory=state_seq,
62 | is_collisions=is_collisions,
63 | top_samples=(top_samples, top_weights),
64 | mode="human",
65 | )
66 | if is_goal_reached:
67 | print("Goal Reached!")
68 | break
69 |
70 | print("average solve time: {}".format(average_time * 1000), " [ms]")
71 | env.close() # close window and save video if save_mode is True
72 |
73 |
74 | if __name__ == "__main__":
75 | fire.Fire(main)
76 |
--------------------------------------------------------------------------------
/app/pendulum.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | import time
4 | import gymnasium
5 | import fire
6 |
7 | from controller.mppi import MPPI
8 |
9 |
10 | @torch.jit.script
11 | def angle_normalize(x):
12 | return ((x + torch.pi) % (2 * torch.pi)) - torch.pi
13 |
14 |
15 | def main(save_mode: bool = False):
16 | # dynamics and cost
17 | @torch.jit.script
18 | def dynamics(state: torch.Tensor, action: torch.Tensor) -> torch.Tensor:
19 | # dynamics from gymnasium
20 | th = state[:, 0].view(-1, 1)
21 | thdot = state[:, 1].view(-1, 1)
22 | g = 10
23 | m = 1
24 | l = 1
25 | dt = 0.05
26 | u = action[:, 0].view(-1, 1)
27 | u = torch.clamp(u, -2, 2)
28 | newthdot = (
29 | thdot
30 | + (-3 * g / (2 * l) * torch.sin(th + torch.pi) + 3.0 / (m * l**2) * u)
31 | * dt
32 | )
33 | newth = th + newthdot * dt
34 | newthdot = torch.clamp(newthdot, -8, 8)
35 |
36 | state = torch.cat((newth, newthdot), dim=1)
37 | return state
38 |
39 | @torch.jit.script
40 | def stage_cost(state: torch.Tensor, action: torch.Tensor) -> torch.Tensor:
41 | theta = state[:, 0]
42 | theta_dt = state[:, 1]
43 | # u = action[:, 0]
44 | cost = angle_normalize(theta) ** 2 + 0.1 * theta_dt**2
45 | return cost
46 |
47 | @torch.jit.script
48 | def terminal_cost(state: torch.Tensor) -> torch.Tensor:
49 | theta = state[:, 0]
50 | theta_dt = state[:, 1]
51 | cost = angle_normalize(theta) ** 2 + 0.1 * theta_dt**2
52 | return cost
53 |
54 | # simulator
55 | if save_mode:
56 | env = gymnasium.make("Pendulum-v1", render_mode="rgb_array")
57 | env = gymnasium.wrappers.RecordVideo(env=env, video_folder="video")
58 | else:
59 | env = gymnasium.make("Pendulum-v1", render_mode="human")
60 | observation, _ = env.reset(seed=42)
61 |
62 | # solver
63 | solver = MPPI(
64 | horizon=15,
65 | num_samples=1000,
66 | dim_state=2,
67 | dim_control=1,
68 | dynamics=dynamics,
69 | stage_cost=stage_cost,
70 | terminal_cost=terminal_cost,
71 | u_min=torch.tensor([-2.0]),
72 | u_max=torch.tensor([2.0]),
73 | sigmas=torch.tensor([1.0]),
74 | lambda_=1.0,
75 | )
76 |
77 | average_time = 0
78 | for i in range(200):
79 | state = env.unwrapped.state.copy()
80 |
81 | # solve
82 | start = time.time()
83 | action_seq, state_seq = solver.forward(state=state)
84 | elipsed_time = time.time() - start
85 | average_time = i / (i + 1) * average_time + elipsed_time / (i + 1)
86 |
87 | action_seq_np = action_seq.cpu().numpy()
88 | state_seq_np = state_seq.cpu().numpy()
89 |
90 | # update simulator
91 | observation, reward, terminated, truncated, info = env.step(action_seq_np[0, :])
92 | env.render()
93 |
94 | print("average solve time: {}".format(average_time * 1000), " [ms]")
95 | env.close()
96 |
97 |
98 | if __name__ == "__main__":
99 | fire.Fire(main)
100 |
--------------------------------------------------------------------------------
/media/cartpole.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/proxima-technology/mppi_playground/12cb79d02c703f818cc9c49d2436f79fd4db7e9f/media/cartpole.gif
--------------------------------------------------------------------------------
/media/mountaincar.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/proxima-technology/mppi_playground/12cb79d02c703f818cc9c49d2436f79fd4db7e9f/media/mountaincar.gif
--------------------------------------------------------------------------------
/media/navigation_2d.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/proxima-technology/mppi_playground/12cb79d02c703f818cc9c49d2436f79fd4db7e9f/media/navigation_2d.gif
--------------------------------------------------------------------------------
/media/pendulum.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/proxima-technology/mppi_playground/12cb79d02c703f818cc9c49d2436f79fd4db7e9f/media/pendulum.gif
--------------------------------------------------------------------------------
/poetry.lock:
--------------------------------------------------------------------------------
1 | # This file is automatically @generated by Poetry 1.4.2 and should not be changed by hand.
2 |
3 | [[package]]
4 | name = "black"
5 | version = "23.3.0"
6 | description = "The uncompromising code formatter."
7 | category = "main"
8 | optional = false
9 | python-versions = ">=3.7"
10 | files = [
11 | {file = "black-23.3.0-cp310-cp310-macosx_10_16_arm64.whl", hash = "sha256:0945e13506be58bf7db93ee5853243eb368ace1c08a24c65ce108986eac65915"},
12 | {file = "black-23.3.0-cp310-cp310-macosx_10_16_universal2.whl", hash = "sha256:67de8d0c209eb5b330cce2469503de11bca4085880d62f1628bd9972cc3366b9"},
13 | {file = "black-23.3.0-cp310-cp310-macosx_10_16_x86_64.whl", hash = "sha256:7c3eb7cea23904399866c55826b31c1f55bbcd3890ce22ff70466b907b6775c2"},
14 | {file = "black-23.3.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:32daa9783106c28815d05b724238e30718f34155653d4d6e125dc7daec8e260c"},
15 | {file = "black-23.3.0-cp310-cp310-win_amd64.whl", hash = "sha256:35d1381d7a22cc5b2be2f72c7dfdae4072a3336060635718cc7e1ede24221d6c"},
16 | {file = "black-23.3.0-cp311-cp311-macosx_10_16_arm64.whl", hash = "sha256:a8a968125d0a6a404842fa1bf0b349a568634f856aa08ffaff40ae0dfa52e7c6"},
17 | {file = "black-23.3.0-cp311-cp311-macosx_10_16_universal2.whl", hash = "sha256:c7ab5790333c448903c4b721b59c0d80b11fe5e9803d8703e84dcb8da56fec1b"},
18 | {file = "black-23.3.0-cp311-cp311-macosx_10_16_x86_64.whl", hash = "sha256:a6f6886c9869d4daae2d1715ce34a19bbc4b95006d20ed785ca00fa03cba312d"},
19 | {file = "black-23.3.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6f3c333ea1dd6771b2d3777482429864f8e258899f6ff05826c3a4fcc5ce3f70"},
20 | {file = "black-23.3.0-cp311-cp311-win_amd64.whl", hash = "sha256:11c410f71b876f961d1de77b9699ad19f939094c3a677323f43d7a29855fe326"},
21 | {file = "black-23.3.0-cp37-cp37m-macosx_10_16_x86_64.whl", hash = "sha256:1d06691f1eb8de91cd1b322f21e3bfc9efe0c7ca1f0e1eb1db44ea367dff656b"},
22 | {file = "black-23.3.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:50cb33cac881766a5cd9913e10ff75b1e8eb71babf4c7104f2e9c52da1fb7de2"},
23 | {file = "black-23.3.0-cp37-cp37m-win_amd64.whl", hash = "sha256:e114420bf26b90d4b9daa597351337762b63039752bdf72bf361364c1aa05925"},
24 | {file = "black-23.3.0-cp38-cp38-macosx_10_16_arm64.whl", hash = "sha256:48f9d345675bb7fbc3dd85821b12487e1b9a75242028adad0333ce36ed2a6d27"},
25 | {file = "black-23.3.0-cp38-cp38-macosx_10_16_universal2.whl", hash = "sha256:714290490c18fb0126baa0fca0a54ee795f7502b44177e1ce7624ba1c00f2331"},
26 | {file = "black-23.3.0-cp38-cp38-macosx_10_16_x86_64.whl", hash = "sha256:064101748afa12ad2291c2b91c960be28b817c0c7eaa35bec09cc63aa56493c5"},
27 | {file = "black-23.3.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:562bd3a70495facf56814293149e51aa1be9931567474993c7942ff7d3533961"},
28 | {file = "black-23.3.0-cp38-cp38-win_amd64.whl", hash = "sha256:e198cf27888ad6f4ff331ca1c48ffc038848ea9f031a3b40ba36aced7e22f2c8"},
29 | {file = "black-23.3.0-cp39-cp39-macosx_10_16_arm64.whl", hash = "sha256:3238f2aacf827d18d26db07524e44741233ae09a584273aa059066d644ca7b30"},
30 | {file = "black-23.3.0-cp39-cp39-macosx_10_16_universal2.whl", hash = "sha256:f0bd2f4a58d6666500542b26354978218a9babcdc972722f4bf90779524515f3"},
31 | {file = "black-23.3.0-cp39-cp39-macosx_10_16_x86_64.whl", hash = "sha256:92c543f6854c28a3c7f39f4d9b7694f9a6eb9d3c5e2ece488c327b6e7ea9b266"},
32 | {file = "black-23.3.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3a150542a204124ed00683f0db1f5cf1c2aaaa9cc3495b7a3b5976fb136090ab"},
33 | {file = "black-23.3.0-cp39-cp39-win_amd64.whl", hash = "sha256:6b39abdfb402002b8a7d030ccc85cf5afff64ee90fa4c5aebc531e3ad0175ddb"},
34 | {file = "black-23.3.0-py3-none-any.whl", hash = "sha256:ec751418022185b0c1bb7d7736e6933d40bbb14c14a0abcf9123d1b159f98dd4"},
35 | {file = "black-23.3.0.tar.gz", hash = "sha256:1c7b8d606e728a41ea1ccbd7264677e494e87cf630e399262ced92d4a8dac940"},
36 | ]
37 |
38 | [package.dependencies]
39 | click = ">=8.0.0"
40 | mypy-extensions = ">=0.4.3"
41 | packaging = ">=22.0"
42 | pathspec = ">=0.9.0"
43 | platformdirs = ">=2"
44 | tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""}
45 |
46 | [package.extras]
47 | colorama = ["colorama (>=0.4.3)"]
48 | d = ["aiohttp (>=3.7.4)"]
49 | jupyter = ["ipython (>=7.8.0)", "tokenize-rt (>=3.2.0)"]
50 | uvloop = ["uvloop (>=0.15.2)"]
51 |
52 | [[package]]
53 | name = "click"
54 | version = "8.1.3"
55 | description = "Composable command line interface toolkit"
56 | category = "main"
57 | optional = false
58 | python-versions = ">=3.7"
59 | files = [
60 | {file = "click-8.1.3-py3-none-any.whl", hash = "sha256:bb4d8133cb15a609f44e8213d9b391b0809795062913b383c62be0ee95b1db48"},
61 | {file = "click-8.1.3.tar.gz", hash = "sha256:7682dc8afb30297001674575ea00d1814d808d6a36af415a82bd481d37ba7b8e"},
62 | ]
63 |
64 | [package.dependencies]
65 | colorama = {version = "*", markers = "platform_system == \"Windows\""}
66 |
67 | [[package]]
68 | name = "colorama"
69 | version = "0.4.6"
70 | description = "Cross-platform colored terminal text."
71 | category = "main"
72 | optional = false
73 | python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7"
74 | files = [
75 | {file = "colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6"},
76 | {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"},
77 | ]
78 |
79 | [[package]]
80 | name = "contourpy"
81 | version = "1.0.7"
82 | description = "Python library for calculating contours of 2D quadrilateral grids"
83 | category = "main"
84 | optional = false
85 | python-versions = ">=3.8"
86 | files = [
87 | {file = "contourpy-1.0.7-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:95c3acddf921944f241b6773b767f1cbce71d03307270e2d769fd584d5d1092d"},
88 | {file = "contourpy-1.0.7-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:fc1464c97579da9f3ab16763c32e5c5d5bb5fa1ec7ce509a4ca6108b61b84fab"},
89 | {file = "contourpy-1.0.7-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:8acf74b5d383414401926c1598ed77825cd530ac7b463ebc2e4f46638f56cce6"},
90 | {file = "contourpy-1.0.7-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1c71fdd8f1c0f84ffd58fca37d00ca4ebaa9e502fb49825484da075ac0b0b803"},
91 | {file = "contourpy-1.0.7-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f99e9486bf1bb979d95d5cffed40689cb595abb2b841f2991fc894b3452290e8"},
92 | {file = "contourpy-1.0.7-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:87f4d8941a9564cda3f7fa6a6cd9b32ec575830780677932abdec7bcb61717b0"},
93 | {file = "contourpy-1.0.7-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:9e20e5a1908e18aaa60d9077a6d8753090e3f85ca25da6e25d30dc0a9e84c2c6"},
94 | {file = "contourpy-1.0.7-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:a877ada905f7d69b2a31796c4b66e31a8068b37aa9b78832d41c82fc3e056ddd"},
95 | {file = "contourpy-1.0.7-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:6381fa66866b0ea35e15d197fc06ac3840a9b2643a6475c8fff267db8b9f1e69"},
96 | {file = "contourpy-1.0.7-cp310-cp310-win32.whl", hash = "sha256:3c184ad2433635f216645fdf0493011a4667e8d46b34082f5a3de702b6ec42e3"},
97 | {file = "contourpy-1.0.7-cp310-cp310-win_amd64.whl", hash = "sha256:3caea6365b13119626ee996711ab63e0c9d7496f65641f4459c60a009a1f3e80"},
98 | {file = "contourpy-1.0.7-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:ed33433fc3820263a6368e532f19ddb4c5990855e4886088ad84fd7c4e561c71"},
99 | {file = "contourpy-1.0.7-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:38e2e577f0f092b8e6774459317c05a69935a1755ecfb621c0a98f0e3c09c9a5"},
100 | {file = "contourpy-1.0.7-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:ae90d5a8590e5310c32a7630b4b8618cef7563cebf649011da80874d0aa8f414"},
101 | {file = "contourpy-1.0.7-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:130230b7e49825c98edf0b428b7aa1125503d91732735ef897786fe5452b1ec2"},
102 | {file = "contourpy-1.0.7-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:58569c491e7f7e874f11519ef46737cea1d6eda1b514e4eb5ac7dab6aa864d02"},
103 | {file = "contourpy-1.0.7-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:54d43960d809c4c12508a60b66cb936e7ed57d51fb5e30b513934a4a23874fae"},
104 | {file = "contourpy-1.0.7-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:152fd8f730c31fd67fe0ffebe1df38ab6a669403da93df218801a893645c6ccc"},
105 | {file = "contourpy-1.0.7-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:9056c5310eb1daa33fc234ef39ebfb8c8e2533f088bbf0bc7350f70a29bde1ac"},
106 | {file = "contourpy-1.0.7-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:a9d7587d2fdc820cc9177139b56795c39fb8560f540bba9ceea215f1f66e1566"},
107 | {file = "contourpy-1.0.7-cp311-cp311-win32.whl", hash = "sha256:4ee3ee247f795a69e53cd91d927146fb16c4e803c7ac86c84104940c7d2cabf0"},
108 | {file = "contourpy-1.0.7-cp311-cp311-win_amd64.whl", hash = "sha256:5caeacc68642e5f19d707471890f037a13007feba8427eb7f2a60811a1fc1350"},
109 | {file = "contourpy-1.0.7-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:fd7dc0e6812b799a34f6d12fcb1000539098c249c8da54f3566c6a6461d0dbad"},
110 | {file = "contourpy-1.0.7-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:0f9d350b639db6c2c233d92c7f213d94d2e444d8e8fc5ca44c9706cf72193772"},
111 | {file = "contourpy-1.0.7-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:e96a08b62bb8de960d3a6afbc5ed8421bf1a2d9c85cc4ea73f4bc81b4910500f"},
112 | {file = "contourpy-1.0.7-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:031154ed61f7328ad7f97662e48660a150ef84ee1bc8876b6472af88bf5a9b98"},
113 | {file = "contourpy-1.0.7-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2e9ebb4425fc1b658e13bace354c48a933b842d53c458f02c86f371cecbedecc"},
114 | {file = "contourpy-1.0.7-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:efb8f6d08ca7998cf59eaf50c9d60717f29a1a0a09caa46460d33b2924839dbd"},
115 | {file = "contourpy-1.0.7-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:6c180d89a28787e4b73b07e9b0e2dac7741261dbdca95f2b489c4f8f887dd810"},
116 | {file = "contourpy-1.0.7-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:b8d587cc39057d0afd4166083d289bdeff221ac6d3ee5046aef2d480dc4b503c"},
117 | {file = "contourpy-1.0.7-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:769eef00437edf115e24d87f8926955f00f7704bede656ce605097584f9966dc"},
118 | {file = "contourpy-1.0.7-cp38-cp38-win32.whl", hash = "sha256:62398c80ef57589bdbe1eb8537127321c1abcfdf8c5f14f479dbbe27d0322e66"},
119 | {file = "contourpy-1.0.7-cp38-cp38-win_amd64.whl", hash = "sha256:57119b0116e3f408acbdccf9eb6ef19d7fe7baf0d1e9aaa5381489bc1aa56556"},
120 | {file = "contourpy-1.0.7-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:30676ca45084ee61e9c3da589042c24a57592e375d4b138bd84d8709893a1ba4"},
121 | {file = "contourpy-1.0.7-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:3e927b3868bd1e12acee7cc8f3747d815b4ab3e445a28d2e5373a7f4a6e76ba1"},
122 | {file = "contourpy-1.0.7-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:366a0cf0fc079af5204801786ad7a1c007714ee3909e364dbac1729f5b0849e5"},
123 | {file = "contourpy-1.0.7-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:89ba9bb365446a22411f0673abf6ee1fea3b2cf47b37533b970904880ceb72f3"},
124 | {file = "contourpy-1.0.7-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:71b0bf0c30d432278793d2141362ac853859e87de0a7dee24a1cea35231f0d50"},
125 | {file = "contourpy-1.0.7-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e7281244c99fd7c6f27c1c6bfafba878517b0b62925a09b586d88ce750a016d2"},
126 | {file = "contourpy-1.0.7-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:b6d0f9e1d39dbfb3977f9dd79f156c86eb03e57a7face96f199e02b18e58d32a"},
127 | {file = "contourpy-1.0.7-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:7f6979d20ee5693a1057ab53e043adffa1e7418d734c1532e2d9e915b08d8ec2"},
128 | {file = "contourpy-1.0.7-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:5dd34c1ae752515318224cba7fc62b53130c45ac6a1040c8b7c1a223c46e8967"},
129 | {file = "contourpy-1.0.7-cp39-cp39-win32.whl", hash = "sha256:c5210e5d5117e9aec8c47d9156d1d3835570dd909a899171b9535cb4a3f32693"},
130 | {file = "contourpy-1.0.7-cp39-cp39-win_amd64.whl", hash = "sha256:60835badb5ed5f4e194a6f21c09283dd6e007664a86101431bf870d9e86266c4"},
131 | {file = "contourpy-1.0.7-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:ce41676b3d0dd16dbcfabcc1dc46090aaf4688fd6e819ef343dbda5a57ef0161"},
132 | {file = "contourpy-1.0.7-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5a011cf354107b47c58ea932d13b04d93c6d1d69b8b6dce885e642531f847566"},
133 | {file = "contourpy-1.0.7-pp38-pypy38_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:31a55dccc8426e71817e3fe09b37d6d48ae40aae4ecbc8c7ad59d6893569c436"},
134 | {file = "contourpy-1.0.7-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:69f8ff4db108815addd900a74df665e135dbbd6547a8a69333a68e1f6e368ac2"},
135 | {file = "contourpy-1.0.7-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:efe99298ba37e37787f6a2ea868265465410822f7bea163edcc1bd3903354ea9"},
136 | {file = "contourpy-1.0.7-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:a1e97b86f73715e8670ef45292d7cc033548266f07d54e2183ecb3c87598888f"},
137 | {file = "contourpy-1.0.7-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cc331c13902d0f50845099434cd936d49d7a2ca76cb654b39691974cb1e4812d"},
138 | {file = "contourpy-1.0.7-pp39-pypy39_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:24847601071f740837aefb730e01bd169fbcaa610209779a78db7ebb6e6a7051"},
139 | {file = "contourpy-1.0.7-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:abf298af1e7ad44eeb93501e40eb5a67abbf93b5d90e468d01fc0c4451971afa"},
140 | {file = "contourpy-1.0.7-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:64757f6460fc55d7e16ed4f1de193f362104285c667c112b50a804d482777edd"},
141 | {file = "contourpy-1.0.7.tar.gz", hash = "sha256:d8165a088d31798b59e91117d1f5fc3df8168d8b48c4acc10fc0df0d0bdbcc5e"},
142 | ]
143 |
144 | [package.dependencies]
145 | numpy = ">=1.16"
146 |
147 | [package.extras]
148 | bokeh = ["bokeh", "chromedriver", "selenium"]
149 | docs = ["furo", "sphinx-copybutton"]
150 | mypy = ["contourpy[bokeh]", "docutils-stubs", "mypy (==0.991)", "types-Pillow"]
151 | test = ["Pillow", "matplotlib", "pytest"]
152 | test-no-images = ["pytest"]
153 |
154 | [[package]]
155 | name = "cycler"
156 | version = "0.11.0"
157 | description = "Composable style cycles"
158 | category = "main"
159 | optional = false
160 | python-versions = ">=3.6"
161 | files = [
162 | {file = "cycler-0.11.0-py3-none-any.whl", hash = "sha256:3a27e95f763a428a739d2add979fa7494c912a32c17c4c38c4d5f082cad165a3"},
163 | {file = "cycler-0.11.0.tar.gz", hash = "sha256:9c87405839a19696e837b3b818fed3f5f69f16f1eec1a1ad77e043dcea9c772f"},
164 | ]
165 |
166 | [[package]]
167 | name = "fonttools"
168 | version = "4.39.3"
169 | description = "Tools to manipulate font files"
170 | category = "main"
171 | optional = false
172 | python-versions = ">=3.8"
173 | files = [
174 | {file = "fonttools-4.39.3-py3-none-any.whl", hash = "sha256:64c0c05c337f826183637570ac5ab49ee220eec66cf50248e8df527edfa95aeb"},
175 | {file = "fonttools-4.39.3.zip", hash = "sha256:9234b9f57b74e31b192c3fc32ef1a40750a8fbc1cd9837a7b7bfc4ca4a5c51d7"},
176 | ]
177 |
178 | [package.extras]
179 | all = ["brotli (>=1.0.1)", "brotlicffi (>=0.8.0)", "fs (>=2.2.0,<3)", "lxml (>=4.0,<5)", "lz4 (>=1.7.4.2)", "matplotlib", "munkres", "scipy", "skia-pathops (>=0.5.0)", "sympy", "uharfbuzz (>=0.23.0)", "unicodedata2 (>=15.0.0)", "xattr", "zopfli (>=0.1.4)"]
180 | graphite = ["lz4 (>=1.7.4.2)"]
181 | interpolatable = ["munkres", "scipy"]
182 | lxml = ["lxml (>=4.0,<5)"]
183 | pathops = ["skia-pathops (>=0.5.0)"]
184 | plot = ["matplotlib"]
185 | repacker = ["uharfbuzz (>=0.23.0)"]
186 | symfont = ["sympy"]
187 | type1 = ["xattr"]
188 | ufo = ["fs (>=2.2.0,<3)"]
189 | unicode = ["unicodedata2 (>=15.0.0)"]
190 | woff = ["brotli (>=1.0.1)", "brotlicffi (>=0.8.0)", "zopfli (>=0.1.4)"]
191 |
192 | [[package]]
193 | name = "kiwisolver"
194 | version = "1.4.4"
195 | description = "A fast implementation of the Cassowary constraint solver"
196 | category = "main"
197 | optional = false
198 | python-versions = ">=3.7"
199 | files = [
200 | {file = "kiwisolver-1.4.4-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:2f5e60fabb7343a836360c4f0919b8cd0d6dbf08ad2ca6b9cf90bf0c76a3c4f6"},
201 | {file = "kiwisolver-1.4.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:10ee06759482c78bdb864f4109886dff7b8a56529bc1609d4f1112b93fe6423c"},
202 | {file = "kiwisolver-1.4.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c79ebe8f3676a4c6630fd3f777f3cfecf9289666c84e775a67d1d358578dc2e3"},
203 | {file = "kiwisolver-1.4.4-cp310-cp310-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:abbe9fa13da955feb8202e215c4018f4bb57469b1b78c7a4c5c7b93001699938"},
204 | {file = "kiwisolver-1.4.4-cp310-cp310-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:7577c1987baa3adc4b3c62c33bd1118c3ef5c8ddef36f0f2c950ae0b199e100d"},
205 | {file = "kiwisolver-1.4.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f8ad8285b01b0d4695102546b342b493b3ccc6781fc28c8c6a1bb63e95d22f09"},
206 | {file = "kiwisolver-1.4.4-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8ed58b8acf29798b036d347791141767ccf65eee7f26bde03a71c944449e53de"},
207 | {file = "kiwisolver-1.4.4-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a68b62a02953b9841730db7797422f983935aeefceb1679f0fc85cbfbd311c32"},
208 | {file = "kiwisolver-1.4.4-cp310-cp310-win32.whl", hash = "sha256:e92a513161077b53447160b9bd8f522edfbed4bd9759e4c18ab05d7ef7e49408"},
209 | {file = "kiwisolver-1.4.4-cp310-cp310-win_amd64.whl", hash = "sha256:3fe20f63c9ecee44560d0e7f116b3a747a5d7203376abeea292ab3152334d004"},
210 | {file = "kiwisolver-1.4.4-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:e0ea21f66820452a3f5d1655f8704a60d66ba1191359b96541eaf457710a5fc6"},
211 | {file = "kiwisolver-1.4.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:bc9db8a3efb3e403e4ecc6cd9489ea2bac94244f80c78e27c31dcc00d2790ac2"},
212 | {file = "kiwisolver-1.4.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:d5b61785a9ce44e5a4b880272baa7cf6c8f48a5180c3e81c59553ba0cb0821ca"},
213 | {file = "kiwisolver-1.4.4-cp311-cp311-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c2dbb44c3f7e6c4d3487b31037b1bdbf424d97687c1747ce4ff2895795c9bf69"},
214 | {file = "kiwisolver-1.4.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6295ecd49304dcf3bfbfa45d9a081c96509e95f4b9d0eb7ee4ec0530c4a96514"},
215 | {file = "kiwisolver-1.4.4-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4bd472dbe5e136f96a4b18f295d159d7f26fd399136f5b17b08c4e5f498cd494"},
216 | {file = "kiwisolver-1.4.4-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:bf7d9fce9bcc4752ca4a1b80aabd38f6d19009ea5cbda0e0856983cf6d0023f5"},
217 | {file = "kiwisolver-1.4.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:78d6601aed50c74e0ef02f4204da1816147a6d3fbdc8b3872d263338a9052c51"},
218 | {file = "kiwisolver-1.4.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:877272cf6b4b7e94c9614f9b10140e198d2186363728ed0f701c6eee1baec1da"},
219 | {file = "kiwisolver-1.4.4-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:db608a6757adabb32f1cfe6066e39b3706d8c3aa69bbc353a5b61edad36a5cb4"},
220 | {file = "kiwisolver-1.4.4-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:5853eb494c71e267912275e5586fe281444eb5e722de4e131cddf9d442615626"},
221 | {file = "kiwisolver-1.4.4-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:f0a1dbdb5ecbef0d34eb77e56fcb3e95bbd7e50835d9782a45df81cc46949750"},
222 | {file = "kiwisolver-1.4.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:283dffbf061a4ec60391d51e6155e372a1f7a4f5b15d59c8505339454f8989e4"},
223 | {file = "kiwisolver-1.4.4-cp311-cp311-win32.whl", hash = "sha256:d06adcfa62a4431d404c31216f0f8ac97397d799cd53800e9d3efc2fbb3cf14e"},
224 | {file = "kiwisolver-1.4.4-cp311-cp311-win_amd64.whl", hash = "sha256:e7da3fec7408813a7cebc9e4ec55afed2d0fd65c4754bc376bf03498d4e92686"},
225 | {file = "kiwisolver-1.4.4-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:62ac9cc684da4cf1778d07a89bf5f81b35834cb96ca523d3a7fb32509380cbf6"},
226 | {file = "kiwisolver-1.4.4-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:41dae968a94b1ef1897cb322b39360a0812661dba7c682aa45098eb8e193dbdf"},
227 | {file = "kiwisolver-1.4.4-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:02f79693ec433cb4b5f51694e8477ae83b3205768a6fb48ffba60549080e295b"},
228 | {file = "kiwisolver-1.4.4-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d0611a0a2a518464c05ddd5a3a1a0e856ccc10e67079bb17f265ad19ab3c7597"},
229 | {file = "kiwisolver-1.4.4-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:db5283d90da4174865d520e7366801a93777201e91e79bacbac6e6927cbceede"},
230 | {file = "kiwisolver-1.4.4-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:1041feb4cda8708ce73bb4dcb9ce1ccf49d553bf87c3954bdfa46f0c3f77252c"},
231 | {file = "kiwisolver-1.4.4-cp37-cp37m-win32.whl", hash = "sha256:a553dadda40fef6bfa1456dc4be49b113aa92c2a9a9e8711e955618cd69622e3"},
232 | {file = "kiwisolver-1.4.4-cp37-cp37m-win_amd64.whl", hash = "sha256:03baab2d6b4a54ddbb43bba1a3a2d1627e82d205c5cf8f4c924dc49284b87166"},
233 | {file = "kiwisolver-1.4.4-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:841293b17ad704d70c578f1f0013c890e219952169ce8a24ebc063eecf775454"},
234 | {file = "kiwisolver-1.4.4-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:f4f270de01dd3e129a72efad823da90cc4d6aafb64c410c9033aba70db9f1ff0"},
235 | {file = "kiwisolver-1.4.4-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:f9f39e2f049db33a908319cf46624a569b36983c7c78318e9726a4cb8923b26c"},
236 | {file = "kiwisolver-1.4.4-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c97528e64cb9ebeff9701e7938653a9951922f2a38bd847787d4a8e498cc83ae"},
237 | {file = "kiwisolver-1.4.4-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1d1573129aa0fd901076e2bfb4275a35f5b7aa60fbfb984499d661ec950320b0"},
238 | {file = "kiwisolver-1.4.4-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ad881edc7ccb9d65b0224f4e4d05a1e85cf62d73aab798943df6d48ab0cd79a1"},
239 | {file = "kiwisolver-1.4.4-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:b428ef021242344340460fa4c9185d0b1f66fbdbfecc6c63eff4b7c29fad429d"},
240 | {file = "kiwisolver-1.4.4-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:2e407cb4bd5a13984a6c2c0fe1845e4e41e96f183e5e5cd4d77a857d9693494c"},
241 | {file = "kiwisolver-1.4.4-cp38-cp38-win32.whl", hash = "sha256:75facbe9606748f43428fc91a43edb46c7ff68889b91fa31f53b58894503a191"},
242 | {file = "kiwisolver-1.4.4-cp38-cp38-win_amd64.whl", hash = "sha256:5bce61af018b0cb2055e0e72e7d65290d822d3feee430b7b8203d8a855e78766"},
243 | {file = "kiwisolver-1.4.4-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:8c808594c88a025d4e322d5bb549282c93c8e1ba71b790f539567932722d7bd8"},
244 | {file = "kiwisolver-1.4.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:f0a71d85ecdd570ded8ac3d1c0f480842f49a40beb423bb8014539a9f32a5897"},
245 | {file = "kiwisolver-1.4.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:b533558eae785e33e8c148a8d9921692a9fe5aa516efbdff8606e7d87b9d5824"},
246 | {file = "kiwisolver-1.4.4-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:efda5fc8cc1c61e4f639b8067d118e742b812c930f708e6667a5ce0d13499e29"},
247 | {file = "kiwisolver-1.4.4-cp39-cp39-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:7c43e1e1206cd421cd92e6b3280d4385d41d7166b3ed577ac20444b6995a445f"},
248 | {file = "kiwisolver-1.4.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bc8d3bd6c72b2dd9decf16ce70e20abcb3274ba01b4e1c96031e0c4067d1e7cd"},
249 | {file = "kiwisolver-1.4.4-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4ea39b0ccc4f5d803e3337dd46bcce60b702be4d86fd0b3d7531ef10fd99a1ac"},
250 | {file = "kiwisolver-1.4.4-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:968f44fdbf6dd757d12920d63b566eeb4d5b395fd2d00d29d7ef00a00582aac9"},
251 | {file = "kiwisolver-1.4.4-cp39-cp39-win32.whl", hash = "sha256:da7e547706e69e45d95e116e6939488d62174e033b763ab1496b4c29b76fabea"},
252 | {file = "kiwisolver-1.4.4-cp39-cp39-win_amd64.whl", hash = "sha256:ba59c92039ec0a66103b1d5fe588fa546373587a7d68f5c96f743c3396afc04b"},
253 | {file = "kiwisolver-1.4.4-pp37-pypy37_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:91672bacaa030f92fc2f43b620d7b337fd9a5af28b0d6ed3f77afc43c4a64b5a"},
254 | {file = "kiwisolver-1.4.4-pp37-pypy37_pp73-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:787518a6789009c159453da4d6b683f468ef7a65bbde796bcea803ccf191058d"},
255 | {file = "kiwisolver-1.4.4-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:da152d8cdcab0e56e4f45eb08b9aea6455845ec83172092f09b0e077ece2cf7a"},
256 | {file = "kiwisolver-1.4.4-pp37-pypy37_pp73-win_amd64.whl", hash = "sha256:ecb1fa0db7bf4cff9dac752abb19505a233c7f16684c5826d1f11ebd9472b871"},
257 | {file = "kiwisolver-1.4.4-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:28bc5b299f48150b5f822ce68624e445040595a4ac3d59251703779836eceff9"},
258 | {file = "kiwisolver-1.4.4-pp38-pypy38_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:81e38381b782cc7e1e46c4e14cd997ee6040768101aefc8fa3c24a4cc58e98f8"},
259 | {file = "kiwisolver-1.4.4-pp38-pypy38_pp73-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:2a66fdfb34e05b705620dd567f5a03f239a088d5a3f321e7b6ac3239d22aa286"},
260 | {file = "kiwisolver-1.4.4-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:872b8ca05c40d309ed13eb2e582cab0c5a05e81e987ab9c521bf05ad1d5cf5cb"},
261 | {file = "kiwisolver-1.4.4-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:70e7c2e7b750585569564e2e5ca9845acfaa5da56ac46df68414f29fea97be9f"},
262 | {file = "kiwisolver-1.4.4-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:9f85003f5dfa867e86d53fac6f7e6f30c045673fa27b603c397753bebadc3008"},
263 | {file = "kiwisolver-1.4.4-pp39-pypy39_pp73-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2e307eb9bd99801f82789b44bb45e9f541961831c7311521b13a6c85afc09767"},
264 | {file = "kiwisolver-1.4.4-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b1792d939ec70abe76f5054d3f36ed5656021dcad1322d1cc996d4e54165cef9"},
265 | {file = "kiwisolver-1.4.4-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f6cb459eea32a4e2cf18ba5fcece2dbdf496384413bc1bae15583f19e567f3b2"},
266 | {file = "kiwisolver-1.4.4-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:36dafec3d6d6088d34e2de6b85f9d8e2324eb734162fba59d2ba9ed7a2043d5b"},
267 | {file = "kiwisolver-1.4.4.tar.gz", hash = "sha256:d41997519fcba4a1e46eb4a2fe31bc12f0ff957b2b81bac28db24744f333e955"},
268 | ]
269 |
270 | [[package]]
271 | name = "matplotlib"
272 | version = "3.7.1"
273 | description = "Python plotting package"
274 | category = "main"
275 | optional = false
276 | python-versions = ">=3.8"
277 | files = [
278 | {file = "matplotlib-3.7.1-cp310-cp310-macosx_10_12_universal2.whl", hash = "sha256:95cbc13c1fc6844ab8812a525bbc237fa1470863ff3dace7352e910519e194b1"},
279 | {file = "matplotlib-3.7.1-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:08308bae9e91aca1ec6fd6dda66237eef9f6294ddb17f0d0b3c863169bf82353"},
280 | {file = "matplotlib-3.7.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:544764ba51900da4639c0f983b323d288f94f65f4024dc40ecb1542d74dc0500"},
281 | {file = "matplotlib-3.7.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:56d94989191de3fcc4e002f93f7f1be5da476385dde410ddafbb70686acf00ea"},
282 | {file = "matplotlib-3.7.1-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e99bc9e65901bb9a7ce5e7bb24af03675cbd7c70b30ac670aa263240635999a4"},
283 | {file = "matplotlib-3.7.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:eb7d248c34a341cd4c31a06fd34d64306624c8cd8d0def7abb08792a5abfd556"},
284 | {file = "matplotlib-3.7.1-cp310-cp310-win32.whl", hash = "sha256:ce463ce590f3825b52e9fe5c19a3c6a69fd7675a39d589e8b5fbe772272b3a24"},
285 | {file = "matplotlib-3.7.1-cp310-cp310-win_amd64.whl", hash = "sha256:3d7bc90727351fb841e4d8ae620d2d86d8ed92b50473cd2b42ce9186104ecbba"},
286 | {file = "matplotlib-3.7.1-cp311-cp311-macosx_10_12_universal2.whl", hash = "sha256:770a205966d641627fd5cf9d3cb4b6280a716522cd36b8b284a8eb1581310f61"},
287 | {file = "matplotlib-3.7.1-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:f67bfdb83a8232cb7a92b869f9355d677bce24485c460b19d01970b64b2ed476"},
288 | {file = "matplotlib-3.7.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:2bf092f9210e105f414a043b92af583c98f50050559616930d884387d0772aba"},
289 | {file = "matplotlib-3.7.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:89768d84187f31717349c6bfadc0e0d8c321e8eb34522acec8a67b1236a66332"},
290 | {file = "matplotlib-3.7.1-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:83111e6388dec67822e2534e13b243cc644c7494a4bb60584edbff91585a83c6"},
291 | {file = "matplotlib-3.7.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a867bf73a7eb808ef2afbca03bcdb785dae09595fbe550e1bab0cd023eba3de0"},
292 | {file = "matplotlib-3.7.1-cp311-cp311-win32.whl", hash = "sha256:fbdeeb58c0cf0595efe89c05c224e0a502d1aa6a8696e68a73c3efc6bc354304"},
293 | {file = "matplotlib-3.7.1-cp311-cp311-win_amd64.whl", hash = "sha256:c0bd19c72ae53e6ab979f0ac6a3fafceb02d2ecafa023c5cca47acd934d10be7"},
294 | {file = "matplotlib-3.7.1-cp38-cp38-macosx_10_12_universal2.whl", hash = "sha256:6eb88d87cb2c49af00d3bbc33a003f89fd9f78d318848da029383bfc08ecfbfb"},
295 | {file = "matplotlib-3.7.1-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:cf0e4f727534b7b1457898c4f4ae838af1ef87c359b76dcd5330fa31893a3ac7"},
296 | {file = "matplotlib-3.7.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:46a561d23b91f30bccfd25429c3c706afe7d73a5cc64ef2dfaf2b2ac47c1a5dc"},
297 | {file = "matplotlib-3.7.1-cp38-cp38-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:8704726d33e9aa8a6d5215044b8d00804561971163563e6e6591f9dcf64340cc"},
298 | {file = "matplotlib-3.7.1-cp38-cp38-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:4cf327e98ecf08fcbb82685acaf1939d3338548620ab8dfa02828706402c34de"},
299 | {file = "matplotlib-3.7.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:617f14ae9d53292ece33f45cba8503494ee199a75b44de7717964f70637a36aa"},
300 | {file = "matplotlib-3.7.1-cp38-cp38-win32.whl", hash = "sha256:7c9a4b2da6fac77bcc41b1ea95fadb314e92508bf5493ceff058e727e7ecf5b0"},
301 | {file = "matplotlib-3.7.1-cp38-cp38-win_amd64.whl", hash = "sha256:14645aad967684e92fc349493fa10c08a6da514b3d03a5931a1bac26e6792bd1"},
302 | {file = "matplotlib-3.7.1-cp39-cp39-macosx_10_12_universal2.whl", hash = "sha256:81a6b377ea444336538638d31fdb39af6be1a043ca5e343fe18d0f17e098770b"},
303 | {file = "matplotlib-3.7.1-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:28506a03bd7f3fe59cd3cd4ceb2a8d8a2b1db41afede01f66c42561b9be7b4b7"},
304 | {file = "matplotlib-3.7.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:8c587963b85ce41e0a8af53b9b2de8dddbf5ece4c34553f7bd9d066148dc719c"},
305 | {file = "matplotlib-3.7.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8bf26ade3ff0f27668989d98c8435ce9327d24cffb7f07d24ef609e33d582439"},
306 | {file = "matplotlib-3.7.1-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:def58098f96a05f90af7e92fd127d21a287068202aa43b2a93476170ebd99e87"},
307 | {file = "matplotlib-3.7.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f883a22a56a84dba3b588696a2b8a1ab0d2c3d41be53264115c71b0a942d8fdb"},
308 | {file = "matplotlib-3.7.1-cp39-cp39-win32.whl", hash = "sha256:4f99e1b234c30c1e9714610eb0c6d2f11809c9c78c984a613ae539ea2ad2eb4b"},
309 | {file = "matplotlib-3.7.1-cp39-cp39-win_amd64.whl", hash = "sha256:3ba2af245e36990facf67fde840a760128ddd71210b2ab6406e640188d69d136"},
310 | {file = "matplotlib-3.7.1-pp38-pypy38_pp73-macosx_10_12_x86_64.whl", hash = "sha256:3032884084f541163f295db8a6536e0abb0db464008fadca6c98aaf84ccf4717"},
311 | {file = "matplotlib-3.7.1-pp38-pypy38_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3a2cb34336110e0ed8bb4f650e817eed61fa064acbefeb3591f1b33e3a84fd96"},
312 | {file = "matplotlib-3.7.1-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b867e2f952ed592237a1828f027d332d8ee219ad722345b79a001f49df0936eb"},
313 | {file = "matplotlib-3.7.1-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:57bfb8c8ea253be947ccb2bc2d1bb3862c2bccc662ad1b4626e1f5e004557042"},
314 | {file = "matplotlib-3.7.1-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:438196cdf5dc8d39b50a45cb6e3f6274edbcf2254f85fa9b895bf85851c3a613"},
315 | {file = "matplotlib-3.7.1-pp39-pypy39_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:21e9cff1a58d42e74d01153360de92b326708fb205250150018a52c70f43c290"},
316 | {file = "matplotlib-3.7.1-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:75d4725d70b7c03e082bbb8a34639ede17f333d7247f56caceb3801cb6ff703d"},
317 | {file = "matplotlib-3.7.1-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:97cc368a7268141afb5690760921765ed34867ffb9655dd325ed207af85c7529"},
318 | {file = "matplotlib-3.7.1.tar.gz", hash = "sha256:7b73305f25eab4541bd7ee0b96d87e53ae9c9f1823be5659b806cd85786fe882"},
319 | ]
320 |
321 | [package.dependencies]
322 | contourpy = ">=1.0.1"
323 | cycler = ">=0.10"
324 | fonttools = ">=4.22.0"
325 | kiwisolver = ">=1.0.1"
326 | numpy = ">=1.20"
327 | packaging = ">=20.0"
328 | pillow = ">=6.2.0"
329 | pyparsing = ">=2.3.1"
330 | python-dateutil = ">=2.7"
331 |
332 | [[package]]
333 | name = "mypy-extensions"
334 | version = "1.0.0"
335 | description = "Type system extensions for programs checked with the mypy type checker."
336 | category = "main"
337 | optional = false
338 | python-versions = ">=3.5"
339 | files = [
340 | {file = "mypy_extensions-1.0.0-py3-none-any.whl", hash = "sha256:4392f6c0eb8a5668a69e23d168ffa70f0be9ccfd32b5cc2d26a34ae5b844552d"},
341 | {file = "mypy_extensions-1.0.0.tar.gz", hash = "sha256:75dbf8955dc00442a438fc4d0666508a9a97b6bd41aa2f0ffe9d2f2725af0782"},
342 | ]
343 |
344 | [[package]]
345 | name = "numpy"
346 | version = "1.24.2"
347 | description = "Fundamental package for array computing in Python"
348 | category = "main"
349 | optional = false
350 | python-versions = ">=3.8"
351 | files = [
352 | {file = "numpy-1.24.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:eef70b4fc1e872ebddc38cddacc87c19a3709c0e3e5d20bf3954c147b1dd941d"},
353 | {file = "numpy-1.24.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:e8d2859428712785e8a8b7d2b3ef0a1d1565892367b32f915c4a4df44d0e64f5"},
354 | {file = "numpy-1.24.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6524630f71631be2dabe0c541e7675db82651eb998496bbe16bc4f77f0772253"},
355 | {file = "numpy-1.24.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a51725a815a6188c662fb66fb32077709a9ca38053f0274640293a14fdd22978"},
356 | {file = "numpy-1.24.2-cp310-cp310-win32.whl", hash = "sha256:2620e8592136e073bd12ee4536149380695fbe9ebeae845b81237f986479ffc9"},
357 | {file = "numpy-1.24.2-cp310-cp310-win_amd64.whl", hash = "sha256:97cf27e51fa078078c649a51d7ade3c92d9e709ba2bfb97493007103c741f1d0"},
358 | {file = "numpy-1.24.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:7de8fdde0003f4294655aa5d5f0a89c26b9f22c0a58790c38fae1ed392d44a5a"},
359 | {file = "numpy-1.24.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:4173bde9fa2a005c2c6e2ea8ac1618e2ed2c1c6ec8a7657237854d42094123a0"},
360 | {file = "numpy-1.24.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4cecaed30dc14123020f77b03601559fff3e6cd0c048f8b5289f4eeabb0eb281"},
361 | {file = "numpy-1.24.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9a23f8440561a633204a67fb44617ce2a299beecf3295f0d13c495518908e910"},
362 | {file = "numpy-1.24.2-cp311-cp311-win32.whl", hash = "sha256:e428c4fbfa085f947b536706a2fc349245d7baa8334f0c5723c56a10595f9b95"},
363 | {file = "numpy-1.24.2-cp311-cp311-win_amd64.whl", hash = "sha256:557d42778a6869c2162deb40ad82612645e21d79e11c1dc62c6e82a2220ffb04"},
364 | {file = "numpy-1.24.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:d0a2db9d20117bf523dde15858398e7c0858aadca7c0f088ac0d6edd360e9ad2"},
365 | {file = "numpy-1.24.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:c72a6b2f4af1adfe193f7beb91ddf708ff867a3f977ef2ec53c0ffb8283ab9f5"},
366 | {file = "numpy-1.24.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c29e6bd0ec49a44d7690ecb623a8eac5ab8a923bce0bea6293953992edf3a76a"},
367 | {file = "numpy-1.24.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2eabd64ddb96a1239791da78fa5f4e1693ae2dadc82a76bc76a14cbb2b966e96"},
368 | {file = "numpy-1.24.2-cp38-cp38-win32.whl", hash = "sha256:e3ab5d32784e843fc0dd3ab6dcafc67ef806e6b6828dc6af2f689be0eb4d781d"},
369 | {file = "numpy-1.24.2-cp38-cp38-win_amd64.whl", hash = "sha256:76807b4063f0002c8532cfeac47a3068a69561e9c8715efdad3c642eb27c0756"},
370 | {file = "numpy-1.24.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:4199e7cfc307a778f72d293372736223e39ec9ac096ff0a2e64853b866a8e18a"},
371 | {file = "numpy-1.24.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:adbdce121896fd3a17a77ab0b0b5eedf05a9834a18699db6829a64e1dfccca7f"},
372 | {file = "numpy-1.24.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:889b2cc88b837d86eda1b17008ebeb679d82875022200c6e8e4ce6cf549b7acb"},
373 | {file = "numpy-1.24.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f64bb98ac59b3ea3bf74b02f13836eb2e24e48e0ab0145bbda646295769bd780"},
374 | {file = "numpy-1.24.2-cp39-cp39-win32.whl", hash = "sha256:63e45511ee4d9d976637d11e6c9864eae50e12dc9598f531c035265991910468"},
375 | {file = "numpy-1.24.2-cp39-cp39-win_amd64.whl", hash = "sha256:a77d3e1163a7770164404607b7ba3967fb49b24782a6ef85d9b5f54126cc39e5"},
376 | {file = "numpy-1.24.2-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:92011118955724465fb6853def593cf397b4a1367495e0b59a7e69d40c4eb71d"},
377 | {file = "numpy-1.24.2-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f9006288bcf4895917d02583cf3411f98631275bc67cce355a7f39f8c14338fa"},
378 | {file = "numpy-1.24.2-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:150947adbdfeceec4e5926d956a06865c1c690f2fd902efede4ca6fe2e657c3f"},
379 | {file = "numpy-1.24.2.tar.gz", hash = "sha256:003a9f530e880cb2cd177cba1af7220b9aa42def9c4afc2a2fc3ee6be7eb2b22"},
380 | ]
381 |
382 | [[package]]
383 | name = "packaging"
384 | version = "23.0"
385 | description = "Core utilities for Python packages"
386 | category = "main"
387 | optional = false
388 | python-versions = ">=3.7"
389 | files = [
390 | {file = "packaging-23.0-py3-none-any.whl", hash = "sha256:714ac14496c3e68c99c29b00845f7a2b85f3bb6f1078fd9f72fd20f0570002b2"},
391 | {file = "packaging-23.0.tar.gz", hash = "sha256:b6ad297f8907de0fa2fe1ccbd26fdaf387f5f47c7275fedf8cce89f99446cf97"},
392 | ]
393 |
394 | [[package]]
395 | name = "pathspec"
396 | version = "0.11.1"
397 | description = "Utility library for gitignore style pattern matching of file paths."
398 | category = "main"
399 | optional = false
400 | python-versions = ">=3.7"
401 | files = [
402 | {file = "pathspec-0.11.1-py3-none-any.whl", hash = "sha256:d8af70af76652554bd134c22b3e8a1cc46ed7d91edcdd721ef1a0c51a84a5293"},
403 | {file = "pathspec-0.11.1.tar.gz", hash = "sha256:2798de800fa92780e33acca925945e9a19a133b715067cf165b8866c15a31687"},
404 | ]
405 |
406 | [[package]]
407 | name = "pillow"
408 | version = "9.5.0"
409 | description = "Python Imaging Library (Fork)"
410 | category = "main"
411 | optional = false
412 | python-versions = ">=3.7"
413 | files = [
414 | {file = "Pillow-9.5.0-cp310-cp310-macosx_10_10_x86_64.whl", hash = "sha256:ace6ca218308447b9077c14ea4ef381ba0b67ee78d64046b3f19cf4e1139ad16"},
415 | {file = "Pillow-9.5.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:d3d403753c9d5adc04d4694d35cf0391f0f3d57c8e0030aac09d7678fa8030aa"},
416 | {file = "Pillow-9.5.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5ba1b81ee69573fe7124881762bb4cd2e4b6ed9dd28c9c60a632902fe8db8b38"},
417 | {file = "Pillow-9.5.0-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:fe7e1c262d3392afcf5071df9afa574544f28eac825284596ac6db56e6d11062"},
418 | {file = "Pillow-9.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8f36397bf3f7d7c6a3abdea815ecf6fd14e7fcd4418ab24bae01008d8d8ca15e"},
419 | {file = "Pillow-9.5.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:252a03f1bdddce077eff2354c3861bf437c892fb1832f75ce813ee94347aa9b5"},
420 | {file = "Pillow-9.5.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:85ec677246533e27770b0de5cf0f9d6e4ec0c212a1f89dfc941b64b21226009d"},
421 | {file = "Pillow-9.5.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:b416f03d37d27290cb93597335a2f85ed446731200705b22bb927405320de903"},
422 | {file = "Pillow-9.5.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:1781a624c229cb35a2ac31cc4a77e28cafc8900733a864870c49bfeedacd106a"},
423 | {file = "Pillow-9.5.0-cp310-cp310-win32.whl", hash = "sha256:8507eda3cd0608a1f94f58c64817e83ec12fa93a9436938b191b80d9e4c0fc44"},
424 | {file = "Pillow-9.5.0-cp310-cp310-win_amd64.whl", hash = "sha256:d3c6b54e304c60c4181da1c9dadf83e4a54fd266a99c70ba646a9baa626819eb"},
425 | {file = "Pillow-9.5.0-cp311-cp311-macosx_10_10_x86_64.whl", hash = "sha256:7ec6f6ce99dab90b52da21cf0dc519e21095e332ff3b399a357c187b1a5eee32"},
426 | {file = "Pillow-9.5.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:560737e70cb9c6255d6dcba3de6578a9e2ec4b573659943a5e7e4af13f298f5c"},
427 | {file = "Pillow-9.5.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:96e88745a55b88a7c64fa49bceff363a1a27d9a64e04019c2281049444a571e3"},
428 | {file = "Pillow-9.5.0-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d9c206c29b46cfd343ea7cdfe1232443072bbb270d6a46f59c259460db76779a"},
429 | {file = "Pillow-9.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cfcc2c53c06f2ccb8976fb5c71d448bdd0a07d26d8e07e321c103416444c7ad1"},
430 | {file = "Pillow-9.5.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:a0f9bb6c80e6efcde93ffc51256d5cfb2155ff8f78292f074f60f9e70b942d99"},
431 | {file = "Pillow-9.5.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:8d935f924bbab8f0a9a28404422da8af4904e36d5c33fc6f677e4c4485515625"},
432 | {file = "Pillow-9.5.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:fed1e1cf6a42577953abbe8e6cf2fe2f566daebde7c34724ec8803c4c0cda579"},
433 | {file = "Pillow-9.5.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:c1170d6b195555644f0616fd6ed929dfcf6333b8675fcca044ae5ab110ded296"},
434 | {file = "Pillow-9.5.0-cp311-cp311-win32.whl", hash = "sha256:54f7102ad31a3de5666827526e248c3530b3a33539dbda27c6843d19d72644ec"},
435 | {file = "Pillow-9.5.0-cp311-cp311-win_amd64.whl", hash = "sha256:cfa4561277f677ecf651e2b22dc43e8f5368b74a25a8f7d1d4a3a243e573f2d4"},
436 | {file = "Pillow-9.5.0-cp311-cp311-win_arm64.whl", hash = "sha256:965e4a05ef364e7b973dd17fc765f42233415974d773e82144c9bbaaaea5d089"},
437 | {file = "Pillow-9.5.0-cp312-cp312-win32.whl", hash = "sha256:22baf0c3cf0c7f26e82d6e1adf118027afb325e703922c8dfc1d5d0156bb2eeb"},
438 | {file = "Pillow-9.5.0-cp312-cp312-win_amd64.whl", hash = "sha256:432b975c009cf649420615388561c0ce7cc31ce9b2e374db659ee4f7d57a1f8b"},
439 | {file = "Pillow-9.5.0-cp37-cp37m-macosx_10_10_x86_64.whl", hash = "sha256:5d4ebf8e1db4441a55c509c4baa7a0587a0210f7cd25fcfe74dbbce7a4bd1906"},
440 | {file = "Pillow-9.5.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:375f6e5ee9620a271acb6820b3d1e94ffa8e741c0601db4c0c4d3cb0a9c224bf"},
441 | {file = "Pillow-9.5.0-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:99eb6cafb6ba90e436684e08dad8be1637efb71c4f2180ee6b8f940739406e78"},
442 | {file = "Pillow-9.5.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2dfaaf10b6172697b9bceb9a3bd7b951819d1ca339a5ef294d1f1ac6d7f63270"},
443 | {file = "Pillow-9.5.0-cp37-cp37m-manylinux_2_28_aarch64.whl", hash = "sha256:763782b2e03e45e2c77d7779875f4432e25121ef002a41829d8868700d119392"},
444 | {file = "Pillow-9.5.0-cp37-cp37m-manylinux_2_28_x86_64.whl", hash = "sha256:35f6e77122a0c0762268216315bf239cf52b88865bba522999dc38f1c52b9b47"},
445 | {file = "Pillow-9.5.0-cp37-cp37m-win32.whl", hash = "sha256:aca1c196f407ec7cf04dcbb15d19a43c507a81f7ffc45b690899d6a76ac9fda7"},
446 | {file = "Pillow-9.5.0-cp37-cp37m-win_amd64.whl", hash = "sha256:322724c0032af6692456cd6ed554bb85f8149214d97398bb80613b04e33769f6"},
447 | {file = "Pillow-9.5.0-cp38-cp38-macosx_10_10_x86_64.whl", hash = "sha256:a0aa9417994d91301056f3d0038af1199eb7adc86e646a36b9e050b06f526597"},
448 | {file = "Pillow-9.5.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:f8286396b351785801a976b1e85ea88e937712ee2c3ac653710a4a57a8da5d9c"},
449 | {file = "Pillow-9.5.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c830a02caeb789633863b466b9de10c015bded434deb3ec87c768e53752ad22a"},
450 | {file = "Pillow-9.5.0-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:fbd359831c1657d69bb81f0db962905ee05e5e9451913b18b831febfe0519082"},
451 | {file = "Pillow-9.5.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f8fc330c3370a81bbf3f88557097d1ea26cd8b019d6433aa59f71195f5ddebbf"},
452 | {file = "Pillow-9.5.0-cp38-cp38-manylinux_2_28_aarch64.whl", hash = "sha256:7002d0797a3e4193c7cdee3198d7c14f92c0836d6b4a3f3046a64bd1ce8df2bf"},
453 | {file = "Pillow-9.5.0-cp38-cp38-manylinux_2_28_x86_64.whl", hash = "sha256:229e2c79c00e85989a34b5981a2b67aa079fd08c903f0aaead522a1d68d79e51"},
454 | {file = "Pillow-9.5.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:9adf58f5d64e474bed00d69bcd86ec4bcaa4123bfa70a65ce72e424bfb88ed96"},
455 | {file = "Pillow-9.5.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:662da1f3f89a302cc22faa9f14a262c2e3951f9dbc9617609a47521c69dd9f8f"},
456 | {file = "Pillow-9.5.0-cp38-cp38-win32.whl", hash = "sha256:6608ff3bf781eee0cd14d0901a2b9cc3d3834516532e3bd673a0a204dc8615fc"},
457 | {file = "Pillow-9.5.0-cp38-cp38-win_amd64.whl", hash = "sha256:e49eb4e95ff6fd7c0c402508894b1ef0e01b99a44320ba7d8ecbabefddcc5569"},
458 | {file = "Pillow-9.5.0-cp39-cp39-macosx_10_10_x86_64.whl", hash = "sha256:482877592e927fd263028c105b36272398e3e1be3269efda09f6ba21fd83ec66"},
459 | {file = "Pillow-9.5.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:3ded42b9ad70e5f1754fb7c2e2d6465a9c842e41d178f262e08b8c85ed8a1d8e"},
460 | {file = "Pillow-9.5.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c446d2245ba29820d405315083d55299a796695d747efceb5717a8b450324115"},
461 | {file = "Pillow-9.5.0-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8aca1152d93dcc27dc55395604dcfc55bed5f25ef4c98716a928bacba90d33a3"},
462 | {file = "Pillow-9.5.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:608488bdcbdb4ba7837461442b90ea6f3079397ddc968c31265c1e056964f1ef"},
463 | {file = "Pillow-9.5.0-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:60037a8db8750e474af7ffc9faa9b5859e6c6d0a50e55c45576bf28be7419705"},
464 | {file = "Pillow-9.5.0-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:07999f5834bdc404c442146942a2ecadd1cb6292f5229f4ed3b31e0a108746b1"},
465 | {file = "Pillow-9.5.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:a127ae76092974abfbfa38ca2d12cbeddcdeac0fb71f9627cc1135bedaf9d51a"},
466 | {file = "Pillow-9.5.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:489f8389261e5ed43ac8ff7b453162af39c3e8abd730af8363587ba64bb2e865"},
467 | {file = "Pillow-9.5.0-cp39-cp39-win32.whl", hash = "sha256:9b1af95c3a967bf1da94f253e56b6286b50af23392a886720f563c547e48e964"},
468 | {file = "Pillow-9.5.0-cp39-cp39-win_amd64.whl", hash = "sha256:77165c4a5e7d5a284f10a6efaa39a0ae8ba839da344f20b111d62cc932fa4e5d"},
469 | {file = "Pillow-9.5.0-pp38-pypy38_pp73-macosx_10_10_x86_64.whl", hash = "sha256:833b86a98e0ede388fa29363159c9b1a294b0905b5128baf01db683672f230f5"},
470 | {file = "Pillow-9.5.0-pp38-pypy38_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:aaf305d6d40bd9632198c766fb64f0c1a83ca5b667f16c1e79e1661ab5060140"},
471 | {file = "Pillow-9.5.0-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0852ddb76d85f127c135b6dd1f0bb88dbb9ee990d2cd9aa9e28526c93e794fba"},
472 | {file = "Pillow-9.5.0-pp38-pypy38_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:91ec6fe47b5eb5a9968c79ad9ed78c342b1f97a091677ba0e012701add857829"},
473 | {file = "Pillow-9.5.0-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:cb841572862f629b99725ebaec3287fc6d275be9b14443ea746c1dd325053cbd"},
474 | {file = "Pillow-9.5.0-pp39-pypy39_pp73-macosx_10_10_x86_64.whl", hash = "sha256:c380b27d041209b849ed246b111b7c166ba36d7933ec6e41175fd15ab9eb1572"},
475 | {file = "Pillow-9.5.0-pp39-pypy39_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7c9af5a3b406a50e313467e3565fc99929717f780164fe6fbb7704edba0cebbe"},
476 | {file = "Pillow-9.5.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5671583eab84af046a397d6d0ba25343c00cd50bce03787948e0fff01d4fd9b1"},
477 | {file = "Pillow-9.5.0-pp39-pypy39_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:84a6f19ce086c1bf894644b43cd129702f781ba5751ca8572f08aa40ef0ab7b7"},
478 | {file = "Pillow-9.5.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:1e7723bd90ef94eda669a3c2c19d549874dd5badaeefabefd26053304abe5799"},
479 | {file = "Pillow-9.5.0.tar.gz", hash = "sha256:bf548479d336726d7a0eceb6e767e179fbde37833ae42794602631a070d630f1"},
480 | ]
481 |
482 | [package.extras]
483 | docs = ["furo", "olefile", "sphinx (>=2.4)", "sphinx-copybutton", "sphinx-inline-tabs", "sphinx-removed-in", "sphinxext-opengraph"]
484 | tests = ["check-manifest", "coverage", "defusedxml", "markdown2", "olefile", "packaging", "pyroma", "pytest", "pytest-cov", "pytest-timeout"]
485 |
486 | [[package]]
487 | name = "platformdirs"
488 | version = "3.2.0"
489 | description = "A small Python package for determining appropriate platform-specific dirs, e.g. a \"user data dir\"."
490 | category = "main"
491 | optional = false
492 | python-versions = ">=3.7"
493 | files = [
494 | {file = "platformdirs-3.2.0-py3-none-any.whl", hash = "sha256:ebe11c0d7a805086e99506aa331612429a72ca7cd52a1f0d277dc4adc20cb10e"},
495 | {file = "platformdirs-3.2.0.tar.gz", hash = "sha256:d5b638ca397f25f979350ff789db335903d7ea010ab28903f57b27e1b16c2b08"},
496 | ]
497 |
498 | [package.extras]
499 | docs = ["furo (>=2022.12.7)", "proselint (>=0.13)", "sphinx (>=6.1.3)", "sphinx-autodoc-typehints (>=1.22,!=1.23.4)"]
500 | test = ["appdirs (==1.4.4)", "covdefaults (>=2.3)", "pytest (>=7.2.2)", "pytest-cov (>=4)", "pytest-mock (>=3.10)"]
501 |
502 | [[package]]
503 | name = "pyparsing"
504 | version = "3.0.9"
505 | description = "pyparsing module - Classes and methods to define and execute parsing grammars"
506 | category = "main"
507 | optional = false
508 | python-versions = ">=3.6.8"
509 | files = [
510 | {file = "pyparsing-3.0.9-py3-none-any.whl", hash = "sha256:5026bae9a10eeaefb61dab2f09052b9f4307d44aee4eda64b309723d8d206bbc"},
511 | {file = "pyparsing-3.0.9.tar.gz", hash = "sha256:2b020ecf7d21b687f219b71ecad3631f644a47f01403fa1d1036b0c6416d70fb"},
512 | ]
513 |
514 | [package.extras]
515 | diagrams = ["jinja2", "railroad-diagrams"]
516 |
517 | [[package]]
518 | name = "python-dateutil"
519 | version = "2.8.2"
520 | description = "Extensions to the standard Python datetime module"
521 | category = "main"
522 | optional = false
523 | python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7"
524 | files = [
525 | {file = "python-dateutil-2.8.2.tar.gz", hash = "sha256:0123cacc1627ae19ddf3c27a5de5bd67ee4586fbdd6440d9748f8abb483d3e86"},
526 | {file = "python_dateutil-2.8.2-py2.py3-none-any.whl", hash = "sha256:961d03dc3453ebbc59dbdea9e4e11c5651520a876d0f4db161e8674aae935da9"},
527 | ]
528 |
529 | [package.dependencies]
530 | six = ">=1.5"
531 |
532 | [[package]]
533 | name = "six"
534 | version = "1.16.0"
535 | description = "Python 2 and 3 compatibility utilities"
536 | category = "main"
537 | optional = false
538 | python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*"
539 | files = [
540 | {file = "six-1.16.0-py2.py3-none-any.whl", hash = "sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254"},
541 | {file = "six-1.16.0.tar.gz", hash = "sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926"},
542 | ]
543 |
544 | [[package]]
545 | name = "tomli"
546 | version = "2.0.1"
547 | description = "A lil' TOML parser"
548 | category = "main"
549 | optional = false
550 | python-versions = ">=3.7"
551 | files = [
552 | {file = "tomli-2.0.1-py3-none-any.whl", hash = "sha256:939de3e7a6161af0c887ef91b7d41a53e7c5a1ca976325f429cb46ea9bc30ecc"},
553 | {file = "tomli-2.0.1.tar.gz", hash = "sha256:de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f"},
554 | ]
555 |
556 | [metadata]
557 | lock-version = "2.0"
558 | python-versions = "^3.10.8"
559 | content-hash = "e8a823794405d03282a7c42f616cf76a51891ffb94a44505cf6d6eddc854fc7e"
560 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools>=64", "wheel"]
3 | build-backend = 'setuptools.build_meta'
4 |
5 | [tool.setuptools]
6 | package-dir = {"" = "src"}
7 |
8 | [project]
9 | name = "mppi_playground"
10 | version = "0.1.0"
11 | description = ""
12 | requires-python = ">=3.10"
13 | dependencies = [
14 | "matplotlib==3.8.2",
15 | "fire==0.5.0",
16 | "numpy==1.26.2",
17 | "torch==2.00",
18 | "torchvision==0.15.1",
19 | "gymnasium[all]==0.29.1",
20 | "mujoco==2.3.7",
21 | # "pybullet==3.2.5",
22 | ]
23 |
24 | [project.optional-dependencies]
25 | dev = [
26 | "pytest",
27 | "pysen",
28 | "black",
29 | "flake8",
30 | "isort",
31 | "mypy",
32 | ]
--------------------------------------------------------------------------------
/src/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/proxima-technology/mppi_playground/12cb79d02c703f818cc9c49d2436f79fd4db7e9f/src/__init__.py
--------------------------------------------------------------------------------
/src/controller/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/proxima-technology/mppi_playground/12cb79d02c703f818cc9c49d2436f79fd4db7e9f/src/controller/__init__.py
--------------------------------------------------------------------------------
/src/controller/mppi.py:
--------------------------------------------------------------------------------
1 | """
2 | Kohei Honda, 2023.
3 | """
4 |
5 | from __future__ import annotations
6 |
7 | from typing import Callable, Tuple
8 |
9 | import torch
10 | import torch.nn as nn
11 | from torch.distributions.multivariate_normal import MultivariateNormal
12 |
13 |
14 | class MPPI(nn.Module):
15 | """
16 | Model Predictive Path Integral Control,
17 | J. Williams et al., T-RO, 2017.
18 | """
19 |
20 | def __init__(
21 | self,
22 | horizon: int,
23 | num_samples: int,
24 | dim_state: int,
25 | dim_control: int,
26 | dynamics: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
27 | stage_cost: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
28 | terminal_cost: Callable[[torch.Tensor], torch.Tensor],
29 | u_min: torch.Tensor,
30 | u_max: torch.Tensor,
31 | sigmas: torch.Tensor,
32 | lambda_: float,
33 | device=torch.device("cuda"),
34 | dtype=torch.float32,
35 | seed: int = 42,
36 | ) -> None:
37 | """
38 | :param horizon: Predictive horizon length.
39 | :param delta: predictive horizon step size (seconds).
40 | :param num_samples: Number of samples.
41 | :param dim_state: Dimension of state.
42 | :param dim_control: Dimension of control.
43 | :param dynamics: Dynamics model.
44 | :param stage_cost: Stage cost.
45 | :param terminal_cost: Terminal cost.
46 | :param u_min: Minimum control.
47 | :param u_max: Maximum control.
48 | :param sigmas: Noise standard deviation for each control dimension.
49 | :param lambda_: temperature parameter.
50 | :param device: Device to run the solver.
51 | :param dtype: Data type to run the solver.
52 | :param seed: Seed for torch.
53 | """
54 |
55 | super().__init__()
56 |
57 | # torch seed
58 | torch.manual_seed(seed)
59 |
60 | # check dimensions
61 | assert u_min.shape == (dim_control,)
62 | assert u_max.shape == (dim_control,)
63 | assert sigmas.shape == (dim_control,)
64 | # assert num_samples % batch_size == 0 and num_samples >= batch_size
65 |
66 | # device and dtype
67 | if torch.cuda.is_available() and device == torch.device("cuda"):
68 | self._device = torch.device("cuda")
69 | else:
70 | self._device = torch.device("cpu")
71 | self._dtype = dtype
72 |
73 | # set parameters
74 | self._horizon = horizon
75 | self._num_samples = num_samples
76 | self._dim_state = dim_state
77 | self._dim_control = dim_control
78 | self._dynamics = dynamics
79 | self._stage_cost = stage_cost
80 | self._terminal_cost = terminal_cost
81 | self._u_min = u_min.clone().detach().to(self._device, self._dtype)
82 | self._u_max = u_max.clone().detach().to(self._device, self._dtype)
83 | self._sigmas = sigmas.clone().detach().to(self._device, self._dtype)
84 | self._lambda = lambda_
85 |
86 | # noise distribution
87 | zero_mean = torch.zeros(dim_control, device=self._device, dtype=self._dtype)
88 | initial_covariance = torch.diag(sigmas**2).to(self._device, self._dtype)
89 | self._inv_covariance = torch.inverse(initial_covariance).to(
90 | self._device, self._dtype
91 | )
92 |
93 | self._noise_distribution = MultivariateNormal(
94 | loc=zero_mean, covariance_matrix=initial_covariance
95 | )
96 | self._sample_shape = torch.Size([self._num_samples, self._horizon])
97 |
98 | # sampling with reparameting trick
99 | self._action_noises = self._noise_distribution.rsample(
100 | sample_shape=self._sample_shape
101 | )
102 |
103 | zero_mean_seq = torch.zeros(
104 | self._horizon, self._dim_control, device=self._device, dtype=self._dtype
105 | )
106 | self._perturbed_action_seqs = torch.clamp(
107 | zero_mean_seq + self._action_noises, self._u_min, self._u_max
108 | )
109 |
110 | self._previous_action_seq = zero_mean_seq
111 |
112 | # inner variables
113 | self._state_seq_batch = torch.zeros(
114 | self._num_samples,
115 | self._horizon + 1,
116 | self._dim_state,
117 | device=self._device,
118 | dtype=self._dtype,
119 | )
120 | self._weights = torch.zeros(
121 | self._num_samples, device=self._device, dtype=self._dtype
122 | )
123 |
124 | def forward(self, state: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
125 | """
126 | Solve the optimal control problem.
127 | Args:
128 | state (torch.Tensor): Current state.
129 | Returns:
130 | Tuple[torch.Tensor, torch.Tensor]: Tuple of predictive control and state sequence.
131 | """
132 | assert state.shape == (self._dim_state,)
133 |
134 | if not torch.is_tensor(state):
135 | state = torch.tensor(state, device=self._device, dtype=self._dtype)
136 | else:
137 | if state.device != self._device or state.dtype != self._dtype:
138 | state = state.to(self._device, self._dtype)
139 |
140 | mean_action_seq = self._previous_action_seq.clone().detach()
141 |
142 | # random sampling with reparametrization trick
143 | self._action_noises = self._noise_distribution.rsample(
144 | sample_shape=self._sample_shape
145 | )
146 | self._perturbed_action_seqs = mean_action_seq + self._action_noises
147 |
148 | # clamp actions
149 | self._perturbed_action_seqs = torch.clamp(
150 | self._perturbed_action_seqs, self._u_min, self._u_max
151 | )
152 |
153 | # rollout samples in parallel
154 | self._state_seq_batch[:, 0, :] = state.repeat(self._num_samples, 1)
155 |
156 | for t in range(self._horizon):
157 | self._state_seq_batch[:, t + 1, :] = self._dynamics(
158 | self._state_seq_batch[:, t, :], self._perturbed_action_seqs[:, t, :]
159 | )
160 |
161 | # compute sample costs
162 | stage_costs = torch.zeros(
163 | self._num_samples, self._horizon, device=self._device, dtype=self._dtype
164 | )
165 | action_costs = torch.zeros(
166 | self._num_samples, self._horizon, device=self._device, dtype=self._dtype
167 | )
168 | for t in range(self._horizon):
169 | stage_costs[:, t] = self._stage_cost(
170 | self._state_seq_batch[:, t, :], self._perturbed_action_seqs[:, t, :]
171 | )
172 | action_costs[:, t] = (
173 | mean_action_seq[t]
174 | @ self._inv_covariance
175 | @ self._perturbed_action_seqs[:, t].T
176 | )
177 |
178 | terminal_costs = self._terminal_cost(self._state_seq_batch[:, -1, :])
179 |
180 | costs = (
181 | torch.sum(stage_costs, dim=1)
182 | + terminal_costs
183 | + torch.sum(self._lambda * action_costs, dim=1)
184 | )
185 |
186 | # calculate weights
187 | self._weights = torch.softmax(-costs / self._lambda, dim=0)
188 |
189 | # find optimal control by weighted average
190 | optimal_action_seq = torch.sum(
191 | self._weights.view(self._num_samples, 1, 1) * self._perturbed_action_seqs,
192 | dim=0,
193 | )
194 |
195 | # predivtive state seq
196 | optimal_state_seq = torch.zeros(
197 | 1,
198 | self._horizon + 1,
199 | self._dim_state,
200 | device=self._device,
201 | dtype=self._dtype,
202 | )
203 | optimal_state_seq[:, 0, :] = state
204 | expanded_optimal_action_seq = optimal_action_seq.repeat(1, 1, 1)
205 | for t in range(self._horizon):
206 | optimal_state_seq[:, t + 1, :] = self._dynamics(
207 | optimal_state_seq[:, t, :], expanded_optimal_action_seq[:, t, :]
208 | )
209 |
210 | # update previous actions
211 | self._previous_action_seq = optimal_action_seq
212 |
213 | return optimal_action_seq, optimal_state_seq
214 |
215 | def get_top_samples(self, num_samples: int) -> Tuple[torch.Tensor, torch.Tensor]:
216 | """
217 | Get top samples.
218 | Args:
219 | num_samples (int): Number of state samples to get.
220 | Returns:
221 | Tuple[torch.Tensor, torch.Tensor]: Tuple of top samples and their weights.
222 | """
223 | assert num_samples <= self._num_samples
224 |
225 | # large weights are better
226 | top_indices = torch.topk(self._weights, num_samples).indices
227 |
228 | top_samples = self._state_seq_batch[top_indices]
229 | top_weights = self._weights[top_indices]
230 |
231 | top_samples = top_samples[torch.argsort(top_weights, descending=True)]
232 | top_weights = top_weights[torch.argsort(top_weights, descending=True)]
233 |
234 | return top_samples, top_weights
235 |
--------------------------------------------------------------------------------
/src/envs/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/proxima-technology/mppi_playground/12cb79d02c703f818cc9c49d2436f79fd4db7e9f/src/envs/__init__.py
--------------------------------------------------------------------------------
/src/envs/navigation_2d.py:
--------------------------------------------------------------------------------
1 | """
2 | Kohei Honda, 2023.
3 | """
4 |
5 | from __future__ import annotations
6 |
7 | from typing import Tuple, Union
8 | from matplotlib import pyplot as plt
9 |
10 | import torch
11 | import numpy as np
12 | import os
13 |
14 |
15 | from moviepy.video.io.ImageSequenceClip import ImageSequenceClip
16 |
17 | from envs.obstacle_map_2d import ObstacleMap, generate_random_obstacles
18 |
19 |
20 | @torch.jit.script
21 | def angle_normalize(x):
22 | return ((x + torch.pi) % (2 * torch.pi)) - torch.pi
23 |
24 |
25 | class Navigation2DEnv:
26 | def __init__(
27 | self, device=torch.device("cuda"), dtype=torch.float32, seed: int = 42
28 | ) -> None:
29 | # device and dtype
30 | if torch.cuda.is_available() and device == torch.device("cuda"):
31 | self._device = torch.device("cuda")
32 | else:
33 | self._device = torch.device("cpu")
34 | self._dtype = dtype
35 |
36 | self._obstacle_map = ObstacleMap(
37 | map_size=(20, 20), cell_size=0.1, device=self._device, dtype=self._dtype
38 | )
39 | self._seed = seed
40 |
41 | generate_random_obstacles(
42 | obstacle_map=self._obstacle_map,
43 | random_x_range=(-7.5, 7.5),
44 | random_y_range=(-7.5, 7.5),
45 | num_circle_obs=7,
46 | radius_range=(1, 1),
47 | num_rectangle_obs=7,
48 | width_range=(2, 2),
49 | height_range=(2, 2),
50 | max_iteration=1000,
51 | seed=seed,
52 | )
53 | self._obstacle_map.convert_to_torch()
54 |
55 | self._start_pos = torch.tensor(
56 | [-9.0, -9.0], device=self._device, dtype=self._dtype
57 | )
58 | self._goal_pos = torch.tensor(
59 | [9.0, 9.0], device=self._device, dtype=self._dtype
60 | )
61 |
62 | self._robot_state = torch.zeros(3, device=self._device, dtype=self._dtype)
63 | self._robot_state[:2] = self._start_pos
64 | self._robot_state[2] = angle_normalize(
65 | torch.atan2(
66 | self._goal_pos[1] - self._start_pos[1],
67 | self._goal_pos[0] - self._start_pos[0],
68 | )
69 | )
70 |
71 | # u: [v, omega] (m/s, rad/s)
72 | self.u_min = torch.tensor([0.0, -1.0], device=self._device, dtype=self._dtype)
73 | self.u_max = torch.tensor([2.0, 1.0], device=self._device, dtype=self._dtype)
74 |
75 | def reset(self) -> torch.Tensor:
76 | """
77 | Reset robot state.
78 | Returns:
79 | torch.Tensor: shape (3,) [x, y, theta]
80 | """
81 | self._robot_state[:2] = self._start_pos
82 | self._robot_state[2] = angle_normalize(
83 | torch.atan2(
84 | self._goal_pos[1] - self._start_pos[1],
85 | self._goal_pos[0] - self._start_pos[0],
86 | )
87 | )
88 |
89 | self._fig = plt.figure(layout="tight")
90 | self._ax = self._fig.add_subplot()
91 | self._ax.set_xlim(self._obstacle_map.x_lim)
92 | self._ax.set_ylim(self._obstacle_map.y_lim)
93 | self._ax.set_aspect("equal")
94 |
95 | self._rendered_frames = []
96 |
97 | return self._robot_state
98 |
99 | def step(self, u: torch.Tensor) -> Tuple[torch.Tensor, bool]:
100 | """
101 | Update robot state based on differential drive dynamics.
102 | Args:
103 | u (torch.Tensor): control batch tensor, shape (2) [v, omega]
104 | Returns:
105 | Tuple[torch.Tensor, bool]: Tuple of robot state and is goal reached.
106 | """
107 | u = torch.clamp(u, self.u_min, self.u_max)
108 |
109 | self._robot_state = self.dynamics(
110 | state=self._robot_state.unsqueeze(0), action=u.unsqueeze(0)
111 | ).squeeze(0)
112 |
113 | # goal check
114 | goal_threshold = 0.5
115 | is_goal_reached = (
116 | torch.norm(self._robot_state[:2] - self._goal_pos) < goal_threshold
117 | )
118 |
119 | return self._robot_state, is_goal_reached
120 |
121 | def render(
122 | self,
123 | predicted_trajectory: torch.Tensor = None,
124 | is_collisions: torch.Tensor = None,
125 | top_samples: Tuple[torch.Tensor, torch.Tensor] = None,
126 | mode: str = "human",
127 | ) -> None:
128 | self._ax.set_xlabel("x [m]")
129 | self._ax.set_ylabel("y [m]")
130 |
131 | # obstacle map
132 | self._obstacle_map.render(self._ax, zorder=10)
133 |
134 | # start and goal
135 | self._ax.scatter(
136 | self._start_pos[0].item(),
137 | self._start_pos[1].item(),
138 | marker="o",
139 | color="red",
140 | zorder=10,
141 | )
142 | self._ax.scatter(
143 | self._goal_pos[0].item(),
144 | self._goal_pos[1].item(),
145 | marker="o",
146 | color="orange",
147 | zorder=10,
148 | )
149 |
150 | # robot
151 | self._ax.scatter(
152 | self._robot_state[0].item(),
153 | self._robot_state[1].item(),
154 | marker="o",
155 | color="green",
156 | zorder=100,
157 | )
158 |
159 | # visualize top samples with different alpha based on weights
160 | if top_samples is not None:
161 | top_samples, top_weights = top_samples
162 | top_samples = top_samples.cpu().numpy()
163 | top_weights = top_weights.cpu().numpy()
164 | top_weights = 0.7 * top_weights / np.max(top_weights)
165 | top_weights = np.clip(top_weights, 0.1, 0.7)
166 | for i in range(top_samples.shape[0]):
167 | self._ax.plot(
168 | top_samples[i, :, 0],
169 | top_samples[i, :, 1],
170 | color="lightblue",
171 | alpha=top_weights[i],
172 | zorder=1,
173 | )
174 |
175 | # predicted trajectory
176 | if predicted_trajectory is not None:
177 | # if is collision color is red
178 | colors = np.array(["darkblue"] * predicted_trajectory.shape[1])
179 | if is_collisions is not None:
180 | is_collisions = is_collisions.cpu().numpy()
181 | is_collisions = np.any(is_collisions, axis=0)
182 | colors[is_collisions] = "red"
183 |
184 | self._ax.scatter(
185 | predicted_trajectory[0, :, 0].cpu().numpy(),
186 | predicted_trajectory[0, :, 1].cpu().numpy(),
187 | color=colors,
188 | marker="o",
189 | s=3,
190 | zorder=2,
191 | )
192 |
193 | if mode == "human":
194 | # online rendering
195 | plt.pause(0.001)
196 | plt.cla()
197 | elif mode == "rgb_array":
198 | # offline rendering for video
199 | # TODO: high resolution rendering
200 | self._fig.canvas.draw()
201 | data = np.frombuffer(self._fig.canvas.tostring_rgb(), dtype=np.uint8)
202 | data = data.reshape(self._fig.canvas.get_width_height()[::-1] + (3,))
203 | plt.cla()
204 | self._rendered_frames.append(data)
205 |
206 | def close(self, path: str = None) -> None:
207 | if path is None:
208 | # mkdir video if not exists
209 |
210 | if not os.path.exists("video"):
211 | os.mkdir("video")
212 | path = "video/" + "navigation_2d_" + str(self._seed) + ".gif"
213 |
214 | if len(self._rendered_frames) > 0:
215 | # save animation
216 | clip = ImageSequenceClip(self._rendered_frames, fps=10)
217 | # clip.write_videofile(path, fps=10)
218 | clip.write_gif(path, fps=10)
219 |
220 | def dynamics(
221 | self, state: torch.Tensor, action: torch.Tensor, delta_t: float = 0.1
222 | ) -> torch.Tensor:
223 | """
224 | Update robot state based on differential drive dynamics.
225 | Args:
226 | state (torch.Tensor): state batch tensor, shape (batch_size, 3) [x, y, theta]
227 | action (torch.Tensor): control batch tensor, shape (batch_size, 2) [v, omega]
228 | delta_t (float): time step interval [s]
229 | Returns:
230 | torch.Tensor: shape (batch_size, 3) [x, y, theta]
231 | """
232 |
233 | # Perform calculations as before
234 | x = state[:, 0].view(-1, 1)
235 | y = state[:, 1].view(-1, 1)
236 | theta = state[:, 2].view(-1, 1)
237 | v = torch.clamp(action[:, 0].view(-1, 1), self.u_min[0], self.u_max[0])
238 | omega = torch.clamp(action[:, 1].view(-1, 1), self.u_min[1], self.u_max[1])
239 | theta = angle_normalize(theta)
240 |
241 | new_x = x + v * torch.cos(theta) * delta_t
242 | new_y = y + v * torch.sin(theta) * delta_t
243 | new_theta = angle_normalize(theta + omega * delta_t)
244 |
245 | # Clamp x and y to the map boundary
246 | x_lim = torch.tensor(
247 | self._obstacle_map.x_lim, device=self._device, dtype=self._dtype
248 | )
249 | y_lim = torch.tensor(
250 | self._obstacle_map.y_lim, device=self._device, dtype=self._dtype
251 | )
252 | clamped_x = torch.clamp(new_x, x_lim[0], x_lim[1])
253 | clamped_y = torch.clamp(new_y, y_lim[0], y_lim[1])
254 |
255 | result = torch.cat([clamped_x, clamped_y, new_theta], dim=1)
256 |
257 | return result
258 |
259 | def stage_cost(self, state: torch.Tensor, action: torch.Tensor) -> torch.Tensor:
260 | """
261 | Calculate stage cost.
262 | Args:
263 | state (torch.Tensor): state batch tensor, shape (batch_size, 3) [x, y, theta]
264 | action (torch.Tensor): control batch tensor, shape (batch_size, 2) [v, omega]
265 | Returns:
266 | torch.Tensor: shape (batch_size,)
267 | """
268 |
269 | goal_cost = torch.norm(state[:, :2] - self._goal_pos, dim=1)
270 |
271 | pos_batch = state[:, :2].unsqueeze(1) # (batch_size, 1, 2)
272 |
273 | obstacle_cost = self._obstacle_map.compute_cost(pos_batch).squeeze(
274 | 1
275 | ) # (batch_size,)
276 |
277 | cost = goal_cost + 10000 * obstacle_cost
278 |
279 | return cost
280 |
281 | def terminal_cost(self, state: torch.Tensor) -> torch.Tensor:
282 | """
283 | Calculate terminal cost.
284 | Args:
285 | x (torch.Tensor): state batch tensor, shape (batch_size, 3) [x, y, theta]
286 | Returns:
287 | torch.Tensor: shape (batch_size,)
288 | """
289 | zero_action = torch.zeros_like(state[:, :2])
290 | return self.stage_cost(state=state, action=torch.zeros_like(zero_action))
291 |
292 | def collision_check(self, state: torch.Tensor) -> torch.Tensor:
293 | """
294 |
295 | Args:
296 | state (torch.Tensor): state batch tensor, shape (batch_size, traj_size , 3) [x, y, theta]
297 | Returns:
298 | torch.Tensor: shape (batch_size,)
299 | """
300 | pos_batch = state[:, :, :2]
301 | is_collisions = self._obstacle_map.compute_cost(pos_batch).squeeze(1)
302 | return is_collisions
303 |
--------------------------------------------------------------------------------
/src/envs/obstacle_map_2d.py:
--------------------------------------------------------------------------------
1 | """
2 | Kohei Honda, 2023.
3 | """
4 |
5 | from __future__ import annotations
6 |
7 | from typing import Callable, Tuple, List, Union
8 | from dataclasses import dataclass
9 | from math import ceil
10 | from matplotlib import pyplot as plt
11 | import torch
12 | import numpy as np
13 |
14 |
15 | @dataclass
16 | class CircleObstacle:
17 | """
18 | Circle obstacle used in the obstacle map.
19 | """
20 |
21 | center: np.ndarray
22 | radius: float
23 |
24 | def __init__(self, center: np.ndarray, radius: float) -> None:
25 | self.center = center
26 | self.radius = radius
27 |
28 |
29 | @dataclass
30 | class RectangleObstacle:
31 | """
32 | Rectangle obstacle used in the obstacle map.
33 | Not consider angle for now.
34 | """
35 |
36 | center: np.ndarray
37 | width: float
38 | height: float
39 |
40 | def __init__(self, center: np.ndarray, width: float, height: float) -> None:
41 | self.center = center
42 | self.width = width
43 | self.height = height
44 |
45 |
46 | class ObstacleMap:
47 | """
48 | Obstacle map represented by a grid.
49 | """
50 |
51 | def __init__(
52 | self,
53 | map_size: Tuple[int, int] = (20, 20),
54 | cell_size: float = 0.01,
55 | device=torch.device("cuda"),
56 | dtype=torch.float32,
57 | ) -> None:
58 | """
59 | map_size: (width, height) [m], origin is at the center
60 | cell_size: (m)
61 | """
62 | # device and dtype
63 | if torch.cuda.is_available() and device == torch.device("cuda"):
64 | self._device = torch.device("cuda")
65 | else:
66 | self._device = torch.device("cpu")
67 | self._dtype = dtype
68 |
69 | assert len(map_size) == 2
70 | assert cell_size > 0
71 | assert map_size[0] % 2 == 0
72 | assert map_size[1] % 2 == 0
73 |
74 | cell_map_dim = [0, 0]
75 | cell_map_dim[0] = ceil(map_size[0] / cell_size)
76 | cell_map_dim[1] = ceil(map_size[1] / cell_size)
77 |
78 | self._map = np.zeros(cell_map_dim)
79 | self._cell_size = cell_size
80 |
81 | # cell map center
82 | self._cell_map_origin = np.zeros(2)
83 | self._cell_map_origin = np.array(
84 | [cell_map_dim[0] / 2, cell_map_dim[1] / 2]
85 | ).astype(int)
86 |
87 | self._torch_cell_map_origin = torch.from_numpy(self._cell_map_origin).to(
88 | self._device, self._dtype
89 | )
90 |
91 | # limit of the map
92 | x_range = self._cell_size * self._map.shape[0]
93 | y_range = self._cell_size * self._map.shape[1]
94 | self.x_lim = [-x_range / 2, x_range / 2] # [m]
95 | self.y_lim = [-y_range / 2, y_range / 2] # [m]
96 |
97 | # Inner variables
98 | self._map_torch: torch.Tensor = None # use to collision check on GPU
99 | self.circle_obs_list: List[CircleObstacle] = [] # use to visualize
100 | self.rectangle_obs_list: List[RectangleObstacle] = [] # use to visualize
101 |
102 | def add_circle_obstacle(self, center: np.ndarray, radius: float) -> None:
103 | """
104 | Add a circle obstacle to the map.
105 | :param center: Center of the circle obstacle.
106 | :param radius: Radius of the circle obstacle.
107 | """
108 | assert len(center) == 2
109 | assert radius > 0
110 |
111 | # convert to cell map
112 | center_occ = (center / self._cell_size) + self._cell_map_origin
113 | center_occ = np.round(center_occ).astype(int)
114 | radius_occ = ceil(radius / self._cell_size)
115 |
116 | # add to occ map
117 | for i in range(-radius_occ, radius_occ + 1):
118 | for j in range(-radius_occ, radius_occ + 1):
119 | if i**2 + j**2 <= radius_occ**2:
120 | i_bounded = np.clip(center_occ[0] + i, 0, self._map.shape[0] - 1)
121 | j_bounded = np.clip(center_occ[1] + j, 0, self._map.shape[1] - 1)
122 | self._map[i_bounded, j_bounded] = 1
123 |
124 | # add to circle obstacle list to use visualize
125 | self.circle_obs_list.append(CircleObstacle(center, radius))
126 |
127 | def add_rectangle_obstacle(
128 | self, center: np.ndarray, width: float, height: float
129 | ) -> None:
130 | """
131 | Add a rectangle obstacle to the map.
132 | :param center: Center of the rectangle obstacle.
133 | :param width: Width of the rectangle obstacle.
134 | :param height: Height of the rectangle obstacle.
135 | """
136 | assert len(center) == 2
137 | assert width > 0
138 | assert height > 0
139 |
140 | # convert to cell map
141 | center_occ = (center / self._cell_size) + self._cell_map_origin
142 | center_occ = np.ceil(center_occ).astype(int)
143 | width_occ = ceil(width / self._cell_size)
144 | height_occ = ceil(height / self._cell_size)
145 |
146 | # add to occ map
147 | x_init = center_occ[0] - ceil(height_occ / 2)
148 | x_end = center_occ[0] + ceil(height_occ / 2)
149 | y_init = center_occ[1] - ceil(width_occ / 2)
150 | y_end = center_occ[1] + ceil(width_occ / 2)
151 |
152 | # # deal with out of bound
153 | x_init = np.clip(x_init, 0, self._map.shape[0] - 1)
154 | x_end = np.clip(x_end, 0, self._map.shape[0] - 1)
155 | y_init = np.clip(y_init, 0, self._map.shape[1] - 1)
156 | y_end = np.clip(y_end, 0, self._map.shape[1] - 1)
157 |
158 | self._map[x_init:x_end, y_init:y_end] = 1
159 |
160 | # add to rectangle obstacle list to use visualize
161 | self.rectangle_obs_list.append(RectangleObstacle(center, width, height))
162 |
163 | def convert_to_torch(self) -> torch.Tensor:
164 | self._map_torch = torch.from_numpy(self._map).to(self._device, self._dtype)
165 | return self._map_torch
166 |
167 | def compute_cost(self, x: torch.Tensor) -> torch.Tensor:
168 | """
169 | Check collision in a batch of trajectories.
170 | :param x: Tensor of shape (batch_size, traj_length, position_dim).
171 | :return: collsion costs on the trajectories.
172 | """
173 | assert self._map_torch is not None
174 | if x.device != self._device or x.dtype != self._dtype:
175 | x = x.to(self._device, self._dtype)
176 |
177 | # project to cell map
178 | x_occ = (x / self._cell_size) + self._torch_cell_map_origin
179 | x_occ = torch.round(x_occ).long().to(self._device)
180 |
181 | # deal with out of bound
182 | is_out_of_bound = torch.logical_or(
183 | torch.logical_or(
184 | x_occ[..., 0] < 0, x_occ[..., 0] >= self._map_torch.shape[0]
185 | ),
186 | torch.logical_or(
187 | x_occ[..., 1] < 0, x_occ[..., 1] >= self._map_torch.shape[1]
188 | ),
189 | )
190 | x_occ[..., 0] = torch.clamp(x_occ[..., 0], 0, self._map_torch.shape[0] - 1)
191 | x_occ[..., 1] = torch.clamp(x_occ[..., 1], 0, self._map_torch.shape[1] - 1)
192 |
193 | # collision check
194 | collisions = self._map_torch[x_occ[..., 0], x_occ[..., 1]]
195 |
196 | # out of bound cost
197 | collisions[is_out_of_bound] = 1.0
198 |
199 | return collisions
200 |
201 | def render_occupancy(self, ax, cmap="binary") -> None:
202 | ax.imshow(self._map, cmap=cmap)
203 |
204 | def render(self, ax, zorder: int = 0) -> None:
205 | """
206 | Render in continuous space.
207 | """
208 | ax.set_xlim(self.x_lim)
209 | ax.set_ylim(self.y_lim)
210 | ax.set_aspect("equal")
211 |
212 | # render circle obstacles
213 | for circle_obs in self.circle_obs_list:
214 | ax.add_patch(
215 | plt.Circle(
216 | circle_obs.center, circle_obs.radius, color="gray", zorder=zorder
217 | )
218 | )
219 |
220 | # render rectangle obstacles
221 | for rectangle_obs in self.rectangle_obs_list:
222 | ax.add_patch(
223 | plt.Rectangle(
224 | rectangle_obs.center
225 | - np.array([rectangle_obs.width / 2, rectangle_obs.height / 2]),
226 | rectangle_obs.width,
227 | rectangle_obs.height,
228 | color="gray",
229 | zorder=zorder,
230 | )
231 | )
232 |
233 |
234 | def generate_random_obstacles(
235 | obstacle_map: ObstacleMap,
236 | random_x_range: Tuple[float, float],
237 | random_y_range: Tuple[float, float],
238 | num_circle_obs: int,
239 | radius_range: Tuple[float, float],
240 | num_rectangle_obs: int,
241 | width_range: Tuple[float, float],
242 | height_range: Tuple[float, float],
243 | max_iteration: int,
244 | seed: int,
245 | ) -> None:
246 | """
247 | Generate random obstacles.
248 | """
249 | rng = np.random.default_rng(seed)
250 |
251 | # if random range is larger than map size, use map size
252 | if random_x_range[0] < obstacle_map.x_lim[0]:
253 | random_x_range[0] = obstacle_map.x_lim[0]
254 | if random_x_range[1] > obstacle_map.x_lim[1]:
255 | random_x_range[1] = obstacle_map.x_lim[1]
256 | if random_y_range[0] < obstacle_map.y_lim[0]:
257 | random_y_range[0] = obstacle_map.y_lim[0]
258 | if random_y_range[1] > obstacle_map.y_lim[1]:
259 | random_y_range[1] = obstacle_map.y_lim[1]
260 |
261 | for i in range(num_circle_obs):
262 | num_trial = 0
263 | while num_trial < max_iteration:
264 | center_x = rng.uniform(random_x_range[0], random_x_range[1])
265 | center_y = rng.uniform(random_y_range[0], random_y_range[1])
266 | center = np.array([center_x, center_y])
267 | radius = rng.uniform(radius_range[0], radius_range[1])
268 |
269 | # overlap check
270 | is_overlap = False
271 | for circle_obs in obstacle_map.circle_obs_list:
272 | if (
273 | np.linalg.norm(circle_obs.center - center)
274 | <= circle_obs.radius + radius
275 | ):
276 | is_overlap = True
277 |
278 | for rectangle_obs in obstacle_map.rectangle_obs_list:
279 | if (
280 | np.linalg.norm(rectangle_obs.center - center)
281 | <= rectangle_obs.width / 2 + radius
282 | ):
283 | if (
284 | np.linalg.norm(rectangle_obs.center - center)
285 | <= rectangle_obs.height / 2 + radius
286 | ):
287 | is_overlap = True
288 |
289 | if not is_overlap:
290 | break
291 |
292 | num_trial += 1
293 |
294 | if num_trial == max_iteration:
295 | raise RuntimeError(
296 | "Cannot generate random obstacles due to reach max iteration."
297 | )
298 |
299 | obstacle_map.add_circle_obstacle(center, radius)
300 |
301 | for i in range(num_rectangle_obs):
302 | num_trial = 0
303 | while num_trial < max_iteration:
304 | center_x = rng.uniform(random_x_range[0], random_x_range[1])
305 | center_y = rng.uniform(random_y_range[0], random_y_range[1])
306 | center = np.array([center_x, center_y])
307 | width = rng.uniform(width_range[0], width_range[1])
308 | height = rng.uniform(height_range[0], height_range[1])
309 |
310 | # overlap check
311 | is_overlap = False
312 | for circle_obs in obstacle_map.circle_obs_list:
313 | if (
314 | np.linalg.norm(circle_obs.center - center)
315 | <= circle_obs.radius + width / 2
316 | ):
317 | if (
318 | np.linalg.norm(circle_obs.center - center)
319 | <= circle_obs.radius + height / 2
320 | ):
321 | is_overlap = True
322 |
323 | for rectangle_obs in obstacle_map.rectangle_obs_list:
324 | if (
325 | np.linalg.norm(rectangle_obs.center - center)
326 | <= rectangle_obs.width / 2 + width / 2
327 | ):
328 | if (
329 | np.linalg.norm(rectangle_obs.center - center)
330 | <= rectangle_obs.height / 2 + height / 2
331 | ):
332 | is_overlap = True
333 |
334 | if not is_overlap:
335 | break
336 |
337 | num_trial += 1
338 |
339 | if num_trial == max_iteration:
340 | raise RuntimeError(
341 | "Cannot generate random obstacles due to reach max iteration."
342 | )
343 |
344 | obstacle_map.add_rectangle_obstacle(center, width, height)
345 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/proxima-technology/mppi_playground/12cb79d02c703f818cc9c49d2436f79fd4db7e9f/tests/__init__.py
--------------------------------------------------------------------------------
/tests/test_brax.py:
--------------------------------------------------------------------------------
1 | """
2 | Because we want to use GPU accerated simulator for MPPI, I tried to use brax simulator.
3 |
4 | # jax cuda install
5 | pip3 install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
6 |
7 | # brax install
8 | pip3 install brax
9 |
10 | # How to run
11 | # https://tech.yellowback.net/posts/jax-oom
12 | XLA_PYTHON_CLIENT_MEM_FRACTION=.8 python3 tests/test_brax.py
13 | """
14 |
15 | from brax.io import image
16 | from brax import envs
17 |
18 | import jax
19 |
20 | rng = jax.random.PRNGKey(0)
21 | ant = envs.create("ant")
22 |
23 | rng, rng_use = jax.random.split(rng)
24 | state = ant.reset(rng_use)
25 |
26 | # Too slow, not sure why
27 | qps = [state.pipeline_state]
28 | for _ in range(20):
29 | rng, rng_use = jax.random.split(rng)
30 | state = ant.step(state, jax.random.uniform(rng_use, (ant.action_size,)))
31 | qps.append(state.pipeline_state)
32 |
33 | # https://github.com/google/brax/issues/47
34 | # How can i get the rendered image without notebook?
35 | image.render(sys=ant.sys, states=qps, width=320, height=240)
36 |
--------------------------------------------------------------------------------
/tests/test_gui.py:
--------------------------------------------------------------------------------
1 | import gymnasium
2 | import matplotlib
3 |
4 | matplotlib.use("TkAgg")
5 | import matplotlib.pyplot as plt
6 | import numpy as np
7 |
8 | # Run gymnasium with real-time rendering
9 | env = gymnasium.make("Pendulum-v1", render_mode="human")
10 | _, _ = env.reset(seed=42)
11 | for _ in range(100):
12 | action = env.action_space.sample()
13 | observation, reward, terminated, truncated, info = env.step(action)
14 | env.render()
15 | if terminated or truncated:
16 | observation, info = env.reset()
17 | env.close()
18 |
19 | # video recording mode
20 | env = gymnasium.make("Pendulum-v1", render_mode="rgb_array")
21 | env = gymnasium.wrappers.RecordVideo(env=env, video_folder="video")
22 | _, _ = env.reset(seed=42)
23 | env.start_video_recorder()
24 | for _ in range(100):
25 | action = env.action_space.sample()
26 | observation, reward, terminated, truncated, info = env.step(action)
27 | env.render()
28 | if terminated or truncated:
29 | observation, info = env.reset()
30 | env.close()
31 |
32 | # Run matplotlib
33 | plt.style.use("ggplot")
34 | plt.figure(figsize=(8, 6))
35 | plt.plot(np.arange(1000), np.random.randn(1000))
36 | plt.show()
37 |
--------------------------------------------------------------------------------
/tests/test_mujoco.py:
--------------------------------------------------------------------------------
1 | import gymnasium as gym
2 |
3 | env = gym.make("Humanoid-v4", render_mode="human")
4 | _, _ = env.reset(seed=42)
5 | for _ in range(1000):
6 | action = env.action_space.sample()
7 | observation, reward, terminated, truncated, info = env.step(action)
8 | env.render()
9 | if terminated or truncated:
10 | observation, info = env.reset()
11 | env.close()
12 |
--------------------------------------------------------------------------------
/tests/test_torch.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import time
3 |
4 |
5 | # torch complie is not supported for python 3.11 yet
6 | # @torch.compile
7 | @torch.jit.script
8 | def matmul_jit(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
9 | return torch.matmul(x, y)
10 |
11 |
12 | if torch.cuda.is_available():
13 | print(torch.__version__)
14 | device = torch.device("cuda")
15 | print("GPU is available")
16 | else:
17 | device = torch.device("cpu")
18 | print("GPU is not available. CPU is used")
19 |
20 |
21 | matrix_size = 10000
22 |
23 | # Calculate on CPU
24 | start_time = time.time()
25 | input_matrix = torch.randn(matrix_size, matrix_size).to("cpu")
26 | result_cpu = torch.matmul(input_matrix, input_matrix)
27 | end_time = time.time()
28 | cpu_time = end_time - start_time
29 |
30 | # Calculate on GPU
31 | input_matrix = input_matrix.to(device)
32 | start_time = time.time()
33 | result_gpu = torch.matmul(input_matrix, input_matrix)
34 | end_time = time.time()
35 | gpu_time = end_time - start_time
36 |
37 | # Calculate on GPU with Torch compile
38 | # input_matrix = input_matrix.to(device)
39 | # start_time = time.time()
40 | # result_gpu = matmul_compile(input_matrix, input_matrix)
41 | # end_time = time.time()
42 | # gpu_time = end_time - start_time
43 |
44 | # Calculate on GPU with jit
45 | input_matrix = input_matrix.to(device)
46 | start_time = time.time()
47 | result_gpu_jit = matmul_jit(input_matrix, input_matrix)
48 | end_time = time.time()
49 | gpu_time_jit = end_time - start_time
50 |
51 | print("CPU time: ", cpu_time)
52 | print("GPU time: ", gpu_time)
53 | print("GPU time with jit: ", gpu_time_jit)
54 | print("Speed up w/o jit: ", cpu_time / gpu_time)
55 | print("Speed up with jit: ", cpu_time / gpu_time_jit)
56 | assert torch.allclose(result_cpu[:2, :2], result_gpu[:2, :2].to("cpu"), atol=1e-3)
57 | assert torch.allclose(result_cpu[:2, :2], result_gpu_jit[:2, :2].to("cpu"), atol=1e-3)
58 |
--------------------------------------------------------------------------------