├── python ├── hid │ ├── __init__.py │ └── gamepad.py ├── spot │ ├── __init__.py │ ├── mock_spot.py │ ├── constants.py │ └── spot.py ├── orbit │ ├── __init__.py │ ├── orbit_constants.py │ ├── orbit_configuration.py │ ├── observations.py │ └── onnx_command_generator.py ├── utils │ ├── __init__.py │ ├── event_divider.py │ ├── env_convert.py │ ├── history.py │ ├── test_controller.py │ └── dict_tools.py ├── requirements.txt ├── gamepad_config.json └── spot_rl_demo.py ├── entrypoint.sh ├── gitman.yml ├── Dockerfile ├── LICENSE └── README.md /python/hid/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /python/spot/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /python/orbit/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /python/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /entrypoint.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # entrypoint.sh 3 | 4 | # Execute the Python script with any provided arguments 5 | python3 /spot-rl/python/spot_rl_demo.py "$@" 6 | -------------------------------------------------------------------------------- /python/requirements.txt: -------------------------------------------------------------------------------- 1 | bosdyn_api-4.0.0-py3-none-any.whl 2 | bosdyn_core-4.0.0-py3-none-any.whl 3 | bosdyn_client-4.0.0-py3-none-any.whl 4 | pygame 5 | pyPS4Controller 6 | spatialmath-python 7 | onnxruntime 8 | -------------------------------------------------------------------------------- /python/orbit/orbit_constants.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved. 2 | 3 | """spots base joints in order expected/used by orbit libraries""" 4 | 5 | ordered_joint_names_orbit = [ 6 | "fl_hx", 7 | "fr_hx", 8 | "hl_hx", 9 | "hr_hx", 10 | "fl_hy", 11 | "fr_hy", 12 | "hl_hy", 13 | "hr_hy", 14 | "fl_kn", 15 | "fr_kn", 16 | "hl_kn", 17 | "hr_kn", 18 | ] 19 | -------------------------------------------------------------------------------- /gitman.yml: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Boston Dynamics AI Institute LLC. All rights reserved. 2 | 3 | location: external 4 | sources_locked: 5 | - repo: https://github.com/bdaiinstitute/spot-sdk-private 6 | name: spot-python-sdk 7 | rev: 10f033efe4f4f0233077fade1f175fc51debc4e9 8 | type: git 9 | params: 10 | sparse_paths: 11 | - 12 | links: 13 | - 14 | scripts: 15 | - 16 | default_group: '' 17 | groups: 18 | - 19 | sources: 20 | - 21 | -------------------------------------------------------------------------------- /python/utils/event_divider.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved. 2 | 3 | import time 4 | from threading import Event 5 | 6 | 7 | class EventDivider: 8 | def __init__(self, event: Event, factor: int): 9 | self._event = event 10 | self._factor = factor 11 | 12 | def __call__(self): 13 | count = 0 14 | 15 | while count < self._factor: 16 | if not self._event.wait(1): 17 | return False 18 | 19 | count += 1 20 | self._event.clear() 21 | time.sleep(0.001) 22 | 23 | return True 24 | -------------------------------------------------------------------------------- /python/gamepad_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "forward": { 3 | "max_velocity": 1.0, 4 | "min_velocity": 0.25 5 | }, 6 | "backward": { 7 | "max_velocity": 1.0, 8 | "min_velocity": 0.25 9 | }, 10 | "lateral": { 11 | "max_velocity": 1.0, 12 | "min_velocity": 0.25 13 | }, 14 | "yaw": { 15 | "max_velocity": 1.0, 16 | "min_velocity": 0.0 17 | }, 18 | "axis_mapping":{ 19 | "forward_backward": { 20 | "index": 1, 21 | "inverted": true 22 | }, 23 | "lateral":{ 24 | "index": 0, 25 | "inverted": true 26 | }, 27 | "yaw":{ 28 | "index": 3, 29 | "inverted": true 30 | } 31 | }, 32 | "median_filter":{ 33 | "window_size": 10 34 | }, 35 | "deadband": 0.2 36 | } 37 | -------------------------------------------------------------------------------- /python/utils/env_convert.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import yaml 4 | 5 | def remove_slice(dictionary): 6 | for key, value in dictionary.items(): 7 | if type(value) is dict: 8 | remove_slice(value) 9 | else: 10 | if "slice" in str(value): 11 | dictionary[key] = None 12 | return dictionary 13 | 14 | 15 | def load_local_cfg(resume_path: str) -> dict: 16 | env_cfg_yaml_path = os.path.join(cfg_dir, "env.yaml") 17 | # load yaml 18 | with open(env_cfg_yaml_path) as yaml_in: 19 | env_cfg = yaml.load(yaml_in, Loader=yaml.Loader) 20 | 21 | env_cfg = remove_slice(env_cfg) 22 | return env_cfg 23 | 24 | print("Enter the path to the env.yaml directory") 25 | cfg_dir = input() 26 | env_cfg = load_local_cfg(cfg_dir) 27 | 28 | cfg_save_path = os.path.join(cfg_dir, "env_cfg.json") 29 | with open(cfg_save_path, "w") as fp: 30 | json.dump(env_cfg, fp, indent=4) -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM ubuntu:22.04 2 | 3 | # Install required apt packages 4 | RUN apt-get update && apt-get install -y \ 5 | git \ 6 | python3-pip 7 | 8 | # Clone the repositories 9 | COPY . /spot-rl 10 | WORKDIR /spot-rl 11 | 12 | COPY external /spot-rl/external 13 | WORKDIR /spot-rl/external 14 | 15 | # Put spot-private-sdk wheels in spot-rl/external/spot-python-sdk/prebuilt 16 | # ^ This can be done with a `gitman update` or by manual intervention 17 | # Install Python dependencies for low level spot API 18 | WORKDIR /spot-rl/external/spot-python-sdk/prebuilt 19 | RUN pip3 install bosdyn_api-4.0.0-py3-none-any.whl \ 20 | bosdyn_core-4.0.0-py3-none-any.whl \ 21 | bosdyn_client-4.0.0-py3-none-any.whl 22 | 23 | RUN pip3 install pygame \ 24 | pyPS4Controller \ 25 | spatialmath-python \ 26 | onnxruntime 27 | 28 | # Copy the entrypoint script to the container 29 | COPY entrypoint.sh /entrypoint.sh 30 | RUN chmod +x /entrypoint.sh 31 | 32 | # Set the entrypoint script as the entrypoint 33 | ENTRYPOINT ["/entrypoint.sh"] -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Boston Dynamics AI Institute 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Spot-RL 2 | Code & Dockerfile for Spot Reinforcement Learning demo 3 | 4 | # Import our image from .tar 5 | ```bash 6 | docker load -i spot-rl-demo-.tar 7 | docker tag spot-rl-demo: spot-rl-demo:latest 8 | ``` 9 | 10 | # Default Model 15k Steps 11 | ```bash 12 | docker run --privileged --rm -it -v /dev/input:/dev/input spot-rl-demo:latest /spot-rl/external/models/ 13 | ```` 14 | 15 | # Bring your own model (don't forget to set the IP) 16 | ```bash 17 | docker run --privileged --rm -it -v /dev/input:/dev/input -v /path/to/folder/with/onz:/models spot-rl-demo:latest 192.168.x.y /models 18 | ``` 19 | 20 | # Example with local directory ./Model_Under_Test (don't forget to set the IP) 21 | ```bash 22 | docker run --privileged --rm -it -v /dev/input:/dev/input -v ./Model_Under_Test/:/mut spot-rl-demo:latest 192.168.x.y /mut 23 | ``` 24 | 25 | # Installing without docker from locally cloned repo 26 | ```bash 27 | sudo apt update 28 | sudo apt install python3-pip 29 | pip3 install gitman 30 | gitman update 31 | cd external/spot_python_sdk/prebuilt 32 | pip3 install bosdyn_api-4.0.0-py3-none-any.whl 33 | pip3 install bosdyn_core-4.0.0-py3-none-any.whl 34 | pip3 install bosdyn_client-4.0.0-py3-none-any.whl 35 | pip3 install pygame 36 | pip3 install pyPS4Controller 37 | pip3 install spatialmath-python 38 | pip3 install onnxruntime 39 | ``` 40 | -------------------------------------------------------------------------------- /python/utils/history.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved. 2 | 3 | import numpy 4 | 5 | 6 | class History: 7 | """Convenience class to record past values of an array of data and calculate statistics on it""" 8 | 9 | def __init__(self) -> None: 10 | self._data = [] 11 | 12 | def record(self, datum): 13 | """adds a new data entry to the history. i.e. adds a row to the history table 14 | 15 | arguments 16 | datum -- list of new values to store 17 | """ 18 | self._data.append(datum) 19 | 20 | def data(self, index): 21 | """return the list values for a single piece of data in each entry. 22 | i.e. return one column of the history table 23 | 24 | arguments 25 | index -- position of desired data in each data entry 26 | """ 27 | return [row[index] for row in self._data] 28 | 29 | @property 30 | def mean(self): 31 | """calculate the mean of each piece of data over the entire history 32 | 33 | return List containing the mean value of each column 34 | """ 35 | return numpy.mean(self._data, axis=0) 36 | 37 | @property 38 | def standard_deviation(self): 39 | """calculate the standard deviation of each piece of data over the entire history 40 | 41 | return List containing the standard deviations of each column 42 | """ 43 | return numpy.std(self._data, axis=0) 44 | -------------------------------------------------------------------------------- /python/utils/test_controller.py: -------------------------------------------------------------------------------- 1 | import pygame 2 | 3 | def print_controller_inputs(): 4 | # Initialize the joystick module 5 | pygame.init() 6 | pygame.joystick.init() 7 | 8 | # Check for connected joysticks 9 | joystick_count = pygame.joystick.get_count() 10 | 11 | if joystick_count == 0: 12 | print("No controller found.") 13 | return 14 | 15 | # Initialize the first joystick 16 | joystick = pygame.joystick.Joystick(0) 17 | joystick.init() 18 | 19 | print("Name:", joystick.get_name()) 20 | print("Number of Axes:", joystick.get_numaxes()) 21 | print("Number of Buttons:", joystick.get_numbuttons()) 22 | print("Number of Hats:", joystick.get_numhats()) 23 | 24 | try: 25 | while True: 26 | pygame.event.pump() 27 | 28 | # Print the state of each axis 29 | for i in range(joystick.get_numaxes()): 30 | axis_value = joystick.get_axis(i) 31 | print(f"Axis {i}: {axis_value:.2f}") 32 | 33 | # Print the state of each button 34 | for i in range(joystick.get_numbuttons()): 35 | button_state = joystick.get_button(i) 36 | print(f"Button {i}: {button_state}") 37 | 38 | # Print the state of each hat 39 | for i in range(joystick.get_numhats()): 40 | hat_state = joystick.get_hat(i) 41 | print(f"Hat {i}: {hat_state}") 42 | 43 | print("\n---\n") 44 | 45 | pygame.time.wait(10) # Delay to reduce console output frequency 46 | 47 | except KeyboardInterrupt: 48 | print("Program terminated.") 49 | 50 | finally: 51 | pygame.quit() 52 | 53 | if __name__ == "__main__": 54 | print_controller_inputs() 55 | -------------------------------------------------------------------------------- /python/utils/dict_tools.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved. 2 | 3 | from typing import Any, List 4 | 5 | 6 | def dict_from_lists(keys: List, values: List): 7 | """construct a dict from two lists where keys and associated values appear at same index 8 | 9 | arguments 10 | keys -- list of keys 11 | values -- list of values in same order as keys 12 | 13 | return dictionary mapping keys to values 14 | """ 15 | 16 | return dict(zip(keys, values)) 17 | 18 | 19 | def dict_to_list(data: dict, keys: List): 20 | """construct a list of values from dictionary in the order specified by keys 21 | 22 | arguments 23 | dict -- dictionary to look up keys in 24 | keys -- list of keys to retrieve in orer 25 | 26 | return list of values from dict in same order as keys 27 | """ 28 | return [data.get(key) for key in keys] 29 | 30 | 31 | def set_matching(data: dict, regex, value): 32 | """set values in dict with keys matching regex 33 | 34 | arguments 35 | dict -- dictionary to set keys in 36 | regex -- regex to select keys that will be set 37 | value -- value to set keys matching regex to 38 | """ 39 | for key in data: 40 | if regex.match(key): 41 | data[key] = value 42 | 43 | 44 | def reorder(inputs: List[Any], ordering: List[int]) -> List[Any]: 45 | """rearrange values in a list to a given order. 46 | 47 | arguments 48 | inputs -- list of values 49 | ordering -- list of len(inputs) containing ints from 0 - len(inputs)-1 in desired order 50 | 51 | return list of values in new order 52 | """ 53 | return [inputs[i] for i in ordering] 54 | 55 | 56 | def find_ordering(input: List[Any], output: List[Any]) -> List[int]: 57 | """given two lists containing the same values return a list of indices mapping input->output 58 | 59 | arguments 60 | input -- first list of values 61 | output -- list with same values as input in a different ordering 62 | 63 | return list such that the nth value is the index of output[n] in the input list 64 | """ 65 | return [input.index(key) for key in output] 66 | -------------------------------------------------------------------------------- /python/spot/mock_spot.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved. 2 | 3 | import math 4 | import time 5 | from contextlib import nullcontext 6 | from threading import Thread 7 | from typing import Any, Callable, List 8 | 9 | from bosdyn.api.robot_command_pb2 import JointControlStreamRequest 10 | from bosdyn.api.robot_state_pb2 import RobotStateStreamResponse 11 | 12 | 13 | class RepeatedTimer(Thread): 14 | def __init__(self, dt_seconds: float, target: Callable, args: List[Any] = []) -> None: 15 | super().__init__() 16 | self._dt_seconds = dt_seconds 17 | self._target = target 18 | self._args = args 19 | self._stopping = False 20 | 21 | def run(self): 22 | run_time = time.monotonic() 23 | while not self._stopping: 24 | now = time.monotonic() 25 | num_dt = math.ceil((now - run_time) / self._dt_seconds) 26 | run_time += num_dt * self._dt_seconds 27 | time.sleep(run_time - now) 28 | self._target(*self._args) 29 | 30 | def stop(self): 31 | self._stopping = True 32 | 33 | 34 | class MockSpot: 35 | def __init__(self): 36 | self._state_stream_stopping = False 37 | self._command_stream_stopping = False 38 | 39 | def start_state_stream(self, on_state_update: Callable[[RobotStateStreamResponse], None]): 40 | self._state_msg = RobotStateStreamResponse() 41 | self._state_msg.kinematic_state.odom_tform_body.rotation.w = 1 42 | self._state_msg.joint_states.position.extend([0] * 12) 43 | self._state_msg.joint_states.velocity.extend([0] * 12) 44 | self._state_msg.joint_states.load.extend([0] * 12) 45 | 46 | self._stateUpdates = RepeatedTimer(1 / 333, on_state_update, args=[self._state_msg]) 47 | self._stateUpdates.start() 48 | 49 | def start_command_stream( 50 | self, command_policy: Callable[[None], JointControlStreamRequest], timing_policy: Callable[[None], None] 51 | ): 52 | self._timing_policy = timing_policy 53 | self._command_generator = command_policy 54 | 55 | self._command_thread = Thread(target=self._commandUpdate) 56 | self._command_thread.start() 57 | 58 | def lease_keep_alive(self): 59 | return nullcontext() 60 | 61 | def _commandUpdate(self): 62 | while not self._command_stream_stopping: 63 | self._timing_policy() 64 | self._command_generator() 65 | 66 | def power_on(self): 67 | pass 68 | 69 | def stand(self, body_height: float): 70 | pass 71 | 72 | def stop_state_stream(self): 73 | if self._stateUpdates is not None: 74 | self._stateUpdates.stop() 75 | self._stateUpdates.join() 76 | 77 | def stop_command_stream(self): 78 | if self._command_thread is not None: 79 | self._command_stream_stopping = True 80 | self._command_thread.join() 81 | -------------------------------------------------------------------------------- /python/spot_rl_demo.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved. 2 | 3 | import argparse 4 | import sys 5 | from pathlib import Path 6 | 7 | import bosdyn.client.util 8 | import orbit.orbit_configuration 9 | from hid.gamepad import ( 10 | Gamepad, 11 | GamepadConfig, 12 | joystick_connected, 13 | load_gamepad_configuration, 14 | ) 15 | from orbit.onnx_command_generator import ( 16 | OnnxCommandGenerator, 17 | OnnxControllerContext, 18 | StateHandler, 19 | ) 20 | from spot.mock_spot import MockSpot 21 | from spot.spot import Spot 22 | from utils.event_divider import EventDivider 23 | 24 | 25 | def main(): 26 | """Command line interface. change that is ok""" 27 | parser = argparse.ArgumentParser() 28 | bosdyn.client.util.add_base_arguments(parser) 29 | parser.add_argument("policy_file_path", type=Path) 30 | parser.add_argument("-m", "--mock", action="store_true") 31 | parser.add_argument("--gamepad-config", type=Path) 32 | options = parser.parse_args() 33 | 34 | conf_file = orbit.orbit_configuration.detect_config_file(options.policy_file_path) 35 | policy_file = orbit.orbit_configuration.detect_policy_file(options.policy_file_path) 36 | 37 | context = OnnxControllerContext() 38 | config = orbit.orbit_configuration.load_configuration(conf_file) 39 | print(config) 40 | 41 | state_handler = StateHandler(context) 42 | print(options.verbose) 43 | command_generator = OnnxCommandGenerator(context, config, policy_file, options.verbose) 44 | 45 | # 333 Hz state update / 6 => ~56 Hz control updates 46 | timeing_policy = EventDivider(context.event, 6) 47 | 48 | gamepad = None 49 | if joystick_connected(): 50 | if options.gamepad_config is not None: 51 | print("[INFO] loading gamepad config from file") 52 | gamepad_config = load_gamepad_configuration(options.gamepad_config) 53 | else: 54 | print("[INFO] using default gamepad configuration") 55 | gamepad_config = GamepadConfig() 56 | 57 | gamepad = Gamepad(context, gamepad_config) 58 | gamepad.start_listening() 59 | 60 | if options.mock: 61 | spot = MockSpot() 62 | else: 63 | spot = Spot(options) 64 | 65 | with spot.lease_keep_alive(): 66 | try: 67 | spot.power_on() 68 | spot.stand(0.0) 69 | spot.start_state_stream(state_handler) 70 | 71 | input() 72 | spot.start_command_stream(command_generator, timeing_policy) 73 | input() 74 | 75 | except KeyboardInterrupt: 76 | print("killed with ctrl-c") 77 | 78 | finally: 79 | print("stop command stream") 80 | spot.stop_command_stream() 81 | print("stop state stream") 82 | spot.stop_state_stream() 83 | print("stop game pad") 84 | if gamepad is not None: 85 | gamepad.stop_listening() 86 | print("all stopped") 87 | 88 | 89 | if __name__ == "__main__": 90 | if not main(): 91 | sys.exit(1) 92 | -------------------------------------------------------------------------------- /python/orbit/orbit_configuration.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved. 2 | 3 | import json 4 | import os 5 | import re 6 | from dataclasses import dataclass 7 | from typing import List 8 | 9 | from orbit.orbit_constants import ordered_joint_names_orbit 10 | from utils.dict_tools import dict_from_lists, set_matching 11 | 12 | 13 | @dataclass 14 | class OrbitConfig: 15 | """dataclass holding data extracted from orbits training configuration""" 16 | 17 | kp: List[float] 18 | kd: List[float] 19 | default_joints: List[float] 20 | standing_height: float 21 | action_scale: float 22 | 23 | 24 | def detect_config_file(directory: os.PathLike) -> os.PathLike: 25 | """find json file in policy directory 26 | 27 | arguments 28 | directory -- path where policy and training configuration can be found 29 | 30 | return filepath to json file 31 | """ 32 | print(os.listdir(directory)) 33 | files = [f for f in os.listdir(directory) if f.endswith(".json")] 34 | print(files) 35 | if len(files) == 1: 36 | return os.path.join(directory, files[0]) 37 | return None 38 | 39 | 40 | def detect_policy_file(directory: os.PathLike) -> os.PathLike: 41 | """find onnx file in policy directory 42 | 43 | arguments 44 | directory -- path where policy and training configuration can be found 45 | 46 | return filepath to onnx file 47 | """ 48 | files = [f for f in os.listdir(directory) if f.endswith(".onnx")] 49 | if len(files) == 1: 50 | return os.path.join(directory, files[0]) 51 | return None 52 | 53 | 54 | def load_configuration(file: os.PathLike) -> OrbitConfig: 55 | """parse json file and populate an OrbitConfig dataclass 56 | 57 | arguments 58 | file -- the path to the json file containing training configuration 59 | 60 | return OrbitConfig containing needed training configuration 61 | """ 62 | 63 | joint_kp = dict_from_lists(ordered_joint_names_orbit, [None] * 12) 64 | joint_kd = dict_from_lists(ordered_joint_names_orbit, [None] * 12) 65 | joint_offsets = dict_from_lists(ordered_joint_names_orbit, [None] * 12) 66 | 67 | with open(file) as f: 68 | env_config = json.load(f) 69 | actuators = env_config["scene"]["robot"]["actuators"] 70 | for group in actuators.keys(): 71 | regex = re.compile(actuators[group]["joint_names_expr"][0]) 72 | 73 | set_matching(joint_kp, regex, actuators[group]["stiffness"]) 74 | set_matching(joint_kd, regex, actuators[group]["damping"]) 75 | 76 | default_joint_data = env_config["scene"]["robot"]["init_state"]["joint_pos"] 77 | default_joint_expressions = default_joint_data.keys() 78 | for expression in default_joint_expressions: 79 | regex = re.compile(expression) 80 | set_matching(joint_offsets, regex, default_joint_data[expression]) 81 | 82 | action_scale = env_config["actions"]["joint_pos"]["scale"] 83 | standing_height = env_config["scene"]["robot"]["init_state"]["pos"][2] 84 | 85 | return OrbitConfig( 86 | kp=joint_kp, 87 | kd=joint_kd, 88 | default_joints=joint_offsets, 89 | standing_height=standing_height, 90 | action_scale=action_scale, 91 | ) 92 | -------------------------------------------------------------------------------- /python/spot/constants.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Boston Dynamics, Inc. All rights reserved. 2 | # 3 | # Downloading, reproducing, distributing or otherwise using the SDK Software 4 | # is subject to the terms and conditions of the Boston Dynamics Software 5 | # Development Kit License (20191101-BDSDK-SL). 6 | 7 | from enum import IntEnum 8 | 9 | ordered_joint_names_bosdyn = [ 10 | "fl_hx", 11 | "fl_hy", 12 | "fl_kn", 13 | "fr_hx", 14 | "fr_hy", 15 | "fr_kn", 16 | "hl_hx", 17 | "hl_hy", 18 | "hl_kn", 19 | "hr_hx", 20 | "hr_hy", 21 | "hr_kn", 22 | ] 23 | 24 | 25 | # Link index and order 26 | class DOF(IntEnum): 27 | FL_HX = 0 28 | FL_HY = 1 29 | FL_KN = 2 30 | FR_HX = 3 31 | FR_HY = 4 32 | FR_KN = 5 33 | HL_HX = 6 34 | HL_HY = 7 35 | HL_KN = 8 36 | HR_HX = 9 37 | HR_HY = 10 38 | HR_KN = 11 39 | # Arm 40 | A0_SH0 = 12 41 | A0_SH1 = 13 42 | A0_EL0 = 14 43 | A0_EL1 = 15 44 | A0_WR0 = 16 45 | A0_WR1 = 17 46 | # Hand 47 | A0_F1X = 18 48 | 49 | # DOF count for strictly the legs. 50 | N_DOF_LEGS = 12 51 | # DOF count for all DOF on robot (arms and legs). 52 | N_DOF = 19 53 | 54 | 55 | # Default joint gains 56 | DEFAULT_K_Q_P = [0] * DOF.N_DOF 57 | DEFAULT_K_QD_P = [0] * DOF.N_DOF 58 | 59 | 60 | def set_default_gains(): 61 | # All legs have the same gains 62 | HX_K_Q_P = 624 63 | HX_K_QD_P = 5.20 64 | HY_K_Q_P = 936 65 | HY_K_QD_P = 5.20 66 | KN_K_Q_P = 286 67 | KN_K_QD_P = 2.04 68 | 69 | # Leg gains 70 | DEFAULT_K_Q_P[DOF.FL_HX] = HX_K_Q_P 71 | DEFAULT_K_QD_P[DOF.FL_HX] = HX_K_QD_P 72 | DEFAULT_K_Q_P[DOF.FL_HY] = HY_K_Q_P 73 | DEFAULT_K_QD_P[DOF.FL_HY] = HY_K_QD_P 74 | DEFAULT_K_Q_P[DOF.FL_KN] = KN_K_Q_P 75 | DEFAULT_K_QD_P[DOF.FL_KN] = KN_K_QD_P 76 | DEFAULT_K_Q_P[DOF.FR_HX] = HX_K_Q_P 77 | DEFAULT_K_QD_P[DOF.FR_HX] = HX_K_QD_P 78 | DEFAULT_K_Q_P[DOF.FR_HY] = HY_K_Q_P 79 | DEFAULT_K_QD_P[DOF.FR_HY] = HY_K_QD_P 80 | DEFAULT_K_Q_P[DOF.FR_KN] = KN_K_Q_P 81 | DEFAULT_K_QD_P[DOF.FR_KN] = KN_K_QD_P 82 | DEFAULT_K_Q_P[DOF.HL_HX] = HX_K_Q_P 83 | DEFAULT_K_QD_P[DOF.HL_HX] = HX_K_QD_P 84 | DEFAULT_K_Q_P[DOF.HL_HY] = HY_K_Q_P 85 | DEFAULT_K_QD_P[DOF.HL_HY] = HY_K_QD_P 86 | DEFAULT_K_Q_P[DOF.HL_KN] = KN_K_Q_P 87 | DEFAULT_K_QD_P[DOF.HL_KN] = KN_K_QD_P 88 | DEFAULT_K_Q_P[DOF.HR_HX] = HX_K_Q_P 89 | DEFAULT_K_QD_P[DOF.HR_HX] = HX_K_QD_P 90 | DEFAULT_K_Q_P[DOF.HR_HY] = HY_K_Q_P 91 | DEFAULT_K_QD_P[DOF.HR_HY] = HY_K_QD_P 92 | DEFAULT_K_Q_P[DOF.HR_KN] = KN_K_Q_P 93 | DEFAULT_K_QD_P[DOF.HR_KN] = KN_K_QD_P 94 | 95 | # Arm gains 96 | DEFAULT_K_Q_P[DOF.A0_SH0] = 1020 97 | DEFAULT_K_QD_P[DOF.A0_SH0] = 10.2 98 | DEFAULT_K_Q_P[DOF.A0_SH1] = 255 99 | DEFAULT_K_QD_P[DOF.A0_SH1] = 15.3 100 | DEFAULT_K_Q_P[DOF.A0_EL0] = 204 101 | DEFAULT_K_QD_P[DOF.A0_EL0] = 10.2 102 | DEFAULT_K_Q_P[DOF.A0_EL1] = 102 103 | DEFAULT_K_QD_P[DOF.A0_EL1] = 2.04 104 | DEFAULT_K_Q_P[DOF.A0_WR0] = 102 105 | DEFAULT_K_QD_P[DOF.A0_WR0] = 2.04 106 | DEFAULT_K_Q_P[DOF.A0_WR1] = 102 107 | DEFAULT_K_QD_P[DOF.A0_WR1] = 2.04 108 | DEFAULT_K_Q_P[DOF.A0_F1X] = 16.0 109 | DEFAULT_K_QD_P[DOF.A0_F1X] = 0.32 110 | 111 | 112 | # Initialize default gains 113 | set_default_gains() 114 | -------------------------------------------------------------------------------- /python/orbit/observations.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved. 2 | 3 | from operator import sub 4 | 5 | from bosdyn.api import robot_state_pb2 6 | from orbit.orbit_configuration import OrbitConfig 7 | from orbit.orbit_constants import ordered_joint_names_orbit 8 | from spatialmath import UnitQuaternion 9 | from spot.constants import ordered_joint_names_bosdyn 10 | from utils.dict_tools import dict_to_list, find_ordering, reorder 11 | 12 | 13 | def get_base_linear_velocity(state: robot_state_pb2.RobotStateStreamResponse): 14 | """calculate linear velocity of spots base in the base frame from data 15 | available in spots state update. note spot gives velocity in odom frame 16 | so we need to rotate it to current estimated pose of the base 17 | 18 | arguments 19 | state -- proto msg from spot containing data on the robots state 20 | """ 21 | msg = state.kinematic_state.velocity_of_body_in_odom.linear 22 | 23 | odom_r_base_msg = state.kinematic_state.odom_tform_body.rotation 24 | scalar = odom_r_base_msg.w 25 | vector = [odom_r_base_msg.x, odom_r_base_msg.y, odom_r_base_msg.z] 26 | odom_r_base = UnitQuaternion(scalar, vector) 27 | 28 | velocity_odom = [msg.x, msg.y, msg.z] 29 | velocity_base = odom_r_base.inv() * velocity_odom 30 | 31 | return velocity_base.tolist() 32 | 33 | 34 | def get_base_angular_velocity(state: robot_state_pb2.RobotStateStreamResponse): 35 | """calculate angular velocity of spots base in the base frame from data 36 | available in spots state update. note spot gives velocity in odom frame 37 | so we need to rotate it to current estimated pose of the base 38 | 39 | arguments 40 | state -- proto msg from spot containing data on the robots state 41 | """ 42 | msg = state.kinematic_state.velocity_of_body_in_odom.angular 43 | 44 | odom_r_base_msg = state.kinematic_state.odom_tform_body.rotation 45 | scalar = odom_r_base_msg.w 46 | vector = [odom_r_base_msg.x, odom_r_base_msg.y, odom_r_base_msg.z] 47 | odom_r_base = UnitQuaternion(scalar, vector) 48 | 49 | angular_velocity_odom = [msg.x, msg.y, msg.z] 50 | angular_velocity_base = odom_r_base.inv() * angular_velocity_odom 51 | 52 | return angular_velocity_base.tolist() 53 | 54 | 55 | def get_projected_gravity(state: robot_state_pb2.RobotStateStreamResponse): 56 | """calculate direction of gravity in spots base frame 57 | the assumption here is that the odom frame Z axis is opposite gravity 58 | this is the case if spots body is parallel to the floor when turned on 59 | 60 | arguments 61 | state -- proto msg from spot containing data on the robots state 62 | """ 63 | odom_r_base_msg = state.kinematic_state.odom_tform_body.rotation 64 | 65 | scalar = odom_r_base_msg.w 66 | vector = [odom_r_base_msg.x, odom_r_base_msg.y, odom_r_base_msg.z] 67 | odom_r_base = UnitQuaternion(scalar, vector) 68 | 69 | gravity_odom = [0, 0, -1] 70 | gravity_base = odom_r_base.inv() * gravity_odom 71 | return gravity_base.tolist() 72 | 73 | 74 | def get_joint_positions(state: robot_state_pb2.RobotStateStreamResponse, config: OrbitConfig): 75 | """get joint position from spots state update a reformat for orbit by 76 | reordering to match orbits expectation and shifting so 0 position is the 77 | same as was used in training 78 | 79 | arguments 80 | state -- proto msg from spot containing data on the robots state 81 | config -- dataclass with values loaded from orbits training data 82 | """ 83 | 84 | spot_to_orbit = find_ordering(ordered_joint_names_bosdyn, ordered_joint_names_orbit) 85 | pos = reorder(state.joint_states.position, spot_to_orbit) 86 | default_joints = dict_to_list(config.default_joints, ordered_joint_names_orbit) 87 | pos = list(map(sub, pos, default_joints)) 88 | return pos 89 | 90 | 91 | def get_joint_velocity(state: robot_state_pb2.RobotStateStreamResponse): 92 | """get joint velocity from spots state update a reformat for orbit by 93 | reordering to match orbits expectation 94 | 95 | arguments 96 | state -- proto msg from spot containing data on the robots state 97 | """ 98 | spot_to_orbit = find_ordering(ordered_joint_names_bosdyn, ordered_joint_names_orbit) 99 | vel = reorder(state.joint_states.velocity, spot_to_orbit) 100 | return vel 101 | -------------------------------------------------------------------------------- /python/hid/gamepad.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved. 2 | 3 | import json 4 | import os 5 | from collections import deque 6 | from dataclasses import dataclass 7 | from threading import Thread 8 | 9 | import numpy as np 10 | import pygame 11 | 12 | 13 | @dataclass 14 | class AxisConfig: 15 | """dataclass holding configuration data for a single axis""" 16 | 17 | index: int # pygame axis index to control this axis 18 | inverted: bool # true if the pygame axis input should be inverted 19 | deadband: float # percentage of joystick travel that should return zero 20 | min_forward_dir: float # magnitude of output when joystick is at forward edge of deadband 21 | max_forward_dir: float # magnitude of output when joystick is fully forward 22 | min_reverse_dir: float # magnitude of output when joystick is at reverse edge of deadband 23 | max_reverse_dir: float # magnitude of output when joystick is fully reversed 24 | 25 | 26 | @dataclass 27 | class GamepadConfig: 28 | """dataclass holding gamepad configuration data, default values are for PS4 controller""" 29 | 30 | x_axis_config :AxisConfig = AxisConfig(1, True, 0.2, 0.25, 1, 0.25, 1) 31 | y_axis_config : AxisConfig = AxisConfig(0, True, 0.2, 0.25, 1, 0.25, 1) 32 | yaw_axis_config : AxisConfig = AxisConfig(3, True, 0.2, 0.0, 1, 0.0, 1) 33 | median_filter_window: int = 10 # size of buffer used for salt and pepper filtering 34 | 35 | 36 | def load_gamepad_configuration(file: os.PathLike) -> GamepadConfig: 37 | """parse json file and populate an Gamepad config dataclass 38 | 39 | arguments 40 | file -- the path to the json file containing gamepad configuration 41 | 42 | return GamepadConfig containing needed data 43 | """ 44 | 45 | with open(file) as f: 46 | config_file = json.load(f) 47 | 48 | axis_mapping = config_file["axis_mapping"] 49 | filter = config_file["median_filter"] 50 | 51 | forward_backward = AxisConfig( 52 | index=axis_mapping["forward_backward"]["index"], 53 | inverted=axis_mapping["forward_backward"]["inverted"], 54 | deadband=config_file["deadband"], 55 | min_forward_dir=config_file["forward"]["min_velocity"], 56 | max_forward_dir=config_file["forward"]["max_velocity"], 57 | min_reverse_dir=config_file["backward"]["min_velocity"], 58 | max_reverse_dir=config_file["backward"]["max_velocity"], 59 | ) 60 | 61 | lateral = AxisConfig( 62 | index=axis_mapping["lateral"]["index"], 63 | inverted=axis_mapping["lateral"]["inverted"], 64 | deadband=config_file["deadband"], 65 | min_forward_dir=config_file["lateral"]["min_velocity"], 66 | max_forward_dir=config_file["lateral"]["max_velocity"], 67 | min_reverse_dir=config_file["lateral"]["min_velocity"], 68 | max_reverse_dir=config_file["lateral"]["max_velocity"], 69 | ) 70 | 71 | yaw = AxisConfig( 72 | index=axis_mapping["yaw"]["index"], 73 | inverted=axis_mapping["yaw"]["inverted"], 74 | deadband=config_file["deadband"], 75 | min_forward_dir=config_file["yaw"]["min_velocity"], 76 | max_forward_dir=config_file["yaw"]["max_velocity"], 77 | min_reverse_dir=config_file["yaw"]["min_velocity"], 78 | max_reverse_dir=config_file["yaw"]["max_velocity"], 79 | ) 80 | 81 | return GamepadConfig( 82 | x_axis_config=forward_backward, 83 | y_axis_config=lateral, 84 | yaw_axis_config=yaw, 85 | median_filter_window=filter["window_size"], 86 | ) 87 | 88 | 89 | def interpolate(start, end, percent): 90 | return start + percent * (end - start) 91 | 92 | 93 | def joystick_connected(): 94 | if not pygame.get_init(): 95 | pygame.init() 96 | if not pygame.joystick.get_init(): 97 | pygame.joystick.init() 98 | 99 | return pygame.joystick.get_count() > 0 100 | 101 | 102 | class Gamepad: 103 | def __init__(self, context, config: GamepadConfig): 104 | if not pygame.get_init(): 105 | pygame.init() 106 | if not pygame.joystick.get_init(): 107 | pygame.joystick.init() 108 | 109 | self._context = context 110 | self.x_vel = 0 111 | self.y_vel = 0 112 | self.yaw = 0 113 | self.joystick = pygame.joystick.Joystick(0) 114 | self._stopping = False 115 | self._listening_thread = None 116 | self._config = config 117 | 118 | buffer_length = config.median_filter_window 119 | self._x_buffer = deque([0] * buffer_length, buffer_length) 120 | self._y_buffer = deque([0] * buffer_length, buffer_length) 121 | self._yaw_buffer = deque([0] * buffer_length, buffer_length) 122 | 123 | print(f"[INFO] Initialized {self.joystick.get_name()}") 124 | print(f"[INFO] Joystick power level {self.joystick.get_power_level()}") 125 | 126 | def _apply_curve(self, value: float, cfg: AxisConfig): 127 | if cfg.inverted: 128 | value = -value 129 | 130 | if abs(value) < cfg.deadband: 131 | return 0 132 | elif value > 0: 133 | slope = (cfg.max_forward_dir - cfg.min_forward_dir) / (1.0 - cfg.deadband) 134 | shift = cfg.max_forward_dir - slope 135 | value = slope * value + shift 136 | return np.clip(value, cfg.min_forward_dir, cfg.max_forward_dir) 137 | else: 138 | slope = (cfg.max_reverse_dir - cfg.min_reverse_dir) / (1.0 - cfg.deadband) 139 | shift = cfg.max_reverse_dir - slope 140 | value = slope * value + shift 141 | return np.clip(value, -cfg.max_reverse_dir, -cfg.min_reverse_dir) 142 | 143 | def start_listening(self): 144 | self._listening_thread = Thread(target=self.listen) 145 | self._listening_thread.start() 146 | 147 | def listen(self): 148 | while not self._stopping: 149 | # Handle events 150 | pygame.event.pump() 151 | x_input = self.joystick.get_axis(self._config.x_axis_config.index) 152 | y_input = self.joystick.get_axis(self._config.y_axis_config.index) 153 | yaw_input = self.joystick.get_axis(self._config.yaw_axis_config.index) 154 | 155 | x_vel = self._apply_curve(x_input, self._config.x_axis_config) 156 | y_vel = self._apply_curve(y_input, self._config.y_axis_config) 157 | yaw = self._apply_curve(yaw_input, self._config.yaw_axis_config) 158 | 159 | self._x_buffer.append(x_vel) 160 | self._y_buffer.append(y_vel) 161 | self._yaw_buffer.append(yaw) 162 | 163 | self.x_vel = np.median(self._x_buffer) 164 | self.y_vel = np.median(self._y_buffer) 165 | self.yaw = np.median(self._yaw_buffer) 166 | 167 | self._context.velocity_cmd = [self.x_vel, self.y_vel, self.yaw] 168 | # update inputs at 100hz this should mean we always have a new data for the 50Hz command update 169 | pygame.time.wait(10) 170 | 171 | def stop_listening(self): 172 | if self._listening_thread is not None: 173 | self._stopping = True 174 | self._listening_thread.join() 175 | -------------------------------------------------------------------------------- /python/spot/spot.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved. 2 | 3 | import time 4 | from threading import Thread 5 | from typing import Callable 6 | 7 | import bosdyn.client 8 | import bosdyn.client.lease 9 | import bosdyn.client.util 10 | from bosdyn import geometry 11 | from bosdyn.api.robot_command_pb2 import JointControlStreamRequest 12 | from bosdyn.api.robot_state_pb2 import RobotStateStreamResponse 13 | from bosdyn.client.robot_command import ( 14 | RobotCommandBuilder, 15 | RobotCommandClient, 16 | RobotCommandStreamingClient, 17 | blocking_stand, 18 | ) 19 | from bosdyn.client.robot_state import RobotStateStreamingClient 20 | 21 | 22 | class Spot: 23 | """wrapper around bosdyn API""" 24 | 25 | def __init__(self, config) -> None: 26 | """setup spot sdk, connect to robot and confirm estop 27 | 28 | arguments 29 | config -- arguments for connecting to spot must contain hostname 30 | """ 31 | 32 | self._started_streaming = False 33 | self._activate_thread_stopping = False 34 | self._command_stream_stopping = False 35 | self._state_stream_stopping = False 36 | 37 | self._command_thread = None 38 | self._state_thread = None 39 | 40 | bosdyn.client.util.setup_logging(config.verbose) 41 | 42 | self.sdk = bosdyn.client.create_standard_sdk("JointControlClient") 43 | 44 | # Register the non standard api clients 45 | self.sdk.register_service_client(RobotCommandStreamingClient) 46 | self.sdk.register_service_client(RobotStateStreamingClient) 47 | self.robot = self.sdk.create_robot(config.hostname) 48 | bosdyn.client.util.authenticate(self.robot) 49 | self.robot.time_sync.wait_for_sync() 50 | assert not self.robot.is_estopped(), "Robot is estopped. Please use an external E-Stop client." 51 | 52 | def __del__(self): 53 | """clean up active streams and threads if spot goes out of scope or is deleted""" 54 | self.stop_command_stream() 55 | self.stop_state_stream() 56 | 57 | def lease_keep_alive(self): 58 | """acquire lease and keep it alive as long as return is in scope 59 | use this function in a with statement around other commands 60 | 61 | return scoped lease 62 | """ 63 | lease_client = self.robot.ensure_client(bosdyn.client.lease.LeaseClient.default_service_name) 64 | return bosdyn.client.lease.LeaseKeepAlive(lease_client, must_acquire=True, return_at_exit=True) 65 | 66 | def power_on(self): 67 | """Turn on power to robot's motors.""" 68 | self.robot.logger.info("Powering on robot... This may take several seconds.") 69 | self.robot.power_on(timeout_sec=20) 70 | if not self.robot.is_powered_on(): 71 | raise RuntimeError("Robot power on failed.") 72 | self.robot.logger.info("Robot powered on.") 73 | 74 | def stand(self, body_height: float = 0.0): 75 | """Tell spot to stand and block until it is standing 76 | 77 | arguments 78 | body_height -- controls height of standing as delta from default or 0.525m 79 | 80 | """ 81 | self._command_client = self.robot.ensure_client(RobotCommandClient.default_service_name) 82 | # Stand the robot 83 | params = RobotCommandBuilder.mobility_params(body_height, footprint_R_body=geometry.EulerZXY()) 84 | 85 | blocking_stand(self._command_client, 10, 1.0, params) 86 | 87 | def start_state_stream(self, on_state_update: Callable[[RobotStateStreamResponse], None]): 88 | """The robot state streaming client will allow us to get the robot's joint and imu information. 89 | 90 | arguments 91 | on_state_update -- Callable that will be called at ~333Hz with latest state data 92 | 93 | """ 94 | self.robot_state_streaming_client = self.robot.ensure_client(RobotStateStreamingClient.default_service_name) 95 | self._state_thread = Thread(target=self._handle_state_stream, args=[on_state_update]) 96 | self._state_thread.start() 97 | 98 | def stop_state_stream(self): 99 | """stop listening to state stream updates this will also tell spot to stop sending them""" 100 | if self._state_thread is not None: 101 | self._state_stream_stopping = True 102 | self._state_thread.join() 103 | 104 | def start_command_stream( 105 | self, command_policy: Callable[[None], JointControlStreamRequest], timing_policy: Callable[[None], None] 106 | ): 107 | """create command streamt to send joint level commands to spot 108 | 109 | command stream will repeatedly call timing_policy to block until a command should be sent 110 | and then call command_policy to create one command 111 | 112 | arguments 113 | command_policy -- Callable that will create one joint command 114 | timing_policy -- Callable that blocks until the next time a command should be generated 115 | 116 | """ 117 | # Async activate once streaming has started 118 | self._activate_thread = Thread(target=self.activate) 119 | self._activate_thread.start() 120 | 121 | self._command_thread = Thread(target=self._run_command_stream, args=[command_policy, timing_policy]) 122 | self._command_thread.start() 123 | 124 | def stop_command_stream(self): 125 | """Stop sending joint commands robot will revert to bosdyns standing controller""" 126 | if self._command_thread is not None: 127 | self._command_stream_stopping = True 128 | self._command_thread.join() 129 | 130 | if self._activate_thread is not None: 131 | self._activate_thread_stopping = True 132 | self._activate_thread.join() 133 | 134 | def _handle_state_stream(self, on_state_update: Callable[[RobotStateStreamResponse], None]): 135 | """private function to be run in state stream thread 136 | listens for state steam events and calls users callback 137 | 138 | arguments 139 | on_state_update -- callback supplied to start_state_stream 140 | """ 141 | for state in self.robot_state_streaming_client.get_robot_state_stream(): 142 | on_state_update(state) 143 | 144 | if self._state_stream_stopping: 145 | return 146 | 147 | def _run_command_stream( 148 | self, command_policy: Callable[[None], JointControlStreamRequest], timing_policy: Callable[[None], None] 149 | ): 150 | """private function to be run in command stream thread handles opening grpc 151 | stream 152 | 153 | arguments 154 | command_policy -- callback supplied to start_command_stream to create commands 155 | timing_policy -- callback supplied to start_command_stream to control timing 156 | """ 157 | 158 | self._command_streaming_client = self.robot.ensure_client(RobotCommandStreamingClient.default_service_name) 159 | 160 | try: 161 | self.robot.logger.info("Starting command stream") 162 | res = self._command_streaming_client.send_joint_control_commands( 163 | self._command_stream_loop(command_policy, timing_policy) 164 | ) 165 | print(res) 166 | finally: 167 | self._activate_thread_stopping = True 168 | 169 | if self._activate_thread: 170 | self._activate_thread.join() 171 | 172 | # Power the robot off. By specifying "cut_immediately=False", a safe power off command 173 | # is issued to the robot. This will attempt to sit the robot before powering off. 174 | self.robot.power_off(cut_immediately=False, timeout_sec=20) 175 | assert not self.robot.is_powered_on(), "Robot power off failed." 176 | self.robot.logger.info("Robot safely powered off.") 177 | 178 | def _command_stream_loop( 179 | self, command_policy: Callable[[None], JointControlStreamRequest], timing_policy: Callable[[None], None] 180 | ): 181 | """coroutine needed for command stream. repeatedly calls timing_policty 182 | to block until next dt and then yields the result of command_policy once 183 | 184 | arguments 185 | command_policy -- callback supplied to start_command_stream to create commands 186 | timing_policy -- callback supplied to start_command_stream to control timing 187 | """ 188 | 189 | while not self._command_stream_stopping: 190 | if timing_policy(): 191 | yield command_policy() 192 | self._started_streaming = True 193 | else: 194 | print("timing policy timeout") 195 | return 196 | print("stopping is True") 197 | 198 | # Method to activate full body joint control through RobotCommand 199 | def activate(self): 200 | self._command_client = self.robot.ensure_client(RobotCommandClient.default_service_name) 201 | 202 | # Wait for streaming to start 203 | while not self._started_streaming: 204 | time.sleep(0.001) 205 | 206 | if self._activate_thread_stopping: 207 | return 208 | 209 | # Activate joint control 210 | self.robot.logger.info("Activating joint control") 211 | joint_command = RobotCommandBuilder.joint_command() 212 | 213 | try: 214 | self._command_client.robot_command(joint_command) 215 | finally: 216 | # Signal everything else to stop too 217 | self._activate_thread_stopping = True 218 | -------------------------------------------------------------------------------- /python/orbit/onnx_command_generator.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved. 2 | 3 | import os 4 | from dataclasses import dataclass 5 | from operator import add, mul 6 | from threading import Event 7 | from typing import List 8 | 9 | import numpy as np 10 | import onnxruntime as ort 11 | import orbit.observations as ob 12 | from bosdyn.api import robot_command_pb2 13 | from bosdyn.api.robot_command_pb2 import JointControlStreamRequest 14 | from bosdyn.api.robot_state_pb2 import RobotStateStreamResponse 15 | from bosdyn.util import seconds_to_timestamp, set_timestamp_from_now, timestamp_to_sec 16 | from orbit.orbit_configuration import OrbitConfig 17 | from orbit.orbit_constants import ordered_joint_names_orbit 18 | from spot.constants import DEFAULT_K_Q_P, DEFAULT_K_QD_P, ordered_joint_names_bosdyn 19 | from utils.dict_tools import dict_to_list, find_ordering, reorder 20 | 21 | @dataclass 22 | class OnnxControllerContext: 23 | """data class to hold runtime data needed by the controller""" 24 | 25 | event = Event() 26 | latest_state = None 27 | velocity_cmd = [0, 0, 0] 28 | count = 0 29 | 30 | 31 | class StateHandler: 32 | """Class to be used as callback for state stream to put state date 33 | into the controllers context 34 | """ 35 | 36 | def __init__(self, context: OnnxControllerContext) -> None: 37 | self._context = context 38 | 39 | def __call__(self, state: RobotStateStreamResponse): 40 | """make class a callable and handle incoming state stream when called 41 | 42 | arguments 43 | state -- proto msg from spot containing most recent data on the robots state""" 44 | self._context.latest_state = state 45 | self._context.event.set() 46 | 47 | 48 | def print_observations(observations: List[float]): 49 | """debug function to print out the observation data used as model input 50 | 51 | arguments 52 | observations -- list of float values ready to be passed into the model 53 | """ 54 | print("base_linear_velocity:", observations[0:3]) 55 | print("base_angular_velocity:", observations[3:6]) 56 | print("projected_gravity:", observations[6:9]) 57 | print("joint_positions", observations[12:24]) 58 | print("joint_velocity", observations[24:36]) 59 | print("last_action", observations[36:48]) 60 | 61 | 62 | class OnnxCommandGenerator: 63 | """class to be used as generator for spots command stream that executes 64 | an onnx model and converts the output to a spot command""" 65 | 66 | def __init__( 67 | self, context: OnnxControllerContext, config: OrbitConfig, policy_file_name: os.PathLike, verbose: bool 68 | ): 69 | self._context = context 70 | self._config = config 71 | self._inference_session = ort.InferenceSession(policy_file_name) 72 | self._last_action = [0] * 12 73 | self._count = 1 74 | self._init_pos = None 75 | self._init_load = None 76 | self.verbose = verbose 77 | 78 | def __call__(self): 79 | """makes class a callable and computes model output for latest controller context 80 | 81 | return proto message to be used in spots command stream 82 | """ 83 | 84 | # cache initial joint position when command stream starts 85 | if self._init_pos is None: 86 | self._init_pos = self._context.latest_state.joint_states.position 87 | self._init_load = self._context.latest_state.joint_states.load 88 | 89 | # extract observation data from latest spot state data 90 | input_list = self.collect_inputs(self._context.latest_state, self._config) 91 | # print("observations", input_list) 92 | 93 | # execute model from onnx file 94 | input = [np.array(input_list).astype("float32")] 95 | output = self._inference_session.run(None, {"obs": input})[0].tolist()[0] 96 | 97 | # post process model output apply action scaling and return to spots 98 | # joint order and offset 99 | test_scale = min(0.1 * self._count, 1) 100 | 101 | scaled_output = list(map(mul, [self._config.action_scale] * 12, output)) 102 | test_scaled = list(map(mul, [test_scale] * 12, scaled_output)) 103 | 104 | default_joints = dict_to_list(self._config.default_joints, ordered_joint_names_orbit) 105 | shifted_output = list(map(add, test_scaled, default_joints)) 106 | 107 | orbit_to_spot = find_ordering(ordered_joint_names_orbit, ordered_joint_names_bosdyn) 108 | reordered_output = reorder(shifted_output, orbit_to_spot) 109 | 110 | # generate proto message from target joint positions 111 | proto = self.create_proto(reordered_output) 112 | 113 | # cache data for history and logging 114 | self._last_action = output 115 | self._count += 1 116 | self._context.count += 1 117 | 118 | return proto 119 | 120 | def collect_inputs(self, state: JointControlStreamRequest, config: OrbitConfig): 121 | """extract observation data from spots current state and format for onnx 122 | 123 | arguments 124 | state -- proto msg with spots latest state 125 | config -- model configuration data from orbit 126 | 127 | return list of float values ready to be passed into the model 128 | """ 129 | observations = [] 130 | observations += ob.get_base_linear_velocity(state) 131 | observations += ob.get_base_angular_velocity(state) 132 | observations += ob.get_projected_gravity(state) 133 | observations += self._context.velocity_cmd 134 | if self.verbose: 135 | print("[INFO] cmd", self._context.velocity_cmd) 136 | observations += ob.get_joint_positions(state, config) 137 | observations += ob.get_joint_velocity(state) 138 | observations += self._last_action 139 | return observations 140 | 141 | def create_proto(self, pos_command: List[float]): 142 | """generate a proto msg for spot with a given pos_command 143 | 144 | arguments 145 | pos_command -- list of joint positions see spot.constants for order 146 | 147 | return proto message to send in spots command stream 148 | """ 149 | update_proto = robot_command_pb2.JointControlStreamRequest() 150 | set_timestamp_from_now(update_proto.header.request_timestamp) 151 | update_proto.header.client_name = "rl_example_client" 152 | 153 | k_q_p = dict_to_list(self._config.kp, ordered_joint_names_bosdyn) 154 | k_qd_p = dict_to_list(self._config.kd, ordered_joint_names_bosdyn) 155 | 156 | N_DOF = len(pos_command) 157 | pos_cmd = [0] * N_DOF 158 | vel_cmd = [0] * N_DOF 159 | load_cmd = [0] * N_DOF 160 | 161 | for joint_ind in range(N_DOF): 162 | pos_cmd[joint_ind] = pos_command[joint_ind] 163 | vel_cmd[joint_ind] = 0 164 | load_cmd[joint_ind] = 0 165 | 166 | # Fill in gains the first dt 167 | if self._count == 1: 168 | update_proto.joint_command.gains.k_q_p.extend(k_q_p) 169 | update_proto.joint_command.gains.k_qd_p.extend(k_qd_p) 170 | 171 | update_proto.joint_command.position.extend(pos_cmd) 172 | update_proto.joint_command.velocity.extend(vel_cmd) 173 | update_proto.joint_command.load.extend(load_cmd) 174 | 175 | observation_time = self._context.latest_state.joint_states.acquisition_timestamp 176 | end_time = seconds_to_timestamp(timestamp_to_sec(observation_time) + 0.1) 177 | update_proto.joint_command.end_time.CopyFrom(end_time) 178 | 179 | # Let it extrapolate the command a little 180 | update_proto.joint_command.extrapolation_duration.nanos = int(5 * 1e6) 181 | 182 | # Set user key for latency tracking 183 | update_proto.joint_command.user_command_key = self._count 184 | return update_proto 185 | 186 | def create_proto_hold(self): 187 | """generate a proto msg that holds spots current pose useful for debugging 188 | 189 | return proto message to send in spots command stream 190 | """ 191 | update_proto = robot_command_pb2.JointControlStreamRequest() 192 | update_proto.Clear() 193 | set_timestamp_from_now(update_proto.header.request_timestamp) 194 | update_proto.header.client_name = "rl_example_client" 195 | 196 | k_q_p = DEFAULT_K_Q_P[0:12] 197 | k_qd_p = DEFAULT_K_QD_P[0:12] 198 | 199 | N_DOF = 12 200 | pos_cmd = [0] * N_DOF 201 | vel_cmd = [0] * N_DOF 202 | load_cmd = [0] * N_DOF 203 | 204 | for joint_ind in range(N_DOF): 205 | pos_cmd[joint_ind] = self._init_pos[joint_ind] 206 | vel_cmd[joint_ind] = 0 207 | load_cmd[joint_ind] = self._init_load[joint_ind] 208 | 209 | # Fill in gains the first dt 210 | if self._count == 1: 211 | update_proto.joint_command.gains.k_q_p.extend(k_q_p) 212 | update_proto.joint_command.gains.k_qd_p.extend(k_qd_p) 213 | 214 | update_proto.joint_command.position.extend(pos_cmd) 215 | update_proto.joint_command.velocity.extend(vel_cmd) 216 | update_proto.joint_command.load.extend(load_cmd) 217 | 218 | observation_time = self._context.latest_state.joint_states.acquisition_timestamp 219 | end_time = seconds_to_timestamp(timestamp_to_sec(observation_time) + 0.1) 220 | update_proto.joint_command.end_time.CopyFrom(end_time) 221 | 222 | # Let it extrapolate the command a little 223 | update_proto.joint_command.extrapolation_duration.nanos = int(5 * 1e6) 224 | 225 | # Set user key for latency tracking 226 | update_proto.joint_command.user_command_key = self._count 227 | return update_proto 228 | --------------------------------------------------------------------------------