├── .devcontainer └── devcontainer.json ├── .gitignore ├── .vscode └── settings.json ├── Dockerfile ├── LICENSE ├── Makefile ├── README.md ├── app ├── catpole.py ├── mountaincar.py ├── mujoco_cartpole.py ├── navigation2d.py └── pendulum.py ├── media ├── cartpole.gif ├── mountaincar.gif ├── navigation_2d.gif └── pendulum.gif ├── poetry.lock ├── pyproject.toml ├── src ├── __init__.py ├── controller │ ├── __init__.py │ └── mppi.py └── envs │ ├── __init__.py │ ├── navigation_2d.py │ └── obstacle_map_2d.py └── tests ├── __init__.py ├── test_brax.py ├── test_gui.py ├── test_mujoco.py └── test_torch.py /.devcontainer/devcontainer.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "dev", 3 | "dockerFile": "../Dockerfile", 4 | "settings": { 5 | "terminal.integrated.shell.linux": "/bin/bash" 6 | }, 7 | "extensions": [ 8 | "ms-python.python", 9 | "ms-vscode-remote.remote-containers", 10 | "ms-python.vscode-pylance", 11 | "GitHub.copilot", 12 | "ms-python.black-formatter", 13 | "ms-python.flake8", 14 | ], 15 | "runArgs": [ 16 | "--gpus", "all", 17 | "--shm-size", "10G", 18 | ] 19 | } -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | 162 | 163 | # project specific 164 | video/ -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "C_Cpp.default.configurationProvider": "ms-vscode.makefile-tools", 3 | "python.analysis.include": [ 4 | "src/**", 5 | "tests/**" 6 | ], 7 | } -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:11.8.0-devel-ubuntu20.04 2 | ENV DEBIAN_FRONTEND=noninteractive 3 | 4 | RUN apt-get -y update && apt-get -y install --no-install-recommends\ 5 | software-properties-common\ 6 | libgl1-mesa-dev\ 7 | libgl1-mesa-glx \ 8 | libglew-dev \ 9 | libosmesa6-dev \ 10 | wget\ 11 | libssl-dev\ 12 | curl\ 13 | git\ 14 | x11-apps \ 15 | swig \ 16 | patchelf 17 | 18 | # Python (version 3.10) 19 | RUN add-apt-repository ppa:deadsnakes/ppa && \ 20 | apt-get update && apt-get install -y \ 21 | python3.10 \ 22 | python3.10-dev \ 23 | python3.10-venv \ 24 | python3.10-distutils \ 25 | python3.10-tk 26 | 27 | RUN update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.10 1 28 | RUN curl -sS https://bootstrap.pypa.io/get-pip.py | python3.10 29 | RUN pip3 install --upgrade pip 30 | RUN pip3 install -U pip distlib setuptools wheel 31 | 32 | # vnc 33 | RUN apt-get install -y xvfb x11vnc icewm lsof net-tools 34 | RUN echo "alias vnc='PASSWORD=\$(openssl rand -hex 24); for i in {99..0}; do export DISPLAY=:\$i; if ! xdpyinfo &>/dev/null; then break; fi; done; for i in {5999..5900}; do if ! netstat -tuln | grep -q \":\$i \"; then PORT=\$i; break; fi; done; Xvfb \$DISPLAY -screen 0 1400x900x24 & until xdpyinfo > /dev/null 2>&1; do sleep 0.1; done; x11vnc -forever -noxdamage -display \$DISPLAY -rfbport \$PORT -passwd \$PASSWORD > /dev/null 2>&1 & until lsof -i :\$PORT > /dev/null; do sleep 0.1; done; icewm-session & echo DISPLAY=\$DISPLAY, PORT=\$PORT, PASSWORD=\$PASSWORD'" >> ~/.bashrc 35 | 36 | # utils 37 | RUN apt-get update && apt-get install -y htop vim ffmpeg 38 | # RUN pip3 install jupyterlab ipywidgets && \ 39 | # echo 'alias jup="jupyter lab --ip 0.0.0.0 --port 8888 --allow-root &"' >> /root/.bashrc 40 | 41 | # clear cache 42 | RUN rm -rf /var/lib/apt/lists/* 43 | 44 | # pytorch 2.0 45 | RUN pip3 install torch==2.0.0+cu118 torchvision==0.15.1+cu118 torchaudio==2.0.1 --extra-index-url https://download.pytorch.org/whl/cu118 46 | 47 | # mujoco 210 48 | RUN curl -o /usr/local/bin/patchelf https://s3-us-west-2.amazonaws.com/openai-sci-artifacts/manual-builds/patchelf_0.9_amd64.elf \ 49 | && chmod +x /usr/local/bin/patchelf 50 | RUN mkdir -p /root/.mujoco \ 51 | && wget https://mujoco.org/download/mujoco210-linux-x86_64.tar.gz -O mujoco.tar.gz \ 52 | && tar -xf mujoco.tar.gz -C /root/.mujoco \ 53 | && rm mujoco.tar.gz 54 | ENV LD_LIBRARY_PATH /root/.mujoco/mujoco210/bin:${LD_LIBRARY_PATH} 55 | 56 | WORKDIR /workspace 57 | COPY src/ src/ 58 | COPY pyproject.toml . 59 | 60 | RUN pip3 install -e .[dev] 61 | 62 | CMD ["bash"] 63 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 kohonda 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | NAME=mppi_playground 2 | VERSION=0.0.1 3 | DOCKER_IMAGE_NAME=$(NAME):$(VERSION) 4 | CONTAINER_NAME=$(NAME) 5 | GPU_ID=all 6 | 7 | build: 8 | docker build -t $(DOCKER_IMAGE_NAME) . 9 | 10 | bash: 11 | xhost +local:docker && \ 12 | docker run -it \ 13 | --gpus '"device=${GPU_ID}"' \ 14 | -v ${PWD}/workspace \ 15 | -v ${PWD}:/workspace/$(NAME) \ 16 | --rm \ 17 | --shm-size 10G \ 18 | -v /tmp/.X11-unix:/tmp/.X11-unix \ 19 | -e DISPLAY \ 20 | -p 5900:5900 \ 21 | --name $(CONTAINER_NAME)-bash \ 22 | $(DOCKER_IMAGE_NAME) \ 23 | bash 24 | 25 | bash-wo-gpu: 26 | docker run -it \ 27 | -v ${PWD}/workspace \ 28 | -v ${PWD}:/workspace/$(NAME) \ 29 | --rm \ 30 | --shm-size 10G \ 31 | --name $(CONTAINER_NAME)-bash \ 32 | $(DOCKER_IMAGE_NAME) \ 33 | bash 34 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MPPI Playground 2 | This repository contains an implementation of [Model Predictive Path Integral Control (MPPI)](https://arxiv.org/abs/1707.02342) with PyTorch to accelerate computations on the GPU. 3 | 4 | ## Tested Native Environment 5 | - Ubuntu Focal 20.04 (LTS) 6 | - NVIDIA Driver 510 or later due to PyTorch 2.x 7 | 8 | ## Dependencies 9 | - cuda 11.8 10 | - Python 3.10 11 | - PyTorch 2.0 12 | 13 |
14 | Docker Setup 15 | 16 | ### Install Docker 17 | 18 | [Installation guide](https://docs.docker.com/engine/install/ubuntu/#install-using-the-repository) 19 | 20 | ```bash 21 | # Install from get.docker.com 22 | curl -fsSL https://get.docker.com -o get-docker.sh 23 | sudo sh get-docker.sh 24 | sudo groupadd docker 25 | sudo usermod -aG docker $USER 26 | ``` 27 | 28 | 29 | ### Setup GPU for Docker 30 | [Installation guide](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html) 31 | ```bash 32 | curl -fsSL https://nvidia.github.io/libnvidia-container/gpgkey | sudo gpg --dearmor -o /usr/share/keyrings/nvidia-container-toolkit-keyring.gpg \ 33 | && curl -s -L https://nvidia.github.io/libnvidia-container/stable/deb/nvidia-container-toolkit.list | \ 34 | sed 's#deb https://#deb [signed-by=/usr/share/keyrings/nvidia-container-toolkit-keyring.gpg] https://#g' | \ 35 | sudo tee /etc/apt/sources.list.d/nvidia-container-toolkit.list 36 | 37 | sudo apt-get update 38 | 39 | sudo apt-get install -y nvidia-container-toolkit nvidia-container-runtime 40 | 41 | sudo nvidia-ctk runtime configure --runtime=docker 42 | 43 | sudo systemctl restart docker 44 | ``` 45 |
46 | 47 | ## Installation 48 | 49 | ### with Docker (Recommend) 50 | 51 | ```bash 52 | # build container 53 | make build 54 | 55 | # Open remote container via Vscode (Recommend) 56 | # 1. Open the folder using vscode 57 | # 2. Ctrl+P and select 'devcontainer rebuild and reopen in container' 58 | # Then, you can skip the following commands 59 | 60 | # Or Run container via terminal 61 | make bash 62 | ``` 63 | 64 | ### with venv 65 | 66 | ```bash 67 | python3 -m venv .venv 68 | source .venv/bin/activate 69 | pip3 install -e .[dev] 70 | ``` 71 | 72 | ## Examples 73 | 74 | ### Navigation 2D 75 | ```bash 76 | python3 app/navigation2d.py 77 | ``` 78 |

79 | navigation2d 80 |

81 | 82 | ### Pendulum 83 | ```bash 84 | python3 app/pendulum.py 85 | ``` 86 |

87 | pendulum 88 |

89 | 90 | ### Cartpole 91 | ```bash 92 | python3 app/cartpole.py 93 | ``` 94 |

95 | cartpole 96 |

97 | 98 | ### Mountain car 99 | ```bash 100 | python3 app/mountaincar.py 101 | ``` 102 |

103 | mountaincar 104 |

105 | 106 | ## Reference 107 | - [pytorch_mppi](https://github.com/UM-ARM-Lab/pytorch_mppi) -------------------------------------------------------------------------------- /app/catpole.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import time 4 | import gymnasium 5 | import fire 6 | import numpy as np 7 | 8 | from controller.mppi import MPPI 9 | 10 | 11 | @torch.jit.script 12 | def angle_normalize(x): 13 | return ((x + torch.pi) % (2 * torch.pi)) - torch.pi 14 | 15 | 16 | def main(save_mode: bool = False): 17 | # dynamics and cost 18 | @torch.jit.script 19 | def dynamics(state: torch.Tensor, action: torch.Tensor) -> torch.Tensor: 20 | """ 21 | Args: 22 | state (torch.Tensor): [x, x_dt, theta, theta_dt] 23 | action (torch.Tensor): [-1, 1] 24 | """ 25 | # dynamics from gymnasium 26 | x = state[:, 0].view(-1, 1) 27 | x_dt = state[:, 1].view(-1, 1) 28 | theta = state[:, 2].view(-1, 1) 29 | theta_dt = state[:, 3].view(-1, 1) 30 | 31 | gravity = 9.8 32 | masscart = 1.0 33 | masspole = 0.1 34 | total_mass = masspole + masscart 35 | length = 0.5 # actually half the pole's length 36 | polemass_length = masspole * length 37 | force_mag = 10.0 38 | tau = 0.02 # seconds between state updates 39 | 40 | # convert continuous action to discrete action 41 | # because MPPI only can handle continuous action 42 | continuous_action = action[:, 0].view(-1, 1) 43 | force = torch.zeros_like(continuous_action) 44 | force[continuous_action >= 0] = force_mag 45 | force[continuous_action < 0] = -force_mag 46 | 47 | costheta = torch.cos(theta) 48 | sintheta = torch.sin(theta) 49 | 50 | temp = (force + polemass_length * theta_dt**2 * sintheta) / total_mass 51 | thetaacc = (gravity * sintheta - costheta * temp) / ( 52 | length * (4.0 / 3.0 - masspole * costheta**2 / total_mass) 53 | ) 54 | xacc = temp - polemass_length * thetaacc * costheta / total_mass 55 | 56 | newx = x + tau * x_dt 57 | newx_dt = x_dt + tau * xacc 58 | newtheta = theta + tau * theta_dt 59 | newtheta_dt = theta_dt + tau * thetaacc 60 | 61 | x_threshold = 2.4 62 | theta_threshold_radians = 12 * 2 * torch.pi / 360 63 | newx = torch.clamp(newx, -x_threshold, x_threshold) 64 | newtheta = torch.clamp( 65 | newtheta, -theta_threshold_radians, theta_threshold_radians 66 | ) 67 | 68 | new_state = torch.cat((newx, newx_dt, newtheta, newtheta_dt), dim=1) 69 | 70 | return new_state 71 | 72 | @torch.jit.script 73 | def stage_cost(state: torch.Tensor, action: torch.Tensor) -> torch.Tensor: 74 | x = state[:, 0] 75 | x_dt = state[:, 1] 76 | theta = state[:, 2] 77 | theta_dt = state[:, 3] 78 | 79 | normlized_theta = angle_normalize(theta) 80 | 81 | cost = normlized_theta**2 + 0.1 * theta_dt**2 + 0.1 * x**2 82 | 83 | return cost 84 | 85 | @torch.jit.script 86 | def terminal_cost(state: torch.Tensor) -> torch.Tensor: 87 | x = state[:, 0] 88 | x_dt = state[:, 1] 89 | theta = state[:, 2] 90 | theta_dt = state[:, 3] 91 | 92 | normlized_theta = angle_normalize(theta) 93 | 94 | cost = normlized_theta**2 + 0.1 * theta_dt**2 + 0.1 * x**2 95 | 96 | return cost 97 | 98 | # simulator 99 | if save_mode: 100 | env = gymnasium.make("CartPole-v1", render_mode="rgb_array") 101 | env = gymnasium.wrappers.RecordVideo(env=env, video_folder="video") 102 | else: 103 | env = gymnasium.make("CartPole-v1", render_mode="human") 104 | observation, _ = env.reset(seed=42) 105 | 106 | # start from the inverted position 107 | env.unwrapped.state = np.array([0.0, 0.0, np.pi, 0.0]) 108 | observation, _, _, _, _ = env.step(0) 109 | 110 | # solver 111 | solver = MPPI( 112 | horizon=1000, 113 | num_samples=5000, 114 | dim_state=4, 115 | dim_control=1, 116 | dynamics=dynamics, 117 | stage_cost=stage_cost, 118 | terminal_cost=terminal_cost, 119 | u_min=torch.tensor([-1.0]), 120 | u_max=torch.tensor([1.0]), 121 | sigmas=torch.tensor([1.0]), 122 | lambda_=0.001, 123 | ) 124 | 125 | average_time = 0 126 | for i in range(500): 127 | # solve 128 | start = time.time() 129 | action_seq, state_seq = solver.forward(state=observation) 130 | elipsed_time = time.time() - start 131 | average_time = i / (i + 1) * average_time + elipsed_time / (i + 1) 132 | 133 | action_seq_np = action_seq.cpu().numpy() 134 | state_seq_np = state_seq.cpu().numpy() 135 | 136 | # convert continuous action to discrete action 137 | discrete_action = 0 if action_seq_np[0, 0] < 0 else 1 138 | 139 | # update simulator 140 | observation, reward, terminated, truncated, info = env.step(discrete_action) 141 | env.render() 142 | 143 | print("average solve time: {}".format(average_time * 1000), " [ms]") 144 | env.close() 145 | 146 | 147 | if __name__ == "__main__": 148 | fire.Fire(main) 149 | -------------------------------------------------------------------------------- /app/mountaincar.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import time 4 | import gymnasium 5 | import fire 6 | 7 | from controller.mppi import MPPI 8 | 9 | 10 | @torch.jit.script 11 | def angle_normalize(x): 12 | return ((x + torch.pi) % (2 * torch.pi)) - torch.pi 13 | 14 | 15 | def main(save_mode: bool = False): 16 | # dynamics and cost 17 | @torch.jit.script 18 | def dynamics(state: torch.Tensor, action: torch.Tensor) -> torch.Tensor: 19 | # dynamics from gymnasium 20 | min_action = -1.0 21 | max_action = 1.0 22 | min_position = -1.2 23 | max_position = 0.6 24 | max_speed = 0.07 25 | goal_position = 0.45 26 | goal_velocity = 0.0 27 | power = 0.0015 28 | 29 | position = state[:, 0].view(-1, 1) 30 | velocity = state[:, 1].view(-1, 1) 31 | 32 | force = torch.clamp(action[:, 0].view(-1, 1), min_action, max_action) 33 | 34 | velocity += force * power - 0.0025 * torch.cos(3 * position) 35 | velocity = torch.clamp(velocity, -max_speed, max_speed) 36 | position += velocity 37 | position = torch.clamp(position, min_position, max_position) 38 | # if (position == min_position and velocity < 0): 39 | # velocity = torch.zeros_like(velocity) 40 | 41 | new_state = torch.cat((position, velocity), dim=1) 42 | 43 | return new_state 44 | 45 | @torch.jit.script 46 | def stage_cost(state: torch.Tensor, action: torch.Tensor) -> torch.Tensor: 47 | goal_position = 0.45 48 | goal_velocity = 0.0 49 | 50 | position = state[:, 0] 51 | velocity = state[:, 1] 52 | 53 | cost = (goal_position - position) ** 2 54 | # + 0.01 * (velocity-goal_velocity)**2 55 | 56 | return cost 57 | 58 | @torch.jit.script 59 | def terminal_cost(state: torch.Tensor) -> torch.Tensor: 60 | goal_position = 0.45 61 | goal_velocity = 0.0 62 | 63 | position = state[:, 0] 64 | velocity = state[:, 1] 65 | 66 | cost = (goal_position - position) ** 2 67 | # + (velocity-goal_velocity)**2 68 | return cost 69 | 70 | # simulator 71 | if save_mode: 72 | env = gymnasium.make("MountainCarContinuous-v0", render_mode="rgb_array") 73 | env = gymnasium.wrappers.RecordVideo(env=env, video_folder="video") 74 | else: 75 | env = gymnasium.make("MountainCarContinuous-v0", render_mode="human") 76 | observation, _ = env.reset(seed=42) 77 | 78 | # solver 79 | solver = MPPI( 80 | horizon=1000, 81 | num_samples=1000, 82 | dim_state=2, 83 | dim_control=1, 84 | dynamics=dynamics, 85 | stage_cost=stage_cost, 86 | terminal_cost=terminal_cost, 87 | u_min=torch.tensor([-1.0]), 88 | u_max=torch.tensor([1.0]), 89 | sigmas=torch.tensor([1.0]), 90 | lambda_=0.1, 91 | ) 92 | 93 | average_time = 0 94 | for i in range(300): 95 | state = env.unwrapped.state.copy() 96 | 97 | # solve 98 | start = time.time() 99 | action_seq, state_seq = solver.forward(state=state) 100 | elipsed_time = time.time() - start 101 | average_time = i / (i + 1) * average_time + elipsed_time / (i + 1) 102 | 103 | action_seq_np = action_seq.cpu().numpy() 104 | state_seq_np = state_seq.cpu().numpy() 105 | 106 | # update simulator 107 | observation, reward, terminated, truncated, info = env.step(action_seq_np[0, :]) 108 | env.render() 109 | 110 | print("average solve time: {}".format(average_time * 1000), " [ms]") 111 | env.close() 112 | 113 | 114 | if __name__ == "__main__": 115 | fire.Fire(main) 116 | -------------------------------------------------------------------------------- /app/mujoco_cartpole.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import time 4 | import gymnasium as gym 5 | import fire 6 | import numpy as np 7 | 8 | from controller.mppi import MPPI 9 | 10 | 11 | @torch.jit.script 12 | def angle_normalize(x): 13 | return ((x + torch.pi) % (2 * torch.pi)) - torch.pi 14 | 15 | 16 | # Not work well because of the difference of dynamics 17 | # I should use the true dynamics from mujoco like: 18 | # https://github.com/mohakbhardwaj/mjmpc/blob/master/examples/example_mpc.py#L112 19 | def main(save_mode: bool = False): 20 | # dynamics and cost 21 | @torch.jit.script 22 | def dynamics(state: torch.Tensor, action: torch.Tensor) -> torch.Tensor: 23 | """ 24 | Args: 25 | state (torch.Tensor): [x, x_dt, theta, theta_dt] 26 | action (torch.Tensor): [-1, 1] 27 | """ 28 | # dynamics not from mujoco 29 | # https://github.com/openai/gym/blob/master/gym/envs/mujoco/assets/inverted_pendulum.xml 30 | x = state[:, 0].view(-1, 1) 31 | x_dt = state[:, 1].view(-1, 1) 32 | theta = state[:, 2].view(-1, 1) 33 | theta_dt = state[:, 3].view(-1, 1) 34 | 35 | force = action[:, 0].view(-1, 1) 36 | 37 | gravity = 9.8 38 | masscart = 1.0 39 | masspole = 1.0 40 | total_mass = masspole + masscart 41 | length = 0.5 # actually half the pole's length 42 | polemass_length = masspole * length 43 | tau = 0.02 # seconds between state updates 44 | 45 | costheta = torch.cos(theta) 46 | sintheta = torch.sin(theta) 47 | 48 | temp = (force + polemass_length * theta_dt**2 * sintheta) / total_mass 49 | thetaacc = (gravity * sintheta - costheta * temp) / ( 50 | length * (4.0 / 3.0 - masspole * costheta**2 / total_mass) 51 | ) 52 | xacc = temp - polemass_length * thetaacc * costheta / total_mass 53 | 54 | newx = x + tau * x_dt 55 | newx_dt = x_dt + tau * xacc 56 | newtheta = theta + tau * theta_dt 57 | newtheta_dt = theta_dt + tau * thetaacc 58 | 59 | x_threshold = 1.0 60 | theta_threshold_radians = 12 * 2 * torch.pi / 360 61 | newx = torch.clamp(newx, -x_threshold, x_threshold) 62 | newtheta = torch.clamp( 63 | newtheta, -theta_threshold_radians, theta_threshold_radians 64 | ) 65 | 66 | new_state = torch.cat((newx, newx_dt, newtheta, newtheta_dt), dim=1) 67 | 68 | return new_state 69 | 70 | @torch.jit.script 71 | def stage_cost(state: torch.Tensor, action: torch.Tensor) -> torch.Tensor: 72 | x = state[:, 0] 73 | x_dt = state[:, 1] 74 | theta = state[:, 2] 75 | theta_dt = state[:, 3] 76 | 77 | normlized_theta = angle_normalize(theta) 78 | 79 | cost = normlized_theta**2 + 0.1 * theta_dt**2 + 0.1 * x**2 80 | 81 | return cost 82 | 83 | @torch.jit.script 84 | def terminal_cost(state: torch.Tensor) -> torch.Tensor: 85 | x = state[:, 0] 86 | x_dt = state[:, 1] 87 | theta = state[:, 2] 88 | theta_dt = state[:, 3] 89 | 90 | normlized_theta = angle_normalize(theta) 91 | 92 | cost = normlized_theta**2 + 0.1 * theta_dt**2 + 0.1 * x**2 93 | 94 | return cost 95 | 96 | # simulator 97 | if save_mode: 98 | env = gym.make("InvertedPendulum-v4", render_mode="rgb_array") 99 | env = gym.wrappers.RecordVideo(env=env, video_folder="video") 100 | else: 101 | env = gym.make("InvertedPendulum-v4", render_mode="human") 102 | 103 | observation, _ = env.reset(seed=42) 104 | 105 | # start from the inverted position 106 | # env.unwrapped.state = np.array([0.0, 0.0, np.pi / 8, 0.0]) 107 | # observation, _, _, _, _ = env.step(0) 108 | 109 | # solver 110 | solver = MPPI( 111 | horizon=50, 112 | num_samples=1000, 113 | dim_state=4, 114 | dim_control=1, 115 | dynamics=dynamics, 116 | stage_cost=stage_cost, 117 | terminal_cost=terminal_cost, 118 | u_min=torch.tensor([-3.0]), 119 | u_max=torch.tensor([3.0]), 120 | sigmas=torch.tensor([1.0]), 121 | lambda_=1.0, 122 | ) 123 | 124 | average_time = 0 125 | for i in range(500): 126 | # solve 127 | start = time.time() 128 | action_seq, state_seq = solver.forward(state=observation) 129 | 130 | elipsed_time = time.time() - start 131 | average_time = i / (i + 1) * average_time + elipsed_time / (i + 1) 132 | 133 | action_seq_np = action_seq.cpu().numpy() 134 | state_seq_np = state_seq.cpu().numpy() 135 | 136 | # update simulator 137 | observation, reward, terminated, truncated, info = env.step(action_seq_np[0, :]) 138 | env.render() 139 | 140 | print("average solve time: {}".format(average_time * 1000), " [ms]") 141 | env.close() 142 | 143 | 144 | if __name__ == "__main__": 145 | fire.Fire(main) 146 | -------------------------------------------------------------------------------- /app/navigation2d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import time 4 | 5 | # import gymnasium 6 | import fire 7 | import tqdm 8 | 9 | from controller.mppi import MPPI 10 | from envs.navigation_2d import Navigation2DEnv 11 | 12 | 13 | def main(save_mode: bool = False): 14 | env = Navigation2DEnv() 15 | 16 | # solver 17 | solver = MPPI( 18 | horizon=50, 19 | num_samples=10000, 20 | dim_state=3, 21 | dim_control=2, 22 | dynamics=env.dynamics, 23 | stage_cost=env.stage_cost, 24 | terminal_cost=env.terminal_cost, 25 | u_min=env.u_min, 26 | u_max=env.u_max, 27 | sigmas=torch.tensor([0.5, 0.5]), 28 | lambda_=1.0, 29 | ) 30 | 31 | state = env.reset() 32 | max_steps = 500 33 | average_time = 0 34 | for i in range(max_steps): 35 | start = time.time() 36 | with torch.no_grad(): 37 | action_seq, state_seq = solver.forward(state=state) 38 | end = time.time() 39 | average_time += (end - start) / max_steps 40 | 41 | state, is_goal_reached = env.step(action_seq[0, :]) 42 | 43 | is_collisions = env.collision_check(state=state_seq) 44 | 45 | top_samples, top_weights = solver.get_top_samples(num_samples=300) 46 | 47 | if save_mode: 48 | env.render( 49 | predicted_trajectory=state_seq, 50 | is_collisions=is_collisions, 51 | top_samples=(top_samples, top_weights), 52 | mode="rgb_array", 53 | ) 54 | # progress bar 55 | if i == 0: 56 | pbar = tqdm.tqdm(total=max_steps, desc="recording video") 57 | pbar.update(1) 58 | 59 | else: 60 | env.render( 61 | predicted_trajectory=state_seq, 62 | is_collisions=is_collisions, 63 | top_samples=(top_samples, top_weights), 64 | mode="human", 65 | ) 66 | if is_goal_reached: 67 | print("Goal Reached!") 68 | break 69 | 70 | print("average solve time: {}".format(average_time * 1000), " [ms]") 71 | env.close() # close window and save video if save_mode is True 72 | 73 | 74 | if __name__ == "__main__": 75 | fire.Fire(main) 76 | -------------------------------------------------------------------------------- /app/pendulum.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import time 4 | import gymnasium 5 | import fire 6 | 7 | from controller.mppi import MPPI 8 | 9 | 10 | @torch.jit.script 11 | def angle_normalize(x): 12 | return ((x + torch.pi) % (2 * torch.pi)) - torch.pi 13 | 14 | 15 | def main(save_mode: bool = False): 16 | # dynamics and cost 17 | @torch.jit.script 18 | def dynamics(state: torch.Tensor, action: torch.Tensor) -> torch.Tensor: 19 | # dynamics from gymnasium 20 | th = state[:, 0].view(-1, 1) 21 | thdot = state[:, 1].view(-1, 1) 22 | g = 10 23 | m = 1 24 | l = 1 25 | dt = 0.05 26 | u = action[:, 0].view(-1, 1) 27 | u = torch.clamp(u, -2, 2) 28 | newthdot = ( 29 | thdot 30 | + (-3 * g / (2 * l) * torch.sin(th + torch.pi) + 3.0 / (m * l**2) * u) 31 | * dt 32 | ) 33 | newth = th + newthdot * dt 34 | newthdot = torch.clamp(newthdot, -8, 8) 35 | 36 | state = torch.cat((newth, newthdot), dim=1) 37 | return state 38 | 39 | @torch.jit.script 40 | def stage_cost(state: torch.Tensor, action: torch.Tensor) -> torch.Tensor: 41 | theta = state[:, 0] 42 | theta_dt = state[:, 1] 43 | # u = action[:, 0] 44 | cost = angle_normalize(theta) ** 2 + 0.1 * theta_dt**2 45 | return cost 46 | 47 | @torch.jit.script 48 | def terminal_cost(state: torch.Tensor) -> torch.Tensor: 49 | theta = state[:, 0] 50 | theta_dt = state[:, 1] 51 | cost = angle_normalize(theta) ** 2 + 0.1 * theta_dt**2 52 | return cost 53 | 54 | # simulator 55 | if save_mode: 56 | env = gymnasium.make("Pendulum-v1", render_mode="rgb_array") 57 | env = gymnasium.wrappers.RecordVideo(env=env, video_folder="video") 58 | else: 59 | env = gymnasium.make("Pendulum-v1", render_mode="human") 60 | observation, _ = env.reset(seed=42) 61 | 62 | # solver 63 | solver = MPPI( 64 | horizon=15, 65 | num_samples=1000, 66 | dim_state=2, 67 | dim_control=1, 68 | dynamics=dynamics, 69 | stage_cost=stage_cost, 70 | terminal_cost=terminal_cost, 71 | u_min=torch.tensor([-2.0]), 72 | u_max=torch.tensor([2.0]), 73 | sigmas=torch.tensor([1.0]), 74 | lambda_=1.0, 75 | ) 76 | 77 | average_time = 0 78 | for i in range(200): 79 | state = env.unwrapped.state.copy() 80 | 81 | # solve 82 | start = time.time() 83 | action_seq, state_seq = solver.forward(state=state) 84 | elipsed_time = time.time() - start 85 | average_time = i / (i + 1) * average_time + elipsed_time / (i + 1) 86 | 87 | action_seq_np = action_seq.cpu().numpy() 88 | state_seq_np = state_seq.cpu().numpy() 89 | 90 | # update simulator 91 | observation, reward, terminated, truncated, info = env.step(action_seq_np[0, :]) 92 | env.render() 93 | 94 | print("average solve time: {}".format(average_time * 1000), " [ms]") 95 | env.close() 96 | 97 | 98 | if __name__ == "__main__": 99 | fire.Fire(main) 100 | -------------------------------------------------------------------------------- /media/cartpole.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/proxima-technology/mppi_playground/12cb79d02c703f818cc9c49d2436f79fd4db7e9f/media/cartpole.gif -------------------------------------------------------------------------------- /media/mountaincar.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/proxima-technology/mppi_playground/12cb79d02c703f818cc9c49d2436f79fd4db7e9f/media/mountaincar.gif -------------------------------------------------------------------------------- /media/navigation_2d.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/proxima-technology/mppi_playground/12cb79d02c703f818cc9c49d2436f79fd4db7e9f/media/navigation_2d.gif -------------------------------------------------------------------------------- /media/pendulum.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/proxima-technology/mppi_playground/12cb79d02c703f818cc9c49d2436f79fd4db7e9f/media/pendulum.gif -------------------------------------------------------------------------------- /poetry.lock: -------------------------------------------------------------------------------- 1 | # This file is automatically @generated by Poetry 1.4.2 and should not be changed by hand. 2 | 3 | [[package]] 4 | name = "black" 5 | version = "23.3.0" 6 | description = "The uncompromising code formatter." 7 | category = "main" 8 | optional = false 9 | python-versions = ">=3.7" 10 | files = [ 11 | {file = "black-23.3.0-cp310-cp310-macosx_10_16_arm64.whl", hash = "sha256:0945e13506be58bf7db93ee5853243eb368ace1c08a24c65ce108986eac65915"}, 12 | {file = "black-23.3.0-cp310-cp310-macosx_10_16_universal2.whl", hash = "sha256:67de8d0c209eb5b330cce2469503de11bca4085880d62f1628bd9972cc3366b9"}, 13 | {file = "black-23.3.0-cp310-cp310-macosx_10_16_x86_64.whl", hash = "sha256:7c3eb7cea23904399866c55826b31c1f55bbcd3890ce22ff70466b907b6775c2"}, 14 | {file = "black-23.3.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:32daa9783106c28815d05b724238e30718f34155653d4d6e125dc7daec8e260c"}, 15 | {file = "black-23.3.0-cp310-cp310-win_amd64.whl", hash = "sha256:35d1381d7a22cc5b2be2f72c7dfdae4072a3336060635718cc7e1ede24221d6c"}, 16 | {file = "black-23.3.0-cp311-cp311-macosx_10_16_arm64.whl", hash = "sha256:a8a968125d0a6a404842fa1bf0b349a568634f856aa08ffaff40ae0dfa52e7c6"}, 17 | {file = "black-23.3.0-cp311-cp311-macosx_10_16_universal2.whl", hash = "sha256:c7ab5790333c448903c4b721b59c0d80b11fe5e9803d8703e84dcb8da56fec1b"}, 18 | {file = "black-23.3.0-cp311-cp311-macosx_10_16_x86_64.whl", hash = "sha256:a6f6886c9869d4daae2d1715ce34a19bbc4b95006d20ed785ca00fa03cba312d"}, 19 | {file = "black-23.3.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6f3c333ea1dd6771b2d3777482429864f8e258899f6ff05826c3a4fcc5ce3f70"}, 20 | {file = "black-23.3.0-cp311-cp311-win_amd64.whl", hash = "sha256:11c410f71b876f961d1de77b9699ad19f939094c3a677323f43d7a29855fe326"}, 21 | {file = "black-23.3.0-cp37-cp37m-macosx_10_16_x86_64.whl", hash = "sha256:1d06691f1eb8de91cd1b322f21e3bfc9efe0c7ca1f0e1eb1db44ea367dff656b"}, 22 | {file = "black-23.3.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:50cb33cac881766a5cd9913e10ff75b1e8eb71babf4c7104f2e9c52da1fb7de2"}, 23 | {file = "black-23.3.0-cp37-cp37m-win_amd64.whl", hash = "sha256:e114420bf26b90d4b9daa597351337762b63039752bdf72bf361364c1aa05925"}, 24 | {file = "black-23.3.0-cp38-cp38-macosx_10_16_arm64.whl", hash = "sha256:48f9d345675bb7fbc3dd85821b12487e1b9a75242028adad0333ce36ed2a6d27"}, 25 | {file = "black-23.3.0-cp38-cp38-macosx_10_16_universal2.whl", hash = "sha256:714290490c18fb0126baa0fca0a54ee795f7502b44177e1ce7624ba1c00f2331"}, 26 | {file = "black-23.3.0-cp38-cp38-macosx_10_16_x86_64.whl", hash = "sha256:064101748afa12ad2291c2b91c960be28b817c0c7eaa35bec09cc63aa56493c5"}, 27 | {file = "black-23.3.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:562bd3a70495facf56814293149e51aa1be9931567474993c7942ff7d3533961"}, 28 | {file = "black-23.3.0-cp38-cp38-win_amd64.whl", hash = "sha256:e198cf27888ad6f4ff331ca1c48ffc038848ea9f031a3b40ba36aced7e22f2c8"}, 29 | {file = "black-23.3.0-cp39-cp39-macosx_10_16_arm64.whl", hash = "sha256:3238f2aacf827d18d26db07524e44741233ae09a584273aa059066d644ca7b30"}, 30 | {file = "black-23.3.0-cp39-cp39-macosx_10_16_universal2.whl", hash = "sha256:f0bd2f4a58d6666500542b26354978218a9babcdc972722f4bf90779524515f3"}, 31 | {file = "black-23.3.0-cp39-cp39-macosx_10_16_x86_64.whl", hash = "sha256:92c543f6854c28a3c7f39f4d9b7694f9a6eb9d3c5e2ece488c327b6e7ea9b266"}, 32 | {file = "black-23.3.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3a150542a204124ed00683f0db1f5cf1c2aaaa9cc3495b7a3b5976fb136090ab"}, 33 | {file = "black-23.3.0-cp39-cp39-win_amd64.whl", hash = "sha256:6b39abdfb402002b8a7d030ccc85cf5afff64ee90fa4c5aebc531e3ad0175ddb"}, 34 | {file = "black-23.3.0-py3-none-any.whl", hash = "sha256:ec751418022185b0c1bb7d7736e6933d40bbb14c14a0abcf9123d1b159f98dd4"}, 35 | {file = "black-23.3.0.tar.gz", hash = "sha256:1c7b8d606e728a41ea1ccbd7264677e494e87cf630e399262ced92d4a8dac940"}, 36 | ] 37 | 38 | [package.dependencies] 39 | click = ">=8.0.0" 40 | mypy-extensions = ">=0.4.3" 41 | packaging = ">=22.0" 42 | pathspec = ">=0.9.0" 43 | platformdirs = ">=2" 44 | tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""} 45 | 46 | [package.extras] 47 | colorama = ["colorama (>=0.4.3)"] 48 | d = ["aiohttp (>=3.7.4)"] 49 | jupyter = ["ipython (>=7.8.0)", "tokenize-rt (>=3.2.0)"] 50 | uvloop = ["uvloop (>=0.15.2)"] 51 | 52 | [[package]] 53 | name = "click" 54 | version = "8.1.3" 55 | description = "Composable command line interface toolkit" 56 | category = "main" 57 | optional = false 58 | python-versions = ">=3.7" 59 | files = [ 60 | {file = "click-8.1.3-py3-none-any.whl", hash = "sha256:bb4d8133cb15a609f44e8213d9b391b0809795062913b383c62be0ee95b1db48"}, 61 | {file = "click-8.1.3.tar.gz", hash = "sha256:7682dc8afb30297001674575ea00d1814d808d6a36af415a82bd481d37ba7b8e"}, 62 | ] 63 | 64 | [package.dependencies] 65 | colorama = {version = "*", markers = "platform_system == \"Windows\""} 66 | 67 | [[package]] 68 | name = "colorama" 69 | version = "0.4.6" 70 | description = "Cross-platform colored terminal text." 71 | category = "main" 72 | optional = false 73 | python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7" 74 | files = [ 75 | {file = "colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6"}, 76 | {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"}, 77 | ] 78 | 79 | [[package]] 80 | name = "contourpy" 81 | version = "1.0.7" 82 | description = "Python library for calculating contours of 2D quadrilateral grids" 83 | category = "main" 84 | optional = false 85 | python-versions = ">=3.8" 86 | files = [ 87 | {file = "contourpy-1.0.7-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:95c3acddf921944f241b6773b767f1cbce71d03307270e2d769fd584d5d1092d"}, 88 | {file = "contourpy-1.0.7-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:fc1464c97579da9f3ab16763c32e5c5d5bb5fa1ec7ce509a4ca6108b61b84fab"}, 89 | {file = "contourpy-1.0.7-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:8acf74b5d383414401926c1598ed77825cd530ac7b463ebc2e4f46638f56cce6"}, 90 | {file = "contourpy-1.0.7-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1c71fdd8f1c0f84ffd58fca37d00ca4ebaa9e502fb49825484da075ac0b0b803"}, 91 | {file = "contourpy-1.0.7-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f99e9486bf1bb979d95d5cffed40689cb595abb2b841f2991fc894b3452290e8"}, 92 | {file = "contourpy-1.0.7-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:87f4d8941a9564cda3f7fa6a6cd9b32ec575830780677932abdec7bcb61717b0"}, 93 | {file = "contourpy-1.0.7-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:9e20e5a1908e18aaa60d9077a6d8753090e3f85ca25da6e25d30dc0a9e84c2c6"}, 94 | {file = "contourpy-1.0.7-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:a877ada905f7d69b2a31796c4b66e31a8068b37aa9b78832d41c82fc3e056ddd"}, 95 | {file = "contourpy-1.0.7-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:6381fa66866b0ea35e15d197fc06ac3840a9b2643a6475c8fff267db8b9f1e69"}, 96 | {file = "contourpy-1.0.7-cp310-cp310-win32.whl", hash = "sha256:3c184ad2433635f216645fdf0493011a4667e8d46b34082f5a3de702b6ec42e3"}, 97 | {file = "contourpy-1.0.7-cp310-cp310-win_amd64.whl", hash = "sha256:3caea6365b13119626ee996711ab63e0c9d7496f65641f4459c60a009a1f3e80"}, 98 | {file = "contourpy-1.0.7-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:ed33433fc3820263a6368e532f19ddb4c5990855e4886088ad84fd7c4e561c71"}, 99 | {file = "contourpy-1.0.7-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:38e2e577f0f092b8e6774459317c05a69935a1755ecfb621c0a98f0e3c09c9a5"}, 100 | {file = "contourpy-1.0.7-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:ae90d5a8590e5310c32a7630b4b8618cef7563cebf649011da80874d0aa8f414"}, 101 | {file = "contourpy-1.0.7-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:130230b7e49825c98edf0b428b7aa1125503d91732735ef897786fe5452b1ec2"}, 102 | {file = "contourpy-1.0.7-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:58569c491e7f7e874f11519ef46737cea1d6eda1b514e4eb5ac7dab6aa864d02"}, 103 | {file = "contourpy-1.0.7-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:54d43960d809c4c12508a60b66cb936e7ed57d51fb5e30b513934a4a23874fae"}, 104 | {file = "contourpy-1.0.7-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:152fd8f730c31fd67fe0ffebe1df38ab6a669403da93df218801a893645c6ccc"}, 105 | {file = "contourpy-1.0.7-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:9056c5310eb1daa33fc234ef39ebfb8c8e2533f088bbf0bc7350f70a29bde1ac"}, 106 | {file = "contourpy-1.0.7-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:a9d7587d2fdc820cc9177139b56795c39fb8560f540bba9ceea215f1f66e1566"}, 107 | {file = "contourpy-1.0.7-cp311-cp311-win32.whl", hash = "sha256:4ee3ee247f795a69e53cd91d927146fb16c4e803c7ac86c84104940c7d2cabf0"}, 108 | {file = "contourpy-1.0.7-cp311-cp311-win_amd64.whl", hash = "sha256:5caeacc68642e5f19d707471890f037a13007feba8427eb7f2a60811a1fc1350"}, 109 | {file = "contourpy-1.0.7-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:fd7dc0e6812b799a34f6d12fcb1000539098c249c8da54f3566c6a6461d0dbad"}, 110 | {file = "contourpy-1.0.7-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:0f9d350b639db6c2c233d92c7f213d94d2e444d8e8fc5ca44c9706cf72193772"}, 111 | {file = "contourpy-1.0.7-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:e96a08b62bb8de960d3a6afbc5ed8421bf1a2d9c85cc4ea73f4bc81b4910500f"}, 112 | {file = "contourpy-1.0.7-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:031154ed61f7328ad7f97662e48660a150ef84ee1bc8876b6472af88bf5a9b98"}, 113 | {file = "contourpy-1.0.7-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2e9ebb4425fc1b658e13bace354c48a933b842d53c458f02c86f371cecbedecc"}, 114 | {file = "contourpy-1.0.7-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:efb8f6d08ca7998cf59eaf50c9d60717f29a1a0a09caa46460d33b2924839dbd"}, 115 | {file = "contourpy-1.0.7-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:6c180d89a28787e4b73b07e9b0e2dac7741261dbdca95f2b489c4f8f887dd810"}, 116 | {file = "contourpy-1.0.7-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:b8d587cc39057d0afd4166083d289bdeff221ac6d3ee5046aef2d480dc4b503c"}, 117 | {file = "contourpy-1.0.7-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:769eef00437edf115e24d87f8926955f00f7704bede656ce605097584f9966dc"}, 118 | {file = "contourpy-1.0.7-cp38-cp38-win32.whl", hash = "sha256:62398c80ef57589bdbe1eb8537127321c1abcfdf8c5f14f479dbbe27d0322e66"}, 119 | {file = "contourpy-1.0.7-cp38-cp38-win_amd64.whl", hash = "sha256:57119b0116e3f408acbdccf9eb6ef19d7fe7baf0d1e9aaa5381489bc1aa56556"}, 120 | {file = "contourpy-1.0.7-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:30676ca45084ee61e9c3da589042c24a57592e375d4b138bd84d8709893a1ba4"}, 121 | {file = "contourpy-1.0.7-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:3e927b3868bd1e12acee7cc8f3747d815b4ab3e445a28d2e5373a7f4a6e76ba1"}, 122 | {file = "contourpy-1.0.7-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:366a0cf0fc079af5204801786ad7a1c007714ee3909e364dbac1729f5b0849e5"}, 123 | {file = "contourpy-1.0.7-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:89ba9bb365446a22411f0673abf6ee1fea3b2cf47b37533b970904880ceb72f3"}, 124 | {file = "contourpy-1.0.7-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:71b0bf0c30d432278793d2141362ac853859e87de0a7dee24a1cea35231f0d50"}, 125 | {file = "contourpy-1.0.7-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e7281244c99fd7c6f27c1c6bfafba878517b0b62925a09b586d88ce750a016d2"}, 126 | {file = "contourpy-1.0.7-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:b6d0f9e1d39dbfb3977f9dd79f156c86eb03e57a7face96f199e02b18e58d32a"}, 127 | {file = "contourpy-1.0.7-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:7f6979d20ee5693a1057ab53e043adffa1e7418d734c1532e2d9e915b08d8ec2"}, 128 | {file = "contourpy-1.0.7-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:5dd34c1ae752515318224cba7fc62b53130c45ac6a1040c8b7c1a223c46e8967"}, 129 | {file = "contourpy-1.0.7-cp39-cp39-win32.whl", hash = "sha256:c5210e5d5117e9aec8c47d9156d1d3835570dd909a899171b9535cb4a3f32693"}, 130 | {file = "contourpy-1.0.7-cp39-cp39-win_amd64.whl", hash = "sha256:60835badb5ed5f4e194a6f21c09283dd6e007664a86101431bf870d9e86266c4"}, 131 | {file = "contourpy-1.0.7-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:ce41676b3d0dd16dbcfabcc1dc46090aaf4688fd6e819ef343dbda5a57ef0161"}, 132 | {file = "contourpy-1.0.7-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5a011cf354107b47c58ea932d13b04d93c6d1d69b8b6dce885e642531f847566"}, 133 | {file = "contourpy-1.0.7-pp38-pypy38_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:31a55dccc8426e71817e3fe09b37d6d48ae40aae4ecbc8c7ad59d6893569c436"}, 134 | {file = "contourpy-1.0.7-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:69f8ff4db108815addd900a74df665e135dbbd6547a8a69333a68e1f6e368ac2"}, 135 | {file = "contourpy-1.0.7-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:efe99298ba37e37787f6a2ea868265465410822f7bea163edcc1bd3903354ea9"}, 136 | {file = "contourpy-1.0.7-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:a1e97b86f73715e8670ef45292d7cc033548266f07d54e2183ecb3c87598888f"}, 137 | {file = "contourpy-1.0.7-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cc331c13902d0f50845099434cd936d49d7a2ca76cb654b39691974cb1e4812d"}, 138 | {file = "contourpy-1.0.7-pp39-pypy39_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:24847601071f740837aefb730e01bd169fbcaa610209779a78db7ebb6e6a7051"}, 139 | {file = "contourpy-1.0.7-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:abf298af1e7ad44eeb93501e40eb5a67abbf93b5d90e468d01fc0c4451971afa"}, 140 | {file = "contourpy-1.0.7-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:64757f6460fc55d7e16ed4f1de193f362104285c667c112b50a804d482777edd"}, 141 | {file = "contourpy-1.0.7.tar.gz", hash = "sha256:d8165a088d31798b59e91117d1f5fc3df8168d8b48c4acc10fc0df0d0bdbcc5e"}, 142 | ] 143 | 144 | [package.dependencies] 145 | numpy = ">=1.16" 146 | 147 | [package.extras] 148 | bokeh = ["bokeh", "chromedriver", "selenium"] 149 | docs = ["furo", "sphinx-copybutton"] 150 | mypy = ["contourpy[bokeh]", "docutils-stubs", "mypy (==0.991)", "types-Pillow"] 151 | test = ["Pillow", "matplotlib", "pytest"] 152 | test-no-images = ["pytest"] 153 | 154 | [[package]] 155 | name = "cycler" 156 | version = "0.11.0" 157 | description = "Composable style cycles" 158 | category = "main" 159 | optional = false 160 | python-versions = ">=3.6" 161 | files = [ 162 | {file = "cycler-0.11.0-py3-none-any.whl", hash = "sha256:3a27e95f763a428a739d2add979fa7494c912a32c17c4c38c4d5f082cad165a3"}, 163 | {file = "cycler-0.11.0.tar.gz", hash = "sha256:9c87405839a19696e837b3b818fed3f5f69f16f1eec1a1ad77e043dcea9c772f"}, 164 | ] 165 | 166 | [[package]] 167 | name = "fonttools" 168 | version = "4.39.3" 169 | description = "Tools to manipulate font files" 170 | category = "main" 171 | optional = false 172 | python-versions = ">=3.8" 173 | files = [ 174 | {file = "fonttools-4.39.3-py3-none-any.whl", hash = "sha256:64c0c05c337f826183637570ac5ab49ee220eec66cf50248e8df527edfa95aeb"}, 175 | {file = "fonttools-4.39.3.zip", hash = "sha256:9234b9f57b74e31b192c3fc32ef1a40750a8fbc1cd9837a7b7bfc4ca4a5c51d7"}, 176 | ] 177 | 178 | [package.extras] 179 | all = ["brotli (>=1.0.1)", "brotlicffi (>=0.8.0)", "fs (>=2.2.0,<3)", "lxml (>=4.0,<5)", "lz4 (>=1.7.4.2)", "matplotlib", "munkres", "scipy", "skia-pathops (>=0.5.0)", "sympy", "uharfbuzz (>=0.23.0)", "unicodedata2 (>=15.0.0)", "xattr", "zopfli (>=0.1.4)"] 180 | graphite = ["lz4 (>=1.7.4.2)"] 181 | interpolatable = ["munkres", "scipy"] 182 | lxml = ["lxml (>=4.0,<5)"] 183 | pathops = ["skia-pathops (>=0.5.0)"] 184 | plot = ["matplotlib"] 185 | repacker = ["uharfbuzz (>=0.23.0)"] 186 | symfont = ["sympy"] 187 | type1 = ["xattr"] 188 | ufo = ["fs (>=2.2.0,<3)"] 189 | unicode = ["unicodedata2 (>=15.0.0)"] 190 | woff = ["brotli (>=1.0.1)", "brotlicffi (>=0.8.0)", "zopfli (>=0.1.4)"] 191 | 192 | [[package]] 193 | name = "kiwisolver" 194 | version = "1.4.4" 195 | description = "A fast implementation of the Cassowary constraint solver" 196 | category = "main" 197 | optional = false 198 | python-versions = ">=3.7" 199 | files = [ 200 | {file = "kiwisolver-1.4.4-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:2f5e60fabb7343a836360c4f0919b8cd0d6dbf08ad2ca6b9cf90bf0c76a3c4f6"}, 201 | {file = "kiwisolver-1.4.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:10ee06759482c78bdb864f4109886dff7b8a56529bc1609d4f1112b93fe6423c"}, 202 | {file = "kiwisolver-1.4.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c79ebe8f3676a4c6630fd3f777f3cfecf9289666c84e775a67d1d358578dc2e3"}, 203 | {file = "kiwisolver-1.4.4-cp310-cp310-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:abbe9fa13da955feb8202e215c4018f4bb57469b1b78c7a4c5c7b93001699938"}, 204 | {file = "kiwisolver-1.4.4-cp310-cp310-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:7577c1987baa3adc4b3c62c33bd1118c3ef5c8ddef36f0f2c950ae0b199e100d"}, 205 | {file = "kiwisolver-1.4.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f8ad8285b01b0d4695102546b342b493b3ccc6781fc28c8c6a1bb63e95d22f09"}, 206 | {file = "kiwisolver-1.4.4-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8ed58b8acf29798b036d347791141767ccf65eee7f26bde03a71c944449e53de"}, 207 | {file = "kiwisolver-1.4.4-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a68b62a02953b9841730db7797422f983935aeefceb1679f0fc85cbfbd311c32"}, 208 | {file = "kiwisolver-1.4.4-cp310-cp310-win32.whl", hash = "sha256:e92a513161077b53447160b9bd8f522edfbed4bd9759e4c18ab05d7ef7e49408"}, 209 | {file = "kiwisolver-1.4.4-cp310-cp310-win_amd64.whl", hash = "sha256:3fe20f63c9ecee44560d0e7f116b3a747a5d7203376abeea292ab3152334d004"}, 210 | {file = "kiwisolver-1.4.4-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:e0ea21f66820452a3f5d1655f8704a60d66ba1191359b96541eaf457710a5fc6"}, 211 | {file = "kiwisolver-1.4.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:bc9db8a3efb3e403e4ecc6cd9489ea2bac94244f80c78e27c31dcc00d2790ac2"}, 212 | {file = "kiwisolver-1.4.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:d5b61785a9ce44e5a4b880272baa7cf6c8f48a5180c3e81c59553ba0cb0821ca"}, 213 | {file = "kiwisolver-1.4.4-cp311-cp311-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c2dbb44c3f7e6c4d3487b31037b1bdbf424d97687c1747ce4ff2895795c9bf69"}, 214 | {file = "kiwisolver-1.4.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6295ecd49304dcf3bfbfa45d9a081c96509e95f4b9d0eb7ee4ec0530c4a96514"}, 215 | {file = "kiwisolver-1.4.4-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4bd472dbe5e136f96a4b18f295d159d7f26fd399136f5b17b08c4e5f498cd494"}, 216 | {file = "kiwisolver-1.4.4-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:bf7d9fce9bcc4752ca4a1b80aabd38f6d19009ea5cbda0e0856983cf6d0023f5"}, 217 | {file = "kiwisolver-1.4.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:78d6601aed50c74e0ef02f4204da1816147a6d3fbdc8b3872d263338a9052c51"}, 218 | {file = "kiwisolver-1.4.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:877272cf6b4b7e94c9614f9b10140e198d2186363728ed0f701c6eee1baec1da"}, 219 | {file = "kiwisolver-1.4.4-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:db608a6757adabb32f1cfe6066e39b3706d8c3aa69bbc353a5b61edad36a5cb4"}, 220 | {file = "kiwisolver-1.4.4-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:5853eb494c71e267912275e5586fe281444eb5e722de4e131cddf9d442615626"}, 221 | {file = "kiwisolver-1.4.4-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:f0a1dbdb5ecbef0d34eb77e56fcb3e95bbd7e50835d9782a45df81cc46949750"}, 222 | {file = "kiwisolver-1.4.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:283dffbf061a4ec60391d51e6155e372a1f7a4f5b15d59c8505339454f8989e4"}, 223 | {file = "kiwisolver-1.4.4-cp311-cp311-win32.whl", hash = "sha256:d06adcfa62a4431d404c31216f0f8ac97397d799cd53800e9d3efc2fbb3cf14e"}, 224 | {file = "kiwisolver-1.4.4-cp311-cp311-win_amd64.whl", hash = "sha256:e7da3fec7408813a7cebc9e4ec55afed2d0fd65c4754bc376bf03498d4e92686"}, 225 | {file = "kiwisolver-1.4.4-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:62ac9cc684da4cf1778d07a89bf5f81b35834cb96ca523d3a7fb32509380cbf6"}, 226 | {file = "kiwisolver-1.4.4-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:41dae968a94b1ef1897cb322b39360a0812661dba7c682aa45098eb8e193dbdf"}, 227 | {file = "kiwisolver-1.4.4-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:02f79693ec433cb4b5f51694e8477ae83b3205768a6fb48ffba60549080e295b"}, 228 | {file = "kiwisolver-1.4.4-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d0611a0a2a518464c05ddd5a3a1a0e856ccc10e67079bb17f265ad19ab3c7597"}, 229 | {file = "kiwisolver-1.4.4-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:db5283d90da4174865d520e7366801a93777201e91e79bacbac6e6927cbceede"}, 230 | {file = "kiwisolver-1.4.4-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:1041feb4cda8708ce73bb4dcb9ce1ccf49d553bf87c3954bdfa46f0c3f77252c"}, 231 | {file = "kiwisolver-1.4.4-cp37-cp37m-win32.whl", hash = "sha256:a553dadda40fef6bfa1456dc4be49b113aa92c2a9a9e8711e955618cd69622e3"}, 232 | {file = "kiwisolver-1.4.4-cp37-cp37m-win_amd64.whl", hash = "sha256:03baab2d6b4a54ddbb43bba1a3a2d1627e82d205c5cf8f4c924dc49284b87166"}, 233 | {file = "kiwisolver-1.4.4-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:841293b17ad704d70c578f1f0013c890e219952169ce8a24ebc063eecf775454"}, 234 | {file = "kiwisolver-1.4.4-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:f4f270de01dd3e129a72efad823da90cc4d6aafb64c410c9033aba70db9f1ff0"}, 235 | {file = "kiwisolver-1.4.4-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:f9f39e2f049db33a908319cf46624a569b36983c7c78318e9726a4cb8923b26c"}, 236 | {file = "kiwisolver-1.4.4-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c97528e64cb9ebeff9701e7938653a9951922f2a38bd847787d4a8e498cc83ae"}, 237 | {file = "kiwisolver-1.4.4-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1d1573129aa0fd901076e2bfb4275a35f5b7aa60fbfb984499d661ec950320b0"}, 238 | {file = "kiwisolver-1.4.4-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ad881edc7ccb9d65b0224f4e4d05a1e85cf62d73aab798943df6d48ab0cd79a1"}, 239 | {file = "kiwisolver-1.4.4-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:b428ef021242344340460fa4c9185d0b1f66fbdbfecc6c63eff4b7c29fad429d"}, 240 | {file = "kiwisolver-1.4.4-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:2e407cb4bd5a13984a6c2c0fe1845e4e41e96f183e5e5cd4d77a857d9693494c"}, 241 | {file = "kiwisolver-1.4.4-cp38-cp38-win32.whl", hash = "sha256:75facbe9606748f43428fc91a43edb46c7ff68889b91fa31f53b58894503a191"}, 242 | {file = "kiwisolver-1.4.4-cp38-cp38-win_amd64.whl", hash = "sha256:5bce61af018b0cb2055e0e72e7d65290d822d3feee430b7b8203d8a855e78766"}, 243 | {file = "kiwisolver-1.4.4-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:8c808594c88a025d4e322d5bb549282c93c8e1ba71b790f539567932722d7bd8"}, 244 | {file = "kiwisolver-1.4.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:f0a71d85ecdd570ded8ac3d1c0f480842f49a40beb423bb8014539a9f32a5897"}, 245 | {file = "kiwisolver-1.4.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:b533558eae785e33e8c148a8d9921692a9fe5aa516efbdff8606e7d87b9d5824"}, 246 | {file = "kiwisolver-1.4.4-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:efda5fc8cc1c61e4f639b8067d118e742b812c930f708e6667a5ce0d13499e29"}, 247 | {file = "kiwisolver-1.4.4-cp39-cp39-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:7c43e1e1206cd421cd92e6b3280d4385d41d7166b3ed577ac20444b6995a445f"}, 248 | {file = "kiwisolver-1.4.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bc8d3bd6c72b2dd9decf16ce70e20abcb3274ba01b4e1c96031e0c4067d1e7cd"}, 249 | {file = "kiwisolver-1.4.4-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4ea39b0ccc4f5d803e3337dd46bcce60b702be4d86fd0b3d7531ef10fd99a1ac"}, 250 | {file = "kiwisolver-1.4.4-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:968f44fdbf6dd757d12920d63b566eeb4d5b395fd2d00d29d7ef00a00582aac9"}, 251 | {file = "kiwisolver-1.4.4-cp39-cp39-win32.whl", hash = "sha256:da7e547706e69e45d95e116e6939488d62174e033b763ab1496b4c29b76fabea"}, 252 | {file = "kiwisolver-1.4.4-cp39-cp39-win_amd64.whl", hash = "sha256:ba59c92039ec0a66103b1d5fe588fa546373587a7d68f5c96f743c3396afc04b"}, 253 | {file = "kiwisolver-1.4.4-pp37-pypy37_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:91672bacaa030f92fc2f43b620d7b337fd9a5af28b0d6ed3f77afc43c4a64b5a"}, 254 | {file = "kiwisolver-1.4.4-pp37-pypy37_pp73-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:787518a6789009c159453da4d6b683f468ef7a65bbde796bcea803ccf191058d"}, 255 | {file = "kiwisolver-1.4.4-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:da152d8cdcab0e56e4f45eb08b9aea6455845ec83172092f09b0e077ece2cf7a"}, 256 | {file = "kiwisolver-1.4.4-pp37-pypy37_pp73-win_amd64.whl", hash = "sha256:ecb1fa0db7bf4cff9dac752abb19505a233c7f16684c5826d1f11ebd9472b871"}, 257 | {file = "kiwisolver-1.4.4-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:28bc5b299f48150b5f822ce68624e445040595a4ac3d59251703779836eceff9"}, 258 | {file = "kiwisolver-1.4.4-pp38-pypy38_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:81e38381b782cc7e1e46c4e14cd997ee6040768101aefc8fa3c24a4cc58e98f8"}, 259 | {file = "kiwisolver-1.4.4-pp38-pypy38_pp73-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:2a66fdfb34e05b705620dd567f5a03f239a088d5a3f321e7b6ac3239d22aa286"}, 260 | {file = "kiwisolver-1.4.4-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:872b8ca05c40d309ed13eb2e582cab0c5a05e81e987ab9c521bf05ad1d5cf5cb"}, 261 | {file = "kiwisolver-1.4.4-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:70e7c2e7b750585569564e2e5ca9845acfaa5da56ac46df68414f29fea97be9f"}, 262 | {file = "kiwisolver-1.4.4-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:9f85003f5dfa867e86d53fac6f7e6f30c045673fa27b603c397753bebadc3008"}, 263 | {file = "kiwisolver-1.4.4-pp39-pypy39_pp73-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2e307eb9bd99801f82789b44bb45e9f541961831c7311521b13a6c85afc09767"}, 264 | {file = "kiwisolver-1.4.4-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b1792d939ec70abe76f5054d3f36ed5656021dcad1322d1cc996d4e54165cef9"}, 265 | {file = "kiwisolver-1.4.4-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f6cb459eea32a4e2cf18ba5fcece2dbdf496384413bc1bae15583f19e567f3b2"}, 266 | {file = "kiwisolver-1.4.4-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:36dafec3d6d6088d34e2de6b85f9d8e2324eb734162fba59d2ba9ed7a2043d5b"}, 267 | {file = "kiwisolver-1.4.4.tar.gz", hash = "sha256:d41997519fcba4a1e46eb4a2fe31bc12f0ff957b2b81bac28db24744f333e955"}, 268 | ] 269 | 270 | [[package]] 271 | name = "matplotlib" 272 | version = "3.7.1" 273 | description = "Python plotting package" 274 | category = "main" 275 | optional = false 276 | python-versions = ">=3.8" 277 | files = [ 278 | {file = "matplotlib-3.7.1-cp310-cp310-macosx_10_12_universal2.whl", hash = "sha256:95cbc13c1fc6844ab8812a525bbc237fa1470863ff3dace7352e910519e194b1"}, 279 | {file = "matplotlib-3.7.1-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:08308bae9e91aca1ec6fd6dda66237eef9f6294ddb17f0d0b3c863169bf82353"}, 280 | {file = "matplotlib-3.7.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:544764ba51900da4639c0f983b323d288f94f65f4024dc40ecb1542d74dc0500"}, 281 | {file = "matplotlib-3.7.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:56d94989191de3fcc4e002f93f7f1be5da476385dde410ddafbb70686acf00ea"}, 282 | {file = "matplotlib-3.7.1-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e99bc9e65901bb9a7ce5e7bb24af03675cbd7c70b30ac670aa263240635999a4"}, 283 | {file = "matplotlib-3.7.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:eb7d248c34a341cd4c31a06fd34d64306624c8cd8d0def7abb08792a5abfd556"}, 284 | {file = "matplotlib-3.7.1-cp310-cp310-win32.whl", hash = "sha256:ce463ce590f3825b52e9fe5c19a3c6a69fd7675a39d589e8b5fbe772272b3a24"}, 285 | {file = "matplotlib-3.7.1-cp310-cp310-win_amd64.whl", hash = "sha256:3d7bc90727351fb841e4d8ae620d2d86d8ed92b50473cd2b42ce9186104ecbba"}, 286 | {file = "matplotlib-3.7.1-cp311-cp311-macosx_10_12_universal2.whl", hash = "sha256:770a205966d641627fd5cf9d3cb4b6280a716522cd36b8b284a8eb1581310f61"}, 287 | {file = "matplotlib-3.7.1-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:f67bfdb83a8232cb7a92b869f9355d677bce24485c460b19d01970b64b2ed476"}, 288 | {file = "matplotlib-3.7.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:2bf092f9210e105f414a043b92af583c98f50050559616930d884387d0772aba"}, 289 | {file = "matplotlib-3.7.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:89768d84187f31717349c6bfadc0e0d8c321e8eb34522acec8a67b1236a66332"}, 290 | {file = "matplotlib-3.7.1-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:83111e6388dec67822e2534e13b243cc644c7494a4bb60584edbff91585a83c6"}, 291 | {file = "matplotlib-3.7.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a867bf73a7eb808ef2afbca03bcdb785dae09595fbe550e1bab0cd023eba3de0"}, 292 | {file = "matplotlib-3.7.1-cp311-cp311-win32.whl", hash = "sha256:fbdeeb58c0cf0595efe89c05c224e0a502d1aa6a8696e68a73c3efc6bc354304"}, 293 | {file = "matplotlib-3.7.1-cp311-cp311-win_amd64.whl", hash = "sha256:c0bd19c72ae53e6ab979f0ac6a3fafceb02d2ecafa023c5cca47acd934d10be7"}, 294 | {file = "matplotlib-3.7.1-cp38-cp38-macosx_10_12_universal2.whl", hash = "sha256:6eb88d87cb2c49af00d3bbc33a003f89fd9f78d318848da029383bfc08ecfbfb"}, 295 | {file = "matplotlib-3.7.1-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:cf0e4f727534b7b1457898c4f4ae838af1ef87c359b76dcd5330fa31893a3ac7"}, 296 | {file = "matplotlib-3.7.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:46a561d23b91f30bccfd25429c3c706afe7d73a5cc64ef2dfaf2b2ac47c1a5dc"}, 297 | {file = "matplotlib-3.7.1-cp38-cp38-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:8704726d33e9aa8a6d5215044b8d00804561971163563e6e6591f9dcf64340cc"}, 298 | {file = "matplotlib-3.7.1-cp38-cp38-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:4cf327e98ecf08fcbb82685acaf1939d3338548620ab8dfa02828706402c34de"}, 299 | {file = "matplotlib-3.7.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:617f14ae9d53292ece33f45cba8503494ee199a75b44de7717964f70637a36aa"}, 300 | {file = "matplotlib-3.7.1-cp38-cp38-win32.whl", hash = "sha256:7c9a4b2da6fac77bcc41b1ea95fadb314e92508bf5493ceff058e727e7ecf5b0"}, 301 | {file = "matplotlib-3.7.1-cp38-cp38-win_amd64.whl", hash = "sha256:14645aad967684e92fc349493fa10c08a6da514b3d03a5931a1bac26e6792bd1"}, 302 | {file = "matplotlib-3.7.1-cp39-cp39-macosx_10_12_universal2.whl", hash = "sha256:81a6b377ea444336538638d31fdb39af6be1a043ca5e343fe18d0f17e098770b"}, 303 | {file = "matplotlib-3.7.1-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:28506a03bd7f3fe59cd3cd4ceb2a8d8a2b1db41afede01f66c42561b9be7b4b7"}, 304 | {file = "matplotlib-3.7.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:8c587963b85ce41e0a8af53b9b2de8dddbf5ece4c34553f7bd9d066148dc719c"}, 305 | {file = "matplotlib-3.7.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8bf26ade3ff0f27668989d98c8435ce9327d24cffb7f07d24ef609e33d582439"}, 306 | {file = "matplotlib-3.7.1-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:def58098f96a05f90af7e92fd127d21a287068202aa43b2a93476170ebd99e87"}, 307 | {file = "matplotlib-3.7.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f883a22a56a84dba3b588696a2b8a1ab0d2c3d41be53264115c71b0a942d8fdb"}, 308 | {file = "matplotlib-3.7.1-cp39-cp39-win32.whl", hash = "sha256:4f99e1b234c30c1e9714610eb0c6d2f11809c9c78c984a613ae539ea2ad2eb4b"}, 309 | {file = "matplotlib-3.7.1-cp39-cp39-win_amd64.whl", hash = "sha256:3ba2af245e36990facf67fde840a760128ddd71210b2ab6406e640188d69d136"}, 310 | {file = "matplotlib-3.7.1-pp38-pypy38_pp73-macosx_10_12_x86_64.whl", hash = "sha256:3032884084f541163f295db8a6536e0abb0db464008fadca6c98aaf84ccf4717"}, 311 | {file = "matplotlib-3.7.1-pp38-pypy38_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3a2cb34336110e0ed8bb4f650e817eed61fa064acbefeb3591f1b33e3a84fd96"}, 312 | {file = "matplotlib-3.7.1-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b867e2f952ed592237a1828f027d332d8ee219ad722345b79a001f49df0936eb"}, 313 | {file = "matplotlib-3.7.1-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:57bfb8c8ea253be947ccb2bc2d1bb3862c2bccc662ad1b4626e1f5e004557042"}, 314 | {file = "matplotlib-3.7.1-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:438196cdf5dc8d39b50a45cb6e3f6274edbcf2254f85fa9b895bf85851c3a613"}, 315 | {file = "matplotlib-3.7.1-pp39-pypy39_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:21e9cff1a58d42e74d01153360de92b326708fb205250150018a52c70f43c290"}, 316 | {file = "matplotlib-3.7.1-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:75d4725d70b7c03e082bbb8a34639ede17f333d7247f56caceb3801cb6ff703d"}, 317 | {file = "matplotlib-3.7.1-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:97cc368a7268141afb5690760921765ed34867ffb9655dd325ed207af85c7529"}, 318 | {file = "matplotlib-3.7.1.tar.gz", hash = "sha256:7b73305f25eab4541bd7ee0b96d87e53ae9c9f1823be5659b806cd85786fe882"}, 319 | ] 320 | 321 | [package.dependencies] 322 | contourpy = ">=1.0.1" 323 | cycler = ">=0.10" 324 | fonttools = ">=4.22.0" 325 | kiwisolver = ">=1.0.1" 326 | numpy = ">=1.20" 327 | packaging = ">=20.0" 328 | pillow = ">=6.2.0" 329 | pyparsing = ">=2.3.1" 330 | python-dateutil = ">=2.7" 331 | 332 | [[package]] 333 | name = "mypy-extensions" 334 | version = "1.0.0" 335 | description = "Type system extensions for programs checked with the mypy type checker." 336 | category = "main" 337 | optional = false 338 | python-versions = ">=3.5" 339 | files = [ 340 | {file = "mypy_extensions-1.0.0-py3-none-any.whl", hash = "sha256:4392f6c0eb8a5668a69e23d168ffa70f0be9ccfd32b5cc2d26a34ae5b844552d"}, 341 | {file = "mypy_extensions-1.0.0.tar.gz", hash = "sha256:75dbf8955dc00442a438fc4d0666508a9a97b6bd41aa2f0ffe9d2f2725af0782"}, 342 | ] 343 | 344 | [[package]] 345 | name = "numpy" 346 | version = "1.24.2" 347 | description = "Fundamental package for array computing in Python" 348 | category = "main" 349 | optional = false 350 | python-versions = ">=3.8" 351 | files = [ 352 | {file = "numpy-1.24.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:eef70b4fc1e872ebddc38cddacc87c19a3709c0e3e5d20bf3954c147b1dd941d"}, 353 | {file = "numpy-1.24.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:e8d2859428712785e8a8b7d2b3ef0a1d1565892367b32f915c4a4df44d0e64f5"}, 354 | {file = "numpy-1.24.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6524630f71631be2dabe0c541e7675db82651eb998496bbe16bc4f77f0772253"}, 355 | {file = "numpy-1.24.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a51725a815a6188c662fb66fb32077709a9ca38053f0274640293a14fdd22978"}, 356 | {file = "numpy-1.24.2-cp310-cp310-win32.whl", hash = "sha256:2620e8592136e073bd12ee4536149380695fbe9ebeae845b81237f986479ffc9"}, 357 | {file = "numpy-1.24.2-cp310-cp310-win_amd64.whl", hash = "sha256:97cf27e51fa078078c649a51d7ade3c92d9e709ba2bfb97493007103c741f1d0"}, 358 | {file = "numpy-1.24.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:7de8fdde0003f4294655aa5d5f0a89c26b9f22c0a58790c38fae1ed392d44a5a"}, 359 | {file = "numpy-1.24.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:4173bde9fa2a005c2c6e2ea8ac1618e2ed2c1c6ec8a7657237854d42094123a0"}, 360 | {file = "numpy-1.24.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4cecaed30dc14123020f77b03601559fff3e6cd0c048f8b5289f4eeabb0eb281"}, 361 | {file = "numpy-1.24.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9a23f8440561a633204a67fb44617ce2a299beecf3295f0d13c495518908e910"}, 362 | {file = "numpy-1.24.2-cp311-cp311-win32.whl", hash = "sha256:e428c4fbfa085f947b536706a2fc349245d7baa8334f0c5723c56a10595f9b95"}, 363 | {file = "numpy-1.24.2-cp311-cp311-win_amd64.whl", hash = "sha256:557d42778a6869c2162deb40ad82612645e21d79e11c1dc62c6e82a2220ffb04"}, 364 | {file = "numpy-1.24.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:d0a2db9d20117bf523dde15858398e7c0858aadca7c0f088ac0d6edd360e9ad2"}, 365 | {file = "numpy-1.24.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:c72a6b2f4af1adfe193f7beb91ddf708ff867a3f977ef2ec53c0ffb8283ab9f5"}, 366 | {file = "numpy-1.24.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c29e6bd0ec49a44d7690ecb623a8eac5ab8a923bce0bea6293953992edf3a76a"}, 367 | {file = "numpy-1.24.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2eabd64ddb96a1239791da78fa5f4e1693ae2dadc82a76bc76a14cbb2b966e96"}, 368 | {file = "numpy-1.24.2-cp38-cp38-win32.whl", hash = "sha256:e3ab5d32784e843fc0dd3ab6dcafc67ef806e6b6828dc6af2f689be0eb4d781d"}, 369 | {file = "numpy-1.24.2-cp38-cp38-win_amd64.whl", hash = "sha256:76807b4063f0002c8532cfeac47a3068a69561e9c8715efdad3c642eb27c0756"}, 370 | {file = "numpy-1.24.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:4199e7cfc307a778f72d293372736223e39ec9ac096ff0a2e64853b866a8e18a"}, 371 | {file = "numpy-1.24.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:adbdce121896fd3a17a77ab0b0b5eedf05a9834a18699db6829a64e1dfccca7f"}, 372 | {file = "numpy-1.24.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:889b2cc88b837d86eda1b17008ebeb679d82875022200c6e8e4ce6cf549b7acb"}, 373 | {file = "numpy-1.24.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f64bb98ac59b3ea3bf74b02f13836eb2e24e48e0ab0145bbda646295769bd780"}, 374 | {file = "numpy-1.24.2-cp39-cp39-win32.whl", hash = "sha256:63e45511ee4d9d976637d11e6c9864eae50e12dc9598f531c035265991910468"}, 375 | {file = "numpy-1.24.2-cp39-cp39-win_amd64.whl", hash = "sha256:a77d3e1163a7770164404607b7ba3967fb49b24782a6ef85d9b5f54126cc39e5"}, 376 | {file = "numpy-1.24.2-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:92011118955724465fb6853def593cf397b4a1367495e0b59a7e69d40c4eb71d"}, 377 | {file = "numpy-1.24.2-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f9006288bcf4895917d02583cf3411f98631275bc67cce355a7f39f8c14338fa"}, 378 | {file = "numpy-1.24.2-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:150947adbdfeceec4e5926d956a06865c1c690f2fd902efede4ca6fe2e657c3f"}, 379 | {file = "numpy-1.24.2.tar.gz", hash = "sha256:003a9f530e880cb2cd177cba1af7220b9aa42def9c4afc2a2fc3ee6be7eb2b22"}, 380 | ] 381 | 382 | [[package]] 383 | name = "packaging" 384 | version = "23.0" 385 | description = "Core utilities for Python packages" 386 | category = "main" 387 | optional = false 388 | python-versions = ">=3.7" 389 | files = [ 390 | {file = "packaging-23.0-py3-none-any.whl", hash = "sha256:714ac14496c3e68c99c29b00845f7a2b85f3bb6f1078fd9f72fd20f0570002b2"}, 391 | {file = "packaging-23.0.tar.gz", hash = "sha256:b6ad297f8907de0fa2fe1ccbd26fdaf387f5f47c7275fedf8cce89f99446cf97"}, 392 | ] 393 | 394 | [[package]] 395 | name = "pathspec" 396 | version = "0.11.1" 397 | description = "Utility library for gitignore style pattern matching of file paths." 398 | category = "main" 399 | optional = false 400 | python-versions = ">=3.7" 401 | files = [ 402 | {file = "pathspec-0.11.1-py3-none-any.whl", hash = "sha256:d8af70af76652554bd134c22b3e8a1cc46ed7d91edcdd721ef1a0c51a84a5293"}, 403 | {file = "pathspec-0.11.1.tar.gz", hash = "sha256:2798de800fa92780e33acca925945e9a19a133b715067cf165b8866c15a31687"}, 404 | ] 405 | 406 | [[package]] 407 | name = "pillow" 408 | version = "9.5.0" 409 | description = "Python Imaging Library (Fork)" 410 | category = "main" 411 | optional = false 412 | python-versions = ">=3.7" 413 | files = [ 414 | {file = "Pillow-9.5.0-cp310-cp310-macosx_10_10_x86_64.whl", hash = "sha256:ace6ca218308447b9077c14ea4ef381ba0b67ee78d64046b3f19cf4e1139ad16"}, 415 | {file = "Pillow-9.5.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:d3d403753c9d5adc04d4694d35cf0391f0f3d57c8e0030aac09d7678fa8030aa"}, 416 | {file = "Pillow-9.5.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5ba1b81ee69573fe7124881762bb4cd2e4b6ed9dd28c9c60a632902fe8db8b38"}, 417 | {file = "Pillow-9.5.0-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:fe7e1c262d3392afcf5071df9afa574544f28eac825284596ac6db56e6d11062"}, 418 | {file = "Pillow-9.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8f36397bf3f7d7c6a3abdea815ecf6fd14e7fcd4418ab24bae01008d8d8ca15e"}, 419 | {file = "Pillow-9.5.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:252a03f1bdddce077eff2354c3861bf437c892fb1832f75ce813ee94347aa9b5"}, 420 | {file = "Pillow-9.5.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:85ec677246533e27770b0de5cf0f9d6e4ec0c212a1f89dfc941b64b21226009d"}, 421 | {file = "Pillow-9.5.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:b416f03d37d27290cb93597335a2f85ed446731200705b22bb927405320de903"}, 422 | {file = "Pillow-9.5.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:1781a624c229cb35a2ac31cc4a77e28cafc8900733a864870c49bfeedacd106a"}, 423 | {file = "Pillow-9.5.0-cp310-cp310-win32.whl", hash = "sha256:8507eda3cd0608a1f94f58c64817e83ec12fa93a9436938b191b80d9e4c0fc44"}, 424 | {file = "Pillow-9.5.0-cp310-cp310-win_amd64.whl", hash = "sha256:d3c6b54e304c60c4181da1c9dadf83e4a54fd266a99c70ba646a9baa626819eb"}, 425 | {file = "Pillow-9.5.0-cp311-cp311-macosx_10_10_x86_64.whl", hash = "sha256:7ec6f6ce99dab90b52da21cf0dc519e21095e332ff3b399a357c187b1a5eee32"}, 426 | {file = "Pillow-9.5.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:560737e70cb9c6255d6dcba3de6578a9e2ec4b573659943a5e7e4af13f298f5c"}, 427 | {file = "Pillow-9.5.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:96e88745a55b88a7c64fa49bceff363a1a27d9a64e04019c2281049444a571e3"}, 428 | {file = "Pillow-9.5.0-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d9c206c29b46cfd343ea7cdfe1232443072bbb270d6a46f59c259460db76779a"}, 429 | {file = "Pillow-9.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cfcc2c53c06f2ccb8976fb5c71d448bdd0a07d26d8e07e321c103416444c7ad1"}, 430 | {file = "Pillow-9.5.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:a0f9bb6c80e6efcde93ffc51256d5cfb2155ff8f78292f074f60f9e70b942d99"}, 431 | {file = "Pillow-9.5.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:8d935f924bbab8f0a9a28404422da8af4904e36d5c33fc6f677e4c4485515625"}, 432 | {file = "Pillow-9.5.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:fed1e1cf6a42577953abbe8e6cf2fe2f566daebde7c34724ec8803c4c0cda579"}, 433 | {file = "Pillow-9.5.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:c1170d6b195555644f0616fd6ed929dfcf6333b8675fcca044ae5ab110ded296"}, 434 | {file = "Pillow-9.5.0-cp311-cp311-win32.whl", hash = "sha256:54f7102ad31a3de5666827526e248c3530b3a33539dbda27c6843d19d72644ec"}, 435 | {file = "Pillow-9.5.0-cp311-cp311-win_amd64.whl", hash = "sha256:cfa4561277f677ecf651e2b22dc43e8f5368b74a25a8f7d1d4a3a243e573f2d4"}, 436 | {file = "Pillow-9.5.0-cp311-cp311-win_arm64.whl", hash = "sha256:965e4a05ef364e7b973dd17fc765f42233415974d773e82144c9bbaaaea5d089"}, 437 | {file = "Pillow-9.5.0-cp312-cp312-win32.whl", hash = "sha256:22baf0c3cf0c7f26e82d6e1adf118027afb325e703922c8dfc1d5d0156bb2eeb"}, 438 | {file = "Pillow-9.5.0-cp312-cp312-win_amd64.whl", hash = "sha256:432b975c009cf649420615388561c0ce7cc31ce9b2e374db659ee4f7d57a1f8b"}, 439 | {file = "Pillow-9.5.0-cp37-cp37m-macosx_10_10_x86_64.whl", hash = "sha256:5d4ebf8e1db4441a55c509c4baa7a0587a0210f7cd25fcfe74dbbce7a4bd1906"}, 440 | {file = "Pillow-9.5.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:375f6e5ee9620a271acb6820b3d1e94ffa8e741c0601db4c0c4d3cb0a9c224bf"}, 441 | {file = "Pillow-9.5.0-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:99eb6cafb6ba90e436684e08dad8be1637efb71c4f2180ee6b8f940739406e78"}, 442 | {file = "Pillow-9.5.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2dfaaf10b6172697b9bceb9a3bd7b951819d1ca339a5ef294d1f1ac6d7f63270"}, 443 | {file = "Pillow-9.5.0-cp37-cp37m-manylinux_2_28_aarch64.whl", hash = "sha256:763782b2e03e45e2c77d7779875f4432e25121ef002a41829d8868700d119392"}, 444 | {file = "Pillow-9.5.0-cp37-cp37m-manylinux_2_28_x86_64.whl", hash = "sha256:35f6e77122a0c0762268216315bf239cf52b88865bba522999dc38f1c52b9b47"}, 445 | {file = "Pillow-9.5.0-cp37-cp37m-win32.whl", hash = "sha256:aca1c196f407ec7cf04dcbb15d19a43c507a81f7ffc45b690899d6a76ac9fda7"}, 446 | {file = "Pillow-9.5.0-cp37-cp37m-win_amd64.whl", hash = "sha256:322724c0032af6692456cd6ed554bb85f8149214d97398bb80613b04e33769f6"}, 447 | {file = "Pillow-9.5.0-cp38-cp38-macosx_10_10_x86_64.whl", hash = "sha256:a0aa9417994d91301056f3d0038af1199eb7adc86e646a36b9e050b06f526597"}, 448 | {file = "Pillow-9.5.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:f8286396b351785801a976b1e85ea88e937712ee2c3ac653710a4a57a8da5d9c"}, 449 | {file = "Pillow-9.5.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c830a02caeb789633863b466b9de10c015bded434deb3ec87c768e53752ad22a"}, 450 | {file = "Pillow-9.5.0-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:fbd359831c1657d69bb81f0db962905ee05e5e9451913b18b831febfe0519082"}, 451 | {file = "Pillow-9.5.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f8fc330c3370a81bbf3f88557097d1ea26cd8b019d6433aa59f71195f5ddebbf"}, 452 | {file = "Pillow-9.5.0-cp38-cp38-manylinux_2_28_aarch64.whl", hash = "sha256:7002d0797a3e4193c7cdee3198d7c14f92c0836d6b4a3f3046a64bd1ce8df2bf"}, 453 | {file = "Pillow-9.5.0-cp38-cp38-manylinux_2_28_x86_64.whl", hash = "sha256:229e2c79c00e85989a34b5981a2b67aa079fd08c903f0aaead522a1d68d79e51"}, 454 | {file = "Pillow-9.5.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:9adf58f5d64e474bed00d69bcd86ec4bcaa4123bfa70a65ce72e424bfb88ed96"}, 455 | {file = "Pillow-9.5.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:662da1f3f89a302cc22faa9f14a262c2e3951f9dbc9617609a47521c69dd9f8f"}, 456 | {file = "Pillow-9.5.0-cp38-cp38-win32.whl", hash = "sha256:6608ff3bf781eee0cd14d0901a2b9cc3d3834516532e3bd673a0a204dc8615fc"}, 457 | {file = "Pillow-9.5.0-cp38-cp38-win_amd64.whl", hash = "sha256:e49eb4e95ff6fd7c0c402508894b1ef0e01b99a44320ba7d8ecbabefddcc5569"}, 458 | {file = "Pillow-9.5.0-cp39-cp39-macosx_10_10_x86_64.whl", hash = "sha256:482877592e927fd263028c105b36272398e3e1be3269efda09f6ba21fd83ec66"}, 459 | {file = "Pillow-9.5.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:3ded42b9ad70e5f1754fb7c2e2d6465a9c842e41d178f262e08b8c85ed8a1d8e"}, 460 | {file = "Pillow-9.5.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c446d2245ba29820d405315083d55299a796695d747efceb5717a8b450324115"}, 461 | {file = "Pillow-9.5.0-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8aca1152d93dcc27dc55395604dcfc55bed5f25ef4c98716a928bacba90d33a3"}, 462 | {file = "Pillow-9.5.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:608488bdcbdb4ba7837461442b90ea6f3079397ddc968c31265c1e056964f1ef"}, 463 | {file = "Pillow-9.5.0-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:60037a8db8750e474af7ffc9faa9b5859e6c6d0a50e55c45576bf28be7419705"}, 464 | {file = "Pillow-9.5.0-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:07999f5834bdc404c442146942a2ecadd1cb6292f5229f4ed3b31e0a108746b1"}, 465 | {file = "Pillow-9.5.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:a127ae76092974abfbfa38ca2d12cbeddcdeac0fb71f9627cc1135bedaf9d51a"}, 466 | {file = "Pillow-9.5.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:489f8389261e5ed43ac8ff7b453162af39c3e8abd730af8363587ba64bb2e865"}, 467 | {file = "Pillow-9.5.0-cp39-cp39-win32.whl", hash = "sha256:9b1af95c3a967bf1da94f253e56b6286b50af23392a886720f563c547e48e964"}, 468 | {file = "Pillow-9.5.0-cp39-cp39-win_amd64.whl", hash = "sha256:77165c4a5e7d5a284f10a6efaa39a0ae8ba839da344f20b111d62cc932fa4e5d"}, 469 | {file = "Pillow-9.5.0-pp38-pypy38_pp73-macosx_10_10_x86_64.whl", hash = "sha256:833b86a98e0ede388fa29363159c9b1a294b0905b5128baf01db683672f230f5"}, 470 | {file = "Pillow-9.5.0-pp38-pypy38_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:aaf305d6d40bd9632198c766fb64f0c1a83ca5b667f16c1e79e1661ab5060140"}, 471 | {file = "Pillow-9.5.0-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0852ddb76d85f127c135b6dd1f0bb88dbb9ee990d2cd9aa9e28526c93e794fba"}, 472 | {file = "Pillow-9.5.0-pp38-pypy38_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:91ec6fe47b5eb5a9968c79ad9ed78c342b1f97a091677ba0e012701add857829"}, 473 | {file = "Pillow-9.5.0-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:cb841572862f629b99725ebaec3287fc6d275be9b14443ea746c1dd325053cbd"}, 474 | {file = "Pillow-9.5.0-pp39-pypy39_pp73-macosx_10_10_x86_64.whl", hash = "sha256:c380b27d041209b849ed246b111b7c166ba36d7933ec6e41175fd15ab9eb1572"}, 475 | {file = "Pillow-9.5.0-pp39-pypy39_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7c9af5a3b406a50e313467e3565fc99929717f780164fe6fbb7704edba0cebbe"}, 476 | {file = "Pillow-9.5.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5671583eab84af046a397d6d0ba25343c00cd50bce03787948e0fff01d4fd9b1"}, 477 | {file = "Pillow-9.5.0-pp39-pypy39_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:84a6f19ce086c1bf894644b43cd129702f781ba5751ca8572f08aa40ef0ab7b7"}, 478 | {file = "Pillow-9.5.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:1e7723bd90ef94eda669a3c2c19d549874dd5badaeefabefd26053304abe5799"}, 479 | {file = "Pillow-9.5.0.tar.gz", hash = "sha256:bf548479d336726d7a0eceb6e767e179fbde37833ae42794602631a070d630f1"}, 480 | ] 481 | 482 | [package.extras] 483 | docs = ["furo", "olefile", "sphinx (>=2.4)", "sphinx-copybutton", "sphinx-inline-tabs", "sphinx-removed-in", "sphinxext-opengraph"] 484 | tests = ["check-manifest", "coverage", "defusedxml", "markdown2", "olefile", "packaging", "pyroma", "pytest", "pytest-cov", "pytest-timeout"] 485 | 486 | [[package]] 487 | name = "platformdirs" 488 | version = "3.2.0" 489 | description = "A small Python package for determining appropriate platform-specific dirs, e.g. a \"user data dir\"." 490 | category = "main" 491 | optional = false 492 | python-versions = ">=3.7" 493 | files = [ 494 | {file = "platformdirs-3.2.0-py3-none-any.whl", hash = "sha256:ebe11c0d7a805086e99506aa331612429a72ca7cd52a1f0d277dc4adc20cb10e"}, 495 | {file = "platformdirs-3.2.0.tar.gz", hash = "sha256:d5b638ca397f25f979350ff789db335903d7ea010ab28903f57b27e1b16c2b08"}, 496 | ] 497 | 498 | [package.extras] 499 | docs = ["furo (>=2022.12.7)", "proselint (>=0.13)", "sphinx (>=6.1.3)", "sphinx-autodoc-typehints (>=1.22,!=1.23.4)"] 500 | test = ["appdirs (==1.4.4)", "covdefaults (>=2.3)", "pytest (>=7.2.2)", "pytest-cov (>=4)", "pytest-mock (>=3.10)"] 501 | 502 | [[package]] 503 | name = "pyparsing" 504 | version = "3.0.9" 505 | description = "pyparsing module - Classes and methods to define and execute parsing grammars" 506 | category = "main" 507 | optional = false 508 | python-versions = ">=3.6.8" 509 | files = [ 510 | {file = "pyparsing-3.0.9-py3-none-any.whl", hash = "sha256:5026bae9a10eeaefb61dab2f09052b9f4307d44aee4eda64b309723d8d206bbc"}, 511 | {file = "pyparsing-3.0.9.tar.gz", hash = "sha256:2b020ecf7d21b687f219b71ecad3631f644a47f01403fa1d1036b0c6416d70fb"}, 512 | ] 513 | 514 | [package.extras] 515 | diagrams = ["jinja2", "railroad-diagrams"] 516 | 517 | [[package]] 518 | name = "python-dateutil" 519 | version = "2.8.2" 520 | description = "Extensions to the standard Python datetime module" 521 | category = "main" 522 | optional = false 523 | python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7" 524 | files = [ 525 | {file = "python-dateutil-2.8.2.tar.gz", hash = "sha256:0123cacc1627ae19ddf3c27a5de5bd67ee4586fbdd6440d9748f8abb483d3e86"}, 526 | {file = "python_dateutil-2.8.2-py2.py3-none-any.whl", hash = "sha256:961d03dc3453ebbc59dbdea9e4e11c5651520a876d0f4db161e8674aae935da9"}, 527 | ] 528 | 529 | [package.dependencies] 530 | six = ">=1.5" 531 | 532 | [[package]] 533 | name = "six" 534 | version = "1.16.0" 535 | description = "Python 2 and 3 compatibility utilities" 536 | category = "main" 537 | optional = false 538 | python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*" 539 | files = [ 540 | {file = "six-1.16.0-py2.py3-none-any.whl", hash = "sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254"}, 541 | {file = "six-1.16.0.tar.gz", hash = "sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926"}, 542 | ] 543 | 544 | [[package]] 545 | name = "tomli" 546 | version = "2.0.1" 547 | description = "A lil' TOML parser" 548 | category = "main" 549 | optional = false 550 | python-versions = ">=3.7" 551 | files = [ 552 | {file = "tomli-2.0.1-py3-none-any.whl", hash = "sha256:939de3e7a6161af0c887ef91b7d41a53e7c5a1ca976325f429cb46ea9bc30ecc"}, 553 | {file = "tomli-2.0.1.tar.gz", hash = "sha256:de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f"}, 554 | ] 555 | 556 | [metadata] 557 | lock-version = "2.0" 558 | python-versions = "^3.10.8" 559 | content-hash = "e8a823794405d03282a7c42f616cf76a51891ffb94a44505cf6d6eddc854fc7e" 560 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=64", "wheel"] 3 | build-backend = 'setuptools.build_meta' 4 | 5 | [tool.setuptools] 6 | package-dir = {"" = "src"} 7 | 8 | [project] 9 | name = "mppi_playground" 10 | version = "0.1.0" 11 | description = "" 12 | requires-python = ">=3.10" 13 | dependencies = [ 14 | "matplotlib==3.8.2", 15 | "fire==0.5.0", 16 | "numpy==1.26.2", 17 | "torch==2.00", 18 | "torchvision==0.15.1", 19 | "gymnasium[all]==0.29.1", 20 | "mujoco==2.3.7", 21 | # "pybullet==3.2.5", 22 | ] 23 | 24 | [project.optional-dependencies] 25 | dev = [ 26 | "pytest", 27 | "pysen", 28 | "black", 29 | "flake8", 30 | "isort", 31 | "mypy", 32 | ] -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/proxima-technology/mppi_playground/12cb79d02c703f818cc9c49d2436f79fd4db7e9f/src/__init__.py -------------------------------------------------------------------------------- /src/controller/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/proxima-technology/mppi_playground/12cb79d02c703f818cc9c49d2436f79fd4db7e9f/src/controller/__init__.py -------------------------------------------------------------------------------- /src/controller/mppi.py: -------------------------------------------------------------------------------- 1 | """ 2 | Kohei Honda, 2023. 3 | """ 4 | 5 | from __future__ import annotations 6 | 7 | from typing import Callable, Tuple 8 | 9 | import torch 10 | import torch.nn as nn 11 | from torch.distributions.multivariate_normal import MultivariateNormal 12 | 13 | 14 | class MPPI(nn.Module): 15 | """ 16 | Model Predictive Path Integral Control, 17 | J. Williams et al., T-RO, 2017. 18 | """ 19 | 20 | def __init__( 21 | self, 22 | horizon: int, 23 | num_samples: int, 24 | dim_state: int, 25 | dim_control: int, 26 | dynamics: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], 27 | stage_cost: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], 28 | terminal_cost: Callable[[torch.Tensor], torch.Tensor], 29 | u_min: torch.Tensor, 30 | u_max: torch.Tensor, 31 | sigmas: torch.Tensor, 32 | lambda_: float, 33 | device=torch.device("cuda"), 34 | dtype=torch.float32, 35 | seed: int = 42, 36 | ) -> None: 37 | """ 38 | :param horizon: Predictive horizon length. 39 | :param delta: predictive horizon step size (seconds). 40 | :param num_samples: Number of samples. 41 | :param dim_state: Dimension of state. 42 | :param dim_control: Dimension of control. 43 | :param dynamics: Dynamics model. 44 | :param stage_cost: Stage cost. 45 | :param terminal_cost: Terminal cost. 46 | :param u_min: Minimum control. 47 | :param u_max: Maximum control. 48 | :param sigmas: Noise standard deviation for each control dimension. 49 | :param lambda_: temperature parameter. 50 | :param device: Device to run the solver. 51 | :param dtype: Data type to run the solver. 52 | :param seed: Seed for torch. 53 | """ 54 | 55 | super().__init__() 56 | 57 | # torch seed 58 | torch.manual_seed(seed) 59 | 60 | # check dimensions 61 | assert u_min.shape == (dim_control,) 62 | assert u_max.shape == (dim_control,) 63 | assert sigmas.shape == (dim_control,) 64 | # assert num_samples % batch_size == 0 and num_samples >= batch_size 65 | 66 | # device and dtype 67 | if torch.cuda.is_available() and device == torch.device("cuda"): 68 | self._device = torch.device("cuda") 69 | else: 70 | self._device = torch.device("cpu") 71 | self._dtype = dtype 72 | 73 | # set parameters 74 | self._horizon = horizon 75 | self._num_samples = num_samples 76 | self._dim_state = dim_state 77 | self._dim_control = dim_control 78 | self._dynamics = dynamics 79 | self._stage_cost = stage_cost 80 | self._terminal_cost = terminal_cost 81 | self._u_min = u_min.clone().detach().to(self._device, self._dtype) 82 | self._u_max = u_max.clone().detach().to(self._device, self._dtype) 83 | self._sigmas = sigmas.clone().detach().to(self._device, self._dtype) 84 | self._lambda = lambda_ 85 | 86 | # noise distribution 87 | zero_mean = torch.zeros(dim_control, device=self._device, dtype=self._dtype) 88 | initial_covariance = torch.diag(sigmas**2).to(self._device, self._dtype) 89 | self._inv_covariance = torch.inverse(initial_covariance).to( 90 | self._device, self._dtype 91 | ) 92 | 93 | self._noise_distribution = MultivariateNormal( 94 | loc=zero_mean, covariance_matrix=initial_covariance 95 | ) 96 | self._sample_shape = torch.Size([self._num_samples, self._horizon]) 97 | 98 | # sampling with reparameting trick 99 | self._action_noises = self._noise_distribution.rsample( 100 | sample_shape=self._sample_shape 101 | ) 102 | 103 | zero_mean_seq = torch.zeros( 104 | self._horizon, self._dim_control, device=self._device, dtype=self._dtype 105 | ) 106 | self._perturbed_action_seqs = torch.clamp( 107 | zero_mean_seq + self._action_noises, self._u_min, self._u_max 108 | ) 109 | 110 | self._previous_action_seq = zero_mean_seq 111 | 112 | # inner variables 113 | self._state_seq_batch = torch.zeros( 114 | self._num_samples, 115 | self._horizon + 1, 116 | self._dim_state, 117 | device=self._device, 118 | dtype=self._dtype, 119 | ) 120 | self._weights = torch.zeros( 121 | self._num_samples, device=self._device, dtype=self._dtype 122 | ) 123 | 124 | def forward(self, state: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 125 | """ 126 | Solve the optimal control problem. 127 | Args: 128 | state (torch.Tensor): Current state. 129 | Returns: 130 | Tuple[torch.Tensor, torch.Tensor]: Tuple of predictive control and state sequence. 131 | """ 132 | assert state.shape == (self._dim_state,) 133 | 134 | if not torch.is_tensor(state): 135 | state = torch.tensor(state, device=self._device, dtype=self._dtype) 136 | else: 137 | if state.device != self._device or state.dtype != self._dtype: 138 | state = state.to(self._device, self._dtype) 139 | 140 | mean_action_seq = self._previous_action_seq.clone().detach() 141 | 142 | # random sampling with reparametrization trick 143 | self._action_noises = self._noise_distribution.rsample( 144 | sample_shape=self._sample_shape 145 | ) 146 | self._perturbed_action_seqs = mean_action_seq + self._action_noises 147 | 148 | # clamp actions 149 | self._perturbed_action_seqs = torch.clamp( 150 | self._perturbed_action_seqs, self._u_min, self._u_max 151 | ) 152 | 153 | # rollout samples in parallel 154 | self._state_seq_batch[:, 0, :] = state.repeat(self._num_samples, 1) 155 | 156 | for t in range(self._horizon): 157 | self._state_seq_batch[:, t + 1, :] = self._dynamics( 158 | self._state_seq_batch[:, t, :], self._perturbed_action_seqs[:, t, :] 159 | ) 160 | 161 | # compute sample costs 162 | stage_costs = torch.zeros( 163 | self._num_samples, self._horizon, device=self._device, dtype=self._dtype 164 | ) 165 | action_costs = torch.zeros( 166 | self._num_samples, self._horizon, device=self._device, dtype=self._dtype 167 | ) 168 | for t in range(self._horizon): 169 | stage_costs[:, t] = self._stage_cost( 170 | self._state_seq_batch[:, t, :], self._perturbed_action_seqs[:, t, :] 171 | ) 172 | action_costs[:, t] = ( 173 | mean_action_seq[t] 174 | @ self._inv_covariance 175 | @ self._perturbed_action_seqs[:, t].T 176 | ) 177 | 178 | terminal_costs = self._terminal_cost(self._state_seq_batch[:, -1, :]) 179 | 180 | costs = ( 181 | torch.sum(stage_costs, dim=1) 182 | + terminal_costs 183 | + torch.sum(self._lambda * action_costs, dim=1) 184 | ) 185 | 186 | # calculate weights 187 | self._weights = torch.softmax(-costs / self._lambda, dim=0) 188 | 189 | # find optimal control by weighted average 190 | optimal_action_seq = torch.sum( 191 | self._weights.view(self._num_samples, 1, 1) * self._perturbed_action_seqs, 192 | dim=0, 193 | ) 194 | 195 | # predivtive state seq 196 | optimal_state_seq = torch.zeros( 197 | 1, 198 | self._horizon + 1, 199 | self._dim_state, 200 | device=self._device, 201 | dtype=self._dtype, 202 | ) 203 | optimal_state_seq[:, 0, :] = state 204 | expanded_optimal_action_seq = optimal_action_seq.repeat(1, 1, 1) 205 | for t in range(self._horizon): 206 | optimal_state_seq[:, t + 1, :] = self._dynamics( 207 | optimal_state_seq[:, t, :], expanded_optimal_action_seq[:, t, :] 208 | ) 209 | 210 | # update previous actions 211 | self._previous_action_seq = optimal_action_seq 212 | 213 | return optimal_action_seq, optimal_state_seq 214 | 215 | def get_top_samples(self, num_samples: int) -> Tuple[torch.Tensor, torch.Tensor]: 216 | """ 217 | Get top samples. 218 | Args: 219 | num_samples (int): Number of state samples to get. 220 | Returns: 221 | Tuple[torch.Tensor, torch.Tensor]: Tuple of top samples and their weights. 222 | """ 223 | assert num_samples <= self._num_samples 224 | 225 | # large weights are better 226 | top_indices = torch.topk(self._weights, num_samples).indices 227 | 228 | top_samples = self._state_seq_batch[top_indices] 229 | top_weights = self._weights[top_indices] 230 | 231 | top_samples = top_samples[torch.argsort(top_weights, descending=True)] 232 | top_weights = top_weights[torch.argsort(top_weights, descending=True)] 233 | 234 | return top_samples, top_weights 235 | -------------------------------------------------------------------------------- /src/envs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/proxima-technology/mppi_playground/12cb79d02c703f818cc9c49d2436f79fd4db7e9f/src/envs/__init__.py -------------------------------------------------------------------------------- /src/envs/navigation_2d.py: -------------------------------------------------------------------------------- 1 | """ 2 | Kohei Honda, 2023. 3 | """ 4 | 5 | from __future__ import annotations 6 | 7 | from typing import Tuple, Union 8 | from matplotlib import pyplot as plt 9 | 10 | import torch 11 | import numpy as np 12 | import os 13 | 14 | 15 | from moviepy.video.io.ImageSequenceClip import ImageSequenceClip 16 | 17 | from envs.obstacle_map_2d import ObstacleMap, generate_random_obstacles 18 | 19 | 20 | @torch.jit.script 21 | def angle_normalize(x): 22 | return ((x + torch.pi) % (2 * torch.pi)) - torch.pi 23 | 24 | 25 | class Navigation2DEnv: 26 | def __init__( 27 | self, device=torch.device("cuda"), dtype=torch.float32, seed: int = 42 28 | ) -> None: 29 | # device and dtype 30 | if torch.cuda.is_available() and device == torch.device("cuda"): 31 | self._device = torch.device("cuda") 32 | else: 33 | self._device = torch.device("cpu") 34 | self._dtype = dtype 35 | 36 | self._obstacle_map = ObstacleMap( 37 | map_size=(20, 20), cell_size=0.1, device=self._device, dtype=self._dtype 38 | ) 39 | self._seed = seed 40 | 41 | generate_random_obstacles( 42 | obstacle_map=self._obstacle_map, 43 | random_x_range=(-7.5, 7.5), 44 | random_y_range=(-7.5, 7.5), 45 | num_circle_obs=7, 46 | radius_range=(1, 1), 47 | num_rectangle_obs=7, 48 | width_range=(2, 2), 49 | height_range=(2, 2), 50 | max_iteration=1000, 51 | seed=seed, 52 | ) 53 | self._obstacle_map.convert_to_torch() 54 | 55 | self._start_pos = torch.tensor( 56 | [-9.0, -9.0], device=self._device, dtype=self._dtype 57 | ) 58 | self._goal_pos = torch.tensor( 59 | [9.0, 9.0], device=self._device, dtype=self._dtype 60 | ) 61 | 62 | self._robot_state = torch.zeros(3, device=self._device, dtype=self._dtype) 63 | self._robot_state[:2] = self._start_pos 64 | self._robot_state[2] = angle_normalize( 65 | torch.atan2( 66 | self._goal_pos[1] - self._start_pos[1], 67 | self._goal_pos[0] - self._start_pos[0], 68 | ) 69 | ) 70 | 71 | # u: [v, omega] (m/s, rad/s) 72 | self.u_min = torch.tensor([0.0, -1.0], device=self._device, dtype=self._dtype) 73 | self.u_max = torch.tensor([2.0, 1.0], device=self._device, dtype=self._dtype) 74 | 75 | def reset(self) -> torch.Tensor: 76 | """ 77 | Reset robot state. 78 | Returns: 79 | torch.Tensor: shape (3,) [x, y, theta] 80 | """ 81 | self._robot_state[:2] = self._start_pos 82 | self._robot_state[2] = angle_normalize( 83 | torch.atan2( 84 | self._goal_pos[1] - self._start_pos[1], 85 | self._goal_pos[0] - self._start_pos[0], 86 | ) 87 | ) 88 | 89 | self._fig = plt.figure(layout="tight") 90 | self._ax = self._fig.add_subplot() 91 | self._ax.set_xlim(self._obstacle_map.x_lim) 92 | self._ax.set_ylim(self._obstacle_map.y_lim) 93 | self._ax.set_aspect("equal") 94 | 95 | self._rendered_frames = [] 96 | 97 | return self._robot_state 98 | 99 | def step(self, u: torch.Tensor) -> Tuple[torch.Tensor, bool]: 100 | """ 101 | Update robot state based on differential drive dynamics. 102 | Args: 103 | u (torch.Tensor): control batch tensor, shape (2) [v, omega] 104 | Returns: 105 | Tuple[torch.Tensor, bool]: Tuple of robot state and is goal reached. 106 | """ 107 | u = torch.clamp(u, self.u_min, self.u_max) 108 | 109 | self._robot_state = self.dynamics( 110 | state=self._robot_state.unsqueeze(0), action=u.unsqueeze(0) 111 | ).squeeze(0) 112 | 113 | # goal check 114 | goal_threshold = 0.5 115 | is_goal_reached = ( 116 | torch.norm(self._robot_state[:2] - self._goal_pos) < goal_threshold 117 | ) 118 | 119 | return self._robot_state, is_goal_reached 120 | 121 | def render( 122 | self, 123 | predicted_trajectory: torch.Tensor = None, 124 | is_collisions: torch.Tensor = None, 125 | top_samples: Tuple[torch.Tensor, torch.Tensor] = None, 126 | mode: str = "human", 127 | ) -> None: 128 | self._ax.set_xlabel("x [m]") 129 | self._ax.set_ylabel("y [m]") 130 | 131 | # obstacle map 132 | self._obstacle_map.render(self._ax, zorder=10) 133 | 134 | # start and goal 135 | self._ax.scatter( 136 | self._start_pos[0].item(), 137 | self._start_pos[1].item(), 138 | marker="o", 139 | color="red", 140 | zorder=10, 141 | ) 142 | self._ax.scatter( 143 | self._goal_pos[0].item(), 144 | self._goal_pos[1].item(), 145 | marker="o", 146 | color="orange", 147 | zorder=10, 148 | ) 149 | 150 | # robot 151 | self._ax.scatter( 152 | self._robot_state[0].item(), 153 | self._robot_state[1].item(), 154 | marker="o", 155 | color="green", 156 | zorder=100, 157 | ) 158 | 159 | # visualize top samples with different alpha based on weights 160 | if top_samples is not None: 161 | top_samples, top_weights = top_samples 162 | top_samples = top_samples.cpu().numpy() 163 | top_weights = top_weights.cpu().numpy() 164 | top_weights = 0.7 * top_weights / np.max(top_weights) 165 | top_weights = np.clip(top_weights, 0.1, 0.7) 166 | for i in range(top_samples.shape[0]): 167 | self._ax.plot( 168 | top_samples[i, :, 0], 169 | top_samples[i, :, 1], 170 | color="lightblue", 171 | alpha=top_weights[i], 172 | zorder=1, 173 | ) 174 | 175 | # predicted trajectory 176 | if predicted_trajectory is not None: 177 | # if is collision color is red 178 | colors = np.array(["darkblue"] * predicted_trajectory.shape[1]) 179 | if is_collisions is not None: 180 | is_collisions = is_collisions.cpu().numpy() 181 | is_collisions = np.any(is_collisions, axis=0) 182 | colors[is_collisions] = "red" 183 | 184 | self._ax.scatter( 185 | predicted_trajectory[0, :, 0].cpu().numpy(), 186 | predicted_trajectory[0, :, 1].cpu().numpy(), 187 | color=colors, 188 | marker="o", 189 | s=3, 190 | zorder=2, 191 | ) 192 | 193 | if mode == "human": 194 | # online rendering 195 | plt.pause(0.001) 196 | plt.cla() 197 | elif mode == "rgb_array": 198 | # offline rendering for video 199 | # TODO: high resolution rendering 200 | self._fig.canvas.draw() 201 | data = np.frombuffer(self._fig.canvas.tostring_rgb(), dtype=np.uint8) 202 | data = data.reshape(self._fig.canvas.get_width_height()[::-1] + (3,)) 203 | plt.cla() 204 | self._rendered_frames.append(data) 205 | 206 | def close(self, path: str = None) -> None: 207 | if path is None: 208 | # mkdir video if not exists 209 | 210 | if not os.path.exists("video"): 211 | os.mkdir("video") 212 | path = "video/" + "navigation_2d_" + str(self._seed) + ".gif" 213 | 214 | if len(self._rendered_frames) > 0: 215 | # save animation 216 | clip = ImageSequenceClip(self._rendered_frames, fps=10) 217 | # clip.write_videofile(path, fps=10) 218 | clip.write_gif(path, fps=10) 219 | 220 | def dynamics( 221 | self, state: torch.Tensor, action: torch.Tensor, delta_t: float = 0.1 222 | ) -> torch.Tensor: 223 | """ 224 | Update robot state based on differential drive dynamics. 225 | Args: 226 | state (torch.Tensor): state batch tensor, shape (batch_size, 3) [x, y, theta] 227 | action (torch.Tensor): control batch tensor, shape (batch_size, 2) [v, omega] 228 | delta_t (float): time step interval [s] 229 | Returns: 230 | torch.Tensor: shape (batch_size, 3) [x, y, theta] 231 | """ 232 | 233 | # Perform calculations as before 234 | x = state[:, 0].view(-1, 1) 235 | y = state[:, 1].view(-1, 1) 236 | theta = state[:, 2].view(-1, 1) 237 | v = torch.clamp(action[:, 0].view(-1, 1), self.u_min[0], self.u_max[0]) 238 | omega = torch.clamp(action[:, 1].view(-1, 1), self.u_min[1], self.u_max[1]) 239 | theta = angle_normalize(theta) 240 | 241 | new_x = x + v * torch.cos(theta) * delta_t 242 | new_y = y + v * torch.sin(theta) * delta_t 243 | new_theta = angle_normalize(theta + omega * delta_t) 244 | 245 | # Clamp x and y to the map boundary 246 | x_lim = torch.tensor( 247 | self._obstacle_map.x_lim, device=self._device, dtype=self._dtype 248 | ) 249 | y_lim = torch.tensor( 250 | self._obstacle_map.y_lim, device=self._device, dtype=self._dtype 251 | ) 252 | clamped_x = torch.clamp(new_x, x_lim[0], x_lim[1]) 253 | clamped_y = torch.clamp(new_y, y_lim[0], y_lim[1]) 254 | 255 | result = torch.cat([clamped_x, clamped_y, new_theta], dim=1) 256 | 257 | return result 258 | 259 | def stage_cost(self, state: torch.Tensor, action: torch.Tensor) -> torch.Tensor: 260 | """ 261 | Calculate stage cost. 262 | Args: 263 | state (torch.Tensor): state batch tensor, shape (batch_size, 3) [x, y, theta] 264 | action (torch.Tensor): control batch tensor, shape (batch_size, 2) [v, omega] 265 | Returns: 266 | torch.Tensor: shape (batch_size,) 267 | """ 268 | 269 | goal_cost = torch.norm(state[:, :2] - self._goal_pos, dim=1) 270 | 271 | pos_batch = state[:, :2].unsqueeze(1) # (batch_size, 1, 2) 272 | 273 | obstacle_cost = self._obstacle_map.compute_cost(pos_batch).squeeze( 274 | 1 275 | ) # (batch_size,) 276 | 277 | cost = goal_cost + 10000 * obstacle_cost 278 | 279 | return cost 280 | 281 | def terminal_cost(self, state: torch.Tensor) -> torch.Tensor: 282 | """ 283 | Calculate terminal cost. 284 | Args: 285 | x (torch.Tensor): state batch tensor, shape (batch_size, 3) [x, y, theta] 286 | Returns: 287 | torch.Tensor: shape (batch_size,) 288 | """ 289 | zero_action = torch.zeros_like(state[:, :2]) 290 | return self.stage_cost(state=state, action=torch.zeros_like(zero_action)) 291 | 292 | def collision_check(self, state: torch.Tensor) -> torch.Tensor: 293 | """ 294 | 295 | Args: 296 | state (torch.Tensor): state batch tensor, shape (batch_size, traj_size , 3) [x, y, theta] 297 | Returns: 298 | torch.Tensor: shape (batch_size,) 299 | """ 300 | pos_batch = state[:, :, :2] 301 | is_collisions = self._obstacle_map.compute_cost(pos_batch).squeeze(1) 302 | return is_collisions 303 | -------------------------------------------------------------------------------- /src/envs/obstacle_map_2d.py: -------------------------------------------------------------------------------- 1 | """ 2 | Kohei Honda, 2023. 3 | """ 4 | 5 | from __future__ import annotations 6 | 7 | from typing import Callable, Tuple, List, Union 8 | from dataclasses import dataclass 9 | from math import ceil 10 | from matplotlib import pyplot as plt 11 | import torch 12 | import numpy as np 13 | 14 | 15 | @dataclass 16 | class CircleObstacle: 17 | """ 18 | Circle obstacle used in the obstacle map. 19 | """ 20 | 21 | center: np.ndarray 22 | radius: float 23 | 24 | def __init__(self, center: np.ndarray, radius: float) -> None: 25 | self.center = center 26 | self.radius = radius 27 | 28 | 29 | @dataclass 30 | class RectangleObstacle: 31 | """ 32 | Rectangle obstacle used in the obstacle map. 33 | Not consider angle for now. 34 | """ 35 | 36 | center: np.ndarray 37 | width: float 38 | height: float 39 | 40 | def __init__(self, center: np.ndarray, width: float, height: float) -> None: 41 | self.center = center 42 | self.width = width 43 | self.height = height 44 | 45 | 46 | class ObstacleMap: 47 | """ 48 | Obstacle map represented by a grid. 49 | """ 50 | 51 | def __init__( 52 | self, 53 | map_size: Tuple[int, int] = (20, 20), 54 | cell_size: float = 0.01, 55 | device=torch.device("cuda"), 56 | dtype=torch.float32, 57 | ) -> None: 58 | """ 59 | map_size: (width, height) [m], origin is at the center 60 | cell_size: (m) 61 | """ 62 | # device and dtype 63 | if torch.cuda.is_available() and device == torch.device("cuda"): 64 | self._device = torch.device("cuda") 65 | else: 66 | self._device = torch.device("cpu") 67 | self._dtype = dtype 68 | 69 | assert len(map_size) == 2 70 | assert cell_size > 0 71 | assert map_size[0] % 2 == 0 72 | assert map_size[1] % 2 == 0 73 | 74 | cell_map_dim = [0, 0] 75 | cell_map_dim[0] = ceil(map_size[0] / cell_size) 76 | cell_map_dim[1] = ceil(map_size[1] / cell_size) 77 | 78 | self._map = np.zeros(cell_map_dim) 79 | self._cell_size = cell_size 80 | 81 | # cell map center 82 | self._cell_map_origin = np.zeros(2) 83 | self._cell_map_origin = np.array( 84 | [cell_map_dim[0] / 2, cell_map_dim[1] / 2] 85 | ).astype(int) 86 | 87 | self._torch_cell_map_origin = torch.from_numpy(self._cell_map_origin).to( 88 | self._device, self._dtype 89 | ) 90 | 91 | # limit of the map 92 | x_range = self._cell_size * self._map.shape[0] 93 | y_range = self._cell_size * self._map.shape[1] 94 | self.x_lim = [-x_range / 2, x_range / 2] # [m] 95 | self.y_lim = [-y_range / 2, y_range / 2] # [m] 96 | 97 | # Inner variables 98 | self._map_torch: torch.Tensor = None # use to collision check on GPU 99 | self.circle_obs_list: List[CircleObstacle] = [] # use to visualize 100 | self.rectangle_obs_list: List[RectangleObstacle] = [] # use to visualize 101 | 102 | def add_circle_obstacle(self, center: np.ndarray, radius: float) -> None: 103 | """ 104 | Add a circle obstacle to the map. 105 | :param center: Center of the circle obstacle. 106 | :param radius: Radius of the circle obstacle. 107 | """ 108 | assert len(center) == 2 109 | assert radius > 0 110 | 111 | # convert to cell map 112 | center_occ = (center / self._cell_size) + self._cell_map_origin 113 | center_occ = np.round(center_occ).astype(int) 114 | radius_occ = ceil(radius / self._cell_size) 115 | 116 | # add to occ map 117 | for i in range(-radius_occ, radius_occ + 1): 118 | for j in range(-radius_occ, radius_occ + 1): 119 | if i**2 + j**2 <= radius_occ**2: 120 | i_bounded = np.clip(center_occ[0] + i, 0, self._map.shape[0] - 1) 121 | j_bounded = np.clip(center_occ[1] + j, 0, self._map.shape[1] - 1) 122 | self._map[i_bounded, j_bounded] = 1 123 | 124 | # add to circle obstacle list to use visualize 125 | self.circle_obs_list.append(CircleObstacle(center, radius)) 126 | 127 | def add_rectangle_obstacle( 128 | self, center: np.ndarray, width: float, height: float 129 | ) -> None: 130 | """ 131 | Add a rectangle obstacle to the map. 132 | :param center: Center of the rectangle obstacle. 133 | :param width: Width of the rectangle obstacle. 134 | :param height: Height of the rectangle obstacle. 135 | """ 136 | assert len(center) == 2 137 | assert width > 0 138 | assert height > 0 139 | 140 | # convert to cell map 141 | center_occ = (center / self._cell_size) + self._cell_map_origin 142 | center_occ = np.ceil(center_occ).astype(int) 143 | width_occ = ceil(width / self._cell_size) 144 | height_occ = ceil(height / self._cell_size) 145 | 146 | # add to occ map 147 | x_init = center_occ[0] - ceil(height_occ / 2) 148 | x_end = center_occ[0] + ceil(height_occ / 2) 149 | y_init = center_occ[1] - ceil(width_occ / 2) 150 | y_end = center_occ[1] + ceil(width_occ / 2) 151 | 152 | # # deal with out of bound 153 | x_init = np.clip(x_init, 0, self._map.shape[0] - 1) 154 | x_end = np.clip(x_end, 0, self._map.shape[0] - 1) 155 | y_init = np.clip(y_init, 0, self._map.shape[1] - 1) 156 | y_end = np.clip(y_end, 0, self._map.shape[1] - 1) 157 | 158 | self._map[x_init:x_end, y_init:y_end] = 1 159 | 160 | # add to rectangle obstacle list to use visualize 161 | self.rectangle_obs_list.append(RectangleObstacle(center, width, height)) 162 | 163 | def convert_to_torch(self) -> torch.Tensor: 164 | self._map_torch = torch.from_numpy(self._map).to(self._device, self._dtype) 165 | return self._map_torch 166 | 167 | def compute_cost(self, x: torch.Tensor) -> torch.Tensor: 168 | """ 169 | Check collision in a batch of trajectories. 170 | :param x: Tensor of shape (batch_size, traj_length, position_dim). 171 | :return: collsion costs on the trajectories. 172 | """ 173 | assert self._map_torch is not None 174 | if x.device != self._device or x.dtype != self._dtype: 175 | x = x.to(self._device, self._dtype) 176 | 177 | # project to cell map 178 | x_occ = (x / self._cell_size) + self._torch_cell_map_origin 179 | x_occ = torch.round(x_occ).long().to(self._device) 180 | 181 | # deal with out of bound 182 | is_out_of_bound = torch.logical_or( 183 | torch.logical_or( 184 | x_occ[..., 0] < 0, x_occ[..., 0] >= self._map_torch.shape[0] 185 | ), 186 | torch.logical_or( 187 | x_occ[..., 1] < 0, x_occ[..., 1] >= self._map_torch.shape[1] 188 | ), 189 | ) 190 | x_occ[..., 0] = torch.clamp(x_occ[..., 0], 0, self._map_torch.shape[0] - 1) 191 | x_occ[..., 1] = torch.clamp(x_occ[..., 1], 0, self._map_torch.shape[1] - 1) 192 | 193 | # collision check 194 | collisions = self._map_torch[x_occ[..., 0], x_occ[..., 1]] 195 | 196 | # out of bound cost 197 | collisions[is_out_of_bound] = 1.0 198 | 199 | return collisions 200 | 201 | def render_occupancy(self, ax, cmap="binary") -> None: 202 | ax.imshow(self._map, cmap=cmap) 203 | 204 | def render(self, ax, zorder: int = 0) -> None: 205 | """ 206 | Render in continuous space. 207 | """ 208 | ax.set_xlim(self.x_lim) 209 | ax.set_ylim(self.y_lim) 210 | ax.set_aspect("equal") 211 | 212 | # render circle obstacles 213 | for circle_obs in self.circle_obs_list: 214 | ax.add_patch( 215 | plt.Circle( 216 | circle_obs.center, circle_obs.radius, color="gray", zorder=zorder 217 | ) 218 | ) 219 | 220 | # render rectangle obstacles 221 | for rectangle_obs in self.rectangle_obs_list: 222 | ax.add_patch( 223 | plt.Rectangle( 224 | rectangle_obs.center 225 | - np.array([rectangle_obs.width / 2, rectangle_obs.height / 2]), 226 | rectangle_obs.width, 227 | rectangle_obs.height, 228 | color="gray", 229 | zorder=zorder, 230 | ) 231 | ) 232 | 233 | 234 | def generate_random_obstacles( 235 | obstacle_map: ObstacleMap, 236 | random_x_range: Tuple[float, float], 237 | random_y_range: Tuple[float, float], 238 | num_circle_obs: int, 239 | radius_range: Tuple[float, float], 240 | num_rectangle_obs: int, 241 | width_range: Tuple[float, float], 242 | height_range: Tuple[float, float], 243 | max_iteration: int, 244 | seed: int, 245 | ) -> None: 246 | """ 247 | Generate random obstacles. 248 | """ 249 | rng = np.random.default_rng(seed) 250 | 251 | # if random range is larger than map size, use map size 252 | if random_x_range[0] < obstacle_map.x_lim[0]: 253 | random_x_range[0] = obstacle_map.x_lim[0] 254 | if random_x_range[1] > obstacle_map.x_lim[1]: 255 | random_x_range[1] = obstacle_map.x_lim[1] 256 | if random_y_range[0] < obstacle_map.y_lim[0]: 257 | random_y_range[0] = obstacle_map.y_lim[0] 258 | if random_y_range[1] > obstacle_map.y_lim[1]: 259 | random_y_range[1] = obstacle_map.y_lim[1] 260 | 261 | for i in range(num_circle_obs): 262 | num_trial = 0 263 | while num_trial < max_iteration: 264 | center_x = rng.uniform(random_x_range[0], random_x_range[1]) 265 | center_y = rng.uniform(random_y_range[0], random_y_range[1]) 266 | center = np.array([center_x, center_y]) 267 | radius = rng.uniform(radius_range[0], radius_range[1]) 268 | 269 | # overlap check 270 | is_overlap = False 271 | for circle_obs in obstacle_map.circle_obs_list: 272 | if ( 273 | np.linalg.norm(circle_obs.center - center) 274 | <= circle_obs.radius + radius 275 | ): 276 | is_overlap = True 277 | 278 | for rectangle_obs in obstacle_map.rectangle_obs_list: 279 | if ( 280 | np.linalg.norm(rectangle_obs.center - center) 281 | <= rectangle_obs.width / 2 + radius 282 | ): 283 | if ( 284 | np.linalg.norm(rectangle_obs.center - center) 285 | <= rectangle_obs.height / 2 + radius 286 | ): 287 | is_overlap = True 288 | 289 | if not is_overlap: 290 | break 291 | 292 | num_trial += 1 293 | 294 | if num_trial == max_iteration: 295 | raise RuntimeError( 296 | "Cannot generate random obstacles due to reach max iteration." 297 | ) 298 | 299 | obstacle_map.add_circle_obstacle(center, radius) 300 | 301 | for i in range(num_rectangle_obs): 302 | num_trial = 0 303 | while num_trial < max_iteration: 304 | center_x = rng.uniform(random_x_range[0], random_x_range[1]) 305 | center_y = rng.uniform(random_y_range[0], random_y_range[1]) 306 | center = np.array([center_x, center_y]) 307 | width = rng.uniform(width_range[0], width_range[1]) 308 | height = rng.uniform(height_range[0], height_range[1]) 309 | 310 | # overlap check 311 | is_overlap = False 312 | for circle_obs in obstacle_map.circle_obs_list: 313 | if ( 314 | np.linalg.norm(circle_obs.center - center) 315 | <= circle_obs.radius + width / 2 316 | ): 317 | if ( 318 | np.linalg.norm(circle_obs.center - center) 319 | <= circle_obs.radius + height / 2 320 | ): 321 | is_overlap = True 322 | 323 | for rectangle_obs in obstacle_map.rectangle_obs_list: 324 | if ( 325 | np.linalg.norm(rectangle_obs.center - center) 326 | <= rectangle_obs.width / 2 + width / 2 327 | ): 328 | if ( 329 | np.linalg.norm(rectangle_obs.center - center) 330 | <= rectangle_obs.height / 2 + height / 2 331 | ): 332 | is_overlap = True 333 | 334 | if not is_overlap: 335 | break 336 | 337 | num_trial += 1 338 | 339 | if num_trial == max_iteration: 340 | raise RuntimeError( 341 | "Cannot generate random obstacles due to reach max iteration." 342 | ) 343 | 344 | obstacle_map.add_rectangle_obstacle(center, width, height) 345 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/proxima-technology/mppi_playground/12cb79d02c703f818cc9c49d2436f79fd4db7e9f/tests/__init__.py -------------------------------------------------------------------------------- /tests/test_brax.py: -------------------------------------------------------------------------------- 1 | """ 2 | Because we want to use GPU accerated simulator for MPPI, I tried to use brax simulator. 3 | 4 | # jax cuda install 5 | pip3 install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html 6 | 7 | # brax install 8 | pip3 install brax 9 | 10 | # How to run 11 | # https://tech.yellowback.net/posts/jax-oom 12 | XLA_PYTHON_CLIENT_MEM_FRACTION=.8 python3 tests/test_brax.py 13 | """ 14 | 15 | from brax.io import image 16 | from brax import envs 17 | 18 | import jax 19 | 20 | rng = jax.random.PRNGKey(0) 21 | ant = envs.create("ant") 22 | 23 | rng, rng_use = jax.random.split(rng) 24 | state = ant.reset(rng_use) 25 | 26 | # Too slow, not sure why 27 | qps = [state.pipeline_state] 28 | for _ in range(20): 29 | rng, rng_use = jax.random.split(rng) 30 | state = ant.step(state, jax.random.uniform(rng_use, (ant.action_size,))) 31 | qps.append(state.pipeline_state) 32 | 33 | # https://github.com/google/brax/issues/47 34 | # How can i get the rendered image without notebook? 35 | image.render(sys=ant.sys, states=qps, width=320, height=240) 36 | -------------------------------------------------------------------------------- /tests/test_gui.py: -------------------------------------------------------------------------------- 1 | import gymnasium 2 | import matplotlib 3 | 4 | matplotlib.use("TkAgg") 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | 8 | # Run gymnasium with real-time rendering 9 | env = gymnasium.make("Pendulum-v1", render_mode="human") 10 | _, _ = env.reset(seed=42) 11 | for _ in range(100): 12 | action = env.action_space.sample() 13 | observation, reward, terminated, truncated, info = env.step(action) 14 | env.render() 15 | if terminated or truncated: 16 | observation, info = env.reset() 17 | env.close() 18 | 19 | # video recording mode 20 | env = gymnasium.make("Pendulum-v1", render_mode="rgb_array") 21 | env = gymnasium.wrappers.RecordVideo(env=env, video_folder="video") 22 | _, _ = env.reset(seed=42) 23 | env.start_video_recorder() 24 | for _ in range(100): 25 | action = env.action_space.sample() 26 | observation, reward, terminated, truncated, info = env.step(action) 27 | env.render() 28 | if terminated or truncated: 29 | observation, info = env.reset() 30 | env.close() 31 | 32 | # Run matplotlib 33 | plt.style.use("ggplot") 34 | plt.figure(figsize=(8, 6)) 35 | plt.plot(np.arange(1000), np.random.randn(1000)) 36 | plt.show() 37 | -------------------------------------------------------------------------------- /tests/test_mujoco.py: -------------------------------------------------------------------------------- 1 | import gymnasium as gym 2 | 3 | env = gym.make("Humanoid-v4", render_mode="human") 4 | _, _ = env.reset(seed=42) 5 | for _ in range(1000): 6 | action = env.action_space.sample() 7 | observation, reward, terminated, truncated, info = env.step(action) 8 | env.render() 9 | if terminated or truncated: 10 | observation, info = env.reset() 11 | env.close() 12 | -------------------------------------------------------------------------------- /tests/test_torch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import time 3 | 4 | 5 | # torch complie is not supported for python 3.11 yet 6 | # @torch.compile 7 | @torch.jit.script 8 | def matmul_jit(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 9 | return torch.matmul(x, y) 10 | 11 | 12 | if torch.cuda.is_available(): 13 | print(torch.__version__) 14 | device = torch.device("cuda") 15 | print("GPU is available") 16 | else: 17 | device = torch.device("cpu") 18 | print("GPU is not available. CPU is used") 19 | 20 | 21 | matrix_size = 10000 22 | 23 | # Calculate on CPU 24 | start_time = time.time() 25 | input_matrix = torch.randn(matrix_size, matrix_size).to("cpu") 26 | result_cpu = torch.matmul(input_matrix, input_matrix) 27 | end_time = time.time() 28 | cpu_time = end_time - start_time 29 | 30 | # Calculate on GPU 31 | input_matrix = input_matrix.to(device) 32 | start_time = time.time() 33 | result_gpu = torch.matmul(input_matrix, input_matrix) 34 | end_time = time.time() 35 | gpu_time = end_time - start_time 36 | 37 | # Calculate on GPU with Torch compile 38 | # input_matrix = input_matrix.to(device) 39 | # start_time = time.time() 40 | # result_gpu = matmul_compile(input_matrix, input_matrix) 41 | # end_time = time.time() 42 | # gpu_time = end_time - start_time 43 | 44 | # Calculate on GPU with jit 45 | input_matrix = input_matrix.to(device) 46 | start_time = time.time() 47 | result_gpu_jit = matmul_jit(input_matrix, input_matrix) 48 | end_time = time.time() 49 | gpu_time_jit = end_time - start_time 50 | 51 | print("CPU time: ", cpu_time) 52 | print("GPU time: ", gpu_time) 53 | print("GPU time with jit: ", gpu_time_jit) 54 | print("Speed up w/o jit: ", cpu_time / gpu_time) 55 | print("Speed up with jit: ", cpu_time / gpu_time_jit) 56 | assert torch.allclose(result_cpu[:2, :2], result_gpu[:2, :2].to("cpu"), atol=1e-3) 57 | assert torch.allclose(result_cpu[:2, :2], result_gpu_jit[:2, :2].to("cpu"), atol=1e-3) 58 | --------------------------------------------------------------------------------