├── assets
└── safari_logging.png
├── examples
├── aloha
│ ├── requirements.txt
│ └── eval.py
├── model
│ ├── genai_robotics_example.py
│ ├── gemini_robotics_policy_example.py
│ ├── genai_robotics_aloha_example.py
│ └── gemini_robotics_aloha_eval_example.py
└── logging
│ ├── lerobot_data_conversion_script.py
│ └── sample_data_upload_script.py
├── safari_sdk
├── logging
│ └── python
│ │ ├── __init__.py
│ │ ├── constants.py
│ │ ├── message.py
│ │ ├── stream_logger.py
│ │ ├── file_handler.py
│ │ └── mcap_lerobot_logger.py
├── protos
│ ├── beamform_data.proto
│ ├── logging
│ │ ├── policy_environment_metadata.proto
│ │ ├── codec.proto
│ │ ├── dtype.proto
│ │ ├── feature_specs.proto
│ │ ├── spec.proto
│ │ ├── tracker.proto
│ │ ├── robot_base.proto
│ │ ├── imu.proto
│ │ ├── audio.proto
│ │ ├── contact_surface.proto
│ │ ├── metadata.proto
│ │ └── machine_info.proto
│ ├── vector.proto
│ ├── sensor_calibration.proto
│ ├── pose.proto
│ ├── transform.proto
│ ├── label.proto
│ ├── image.proto
│ ├── monitoring.proto
│ ├── camera_spec.proto
│ └── joints.proto
├── __init__.py
├── orchestrator
│ ├── client
│ │ ├── dataclass
│ │ │ ├── operator_event.py
│ │ │ ├── rui_workcell_state.py
│ │ │ ├── ticket.py
│ │ │ ├── robot_job.py
│ │ │ ├── current_robot_info.py
│ │ │ ├── artifact.py
│ │ │ ├── current_robot_info_test.py
│ │ │ ├── api_response.py
│ │ │ ├── visual_overlay_icon.py
│ │ │ └── artifact_dataclass_test.py
│ │ └── libs
│ │ │ ├── current_robot.py
│ │ │ ├── artifact.py
│ │ │ ├── robot_job.py
│ │ │ ├── robot_job_test.py
│ │ │ ├── current_robot_test.py
│ │ │ └── artifact_lib_test.py
│ └── example_client_sdk_robot_and_operator_info.py
├── utils.py
├── model
│ ├── additional_observations_provider.py
│ ├── constants.py
│ └── genai_robotics_test.py
├── flywheel
│ ├── upload_data.py
│ └── upload_data_test.py
└── auth.py
├── cmake
├── tests
│ ├── test.proto
│ └── CMakeLists.txt
└── protobuf-generate.cmake
├── CONTRIBUTING.md
├── CMakeLists.txt
├── scripts
└── build_wheel.sh
├── README.md
└── pyproject.toml
/assets/safari_logging.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google-deepmind/gemini-robotics-sdk/HEAD/assets/safari_logging.png
--------------------------------------------------------------------------------
/examples/aloha/requirements.txt:
--------------------------------------------------------------------------------
1 | # Additional requirements for running the real & sim robot eval.
2 | gymnasium
3 | modern_robotics
4 | pyyaml
5 | transforms3d
6 | dm_control
7 | fastapi[standard]
8 | uvicorn
9 | gdm_robotics >= 1.0.1
--------------------------------------------------------------------------------
/safari_sdk/logging/python/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/cmake/tests/test.proto:
--------------------------------------------------------------------------------
1 | // Copyright 2025 Google LLC
2 | //
3 | // Licensed under the Apache License, Version 2.0 (the "License");
4 | // you may not use this file except in compliance with the License.
5 | // You may obtain a copy of the License at
6 | //
7 | // https://www.apache.org/licenses/LICENSE-2.0
8 | //
9 | // Unless required by applicable law or agreed to in writing, software
10 | // distributed under the License is distributed on an "AS IS" BASIS,
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | // See the License for the specific language governing permissions and
13 | // limitations under the License.
14 |
15 | syntax = "proto3";
16 |
17 | package safari_sdk.externalization.testing;
18 |
19 | message TestMessage {
20 | string name = 1;
21 | int32 value = 2;
22 | }
23 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # How to Contribute
2 |
3 | ## Contributor License Agreement
4 |
5 | Contributions to this project must be accompanied by a Contributor License
6 | Agreement. You (or your employer) retain the copyright to your contribution,
7 | this simply gives us permission to use and redistribute your contributions as
8 | part of the project. Head over to to see
9 | your current agreements on file or to sign a new one.
10 |
11 | You generally only need to submit a CLA once, so if you've already submitted one
12 | (even if it was for a different project), you probably don't need to do it
13 | again.
14 |
15 | ## Code reviews
16 |
17 | All submissions, including submissions by project members, require review.
18 | Pull requests welcome.
19 |
20 | ## Community Guidelines
21 |
22 | This project follows [Google's Open Source Community
23 | Guidelines](https://opensource.google/conduct/).
24 |
--------------------------------------------------------------------------------
/safari_sdk/protos/beamform_data.proto:
--------------------------------------------------------------------------------
1 | // Copyright 2025 Google LLC
2 | //
3 | // Licensed under the Apache License, Version 2.0 (the "License");
4 | // you may not use this file except in compliance with the License.
5 | // You may obtain a copy of the License at
6 | //
7 | // https://www.apache.org/licenses/LICENSE-2.0
8 | //
9 | // Unless required by applicable law or agreed to in writing, software
10 | // distributed under the License is distributed on an "AS IS" BASIS,
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | // See the License for the specific language governing permissions and
13 | // limitations under the License.
14 |
15 | syntax = "proto3";
16 |
17 | package robotics;
18 |
19 | // Next ID: 6
20 | message BeamformData {
21 | repeated float doa_azimuth = 1;
22 | repeated float doa_elevation = 2;
23 | repeated float doa_intensity = 3;
24 | int32 channels = 4;
25 | repeated float audio_data = 5;
26 | }
27 |
--------------------------------------------------------------------------------
/safari_sdk/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Module providing an SDK for Google DeepMind Robotics models."""
16 | # Release version of the Safari SDK. Note: This is read by cmake, which expects
17 | # the version to be a plain string surrounded by quotes. Do not evaluate any
18 | # python expressions, including string formatting, variable substitution, etc.
19 | __version__ = "2.74.0.dev0"
20 |
--------------------------------------------------------------------------------
/safari_sdk/protos/logging/policy_environment_metadata.proto:
--------------------------------------------------------------------------------
1 | // Copyright 2025 Google LLC
2 | //
3 | // Licensed under the Apache License, Version 2.0 (the "License");
4 | // you may not use this file except in compliance with the License.
5 | // You may obtain a copy of the License at
6 | //
7 | // https://www.apache.org/licenses/LICENSE-2.0
8 | //
9 | // Unless required by applicable law or agreed to in writing, software
10 | // distributed under the License is distributed on an "AS IS" BASIS,
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | // See the License for the specific language governing permissions and
13 | // limitations under the License.
14 |
15 | syntax = "proto2";
16 |
17 | package safari_sdk.protos.logging;
18 |
19 | import "safari_sdk/protos/logging/feature_specs.proto";
20 |
21 | // Metadata related to the policy and environment.
22 | // Next ID: 2
23 | message PolicyEnvironmentMetadata {
24 | optional FeatureSpecs feature_specs = 1;
25 | }
26 |
--------------------------------------------------------------------------------
/safari_sdk/protos/logging/codec.proto:
--------------------------------------------------------------------------------
1 | // Copyright 2025 Google LLC
2 | //
3 | // Licensed under the Apache License, Version 2.0 (the "License");
4 | // you may not use this file except in compliance with the License.
5 | // You may obtain a copy of the License at
6 | //
7 | // https://www.apache.org/licenses/LICENSE-2.0
8 | //
9 | // Unless required by applicable law or agreed to in writing, software
10 | // distributed under the License is distributed on an "AS IS" BASIS,
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | // See the License for the specific language governing permissions and
13 | // limitations under the License.
14 |
15 | syntax = "proto2";
16 |
17 | package safari_sdk.protos.logging;
18 |
19 | // Proto to specify the codec of the media being logged.
20 | // Next ID: 4
21 | enum Codec {
22 | // Used for non-image and non-video data.
23 | CODEC_NONE = 0;
24 | CODEC_VIDEO_MPEG4 = 1;
25 | CODEC_IMAGE_JPEG = 2;
26 | CODEC_IMAGE_PNG = 3;
27 | }
28 |
--------------------------------------------------------------------------------
/safari_sdk/orchestrator/client/dataclass/operator_event.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Orchestrator operator event information."""
16 |
17 | import dataclasses
18 |
19 | import dataclasses_json
20 |
21 |
22 | @dataclasses_json.dataclass_json
23 | @dataclasses.dataclass(kw_only=True)
24 | class AddOperatorEventResponse:
25 | """Orchestrator AddOperatorEvent response information."""
26 |
27 | success: bool | None = False
28 |
--------------------------------------------------------------------------------
/cmake/tests/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # Simple CMake file to test the cmake functions.
16 |
17 | cmake_minimum_required(VERSION 3.10)
18 | project(safari_externalization_tests)
19 |
20 | include(../protobuf-generate.cmake)
21 |
22 | # Test protobuf generation.
23 | protobuf_generate(
24 | PROTOS test.proto
25 | LANGUAGE python
26 | OUT_VAR _PB_GENERATED_FILES
27 | )
28 | add_custom_target(proto_py ALL DEPENDS ${_PB_GENERATED_FILES})
29 |
--------------------------------------------------------------------------------
/safari_sdk/protos/logging/dtype.proto:
--------------------------------------------------------------------------------
1 | // Copyright 2025 Google LLC
2 | //
3 | // Licensed under the Apache License, Version 2.0 (the "License");
4 | // you may not use this file except in compliance with the License.
5 | // You may obtain a copy of the License at
6 | //
7 | // https://www.apache.org/licenses/LICENSE-2.0
8 | //
9 | // Unless required by applicable law or agreed to in writing, software
10 | // distributed under the License is distributed on an "AS IS" BASIS,
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | // See the License for the specific language governing permissions and
13 | // limitations under the License.
14 |
15 | syntax = "proto2";
16 |
17 | package safari_sdk.protos.logging;
18 |
19 | // Proto to serialize Numpy dtypes used for Safari logging.
20 | // Next ID: 8
21 | enum Dtype {
22 | DTYPE_UNSPECIFIED = 0;
23 | DTYPE_UINT8 = 1;
24 | DTYPE_UINT16 = 2;
25 | DTYPE_INT32 = 3;
26 | DTYPE_INT64 = 4;
27 | DTYPE_FLOAT32 = 5;
28 | DTYPE_FLOAT64 = 6;
29 | DTYPE_STRING = 7;
30 | }
31 |
--------------------------------------------------------------------------------
/safari_sdk/orchestrator/client/dataclass/rui_workcell_state.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Orchestrator RUI workcell state information."""
16 |
17 | import dataclasses
18 | import dataclasses_json
19 |
20 | # pylint: disable=invalid-name
21 |
22 |
23 | @dataclasses_json.dataclass_json
24 | @dataclasses.dataclass(kw_only=True)
25 | class LoadRuiWorkcellStateResponse:
26 | """Orchestrator robot job information."""
27 |
28 | workcellState: str | None = None
29 |
--------------------------------------------------------------------------------
/safari_sdk/logging/python/constants.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Constants for logging."""
16 |
17 |
18 | # Reserved topic names.
19 | FILE_METADATA_TOPIC_NAME: str = '/file_metadata'
20 | SESSION_TOPIC_NAME: str = '/session'
21 | SYNC_TOPIC_NAME: str = '/sync'
22 | TIMESTEP_TOPIC_NAME: str = '/timestep'
23 | ACTION_TOPIC_NAME: str = '/action'
24 | POLICY_EXTRA_TOPIC_NAME: str = '/policy_extra'
25 | DEFAULT_FILE_SHARD_SIZE_LIMIT_BYTES: int = 1 * 1024 * 1024 * 1024
26 |
--------------------------------------------------------------------------------
/safari_sdk/protos/logging/feature_specs.proto:
--------------------------------------------------------------------------------
1 | // Copyright 2025 Google LLC
2 | //
3 | // Licensed under the Apache License, Version 2.0 (the "License");
4 | // you may not use this file except in compliance with the License.
5 | // You may obtain a copy of the License at
6 | //
7 | // https://www.apache.org/licenses/LICENSE-2.0
8 | //
9 | // Unless required by applicable law or agreed to in writing, software
10 | // distributed under the License is distributed on an "AS IS" BASIS,
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | // See the License for the specific language governing permissions and
13 | // limitations under the License.
14 |
15 | syntax = "proto2";
16 |
17 | package safari_sdk.protos.logging;
18 |
19 | import "safari_sdk/protos/logging/spec.proto";
20 |
21 | // Maps each episodic key or stream topic to a Numpy array representation.
22 | // Next ID: 6
23 | message FeatureSpecs {
24 | map observation = 1;
25 | map reward = 2;
26 | map discount = 3;
27 | map action = 4;
28 | map policy_extra_output = 5;
29 | }
30 |
--------------------------------------------------------------------------------
/safari_sdk/protos/vector.proto:
--------------------------------------------------------------------------------
1 | // Copyright 2025 Google LLC
2 | //
3 | // Licensed under the Apache License, Version 2.0 (the "License");
4 | // you may not use this file except in compliance with the License.
5 | // You may obtain a copy of the License at
6 | //
7 | // https://www.apache.org/licenses/LICENSE-2.0
8 | //
9 | // Unless required by applicable law or agreed to in writing, software
10 | // distributed under the License is distributed on an "AS IS" BASIS,
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | // See the License for the specific language governing permissions and
13 | // limitations under the License.
14 |
15 | syntax = "proto2";
16 |
17 | package safari_sdk.protos;
18 |
19 | // Double precision vector of arbitrary size.
20 | message NamedVectorDouble {
21 | repeated string names = 1; // must be the same size as data if set.
22 | repeated double data = 2 [packed = true];
23 | }
24 |
25 | // Integer vector of arbitrary size.
26 | message NamedVectorInt64 {
27 | repeated string names = 1; // must be the same size as data if set.
28 | repeated int64 data = 2 [packed = true];
29 | }
30 |
--------------------------------------------------------------------------------
/safari_sdk/orchestrator/client/dataclass/ticket.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Orchestrator ticket information."""
16 |
17 | import enum
18 |
19 | # pylint: disable=invalid-name
20 |
21 |
22 | class TicketType(enum.IntEnum):
23 | TICKET_TYPE_UNSPECIFIED = 0
24 | TICKET_TYPE_ROBOT_MAINTENANCE = 1
25 | TICKET_TYPE_ORCHESTRATOR_ISSUE = 2
26 |
27 |
28 | class RobotFailureReason(enum.IntEnum):
29 | ROBOT_FAILURE_REASON_UNSPECIFIED = 0
30 | HARDWARE = 1
31 | SOFTWARE = 2
32 | SOFTWARE_EVAL = 7
33 | ROBOT_BEHAVIOR = 3
34 | INVESTIGATION = 4
35 | UPGRADE = 5
36 |
--------------------------------------------------------------------------------
/safari_sdk/orchestrator/client/dataclass/robot_job.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Orchestrator robot job information."""
16 |
17 | import dataclasses
18 | import dataclasses_json
19 |
20 | # pylint: disable=invalid-name
21 |
22 |
23 | @dataclasses_json.dataclass_json
24 | @dataclasses.dataclass(kw_only=True)
25 | class RobotJob:
26 | """Orchestrator robot job information."""
27 |
28 | robotJobId: str | None = None
29 |
30 |
31 | @dataclasses_json.dataclass_json
32 | @dataclasses.dataclass(kw_only=True)
33 | class RobotJobResponse:
34 | """Orchestrator robot job information."""
35 |
36 | robotJob: RobotJob
37 |
--------------------------------------------------------------------------------
/safari_sdk/protos/logging/spec.proto:
--------------------------------------------------------------------------------
1 | // Copyright 2025 Google LLC
2 | //
3 | // Licensed under the Apache License, Version 2.0 (the "License");
4 | // you may not use this file except in compliance with the License.
5 | // You may obtain a copy of the License at
6 | //
7 | // https://www.apache.org/licenses/LICENSE-2.0
8 | //
9 | // Unless required by applicable law or agreed to in writing, software
10 | // distributed under the License is distributed on an "AS IS" BASIS,
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | // See the License for the specific language governing permissions and
13 | // limitations under the License.
14 |
15 | syntax = "proto2";
16 |
17 | package safari_sdk.protos.logging;
18 |
19 | import "safari_sdk/protos/logging/codec.proto";
20 | import "safari_sdk/protos/logging/dtype.proto";
21 |
22 | // Proto to serialize a dm_env.specs.Array
23 | // Next ID: 6
24 | message Spec {
25 | repeated int64 shape = 1 [packed = true];
26 | optional Dtype dtype = 2;
27 | // Fields for a dm_env.specs.BoundedArray
28 | repeated double maximum_values = 3 [packed = true];
29 | repeated double minimum_values = 4 [packed = true];
30 | optional Codec codec = 5;
31 | }
32 |
--------------------------------------------------------------------------------
/safari_sdk/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Common utility functions for Safari SDK."""
16 |
17 | import os
18 |
19 | # Environment variable name to extract the robot ID.
20 | _DEFAULT_ROBOT_ID_IN_SYSTEM_ENV = "GA_ROBOT_ID"
21 |
22 |
23 | def get_system_env_variable(var_name: str) -> str:
24 | """Gets the requested system environment variable."""
25 | return os.environ.get(var_name, "")
26 |
27 |
28 | def get_robot_id_from_system_env(
29 | var_name: str = _DEFAULT_ROBOT_ID_IN_SYSTEM_ENV,
30 | ) -> str:
31 | """Gets the robot ID from the system environment variable."""
32 | return get_system_env_variable(var_name=var_name)
33 |
--------------------------------------------------------------------------------
/safari_sdk/protos/sensor_calibration.proto:
--------------------------------------------------------------------------------
1 | // Copyright 2025 Google LLC
2 | //
3 | // Licensed under the Apache License, Version 2.0 (the "License");
4 | // you may not use this file except in compliance with the License.
5 | // You may obtain a copy of the License at
6 | //
7 | // https://www.apache.org/licenses/LICENSE-2.0
8 | //
9 | // Unless required by applicable law or agreed to in writing, software
10 | // distributed under the License is distributed on an "AS IS" BASIS,
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | // See the License for the specific language governing permissions and
13 | // limitations under the License.
14 |
15 | syntax = "proto2";
16 |
17 | package safari_sdk.protos;
18 |
19 | import "safari_sdk/protos/camera_spec.proto";
20 | import "safari_sdk/protos/transform.proto";
21 |
22 | message SensorIntrinsics {
23 | // source_frame_id is also known as the refernce frame.
24 | optional string source_frame_id = 1;
25 |
26 | oneof intrinsics_type {
27 | PinholeCamera pinhole_camera = 2;
28 | }
29 | }
30 |
31 | // This can include extrinsics to a non-sensor such as a robot part.
32 | message SensorCalibration {
33 | // Payloads
34 | repeated SensorIntrinsics sensor_intrinsics = 1;
35 | repeated Transform sensor_extrinsics = 2;
36 | }
37 |
--------------------------------------------------------------------------------
/safari_sdk/protos/logging/tracker.proto:
--------------------------------------------------------------------------------
1 | // Copyright 2025 Google LLC
2 | //
3 | // Licensed under the Apache License, Version 2.0 (the "License");
4 | // you may not use this file except in compliance with the License.
5 | // You may obtain a copy of the License at
6 | //
7 | // https://www.apache.org/licenses/LICENSE-2.0
8 | //
9 | // Unless required by applicable law or agreed to in writing, software
10 | // distributed under the License is distributed on an "AS IS" BASIS,
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | // See the License for the specific language governing permissions and
13 | // limitations under the License.
14 |
15 | syntax = "proto2";
16 |
17 | package safari_sdk.protos.logging;
18 |
19 | import "google/protobuf/struct.proto";
20 | import "safari_sdk/protos/pose.proto";
21 |
22 | message Tracker {
23 | // The name of the tracker.
24 | optional string name = 1;
25 |
26 | // The pose of the tracker.
27 | optional Pose pose = 2;
28 |
29 | // The status of the tracker.
30 | enum Status {
31 | UNINITIALIZED = 0;
32 | ACTIVE = 1;
33 | }
34 | optional Status status = 3;
35 |
36 | // Metadata about the tracker for storing information such as tracking
37 | // confidence.
38 | optional google.protobuf.Struct metadata = 4;
39 | }
40 |
41 | message Trackers {
42 | repeated Tracker trackers = 1;
43 | }
44 |
--------------------------------------------------------------------------------
/safari_sdk/protos/logging/robot_base.proto:
--------------------------------------------------------------------------------
1 | // Copyright 2025 Google LLC
2 | //
3 | // Licensed under the Apache License, Version 2.0 (the "License");
4 | // you may not use this file except in compliance with the License.
5 | // You may obtain a copy of the License at
6 | //
7 | // https://www.apache.org/licenses/LICENSE-2.0
8 | //
9 | // Unless required by applicable law or agreed to in writing, software
10 | // distributed under the License is distributed on an "AS IS" BASIS,
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | // See the License for the specific language governing permissions and
13 | // limitations under the License.
14 |
15 | syntax = "proto2";
16 |
17 | package safari_sdk.protos.logging;
18 |
19 | import "safari_sdk/protos/pose.proto";
20 |
21 | // The pose and twist data of a robot base, aka root or center of body.
22 | message RobotBase {
23 | // The pose of the robot base.
24 | optional Pose pose = 1;
25 |
26 | // The linear velocity of the root floating base, len(linear_velocity_xyz) ==
27 | // 3, in m/s.
28 | repeated double linear_velocity_xyz = 2 [packed = true];
29 |
30 | // The angular velocity of the root floating base, len(angular_velocity_xyz)
31 | // == 3, in rad/s.
32 | repeated double angular_velocity_xyz = 3 [packed = true];
33 |
34 | // The name of the base link.
35 | optional string frame_name = 4;
36 | }
37 |
--------------------------------------------------------------------------------
/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # Cmake file for building the Safari SDK protobuf messages. This is invoked by
16 | # pip when it builds the safari_sdk package; it should not be invoked directly.
17 |
18 | cmake_minimum_required(VERSION 3.16)
19 | project(safari_sdk)
20 | include(cmake/protobuf-generate.cmake)
21 |
22 | # Generate the python protobufs and add them to the 'all' build target.
23 | # Ouput files are placed in the source tree, along side their corresponding
24 | # proto files. They can be removed with `make clean`.
25 | file(GLOB_RECURSE _PROTO_FILES CONFIGURE_DEPENDS "safari_sdk/protos/*.proto")
26 | protobuf_generate(
27 | PROTOS ${_PROTO_FILES}
28 | PROTOC_OUT_DIR ${CMAKE_SOURCE_DIR}
29 | LANGUAGE python
30 | OUT_VAR _PB_GENERATED_FILES
31 | )
32 | add_custom_target(py_proto ALL
33 | DEPENDS ${_PB_GENERATED_FILES}
34 | )
35 |
--------------------------------------------------------------------------------
/safari_sdk/protos/pose.proto:
--------------------------------------------------------------------------------
1 | // Copyright 2025 Google LLC
2 | //
3 | // Licensed under the Apache License, Version 2.0 (the "License");
4 | // you may not use this file except in compliance with the License.
5 | // You may obtain a copy of the License at
6 | //
7 | // https://www.apache.org/licenses/LICENSE-2.0
8 | //
9 | // Unless required by applicable law or agreed to in writing, software
10 | // distributed under the License is distributed on an "AS IS" BASIS,
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | // See the License for the specific language governing permissions and
13 | // limitations under the License.
14 |
15 | syntax = "proto2";
16 |
17 | package safari_sdk.protos;
18 |
19 | // A pose describes the position and orientation of an object or frame in 3D
20 | // space relative to a reference frame. It's considered as a state, describing
21 | // where something is in the reference frame.
22 | message Pose {
23 | // The position, len(position_meters_xyz) == 3, in meters.
24 | repeated double position_meters_xyz = 1 [packed = true];
25 |
26 | // The orientation as a quaternion len(orientation_xyzw) == 4.
27 | repeated double orientation_xyzw = 2 [packed = true];
28 |
29 | // source_frame_id is also known as the refernce frame.
30 | optional string source_frame_id = 3;
31 | }
32 |
33 | // Wrapper message for a list of poses.
34 | message Poses {
35 | repeated Pose poses = 1;
36 | }
37 |
--------------------------------------------------------------------------------
/safari_sdk/model/additional_observations_provider.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Interface for adding additional observations to a timestep."""
16 |
17 | import abc
18 |
19 | import dm_env
20 | from dm_env import specs
21 | import numpy as np
22 |
23 |
24 | class AdditionalObservationsProvider(abc.ABC):
25 | """Abstract class for adding new observations to the existing timestep."""
26 |
27 | @abc.abstractmethod
28 | def get_additional_observations(
29 | self, timestep: dm_env.TimeStep, should_replan: bool
30 | ) -> dict[str, np.ndarray]:
31 | """Returns a dictionary of additional observations."""
32 |
33 | @abc.abstractmethod
34 | def get_additional_observations_spec(self) -> dict[str, specs.Array]:
35 | """Returns the spec for the additional observations."""
36 |
37 | def reset(self) -> None:
38 | """Resets the internal state of the provider."""
39 |
--------------------------------------------------------------------------------
/safari_sdk/model/constants.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Constants for Gemini Robotics Policy."""
16 |
17 | import enum
18 |
19 |
20 | class RoboticsApiConnectionType(enum.Enum):
21 | """Connection types for the Robotics API."""
22 |
23 | CLOUD = "cloud" # Connects to a Google Cloud-based server.
24 | LOCAL = "local" # Connects to a local server.
25 | CLOUD_GENAI = "cloud_genai" # Connects via GenAI API.
26 |
27 |
28 | class InferenceMode(enum.Enum):
29 | SYNCHRONOUS = "synchronous"
30 | ASYNCHRONOUS = "asynchronous"
31 |
32 |
33 | # Prefix for image observation keys in the encoded observation.
34 | IMAGE_ENCODED_OBS_PREFIX = "images/"
35 |
36 | # Keys for the encoded observation.
37 | CONDITIONING_ENCODED_OBS_KEY = "conditioning"
38 |
39 | # Key for the task instruction in the encoded observation.
40 | TASK_INSTRUCTION_ENCODED_OBS_KEY = "task_instruction"
41 |
--------------------------------------------------------------------------------
/scripts/build_wheel.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # Copyright 2024 DeepMind Technologies Limited
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | # Fail on any error.
17 | set -e
18 |
19 | CMAKE_BUILD_DIR="${SAFARI_DIR}/build" # cmake build directory
20 | mkdir -p "${CMAKE_BUILD_DIR}"
21 | cd "${CMAKE_BUILD_DIR}"
22 | cmake "${SAFARI_DIR}"
23 | make pip_wheel pip_install
24 |
25 | # Smoke test generated package
26 | echo "Start smoke test"
27 | source "${CMAKE_BUILD_DIR}/safari_venv/bin/activate"
28 |
29 | flywheel-cli help
30 |
31 | python3 -c "from safari_sdk.logging.python import stream_logger"
32 | python3 -c "from safari_sdk.model import saved_model_policy"
33 | python3 -c "from safari_sdk.logging.python import mcap_episodic_logger"
34 |
35 | deactivate
36 |
37 | echo "Smoke test done."
38 | echo Pip wheel is in ${SAFARI_DIR}/dist/safari_sdk-*-py3-none-any.whl
39 |
40 | ln -fs `ls ${SAFARI_DIR}/dist/safari_sdk-*-py3-none-any.whl` /tmp/safari_sdk-lastbuild-py3-none-any.whl
41 |
--------------------------------------------------------------------------------
/safari_sdk/orchestrator/client/dataclass/current_robot_info.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Current robot information in Orchestrator."""
16 |
17 | import dataclasses
18 | import dataclasses_json
19 | from safari_sdk.orchestrator.client.dataclass import work_unit
20 |
21 | # pylint: disable=invalid-name
22 |
23 |
24 | @dataclasses_json.dataclass_json
25 | @dataclasses.dataclass(kw_only=True)
26 | class CurrentRobotInfoResponse:
27 | """Current information about the robot."""
28 |
29 | robotId: str
30 | isOperational: bool | None = False
31 | operatorId: str | None = None
32 | robotJobId: str | None = None
33 | workUnitId: str | None = None
34 | stage: work_unit.WorkUnitStage | None = None
35 | robotStage: str | None = None
36 |
37 | def __post_init__(self):
38 | if self.stage is None:
39 | self.stage = work_unit.WorkUnitStage.WORK_UNIT_STAGE_UNSPECIFIED
40 | elif isinstance(self.stage, str):
41 | self.stage = work_unit.WorkUnitStage(self.stage)
42 |
--------------------------------------------------------------------------------
/safari_sdk/logging/python/message.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Safari Message dataclass."""
16 |
17 | import dataclasses
18 |
19 | from google.protobuf import message as message_lib
20 |
21 |
22 | @dataclasses.dataclass
23 | class Message:
24 | """Safari Message dataclass."""
25 |
26 | def __init__(
27 | self,
28 | topic: str,
29 | message: message_lib.Message,
30 | publish_time_nsec: int,
31 | log_time_nsec: int = 0,
32 | ):
33 | """Initializes a Message.
34 |
35 | Args:
36 | topic: The safari_logging_topic of the message.
37 | message: The proto message to be written.
38 | publish_time_nsec: The timestamp of the message (this may be the time the
39 | message was published, or the time the data in the message was
40 | sampled).
41 | log_time_nsec: The time when the logger received the message.
42 | """
43 | self.topic = topic
44 | self.message = message
45 | self.publish_time_nsec = publish_time_nsec
46 | self.log_time_nsec = log_time_nsec
47 |
--------------------------------------------------------------------------------
/safari_sdk/protos/transform.proto:
--------------------------------------------------------------------------------
1 | // Copyright 2025 Google LLC
2 | //
3 | // Licensed under the Apache License, Version 2.0 (the "License");
4 | // you may not use this file except in compliance with the License.
5 | // You may obtain a copy of the License at
6 | //
7 | // https://www.apache.org/licenses/LICENSE-2.0
8 | //
9 | // Unless required by applicable law or agreed to in writing, software
10 | // distributed under the License is distributed on an "AS IS" BASIS,
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | // See the License for the specific language governing permissions and
13 | // limitations under the License.
14 |
15 | syntax = "proto2";
16 |
17 | package safari_sdk.protos;
18 |
19 | // A transform describes the relationship between two frames or coordinate
20 | // systems. It provides the instructions on how to transform a point or vector
21 | // from one frame to another. It is considered as an operation, describing how
22 | // to move between coordinate systems.
23 | message Transform {
24 | // The translation, len(translation_meters_xyz) == 3, in meters.
25 | repeated double translation_meters_xyz = 1 [packed = true];
26 |
27 | // The rotation as a quaternion len(rotation_xyzw) == 4.
28 | repeated double rotation_xyzw = 2 [packed = true];
29 |
30 | // translation_meters_xyz and rotation_xyzw forms the dst_transform_src
31 | // matrix. It represents a transform between two coordinate frames in free
32 | // space.
33 | // A point in the source frame: point_src
34 | // The corresponding point in the destination frame: point_dst
35 | // point_dst = dst_transform_src * point_src
36 | optional string source_frame_id = 3;
37 | optional string destination_frame_id = 4;
38 | }
39 |
40 | // Wrapper message for a list of poses.
41 | message Transforms {
42 | repeated Transform transforms = 1;
43 | }
44 |
--------------------------------------------------------------------------------
/safari_sdk/protos/logging/imu.proto:
--------------------------------------------------------------------------------
1 | // Copyright 2025 Google LLC
2 | //
3 | // Licensed under the Apache License, Version 2.0 (the "License");
4 | // you may not use this file except in compliance with the License.
5 | // You may obtain a copy of the License at
6 | //
7 | // https://www.apache.org/licenses/LICENSE-2.0
8 | //
9 | // Unless required by applicable law or agreed to in writing, software
10 | // distributed under the License is distributed on an "AS IS" BASIS,
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | // See the License for the specific language governing permissions and
13 | // limitations under the License.
14 |
15 | syntax = "proto2";
16 |
17 | package safari_sdk.protos.logging;
18 |
19 | // InertialMeasurement abstract reading from most IMU units.
20 | message Imu {
21 | // Reading from accelerometer, len(accelerometer_xyz) == 3, in m/s^2.
22 | repeated double accelerometer_xyz = 1 [packed = true];
23 |
24 | // Covariance matrix of the accelerometer.
25 | repeated double accelerometer_covariance = 10 [packed = true];
26 |
27 | // Reading from gyro, len(gyro_xyz) == 3, in rad/s.
28 | repeated double gyro_xyz = 2 [packed = true];
29 |
30 | // Covariance matrix of the gyro.
31 | repeated double gyro_covariance = 11 [packed = true];
32 |
33 | // Reading from magnetometer, len(magnetometer_xyz) == 3, in uT. This is
34 | // calibrated readings.
35 | repeated double magnetometer_xyz = 3;
36 |
37 | // Reading from uncalibrated magnetometer, len(raw_magnetometer_xyz) == 3, in
38 | // uT.
39 | repeated double raw_magnetometer_xyz = 4;
40 |
41 | // Reading from barometer in kPa.
42 | optional double barometer = 5;
43 |
44 | // Reading from uncalibrated barometer.
45 | optional double raw_barometer = 6;
46 |
47 | // The orientation as a quaternion len(quaternion_xyzw) == 4, w is always >=0.
48 | repeated double quaternion_xyzw = 7 [packed = true];
49 |
50 | // Covariance matrix of the pose quaternion.
51 | repeated double quaternion_covariance = 12 [packed = true];
52 |
53 | // Temperature of device (for temperature drift correction).
54 | optional double temperature = 9;
55 |
56 | reserved 8;
57 | }
58 |
--------------------------------------------------------------------------------
/safari_sdk/protos/label.proto:
--------------------------------------------------------------------------------
1 | // Copyright 2025 Google LLC
2 | //
3 | // Licensed under the Apache License, Version 2.0 (the "License");
4 | // you may not use this file except in compliance with the License.
5 | // You may obtain a copy of the License at
6 | //
7 | // https://www.apache.org/licenses/LICENSE-2.0
8 | //
9 | // Unless required by applicable law or agreed to in writing, software
10 | // distributed under the License is distributed on an "AS IS" BASIS,
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | // See the License for the specific language governing permissions and
13 | // limitations under the License.
14 |
15 | syntax = "proto2";
16 |
17 | package safari_sdk.protos;
18 |
19 | import "google/protobuf/struct.proto";
20 |
21 | // For marking a time interval.
22 | message IntervalValue {
23 | // Inclusive. Unix time stamp in nanosecond.
24 | optional int64 start_nsec = 1;
25 |
26 | // Exclusive. Unix time stamp in nanosecond.
27 | optional int64 stop_nsec = 2;
28 | }
29 |
30 | message DomainTimestamp {
31 | // A robot / project specific clock domain, like "camera/acquisition" for
32 | // acquisition time on camera clock.
33 | optional string clock_domain = 1;
34 | // The unix timestamp in nanoseconds.
35 | optional int64 stamp_nsec = 2;
36 | }
37 |
38 | message IntervalAnnotation {
39 | optional IntervalValue range = 1;
40 | optional google.protobuf.Value value = 2;
41 | }
42 |
43 | message IntervalAnnotations {
44 | repeated IntervalAnnotation annotations = 1;
45 | }
46 |
47 | // For storing labels and tags.
48 | message LabelMessage {
49 | // The key name. If the value is not set, it is a tag. If the
50 | // value is set, it is a the key of the label. Standard key
51 | // strings:
52 | // "success": whether the session is successful.
53 | // "task_instruction": the task instruction, descriptive instruction of the
54 | // task in natural language in English. Prefers to cap this to 300
55 | // characters.
56 | // "session_log_type": the type of the session, e.g. "teleop", "policy".
57 | optional string key = 1;
58 |
59 | oneof value {
60 | google.protobuf.Value label_value = 2;
61 | IntervalAnnotations interval_annotations = 3;
62 | }
63 | }
64 |
--------------------------------------------------------------------------------
/safari_sdk/protos/logging/audio.proto:
--------------------------------------------------------------------------------
1 | // Copyright 2025 Google LLC
2 | //
3 | // Licensed under the Apache License, Version 2.0 (the "License");
4 | // you may not use this file except in compliance with the License.
5 | // You may obtain a copy of the License at
6 | //
7 | // https://www.apache.org/licenses/LICENSE-2.0
8 | //
9 | // Unless required by applicable law or agreed to in writing, software
10 | // distributed under the License is distributed on an "AS IS" BASIS,
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | // See the License for the specific language governing permissions and
13 | // limitations under the License.
14 |
15 | syntax = "proto2";
16 |
17 | package safari_sdk.protos.logging;
18 |
19 | enum AudioFormat {
20 | AUDIO_FORMAT_UNKNOWN = 0;
21 | AUDIO_FORMAT_WEBM = 1;
22 | AUDIO_FORMAT_MP3 = 2;
23 | // Waveform audio file format. The audio data contains the bytes of a WAV
24 | // file. The actual encoding is stored within the header bytes; typically this
25 | // will be LPCM for storing uncompressed audio.
26 | AUDIO_FORMAT_WAV = 3;
27 | // PCM audio format. Pulse Code Modulation, is a method for digitally
28 | // representing analog audio signals. PCM data has no encoding.
29 | AUDIO_FORMAT_PCM = 4;
30 | }
31 |
32 | message AudioMetadata {
33 | // required for audio transcription with SAS.
34 | optional float sample_rate_hz = 1;
35 | optional int32 channel_count = 2;
36 | optional AudioFormat format = 3;
37 | }
38 |
39 | message Duration {
40 | // Signed seconds of the span of time. Must be from -315,576,000,000
41 | // to +315,576,000,000 inclusive. Note: these bounds are computed from:
42 | // 60 sec/min * 60 min/hr * 24 hr/day * 365.25 days/year * 10000 years
43 | optional int64 seconds = 1;
44 |
45 | // Signed fractions of a second at nanosecond resolution of the span
46 | // of time. Durations less than one second are represented with a 0
47 | // `seconds` field and a positive or negative `nanos` field. For durations
48 | // of one second or more, a non-zero value for the `nanos` field must be
49 | // of the same sign as the `seconds` field. Must be from -999,999,999
50 | // to +999,999,999 inclusive.
51 | optional int32 nanos = 2;
52 | }
53 |
54 | message Audio {
55 | optional AudioMetadata metadata = 1;
56 | optional Duration duration = 2;
57 | optional bytes data = 3;
58 | }
59 |
--------------------------------------------------------------------------------
/safari_sdk/protos/logging/contact_surface.proto:
--------------------------------------------------------------------------------
1 | // Copyright 2025 Google LLC
2 | //
3 | // Licensed under the Apache License, Version 2.0 (the "License");
4 | // you may not use this file except in compliance with the License.
5 | // You may obtain a copy of the License at
6 | //
7 | // https://www.apache.org/licenses/LICENSE-2.0
8 | //
9 | // Unless required by applicable law or agreed to in writing, software
10 | // distributed under the License is distributed on an "AS IS" BASIS,
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | // See the License for the specific language governing permissions and
13 | // limitations under the License.
14 |
15 | syntax = "proto2";
16 |
17 | package safari_sdk.protos.logging;
18 |
19 | // Observed sensor data from contact surface, which could be end-effector (ie.
20 | // finger surface in a dextrous hand), electronic skin, or other surfaces.
21 | message ContactSurface {
22 | // The names of the site of the contact surface.
23 | repeated string site_names = 1;
24 |
25 | // Each of the following fields should have the same length as the number of
26 | // sites.
27 |
28 | // Force or torques; unit is N or N.m.
29 | // If the sensor is 3D, then the coordinate frame is right-handed, where Z
30 | // corresponds to the thumb and is perpendicular to the contact surface site.
31 | // If the sensor is 1D, then only fz is populated representing the force /
32 | // force torque perpendicular to the contact surface site.
33 | message Force3D {
34 | optional double fx = 1;
35 | optional double fy = 2;
36 | optional double fz = 3;
37 | }
38 |
39 | // Force or torques observed at the contact surface sites. This is different
40 | // from the force torques (observed or commanded) at the underneath joint or
41 | // actuator.
42 | repeated Force3D force_torques = 2;
43 |
44 | // Temperature observed at the contact surface sites. Unit in Celsius degree
45 | // (C). This is different from the temperature of the underneath joint or
46 | // actuator.
47 | repeated double temperature = 3 [packed = true];
48 |
49 | // With Force Sensing Resistor (FSR) based tactile sensor, the "raw" reading
50 | // is individual digit per contact surface site from which the force data for
51 | // each site is approximated. The approximation quality depends on
52 | // calibration. Hence, prefers to also log the raw digits whenever available.
53 | repeated double tactile_digits = 4 [packed = true];
54 | }
55 |
--------------------------------------------------------------------------------
/safari_sdk/orchestrator/client/dataclass/artifact.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Orchestrator artifact information."""
16 |
17 | import dataclasses
18 | import enum
19 |
20 | import dataclasses_json
21 |
22 |
23 | class ArtifactObjectType(enum.Enum):
24 | """Orchestrator artifact object type."""
25 |
26 | ARTIFACT_OBJECT_TYPE_UNSPECIFIED = "ARTIFACT_OBJECT_TYPE_UNSPECIFIED"
27 | ARTIFACT_OBJECT_TYPE_IMAGE = "ARTIFACT_OBJECT_TYPE_IMAGE"
28 | ARTIFACT_OBJECT_TYPE_VIDEO = "ARTIFACT_OBJECT_TYPE_VIDEO"
29 | ARTIFACT_OBJECT_TYPE_AUDIO = "ARTIFACT_OBJECT_TYPE_AUDIO"
30 | ARTIFACT_OBJECT_TYPE_TEXT = "ARTIFACT_OBJECT_TYPE_TEXT"
31 | ARTIFACT_OBJECT_TYPE_JSON = "ARTIFACT_OBJECT_TYPE_JSON"
32 | ARTIFACT_OBJECT_TYPE_PROTOBUF = "ARTIFACT_OBJECT_TYPE_PROTOBUF"
33 | ARTIFACT_OBJECT_TYPE_DOCKER = "ARTIFACT_OBJECT_TYPE_DOCKER"
34 | ARTIFACT_OBJECT_TYPE_BYTE = "ARTIFACT_OBJECT_TYPE_BYTE"
35 | ARTIFACT_OBJECT_TYPE_OTHER = "ARTIFACT_OBJECT_TYPE_OTHER"
36 |
37 |
38 | @dataclasses_json.dataclass_json
39 | @dataclasses.dataclass(kw_only=True)
40 | class Artifact:
41 | """Represents an artifact."""
42 |
43 | # pylint: disable=invalid-name
44 | uri: str | None = None
45 | artifactId: str | None = None
46 | name: str | None = None
47 | desc: str | None = None
48 | artifactObjectType: ArtifactObjectType | None = None
49 | commitTime: str | None = None
50 | tags: list[str] | None = None
51 | version: str | None = None
52 | isZipped: bool | None = None
53 |
54 | def __post_init__(self):
55 | if self.artifactObjectType is None:
56 | self.artifactObjectType = (
57 | ArtifactObjectType.ARTIFACT_OBJECT_TYPE_UNSPECIFIED
58 | )
59 | # pylint: enable=invalid-name
60 |
61 |
62 | @dataclasses_json.dataclass_json
63 | @dataclasses.dataclass(kw_only=True)
64 | class LoadArtifactResponse:
65 | """Orchestrator artifact information."""
66 |
67 | artifact: Artifact
68 |
--------------------------------------------------------------------------------
/safari_sdk/protos/image.proto:
--------------------------------------------------------------------------------
1 | // Copyright 2025 Google LLC
2 | //
3 | // Licensed under the Apache License, Version 2.0 (the "License");
4 | // you may not use this file except in compliance with the License.
5 | // You may obtain a copy of the License at
6 | //
7 | // https://www.apache.org/licenses/LICENSE-2.0
8 | //
9 | // Unless required by applicable law or agreed to in writing, software
10 | // distributed under the License is distributed on an "AS IS" BASIS,
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | // See the License for the specific language governing permissions and
13 | // limitations under the License.
14 |
15 | syntax = "proto2";
16 |
17 | package safari_sdk.protos;
18 |
19 | import "google/protobuf/struct.proto";
20 |
21 | // A descriptive proto for image storage. This proto explicitly splits out the
22 | // pixel description for its primitive type, number of channels, and channel
23 | // order, where the latter two are described jointly. The image producer
24 | // configures the metadata for the user interpretability. Moreover, this proto
25 | // can describe a raw image and, if/when a compression type enum is added, an
26 | // encoded image.
27 | message Image {
28 | message PixelType {
29 | enum PixelPrimitive { // Per channel.
30 | UNSPECIFIED_PIXEL_PRIMITIVE = 0;
31 | UCHAR8 = 1;
32 | UINT16 = 2;
33 | }
34 | enum ChannelType1 {
35 | UNSPECIFIED_CHANNEL_TYPE_1 = 0;
36 | MONO = 1;
37 | DEPTH = 2;
38 | }
39 | enum ChannelType3 {
40 | UNSPECIFIED_CHANNEL_TYPE_3 = 0;
41 | RGB = 1;
42 | }
43 | enum ChannelType4 {
44 | UNSPECIFIED_CHANNEL_TYPE_4 = 0;
45 | RGBA = 1;
46 | }
47 |
48 | // Image compression type.
49 | enum Compression {
50 | NO_COMPRESSION = 0;
51 | JPEG = 1;
52 | PNG = 2;
53 | }
54 |
55 | optional PixelPrimitive pixel_primitive = 1;
56 |
57 | oneof channel_oneof {
58 | ChannelType1 channel_type_1 = 2;
59 | ChannelType3 channel_type_3 = 3;
60 | ChannelType4 channel_type_4 = 4;
61 | }
62 |
63 | optional Compression compression = 5;
64 | }
65 |
66 | // The cols (width) and rows (height) of the image.
67 | optional int32 cols = 1;
68 | optional int32 rows = 2;
69 |
70 | optional PixelType pixel_type = 3;
71 |
72 | optional bytes data = 4;
73 |
74 | // Metadata about the tracker for storing information such as tracking
75 | // confidence.
76 | optional google.protobuf.Struct metadata = 5;
77 | }
78 |
--------------------------------------------------------------------------------
/safari_sdk/protos/logging/metadata.proto:
--------------------------------------------------------------------------------
1 | // Copyright 2025 Google LLC
2 | //
3 | // Licensed under the Apache License, Version 2.0 (the "License");
4 | // you may not use this file except in compliance with the License.
5 | // You may obtain a copy of the License at
6 | //
7 | // https://www.apache.org/licenses/LICENSE-2.0
8 | //
9 | // Unless required by applicable law or agreed to in writing, software
10 | // distributed under the License is distributed on an "AS IS" BASIS,
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | // See the License for the specific language governing permissions and
13 | // limitations under the License.
14 |
15 | syntax = "proto2";
16 |
17 | package safari_sdk.protos.logging;
18 |
19 | import "safari_sdk/protos/label.proto";
20 | import "safari_sdk/protos/logging/policy_environment_metadata.proto";
21 |
22 | message KeyRange {
23 | // Name of the data stream.
24 | optional string topic = 1;
25 |
26 | // The time interval of the data stream.
27 | optional IntervalValue interval = 2;
28 | }
29 |
30 | message Session {
31 | // The session time interval.
32 | optional IntervalValue interval = 1;
33 |
34 | message StreamMetadata {
35 | optional KeyRange key_range = 1;
36 |
37 | optional bool is_required = 2; // Topic stream consistency options:
38 | }
39 | repeated StreamMetadata streams = 2;
40 |
41 | // For MDP logging only, a RLDSSpec in json format.
42 | optional string rlds_specs = 3;
43 |
44 | // The additional label metadata of the session.
45 | repeated LabelMessage labels = 4;
46 |
47 | // The string should consist of only alphanumeric and underscore.
48 | optional string task_id = 5;
49 |
50 | // Metadata related to the policy and environment.
51 | optional PolicyEnvironmentMetadata policy_environment_metadata = 6;
52 | }
53 |
54 | // Metadata of a log file. There should be one file metadata per log file.
55 | message FileMetadata {
56 | // The agent id string (typically robot id). The string should be in the
57 | // format of regex '[a-zA-Z]([a-zA-Z]|[0-9]|_)+', aka. the same restrictions
58 | // as proto field names. This should be less than 30 characters.
59 | optional string agent_id = 1;
60 |
61 | // Identifies the time coverage of the log file at a per-stream level.
62 | repeated KeyRange stream_coverages = 2;
63 | }
64 |
65 | message TimeSynchronization {
66 | // The key is the topic specified to the logger. Timestamps in unix time
67 | // nanoseconds.
68 | map last_timestamp_by_topic = 1;
69 | }
70 |
--------------------------------------------------------------------------------
/safari_sdk/protos/monitoring.proto:
--------------------------------------------------------------------------------
1 | // Copyright 2025 Google LLC
2 | //
3 | // Licensed under the Apache License, Version 2.0 (the "License");
4 | // you may not use this file except in compliance with the License.
5 | // You may obtain a copy of the License at
6 | //
7 | // https://www.apache.org/licenses/LICENSE-2.0
8 | //
9 | // Unless required by applicable law or agreed to in writing, software
10 | // distributed under the License is distributed on an "AS IS" BASIS,
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | // See the License for the specific language governing permissions and
13 | // limitations under the License.
14 |
15 | syntax = "proto2";
16 |
17 | package safari_sdk.protos;
18 |
19 | import "safari_sdk/protos/vector.proto";
20 |
21 | enum LogSeverity {
22 | // (--
23 | // Gaps left to allow adding new codes if needed, without perturbing the
24 | // numeric ordering.
25 | // --)
26 |
27 | // (0) The log entry has no assigned severity level.
28 | LOG_SEVERITY_DEFAULT = 0;
29 | // (100) Debug or trace information.
30 | LOG_SEVERITY_DEBUG = 100;
31 | // (200) Routine information, such as ongoing status or performance.
32 | LOG_SEVERITY_INFO = 200;
33 | // (300) Normal but significant events, such as start up, shut down, or
34 | // a configuration change.
35 | LOG_SEVERITY_NOTICE = 300;
36 | // (400) Warning events might cause problems.
37 | LOG_SEVERITY_WARNING = 400;
38 | // (500) Error events are likely to cause problems.
39 | LOG_SEVERITY_ERROR = 500;
40 | // (600) Critical events cause more severe problems or outages.
41 | LOG_SEVERITY_CRITICAL = 600;
42 | // (700) A person must take an action immediately.
43 | LOG_SEVERITY_ALERT = 700;
44 | // (800) One or more systems are unusable.
45 | LOG_SEVERITY_EMERGENCY = 800;
46 | }
47 |
48 | message MonitoringMeasurement {
49 | optional string measure_key = 1;
50 | map labels = 2;
51 |
52 | optional string unit = 3;
53 | // Description of the measuring metric, sparsely provided in stored data,
54 | // GROUP BY measure_key and select the latest value for display.
55 | optional string description = 4;
56 |
57 | oneof value {
58 | int64 int_value = 8;
59 | double double_value = 9;
60 | NamedVectorDouble named_vector_double = 10;
61 | }
62 | }
63 |
64 | message MonitoringEvent {
65 | optional LogSeverity severity = 1;
66 | optional string message = 2;
67 | map labels = 3;
68 | }
69 |
70 | message MonitoringPayload {
71 | oneof payload {
72 | MonitoringMeasurement measurement = 2;
73 | MonitoringEvent event = 3;
74 | }
75 | }
76 |
--------------------------------------------------------------------------------
/examples/model/genai_robotics_example.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Example of using genai_robotics.py."""
16 |
17 | from collections.abc import Sequence
18 | import json
19 | import time
20 |
21 | from absl import app
22 | from absl import flags
23 | import numpy as np
24 |
25 | from safari_sdk.model import genai_robotics
26 |
27 | _SERVE_ID = flags.DEFINE_string(
28 | "serve_id", None, "The ID of the model to use.", required=True
29 | )
30 |
31 |
32 | def main(argv: Sequence[str]) -> None:
33 | if len(argv) > 1:
34 | raise app.UsageError("Too many command-line arguments.")
35 |
36 | # 1. Initialize the client
37 | client = genai_robotics.Client()
38 |
39 | # 2. Prepare sample input data
40 | # Sample image (e.g., from a camera)
41 | sample_image = np.zeros(
42 | (64, 64, 3), dtype=np.uint8
43 | ) # Example: 64x64 RGB image
44 |
45 | # Sample observations, including image index and other data
46 | observations = {
47 | "images/overhead_cam": 0, # Index 0 refers to sample_image
48 | "task_instruction": "Pick up the red block.",
49 | "joints_pos": [0.1, -0.2, 0.3, -0.4, 0.5, -0.6],
50 | }
51 | observations_json = json.dumps(observations)
52 |
53 | # 3. Call the generate_content method
54 | print(f"Calling model {_SERVE_ID.value}...")
55 | try:
56 | response = client.models.generate_content(
57 | model=_SERVE_ID.value,
58 | contents=[
59 | sample_image, # Can be np.array, tf.Tensor, bytes, or types.Part
60 | observations_json,
61 | ],
62 | )
63 |
64 | # 4. Print the response
65 | print("Model Response:", response.text)
66 |
67 | except Exception as e: # pylint: disable=broad-exception-caught
68 | print(f"An error occurred: {e}")
69 |
70 | # 5. Call 5 times and print the average time
71 | print("Calling model 10 times...")
72 | times = []
73 | for _ in range(10):
74 | start = time.time()
75 | response = client.models.generate_content(
76 | model=_SERVE_ID.value,
77 | contents=[
78 | sample_image, # Can be np.array, tf.Tensor, bytes, or types.Part
79 | observations_json,
80 | ],
81 | )
82 | end = time.time()
83 | times.append(end - start)
84 | del response
85 | print("times: ", times)
86 | print("Average time:", np.mean(times))
87 |
88 |
89 | if __name__ == "__main__":
90 | app.run(main)
91 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Safari SDK: the SDK for Google DeepMind Gemini Robotics models 🦓🦄🐘🐒🐍
2 |
3 | ## Disclaimer
4 |
5 | This is not an officially supported Google product.
6 |
7 | Safari SDK provides full lifecycle toolings necessary for using Gemini Robotics
8 | models, including but not limited to, access checkpoint, serving a model,
9 | evaluate the model on robot and in sim, upload data, finetuning the model,
10 | download the finetuned checkpoint, etc. Most of the functionality requires you
11 | to join Gemini Robotics Trusted Tester Program to use. See details in Gemini
12 | Robotics [main page](https://deepmind.google/models/gemini-robotics/).
13 |
14 | ## Installation and access the source code
15 |
16 | Safari SDK can be easily installed via PyPI. It is recommended to use a
17 | virtual environment to avoid dependency version conflict.
18 |
19 | ```shell
20 | pip install safari_sdk
21 | ```
22 |
23 | The source code can be found in [GitHub](https://github.com/google-deepmind/gemini-robotics-sdk).
24 |
25 | ## Building the wheel after code change
26 |
27 | To build a Python wheel, run the following command from the root of the
28 | repository.
29 |
30 | ```shell
31 | scripts/build_wheel.sh
32 | ```
33 |
34 | This script will build a pip installable wheel for the Safari SDK, and print the
35 | file's path to stdout.
36 |
37 | ## Model support
38 |
39 | Safari SDK aims to support all models in the Gemini Robotics model series.
40 |
41 | Trusted Testers can access the Gemini Robotics On Device model from SDK v2.4.1.
42 |
43 | ## Libraries
44 |
45 | Libraries related to robot data logging is in `safari/logging`.
46 |
47 | Libraries related to model inference and interface with model servers are in
48 | `safari/model`.
49 |
50 | Libraries and binary related to accessing model checkpoints, upload data and
51 | request of model finetune can be found in `safari/flywheel`.
52 |
53 | Examples, including robot and simulation evaluation of models are in
54 | `examples/`. Aloha specific eval code are in `examples/aloha`.
55 |
56 | ## Flywheel CLI
57 |
58 | The flywheel CLI is a convenient CLI tool available after installation of the
59 | pip package. It provides a set of commands to interact with the Gemini Robotics
60 | platform, such as training models, serving models, managing data, and
61 | downloading artifacts.
62 |
63 | To use the CLI
64 |
65 | ```
66 | flywheel-cli [--flags] [--flags]
67 | ```
68 |
69 | Supported commands are:
70 |
71 | * `train`: Train a model. Requires specifying task ID, start date, and end
72 | date.
73 | * `serve`: Serve a model. Requires specifying the training job ID.
74 | * `list`: List available training jobs.
75 | * `list_serve`: List available serving jobs.
76 | * `data_stats`: Show data statistics available for training.
77 | * `download`: Download artifacts from a training job or a specific artifact
78 | ID.
79 | * `upload_data`: Upload data to the data ingestion service.
80 | * `version`: Show the version of the SDK.
81 | * `help`: Show this help message with all the available commands and flags.
82 |
83 | The codebase is still in active development. We will update our most updated
84 | user guide with Trusted Testers of Gemini Robotics.
85 |
--------------------------------------------------------------------------------
/safari_sdk/orchestrator/client/dataclass/current_robot_info_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Unit tests for current_robot_info.py."""
16 |
17 | from absl.testing import absltest
18 |
19 | from safari_sdk.orchestrator.client.dataclass import current_robot_info
20 |
21 |
22 | class CurrentRobotInfoTest(absltest.TestCase):
23 |
24 | def test_response_post_init_from_json_response(self):
25 | response = current_robot_info.CurrentRobotInfoResponse(
26 | robotId="test_robot_id",
27 | isOperational=True,
28 | operatorId="test_operator_id",
29 | robotJobId="test_robot_job_id",
30 | workUnitId="test_work_unit_id",
31 | stage="WORK_UNIT_STAGE_QUEUED_TO_ROBOT",
32 | )
33 | self.assertEqual(response.robotId, "test_robot_id")
34 | self.assertTrue(response.isOperational)
35 | self.assertEqual(response.operatorId, "test_operator_id")
36 | self.assertEqual(response.robotJobId, "test_robot_job_id")
37 | self.assertEqual(response.workUnitId, "test_work_unit_id")
38 | self.assertIsInstance(
39 | response.stage, current_robot_info.work_unit.WorkUnitStage
40 | )
41 | self.assertEqual(
42 | response.stage,
43 | current_robot_info.work_unit.WorkUnitStage.WORK_UNIT_STAGE_QUEUED_TO_ROBOT,
44 | )
45 |
46 | def test_response_post_init_as_enum(self):
47 | response = current_robot_info.CurrentRobotInfoResponse(
48 | robotId="test_robot_id",
49 | stage=current_robot_info.work_unit.WorkUnitStage.WORK_UNIT_STAGE_QUEUED_TO_ROBOT,
50 | )
51 | self.assertEqual(response.robotId, "test_robot_id")
52 | self.assertFalse(response.isOperational)
53 | self.assertIsNone(response.operatorId)
54 | self.assertIsNone(response.robotJobId)
55 | self.assertIsNone(response.workUnitId)
56 | self.assertIsInstance(
57 | response.stage, current_robot_info.work_unit.WorkUnitStage
58 | )
59 | self.assertEqual(
60 | response.stage,
61 | current_robot_info.work_unit.WorkUnitStage.WORK_UNIT_STAGE_QUEUED_TO_ROBOT,
62 | )
63 |
64 | def test_response_post_init_as_none(self):
65 | response = current_robot_info.CurrentRobotInfoResponse(
66 | robotId="test_robot_id",
67 | )
68 | self.assertEqual(response.robotId, "test_robot_id")
69 | self.assertFalse(response.isOperational)
70 | self.assertIsNone(response.operatorId)
71 | self.assertIsNone(response.robotJobId)
72 | self.assertIsNone(response.workUnitId)
73 | self.assertIsInstance(
74 | response.stage, current_robot_info.work_unit.WorkUnitStage
75 | )
76 | self.assertEqual(
77 | response.stage,
78 | current_robot_info.work_unit.WorkUnitStage.WORK_UNIT_STAGE_UNSPECIFIED,
79 | )
80 |
81 |
82 | if __name__ == "__main__":
83 | absltest.main()
84 |
--------------------------------------------------------------------------------
/safari_sdk/orchestrator/client/dataclass/api_response.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Orchestrator API response format."""
16 |
17 | import dataclasses
18 |
19 | from googleapiclient import discovery
20 | import numpy as np
21 | from PIL import Image
22 |
23 | from safari_sdk.orchestrator.client.dataclass import artifact as artifact_data
24 | from safari_sdk.orchestrator.client.dataclass import robot_job as robot_job_data
25 | from safari_sdk.orchestrator.client.dataclass import work_unit as work_unit_data
26 |
27 |
28 | @dataclasses.dataclass(frozen=True, kw_only=True)
29 | class OrchestratorAPIResponse:
30 | """Orchestrator API response.
31 |
32 | All orchestrator client API calls will return their response with this
33 | dataclass. The only exception is for the disconnect() API call that is not
34 | expected to return any data. This gives an explicit information on if the API
35 | call was successful or not, as well as the requested data if any.
36 |
37 | Attributes:
38 | success: Whether the API call was successful.
39 | error_message: Error message if the API call was not successful.
40 | no_more_robot_job: If true, there are no active robot jobs available to this
41 | robot.
42 | no_more_work_unit: If true, there are no active work units available to this
43 | robot.
44 | is_visual_overlay_found: If true, there are visual overlay information
45 | specified within the current work unit.
46 | project_id: Project ID for the current robot job and work unit.
47 | robot_id: Robot ID for the current robot job and work unit.
48 | robot_job: Actual RobotJob dataclass object, containing all the values
49 | of the current robot job.
50 | robot_job_id: Robot job ID for the current robot job and work unit.
51 | work_unit_id: Work unit ID for the current work unit.
52 | work_unit: Actual WorkUnit dataclass object, containing all the values
53 | of the current work unit.
54 | work_unit_stage: Current stage of the work unit, if any.
55 | operator_id: Current operator ID for the robot, if any.
56 | is_operational: Whether the robot is operational.
57 | server_connection: Server connection to the orchestrator server.
58 | visual_overlay_renderer_keys: List of keys to access specific visual overlay
59 | renderers.
60 | visual_overlay_image: Image with visual overlay drawn on it.
61 | artifact: Artifact information for the current work unit.
62 | workcell_state: Current RUI workcell state.
63 | robot_stage: Current Orca status of the robot.
64 | artifact_uri: Download URI for the specified artifact.
65 | """
66 |
67 | success: bool = False
68 | error_message: str = ""
69 | no_more_robot_job: bool = False
70 | no_more_work_unit: bool = False
71 | is_visual_overlay_found: bool = False
72 | project_id: str | None = None
73 | robot_id: str | None = None
74 | robot_job: robot_job_data.RobotJob | None = None
75 | robot_job_id: str | None = None
76 | work_unit_id: str | None = None
77 | work_unit: work_unit_data.WorkUnit | None = None
78 | work_unit_stage: work_unit_data.WorkUnitStage | None = None
79 | operator_id: str | None = None
80 | is_operational: bool | None = None
81 | server_connection: discovery.Resource | None = None
82 | visual_overlay_renderer_keys: list[str] | None = None
83 | visual_overlay_image: Image.Image | np.ndarray | bytes | None = None
84 | artifact: artifact_data.Artifact | None = None
85 | workcell_state: str | None = None
86 | robot_stage: str | None = None
87 | artifact_uri: str | None = None
88 |
--------------------------------------------------------------------------------
/safari_sdk/flywheel/upload_data.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Upload data library."""
16 |
17 | import datetime
18 | import json
19 | import os
20 | import time
21 |
22 | import pytz
23 | import requests
24 |
25 | from safari_sdk import auth
26 |
27 |
28 | def _upload_file(
29 | *,
30 | api_endpoint,
31 | agent_id,
32 | filename,
33 | file_content_bytes,
34 | api_key,
35 | now,
36 | ):
37 | """Calls the data ingestion service to upload the file."""
38 |
39 | def to_multi_part(metadata, body, ct):
40 | """Returns a multi-part request for the metadata and body."""
41 | boundary_ = b'BOUNDARY'
42 | data_ct = b'Content-Type: application/octet-stream'
43 | payload = b''.join([
44 | b'--',
45 | boundary_,
46 | b'\r\n',
47 | data_ct,
48 | b'\r\n\r\n',
49 | metadata,
50 | b'\r\n--',
51 | boundary_,
52 | b'\r\n',
53 | data_ct,
54 | b'\r\n\r\n',
55 | body,
56 | b'\r\n--',
57 | boundary_,
58 | b'--\r\n',
59 | ])
60 | headers = {
61 | 'X-Goog-Upload-Protocol': 'multipart',
62 | 'X-Goog-Upload-Header-Content-Type': ct.decode('utf-8'),
63 | 'Content-Type': (
64 | 'multipart/related; boundary=%s' % boundary_.decode('utf-8')
65 | ),
66 | }
67 | return headers, payload
68 |
69 | request_dict = {
70 | 'date': {'year': now.year, 'month': now.month, 'day': now.day},
71 | 'agentId': agent_id,
72 | 'filename': filename,
73 | }
74 | headers, body = to_multi_part(
75 | json.dumps(request_dict).encode(), file_content_bytes, b'text/plain'
76 | )
77 | r = requests.post(
78 | api_endpoint,
79 | params={'key': api_key},
80 | headers=headers,
81 | data=body,
82 | )
83 | return (r.status_code, r.reason)
84 |
85 |
86 | def upload_data_directory(
87 | api_endpoint,
88 | data_directory,
89 | robot_id,
90 | ):
91 | """Upload data directory."""
92 |
93 | api_key = auth.get_api_key()
94 | if not api_key:
95 | raise ValueError('No API key found.')
96 |
97 | for root, dirs, files in os.walk(data_directory):
98 | del dirs
99 | for file in files:
100 | if file.endswith('.mcap'):
101 | file_path = os.path.join(root, file)
102 |
103 | with open(file_path, 'rb') as f:
104 | file_content_bytes = f.read()
105 | file_size_mb = len(file_content_bytes) / (1024 * 1024)
106 |
107 | t_start = time.time()
108 | status_code, reason = _upload_file(
109 | api_endpoint=api_endpoint,
110 | agent_id=robot_id,
111 | filename=file,
112 | file_content_bytes=file_content_bytes,
113 | api_key=api_key,
114 | now=datetime.datetime.now(pytz.timezone('America/Los_Angeles')),
115 | )
116 | t_end = time.time()
117 |
118 | if status_code == 200:
119 | uploaded_file_path = file_path + '.uploaded'
120 | os.rename(file_path, uploaded_file_path)
121 |
122 | upload_speed_mb_s = file_size_mb / (t_end - t_start)
123 | print(
124 | f'Uploaded {file} ({file_size_mb:.2f} MB) and renamed to'
125 | f' {uploaded_file_path} in {t_end - t_start:.2f}s'
126 | f' ({upload_speed_mb_s:.2f} MB/s)'
127 | )
128 | else:
129 | print(f'Failed to upload {file} ({file_size_mb:.2f} MB): {reason}')
130 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "safari_sdk"
3 | description = "Safari SDK: the SDK for Google DeepMind Gemini Robotics models"
4 | readme = "README.md"
5 | requires-python = ">=3.10,<3.14" # TODO
6 | license = {file = "LICENSE"}
7 | authors = [{name = "Safari SDKAuthors", email="safari-sdk-authors@google.com"}]
8 | classifiers = [ # List of https://pypi.org/classifiers/
9 | "Programming Language :: Python :: 3",
10 | "Programming Language :: Python :: 3 :: Only",
11 | # "Programming Language :: Python :: 3.9", # TODO
12 | "Programming Language :: Python :: 3.10", # EOL 2026-10
13 | "Programming Language :: Python :: 3.11", # EOL 2027-10
14 | "Programming Language :: Python :: 3.12", # EOL 2028-10
15 | # "Programming Language :: Python :: 3.13", # EOL 2029-19, no compatible tensorflow version as of TF 2.18
16 | "License :: Other/Proprietary License",
17 | "Intended Audience :: Science/Research",
18 | ]
19 | keywords = []
20 |
21 | # pip dependencies of the project
22 | dependencies = [
23 | # copybara:strip_begin(for internal changes only)
24 | # LINT.IfChange
25 | # copybara:strip_end
26 | # go/keep-sorted start
27 | "absl-py",
28 | "dataclasses_json",
29 | "dm-env",
30 | "evdev",
31 | "gdm-robotics",
32 | "google-api-python-client",
33 | "google-auth-httplib2",
34 | "google-auth-oauthlib",
35 | "google-genai",
36 | "grpcio",
37 | "imageio",
38 | "immutabledict",
39 | "lark~=1.2",
40 | "mcap-protobuf-support",
41 | "mediapy",
42 | "mujoco>=3.3.6",
43 | "opencv-python",
44 | "overrides",
45 | "protobuf<7",
46 | "psutil",
47 | "python-magic",
48 | "pytz",
49 | "scipy",
50 | "tensorflow",
51 | "watchdog",
52 | # go/keep-sorted end
53 | # copybara:strip_begin(for internal changes only)
54 | # LINT.ThenChange(//depot/google3/third_party/safari/sdk/requirements.txt)
55 | # Run `bash third_party/safari/sdk/scripts/check_dependency_version_requirements.sh`
56 | # to update the dependencies in requirements.txt.
57 | # copybara:strip_end
58 | ]
59 |
60 | # Configure scikit to pull version value from "safari_sdk/__init__.py"
61 | # See: https://scikit-build-core.readthedocs.io/en/latest/configuration/dynamic.html#regex
62 | dynamic = ["version"]
63 |
64 | [tool.scikit-build.metadata.version]
65 | provider = "scikit_build_core.metadata.regex"
66 | input = "safari_sdk/__init__.py"
67 |
68 | [tool.scikit-build.wheel]
69 | # By default, scikit-build assumes the package contains platform-specific files
70 | # and names the .whl file accordingly (eg "*-cp313-cp313-linux_x86_64.whl").
71 | # However, this package does not contain any platform-specific files, so set
72 | # platlib to false so the output whl ends with "-py3-none-any.whl". If this ever
73 | # changes (e.g. if we add pybind11 extensions), this flag should be removed.
74 | platlib = false
75 |
76 | [project.scripts]
77 | flywheel-cli = "safari_sdk.flywheel.flywheel_cli:cli_main" # 'flywheel-cli' will be installed as a command
78 |
79 | [project.urls]
80 | homepage = "https://deepmind.google/models/gemini-robotics/"
81 | repository = "https://github.com/google-deepmind/gemini-robotics-sdk"
82 | # Other: `documentation`, `changelog`
83 |
84 | [project.optional-dependencies]
85 | # Development deps (unittest, linting, formating,...)
86 | # Installed through `pip install .[dev]`
87 | dev = [
88 | # go/keep-sorted start
89 | "immutabledict",
90 | "parameterized",
91 | "pyink",
92 | "pylint>=2.6.0",
93 | "pytest",
94 | "pytest-xdist",
95 | # go/keep-sorted end
96 | ]
97 |
98 |
99 | [tool.pyink]
100 | # Formatting configuration to follow Google style-guide
101 | pyink-indentation = 2
102 | pyink-use-majority-quotes = true
103 |
104 | # The scikit backend invokes cmake to build the proto messages.
105 | # This is done in a unique temporary build directory unless otherwise specified.
106 | [build-system]
107 | requires = ["scikit-build-core>=0.10","setuptools"]
108 | build-backend = "scikit_build_core.build"
109 |
110 | [tool.pytest.ini_options]
111 | minversion = "8.0"
112 | addopts = "-ra -q"
113 | python_files = ["*_test.py"]
114 |
--------------------------------------------------------------------------------
/safari_sdk/orchestrator/client/libs/current_robot.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Current robot info APIs for interacting with the orchestrator server."""
16 |
17 | import json
18 |
19 | from googleapiclient import discovery
20 | from googleapiclient import errors
21 |
22 | from safari_sdk.orchestrator.client.dataclass import api_response
23 | from safari_sdk.orchestrator.client.dataclass import current_robot_info
24 |
25 | _RESPONSE = api_response.OrchestratorAPIResponse
26 |
27 | _ERROR_NO_ORCHESTRATOR_CONNECTION = (
28 | "OrchestratorCurrentRobotInfo: Orchestrator connection is invalid."
29 | )
30 | _ERROR_GET_CURRENT_ROBOT_INFO = (
31 | "OrchestratorCurrentRobotInfo: Error in requesting current robot info.\n"
32 | )
33 | _ERROR_SET_CURRENT_ROBOT_OPERATOR_ID = (
34 | "OrchestratorCurrentRobotInfo: Error in setting robot operator ID.\n"
35 | )
36 |
37 |
38 | class OrchestratorCurrentRobotInfo:
39 | """Current robot info API client for interacting with orchestrator server."""
40 |
41 | def __init__(
42 | self, *, connection: discovery.Resource, robot_id: str
43 | ):
44 | """Initializes the robot job handler."""
45 | self._connection = connection
46 | self._robot_id = robot_id
47 |
48 | def disconnect(self) -> None:
49 | """Clears current connection to the orchestrator server."""
50 | self._connection = None
51 |
52 | def get_current_robot_info(self) -> _RESPONSE:
53 | """Gets the current robot job."""
54 |
55 | if self._connection is None:
56 | return _RESPONSE(error_message=_ERROR_NO_ORCHESTRATOR_CONNECTION)
57 |
58 | body = {
59 | "robot_id": self._robot_id,
60 | }
61 |
62 | # assert server_connection_response.server_connection is not None
63 | try:
64 | response = (
65 | self._connection.orchestrator().currentRobotInfo(body=body).execute()
66 | )
67 | except errors.HttpError as e:
68 | return _RESPONSE(
69 | error_message=(
70 | _ERROR_GET_CURRENT_ROBOT_INFO
71 | + f"Reason: {e.reason}\nDetail: {e.error_details}"
72 | )
73 | )
74 |
75 | as_json = json.dumps(response)
76 | info = current_robot_info.CurrentRobotInfoResponse.from_json(as_json)
77 |
78 | return _RESPONSE(
79 | success=True,
80 | robot_id=self._robot_id,
81 | robot_job_id=info.robotJobId,
82 | work_unit_id=info.workUnitId,
83 | work_unit_stage=info.stage,
84 | operator_id=info.operatorId,
85 | is_operational=info.isOperational,
86 | robot_stage=info.robotStage,
87 | )
88 |
89 | def set_current_robot_operator_id(self, operator_id: str) -> _RESPONSE:
90 | """Set the current operator ID for the robot."""
91 |
92 | if self._connection is None:
93 | return _RESPONSE(error_message=_ERROR_NO_ORCHESTRATOR_CONNECTION)
94 |
95 | body = {
96 | "robot_id": self._robot_id,
97 | "operator_id": operator_id,
98 | }
99 |
100 | try:
101 | (
102 | self._connection.orchestrator().currentRobotSetOperatorId(body=body)
103 | .execute()
104 | )
105 | except errors.HttpError as e:
106 | return _RESPONSE(
107 | error_message=(
108 | _ERROR_SET_CURRENT_ROBOT_OPERATOR_ID
109 | + f"Reason: {e.reason}\nDetail: {e.error_details}"
110 | )
111 | )
112 |
113 | return _RESPONSE(
114 | success=True,
115 | robot_id=self._robot_id,
116 | operator_id=operator_id,
117 | )
118 |
--------------------------------------------------------------------------------
/safari_sdk/orchestrator/client/libs/artifact.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Artifact APIs interacting with the orchestrator server."""
16 |
17 | import json
18 |
19 | from googleapiclient import discovery
20 | from googleapiclient import errors
21 |
22 | from safari_sdk.orchestrator.client.dataclass import api_response
23 | from safari_sdk.orchestrator.client.dataclass import artifact
24 |
25 | _RESPONSE = api_response.OrchestratorAPIResponse
26 |
27 | _ERROR_NO_ORCHESTRATOR_CONNECTION = (
28 | "OrchestratorArtifact: Orchestrator connection is invalid."
29 | )
30 | _ERROR_GET_ARTIFACT = "OrchestratorArtifact: Error in requesting artifact.\n"
31 | _ERROR_EMPTY_RESPONSE = (
32 | "OrchestratorArtifact: Received empty response for get artifact request."
33 | )
34 |
35 |
36 | class OrchestratorArtifact:
37 | """Artifact API client for interacting with the orchestrator server."""
38 |
39 | def __init__(
40 | self,
41 | *,
42 | connection: discovery.Resource,
43 | ):
44 | """Initializes the robot job handler."""
45 | self._connection = connection
46 |
47 | def disconnect(self) -> None:
48 | """Clears current connection to the orchestrator server."""
49 | self._connection = None
50 |
51 | def get_artifact(self, artifact_id: str) -> _RESPONSE:
52 | """Gets detailed artifact information."""
53 | if self._connection is None:
54 | return _RESPONSE(error_message=_ERROR_NO_ORCHESTRATOR_CONNECTION)
55 |
56 | body = {"artifact_id": artifact_id}
57 |
58 | try:
59 | response = (
60 | self._connection.orchestrator().loadArtifact(body=body).execute()
61 | )
62 | except errors.HttpError as e:
63 | return _RESPONSE(
64 | error_message=(
65 | _ERROR_GET_ARTIFACT
66 | + f"Reason: {e.reason}\nDetail: {e.error_details}"
67 | )
68 | )
69 |
70 | if not response or "artifact" not in response:
71 | return _RESPONSE(error_message=_ERROR_EMPTY_RESPONSE)
72 |
73 | as_json = json.dumps(response)
74 | artifact_response = artifact.LoadArtifactResponse.from_json(as_json)
75 |
76 | artifact_obj = artifact_response.artifact
77 | if not artifact_obj or not artifact_obj.uri:
78 | return _RESPONSE(error_message=_ERROR_EMPTY_RESPONSE)
79 |
80 | return _RESPONSE(success=True, artifact=artifact_obj)
81 |
82 | def get_artifact_uri(self, artifact_id: str) -> _RESPONSE:
83 | """Gets the artifact's download URI."""
84 | if not artifact_id:
85 | return _RESPONSE(error_message="Artifact ID is empty.")
86 |
87 | if self._connection is None:
88 | return _RESPONSE(error_message=_ERROR_NO_ORCHESTRATOR_CONNECTION)
89 |
90 | body = {"artifact_id": artifact_id}
91 |
92 | try:
93 | response = (
94 | self._connection.orchestrator().loadArtifact(body=body).execute()
95 | )
96 | except errors.HttpError as e:
97 | return _RESPONSE(
98 | error_message=(
99 | _ERROR_GET_ARTIFACT
100 | + f"Reason: {e.reason}\nDetail: {e.error_details}"
101 | )
102 | )
103 |
104 | if not response:
105 | return _RESPONSE(error_message=_ERROR_EMPTY_RESPONSE)
106 |
107 | as_json = json.dumps(response)
108 | artifact_response = artifact.LoadArtifactResponse.from_json(as_json)
109 |
110 | if not artifact_response or not artifact_response.artifact:
111 | return _RESPONSE(error_message=_ERROR_EMPTY_RESPONSE)
112 |
113 | artifact_obj = artifact_response.artifact
114 |
115 | return _RESPONSE(success=True, artifact_uri=artifact_obj.uri)
116 |
--------------------------------------------------------------------------------
/safari_sdk/protos/camera_spec.proto:
--------------------------------------------------------------------------------
1 | // Copyright 2025 Google LLC
2 | //
3 | // Licensed under the Apache License, Version 2.0 (the "License");
4 | // you may not use this file except in compliance with the License.
5 | // You may obtain a copy of the License at
6 | //
7 | // https://www.apache.org/licenses/LICENSE-2.0
8 | //
9 | // Unless required by applicable law or agreed to in writing, software
10 | // distributed under the License is distributed on an "AS IS" BASIS,
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | // See the License for the specific language governing permissions and
13 | // limitations under the License.
14 |
15 | syntax = "proto2";
16 |
17 | package safari_sdk.protos;
18 |
19 | // Intrinsic properties of an optical sensor with a 2D image result.
20 | message PinholeCamera {
21 | // Camera matrix:
22 | // | fx 0 cx |
23 | // | 0 fy cy |
24 | // | 0 0 1 |
25 | // Transforms points from the image plane into pixels.
26 | optional double fx = 1; // Focal length in x.
27 | optional double fy = 2; // Focal length in y.
28 | optional double cx = 3; // Center of projection in x.
29 | optional double cy = 4; // Center of projection in y.
30 |
31 | // Dimensions of the image array in pixels.
32 | optional int32 image_width = 9; // Number of image columns.
33 | optional int32 image_height = 10; // Number of image rows.
34 |
35 | // Field of view angle, measured radially in the canonical (distorted) image
36 | // plane.
37 | //
38 | // The FOV angle is a radial limit (from the camera’s principal ray) where the
39 | // distortion function starts failing to describe the optics present.
40 | optional double fov_radial_radians = 6;
41 |
42 | // https://en.wikipedia.org/wiki/Distortion_(optics)
43 | //
44 | // Each distortion model maps a ray (or point in the normalized image plane)
45 | // to a pixel in the canonical (distorted) image plane.
46 | //
47 | // (x,y) are the intersection of the ray with the z=1 plane.
48 | // (u,v) are the distorted coordinates in the canonical image plane.
49 | //
50 | // The field of view must be less than 180 degrees in this type of model.
51 |
52 | // Parameters for Brown-Conrady distortion.
53 | //
54 | // r = |(x,y)|
55 | // f(r) = (1 + k1 r^2 + k2 r^4 + k3 r^6) / (1 + k4 r^2 + k5 r^4 + k6 r^6)
56 | // u = x * f
57 | // v = y * f
58 | // u += 2 p1 xy + p2(r^2 + 2 x^2)
59 | // v += 2 p2 xy + p1(r^2 + 2 y^2)
60 | message BrownConradyDistortion {
61 | // Radial coefficients.
62 | optional double k1 = 1;
63 | optional double k2 = 2;
64 | optional double k3 = 3;
65 | optional double k4 = 4;
66 | optional double k5 = 5;
67 | optional double k6 = 6;
68 |
69 | // Decentering coefficients.
70 | optional double p1 = 7;
71 | optional double p2 = 8;
72 | }
73 |
74 | // Kannala-Brandt is the distortion model used in OpenCV for a fisheye camera.
75 | //
76 | // It models only radial distortion. It operates on the angle, θ, between the
77 | // ray and the view direction.
78 | //
79 | // r = |(x,y)|
80 | // θ = atan(r)
81 | // f(θ) = θ(1 + k1 θ^2 + k2 θ^4 + k3 θ^6 + k4 θ^8)
82 | // u = (f/r) x
83 | // v = (f/r) y
84 | message KannalaBrandtDistortion {
85 | optional double k1 = 1;
86 | optional double k2 = 2;
87 | optional double k3 = 3;
88 | optional double k4 = 4;
89 | }
90 |
91 | // Unified distortion model.
92 | //
93 | // This model can be used for both perspective, fisheye cameras, and
94 | // catadiptical cameras.
95 | message UnifiedDistortion {
96 | optional double xi = 1;
97 | }
98 |
99 | // Special Enhanced Unified distortion model.
100 | //
101 | // This model can be used for ultrawide and fisheye cameras.
102 | message EnhancedUnifiedDistortion {
103 | optional double eu = 1;
104 | optional double ev = 2;
105 | optional double alpha = 3;
106 | optional double beta = 4;
107 | }
108 |
109 | oneof distortion_model {
110 | BrownConradyDistortion brown_conrady = 7;
111 | KannalaBrandtDistortion kannala_brandt = 8;
112 | UnifiedDistortion unified = 11;
113 | EnhancedUnifiedDistortion extended_unified = 12;
114 | }
115 |
116 | reserved 5, 13;
117 | reserved "image_size";
118 | }
119 |
--------------------------------------------------------------------------------
/safari_sdk/protos/joints.proto:
--------------------------------------------------------------------------------
1 | // Copyright 2025 Google LLC
2 | //
3 | // Licensed under the Apache License, Version 2.0 (the "License");
4 | // you may not use this file except in compliance with the License.
5 | // You may obtain a copy of the License at
6 | //
7 | // https://www.apache.org/licenses/LICENSE-2.0
8 | //
9 | // Unless required by applicable law or agreed to in writing, software
10 | // distributed under the License is distributed on an "AS IS" BASIS,
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | // See the License for the specific language governing permissions and
13 | // limitations under the License.
14 |
15 | syntax = "proto2";
16 |
17 | package safari_sdk.protos;
18 |
19 | import "safari_sdk/protos/label.proto";
20 |
21 | // Joints abstracts information about any kinematics chain or multiple
22 | // kinematics chains. Each repeated field covers an aspect of the joints. For a
23 | // single repeated field, the number of elements is either zero, which denotes
24 | // no information, or the number of simple joint (1d revolute or prismatic).
25 | //
26 | // When used as logging payload, it is recommended to have consistency on the
27 | // availability of fields in the same log stream, i.e. a single log stream would
28 | // always has some fields (e.g. position and velocity) but not others.
29 | //
30 | // This type can be used as both commands to and states from robot.
31 |
32 | // Although named Joints, this message can also represent Actuators.
33 | message Joints {
34 | // Position of joints. Depending on joint type, unit is meters or radians.
35 | repeated double positions = 1 [packed = true];
36 |
37 | // Velocity of joints. Depending on joint type, unit is m/s or rad/s.
38 | repeated double velocities = 2 [packed = true];
39 |
40 | // Acceleration of joints. Depending on joint type, unit is m/s^2 or rad/s^2.
41 | repeated double accelerations = 3 [packed = true];
42 |
43 | // Jerk of joints. Depending on joint type, unit is m/s^3 or rad/s^3.
44 | repeated double jerks = 4 [packed = true];
45 |
46 | // Force or torques at the joints (applied or measured). Unit is N or N.m.
47 | repeated double force_torques = 5 [packed = true];
48 |
49 | // Current to joint accuators. Unit is ampere.
50 | repeated double currents = 6 [packed = true];
51 |
52 | // Temperature of joint actuators. Unit in Celsius degree (C).
53 | repeated double temperature = 7 [packed = true];
54 |
55 | // Proportional gain which determines the joint stiffness.
56 | // High kp could cause instability from overshotting and oscillation.
57 | // Sometimes referred to as position_gain or stiffness.
58 | repeated double kps = 8 [packed = true];
59 |
60 | // Derivative gain which is the damping effects on the joint.
61 | // Increasing kd reduces oscillation.
62 | // Sometimes referred to as velocity_gain or damping.
63 | repeated double kds = 9 [packed = true];
64 |
65 | // Input voltage (bus voltage) to each joint. Unit is Volt.
66 | repeated double bus_voltages = 11 [packed = true];
67 |
68 | // The names of the joints.
69 | // For platforms that publish joint names with other joint data.
70 | // Otherwise, consider using `joint_names` in JointsTrajectory below.
71 | repeated string names = 12;
72 |
73 | reserved 10;
74 | }
75 |
76 | // A list of Joints messages which represent a temporal sequence of states.
77 | message JointsTrajectory {
78 | // The sequence of joint states which represent the trajectory.
79 | repeated Joints points = 1;
80 |
81 | // The time in nanosecond from the point in the trajectory. The length of this
82 | // field should be the same as the length of points. The first element can be
83 | // a placeholder value if not applicable.
84 | repeated int64 time_from_previous_nsec = 2 [packed = true];
85 |
86 | // Optional. The name of the trajectory to distinguish different trajectories
87 | // in the same topic stream.
88 | optional string trajectory_name = 3;
89 |
90 | // Start time is when the first point in the trajectory should take effect.
91 | optional int64 start_time_nsec = 4;
92 |
93 | // Optional. The names of the joints, assuming identical for all points in the
94 | // trajectory.
95 | repeated string joint_names = 5;
96 |
97 | // Optional additional metadata.
98 | message Metadata {
99 | optional DomainTimestamp domain_timestamp = 1;
100 | }
101 | optional Metadata metadata = 6;
102 | }
103 |
--------------------------------------------------------------------------------
/safari_sdk/auth.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Helper functions for connecting to a discovery service endpoint.
16 |
17 | This module provides a helper function to build a discovery service with a given
18 | API key. The API key can be provided via two methods.
19 |
20 | 1. By flag, --api_key="your_api_key"
21 | 2. By file at one of these locations:
22 | $HOME/.config/safari_sdk/API_KEY
23 | /opt/safari_sdk/API_KEY
24 |
25 | The resolution order on which API key value will be used is:
26 |
27 | 1. By flag, "--api_key"
28 | 2. By path, "$HOME/.config/safari_sdk/API_KEY"
29 | 3. By path, "/opt/safari_sdk/API_KEY"
30 |
31 | Here is an example of using this module:
32 |
33 | from safari_sdk import auth
34 |
35 | # This returns a discovery.Resource object.
36 | service = auth.get_service()
37 | """
38 |
39 | import os
40 | from absl import flags
41 |
42 | from googleapiclient import discovery
43 |
44 | import httplib2
45 |
46 | # Flag to manually specify the API key.
47 | _API_KEY = flags.DEFINE_string(
48 | name="api_key",
49 | default=None,
50 | help="API key to use for the Safari API.",
51 | )
52 |
53 | # Fixed paths to search for the API key file.
54 | _API_KEY_FILE_PATHS = [
55 | os.path.join(os.path.expanduser("~"), ".config/safari_sdk/API_KEY"),
56 | "/opt/safari_sdk/API_KEY",
57 | ]
58 |
59 | # Default service name, version, and discovery service URL for connection API.
60 | _DEFAULT_SERVICE_NAME = "roboticsdeveloper"
61 | _DEFAULT_VERSION = "v1"
62 | _DEFAULT_DISCOVERY_SERVICE_URL = (
63 | "https://roboticsdeveloper.googleapis.com/$discovery/rest?version=v1"
64 | )
65 |
66 | # Error message.
67 | _ERROR_NO_API_KEY_PROVIDED = (
68 | "Auth: No API key provided by flag or file."
69 | )
70 | _ERROR_NO_API_KEY_PROVIDED_IN_FILE = (
71 | "Auth: No API key provided in file:"
72 | )
73 |
74 |
75 | def _extract_api_key_from_file(file_path: str) -> str:
76 | """Extracts API key from file."""
77 | with open(file_path, "r") as f:
78 | return f.read().strip()
79 |
80 |
81 | def _build_service(api_key: str) -> discovery.Resource:
82 | """Builds the service."""
83 | http = httplib2.Http(timeout=900) # 15 minutes
84 | return discovery.build(
85 | serviceName=_DEFAULT_SERVICE_NAME,
86 | version=_DEFAULT_VERSION,
87 | discoveryServiceUrl=_DEFAULT_DISCOVERY_SERVICE_URL,
88 | developerKey=api_key,
89 | http=http,
90 | )
91 |
92 |
93 | def get_service() -> discovery.Resource:
94 | """Gets a built discovery service based on flags or fixed file locations.
95 |
96 | The order of resolution precedence for the API key is:
97 | 1. By flag, "--api_key"
98 | 2. By path, "$HOME/.config/safari_sdk/API_KEY"
99 | 3. By path, "/opt/safari_sdk/API_KEY"
100 |
101 | Returns:
102 | The service as a discovery.Resource object.
103 |
104 | Raises:
105 | ValueError: If no API key is provided by flag or file.
106 | """
107 | if _API_KEY.value:
108 | return _build_service(api_key=_API_KEY.value)
109 |
110 | for file_path in _API_KEY_FILE_PATHS:
111 | if os.path.isfile(file_path):
112 | api_key_from_file = _extract_api_key_from_file(file_path)
113 | if not api_key_from_file:
114 | raise ValueError(f"{_ERROR_NO_API_KEY_PROVIDED_IN_FILE} {file_path}")
115 | return _build_service(api_key=api_key_from_file)
116 |
117 | raise ValueError(_ERROR_NO_API_KEY_PROVIDED)
118 |
119 |
120 | def get_api_key() -> str | None:
121 | """Gets the API key based on flags or fixed file locations."""
122 | if _API_KEY.value:
123 | return _API_KEY.value
124 |
125 | for file_path in _API_KEY_FILE_PATHS:
126 | if os.path.isfile(file_path):
127 | api_key_from_file = _extract_api_key_from_file(file_path)
128 | if api_key_from_file:
129 | return api_key_from_file
130 |
131 | # No API key found by flag or file.
132 | return None
133 |
--------------------------------------------------------------------------------
/examples/model/gemini_robotics_policy_example.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Example script to test the GeminiRoboticsPolicy."""
16 |
17 | from collections.abc import Sequence
18 |
19 | from absl import app
20 | from absl import flags
21 | import numpy as np
22 | from safari_sdk.model import gemini_robotics_policy
23 | import tensorflow as tf
24 |
25 |
26 | _SERVE_ID = flags.DEFINE_string(
27 | "serve_id",
28 | None,
29 | "The serve ID to use.",
30 | required=True,
31 | )
32 | _TASK_INSTRUCTION = flags.DEFINE_string(
33 | "task_instruction",
34 | "Pick up the red block.",
35 | "The task instruction for the policy.",
36 | )
37 | _ROBOT_TYPE = flags.DEFINE_enum(
38 | "robot_type",
39 | "aloha",
40 | ["aloha", "atari"],
41 | "The robot type to use.",
42 | )
43 |
44 | # Aloha
45 | _ALOHA_IMAGE_SIZE = (480, 848)
46 | _ALOHA_CAMERAS = {
47 | "overhead_cam": _ALOHA_IMAGE_SIZE,
48 | "worms_eye_cam": _ALOHA_IMAGE_SIZE,
49 | "wrist_cam_left": _ALOHA_IMAGE_SIZE,
50 | "wrist_cam_right": _ALOHA_IMAGE_SIZE,
51 | }
52 | _ALOHA_JOINTS = {"joints_pos": 14}
53 |
54 | # Atari
55 | _ATARI_STEREOLAB_HEADCAM_IMAGE_SIZE = (1200, 1920)
56 | _ATARI_WRISTCAM_IMAGE_SIZE = (480, 640)
57 | _ATARI_CAMERAS = {
58 | "stereolab_headcam0": _ATARI_STEREOLAB_HEADCAM_IMAGE_SIZE,
59 | "left_wrist_cam": _ATARI_WRISTCAM_IMAGE_SIZE,
60 | "right_wrist_cam": _ATARI_WRISTCAM_IMAGE_SIZE,
61 | }
62 | _ATARI_JOINTS = {
63 | "left_arm_joint_pos": 7,
64 | "right_arm_joint_pos": 7,
65 | "left_hand_command": 6,
66 | "right_hand_command": 6,
67 | "neck_joint_pos": 3,
68 | "torso_joint_pos": 3,
69 | }
70 |
71 | _CAMERAS = {"aloha": _ALOHA_CAMERAS, "atari": _ATARI_CAMERAS}
72 | _JOINTS = {"aloha": _ALOHA_JOINTS, "atari": _ATARI_JOINTS}
73 |
74 |
75 | def main(argv: Sequence[str]) -> None:
76 | if len(argv) > 1:
77 | raise app.UsageError("Too many command-line arguments.")
78 |
79 | # Instantiate the policy
80 | try:
81 | policy = gemini_robotics_policy.GeminiRoboticsPolicy(
82 | serve_id=_SERVE_ID.value,
83 | task_instruction=_TASK_INSTRUCTION.value,
84 | cameras=_CAMERAS[_ROBOT_TYPE.value],
85 | joints=_JOINTS[_ROBOT_TYPE.value],
86 | )
87 | policy.setup() # Initialize the policy
88 | print("GeminiRoboticsPolicy initialized successfully.")
89 | except ValueError as e:
90 | print(f"Error initializing policy: {e}")
91 | return
92 | except Exception as e: # pylint: disable=broad-except
93 | print(f"An unexpected error occurred during initialization: {e}")
94 | return
95 |
96 | # Create a dummy observation based on the observation_spec
97 | dummy_observation = {}
98 | for key, spec in policy.observation_spec.items():
99 | if spec.dtype == tf.string:
100 | # Use the provided task instruction for the 'instruction' field
101 | dummy_observation[key] = np.array(_TASK_INSTRUCTION.value, dtype=object)
102 | else:
103 | # Create dummy data (zeros) for other specs (like images)
104 | dummy_observation[key] = np.zeros(
105 | spec.shape, dtype=spec.dtype.as_numpy_dtype
106 | )
107 |
108 | print("\nCreated dummy observation:")
109 | for key, value in dummy_observation.items():
110 | print(f" {key}: shape={value.shape}, dtype={value.dtype}")
111 |
112 | # Run 100 steps
113 | try:
114 | for i in range(100):
115 | print(f"\nCalling policy.step() step {i}...")
116 | action = policy.step(dummy_observation)
117 | print("\nReceived action from policy:")
118 | print(action)
119 | except Exception as e: # pylint: disable=broad-except
120 | # Catch broad exceptions as API calls can fail in various ways
121 | print(f"\nAn error occurred during policy.step(): {e}")
122 | print("Please check your API key, serve ID, and network connection.")
123 |
124 |
125 | if __name__ == "__main__":
126 | app.run(main)
127 |
--------------------------------------------------------------------------------
/safari_sdk/orchestrator/client/libs/robot_job.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Robot job APIs interacting with the orchestrator server."""
16 |
17 | import enum
18 | import json
19 |
20 | from googleapiclient import discovery
21 | from googleapiclient import errors
22 |
23 | from safari_sdk.orchestrator.client.dataclass import api_response
24 | from safari_sdk.orchestrator.client.dataclass import robot_job
25 |
26 | _RESPONSE = api_response.OrchestratorAPIResponse
27 |
28 | _ERROR_ROBOT_JOB_NOT_ACQUIRED = (
29 | "OrchestratorRobotJob: No active robot job. Please call request_robot_job()"
30 | " first."
31 | )
32 | _ERROR_NO_ORCHESTRATOR_CONNECTION = (
33 | "OrchestratorRobotJob: Orchestrator connection is invalid."
34 | )
35 | _ERROR_GET_ROBOT_JOB = "OrchestratorRobotJob: Error in requesting robot job.\n"
36 |
37 | _ERROR_EMPTY_RESPONSE = (
38 | "OrchestratorRobotJob: Received empty response for robot job request."
39 | )
40 | _ERROR_EMPTY_ROBOT_JOB_ID = (
41 | "OrchestratorRobotJob: Received empty robot job ID in response for robot"
42 | " job request."
43 | )
44 |
45 |
46 | class JobType(enum.Enum):
47 | """Type of robot job."""
48 | # This "ALL" enum value is an unique usage, where it maps to the default proto
49 | # value of "UNSPECIFIED" in Orchestrator, which is treated as "all job types".
50 | ALL = 0 # All job types.
51 | COLLECTION = 1 # Collection job only.
52 | EVALUATION = 2 # Evaluation job only.
53 |
54 |
55 | class OrchestratorRobotJob:
56 | """Robot job API client for interacting with the orchestrator server."""
57 |
58 | def __init__(
59 | self,
60 | *,
61 | connection: discovery.Resource,
62 | robot_id: str,
63 | job_type: JobType,
64 | ):
65 | """Initializes the robot job handler."""
66 | self._connection = connection
67 | self._robot_id = robot_id
68 | self._job_type = job_type
69 |
70 | self._current_robot_job: robot_job.RobotJob | None = None
71 |
72 | def disconnect(self) -> None:
73 | """Clears current connection to the orchestrator server."""
74 | self._connection = None
75 |
76 | def get_current_robot_job(self) -> _RESPONSE:
77 | """Gets the current robot job."""
78 | if self._current_robot_job is None:
79 | return _RESPONSE(error_message=_ERROR_ROBOT_JOB_NOT_ACQUIRED)
80 | else:
81 | return _RESPONSE(success=True, robot_job=self._current_robot_job)
82 |
83 | def request_robot_job(self) -> _RESPONSE:
84 | """Request orchestrator server for next available robot job to execute."""
85 | if self._connection is None:
86 | return _RESPONSE(error_message=_ERROR_NO_ORCHESTRATOR_CONNECTION)
87 |
88 | body = {
89 | "robot_id": self._robot_id,
90 | "type": self._job_type.value,
91 | }
92 |
93 | try:
94 | response = (
95 | self._connection.orchestrator().allocateRobotJob(body=body).execute()
96 | )
97 | except errors.HttpError as e:
98 | return _RESPONSE(
99 | error_message=(
100 | _ERROR_GET_ROBOT_JOB
101 | + f"Reason: {e.reason}\nDetail: {e.error_details}"
102 | )
103 | )
104 |
105 | as_json = json.dumps(response)
106 | self._current_robot_job = robot_job.RobotJobResponse.from_json(as_json)
107 |
108 | if not self._current_robot_job:
109 | self._current_robot_job = None
110 | return _RESPONSE(error_message=_ERROR_EMPTY_RESPONSE)
111 |
112 | if self._current_robot_job.robotJob.robotJobId is None:
113 | self._current_robot_job = None
114 | return _RESPONSE(
115 | success=True,
116 | no_more_robot_job=True,
117 | error_message=_ERROR_EMPTY_RESPONSE
118 | )
119 |
120 | if not self._current_robot_job.robotJob.robotJobId:
121 | self._current_robot_job = None
122 | return _RESPONSE(error_message=_ERROR_EMPTY_ROBOT_JOB_ID)
123 |
124 | self._current_robot_job = self._current_robot_job.robotJob
125 | return _RESPONSE(
126 | success=True,
127 | robot_id=self._robot_id,
128 | robot_job_id=self._current_robot_job.robotJobId,
129 | )
130 |
--------------------------------------------------------------------------------
/examples/logging/lerobot_data_conversion_script.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """A script to convert LeRobot datasets to MCAP format."""
16 |
17 | from collections.abc import Sequence
18 | import datetime
19 | import os
20 | import re
21 | import shutil
22 |
23 | from absl import app
24 | from absl import flags
25 | from absl import logging
26 | from lerobot.datasets.lerobot_dataset import LeRobotDataset
27 |
28 | from safari_sdk.logging.python import mcap_lerobot_logger
29 |
30 | _DATASET_NAME = flags.DEFINE_string(
31 | 'lerobot_dataset_name',
32 | default=None,
33 | help=(
34 | 'Name of the LeRobot dataset to load. e.g. '
35 | 'lerobot/aloha_static_cups_open'
36 | ),
37 | required=True,
38 | )
39 | _TASK_ID = flags.DEFINE_string(
40 | 'task_id',
41 | 'lerobot_test_task',
42 | 'Task ID for the logger, used to identify data for later finetuning.',
43 | )
44 |
45 | _EPISODE_START_TIME_NS = flags.DEFINE_multi_string(
46 | 'episode_start_time_ns',
47 | default=None,
48 | help=(
49 | 'Start time of the episode, in nanoseconds since UNIX epoch,'
50 | ' in the format '
51 | ':. It '
52 | 'can be specified multiple times.'
53 | ),
54 | required=True,
55 | )
56 |
57 | _OUTPUT_DIRECTORY = flags.DEFINE_string(
58 | 'output_directory',
59 | '/tmp/converted_lerobot_log',
60 | 'Directory to save MCAP files.',
61 | )
62 |
63 | _NUM_EPISODES = flags.DEFINE_integer(
64 | 'num_episodes',
65 | 0,
66 | 'Number of episodes to process. Default value 0 means all episodes.',
67 | )
68 |
69 |
70 | _PROPRIO_KEY = flags.DEFINE_string(
71 | 'proprio_key',
72 | 'state',
73 | 'The key of the proprio data in the observation.',
74 | )
75 |
76 | _MAX_WORKERS = flags.DEFINE_integer(
77 | 'max_workers',
78 | 200,
79 | 'Maximum number of threads for parallel processing and logging.',
80 | )
81 |
82 |
83 | def validate_episode_start_time_ns_format(values):
84 | seconds_in_a_year = datetime.timedelta(days=365).total_seconds()
85 | nanoseconds_per_second = 1e9
86 | earliest_expected_time = (
87 | (2000 - 1970) * seconds_in_a_year * nanoseconds_per_second
88 | )
89 |
90 | def get_time(text):
91 | return int(text.split(':')[1])
92 |
93 | return all(re.fullmatch(r'[0-9]+:[0-9]+', value) for value in values) and all(
94 | get_time(value) > earliest_expected_time for value in values
95 | )
96 |
97 |
98 | flags.register_validator(
99 | 'episode_start_time_ns',
100 | validate_episode_start_time_ns_format,
101 | message=(
102 | 'episode_start_time_ns must be specified in '
103 | 'the format :, and '
104 | 'the time needs to be in _nanoseconds_.'
105 | ),
106 | )
107 |
108 |
109 | def main(argv: Sequence[str]) -> None:
110 | if len(argv) > 1:
111 | raise app.UsageError('Too many command-line arguments.')
112 |
113 | dataset_folder_name = _DATASET_NAME.value.split('/')[-1]
114 | output_directory = os.path.join(_OUTPUT_DIRECTORY.value, dataset_folder_name)
115 |
116 | if os.path.exists(output_directory):
117 | shutil.rmtree(output_directory)
118 | os.makedirs(output_directory, exist_ok=True)
119 |
120 | logging.info('Logs will be written to: %s', output_directory)
121 |
122 | logging.info('--- Loading and processing from "%s" ---', _DATASET_NAME.value)
123 | dataset = LeRobotDataset(_DATASET_NAME.value)
124 |
125 | episode_start_time_ns = dict(
126 | [map(int, item.split(':')) for item in _EPISODE_START_TIME_NS.value]
127 | )
128 |
129 | mcap_lerobot_logger.convert_lerobot_data_to_mcap(
130 | dataset=dataset,
131 | task_id=_TASK_ID.value,
132 | output_directory=output_directory,
133 | proprioceptive_observation_keys=[_PROPRIO_KEY.value],
134 | episodes_limit=_NUM_EPISODES.value,
135 | max_workers=_MAX_WORKERS.value,
136 | episode_start_timestamps_ns=episode_start_time_ns,
137 | )
138 |
139 | logging.info('\n--- Script finished successfully!')
140 |
141 |
142 | if __name__ == '__main__':
143 | app.run(main)
144 |
--------------------------------------------------------------------------------
/examples/model/genai_robotics_aloha_example.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Example of using genai_robotics.py."""
16 |
17 | from collections.abc import Sequence
18 | import json
19 | import time
20 |
21 | from absl import app
22 | from absl import flags
23 | # import cv2
24 | import numpy as np
25 |
26 | from safari_sdk.model import constants
27 | from safari_sdk.model import genai_robotics
28 |
29 | _CONNECTION = constants.RoboticsApiConnectionType
30 |
31 | _SERVE_ID = flags.DEFINE_string(
32 | "serve_id",
33 | None,
34 | "The ID of the model to use. Required for Cloud-based inference.",
35 | )
36 | _ROBOTICS_API_CONNECTION = flags.DEFINE_enum_class(
37 | "robotics_api_connection",
38 | _CONNECTION.LOCAL,
39 | _CONNECTION,
40 | "The robotics API connection type to use.",
41 | )
42 | _SERVER_BASE_URL = flags.DEFINE_string(
43 | "server_base_url",
44 | None,
45 | "The server URL to use. None means use the default.",
46 | )
47 |
48 |
49 | def main(argv: Sequence[str]) -> None:
50 | if len(argv) > 1:
51 | raise app.UsageError("Too many command-line arguments.")
52 |
53 | # 1. Initialize the client
54 | http_options = (
55 | genai_robotics.types.HttpOptions(base_url=_SERVER_BASE_URL.value)
56 | if _SERVER_BASE_URL.value
57 | else None
58 | )
59 | robotics_api_connection = _CONNECTION(_ROBOTICS_API_CONNECTION.value)
60 | client = genai_robotics.Client(
61 | robotics_api_connection=robotics_api_connection,
62 | http_options=http_options,
63 | )
64 |
65 | # 2. Prepare sample input data
66 | # Gemini Robotics
67 | test_img_gemini_robotics = np.random.randint(
68 | 0, 255, (480, 848, 3), dtype=np.uint8
69 | )
70 | # Gemini Robotics Nano
71 | test_img_robotics_nano = np.random.randint(
72 | 0, 255, (224, 224, 3), dtype=np.uint8
73 | )
74 |
75 | obs = {
76 | "images/overhead_cam": 0,
77 | "images/wrist_cam_left": 1,
78 | "images/wrist_cam_right": 2,
79 | "images/worms_eye_cam": 3,
80 | "task_instruction": "make a fox shaped origami",
81 | # "joints_pos": np.random.randn(14).astype(np.float32).tolist(),
82 | "joints_pos": [-np.inf] * 14,
83 | }
84 | obs_json = json.dumps(obs)
85 |
86 | match _ROBOTICS_API_CONNECTION.value:
87 | case _CONNECTION.CLOUD | _CONNECTION.CLOUD_GENAI:
88 | content = [
89 | test_img_gemini_robotics,
90 | test_img_gemini_robotics,
91 | test_img_gemini_robotics,
92 | test_img_gemini_robotics,
93 | obs_json,
94 | ]
95 | if _ROBOTICS_API_CONNECTION.value == _CONNECTION.CLOUD_GENAI:
96 | content = genai_robotics.update_robotics_content_to_genai_format(
97 | contents=content
98 | )
99 | case _CONNECTION.LOCAL:
100 | content = [
101 | test_img_robotics_nano,
102 | test_img_robotics_nano,
103 | test_img_robotics_nano,
104 | test_img_robotics_nano,
105 | obs_json,
106 | ]
107 | case _:
108 | raise ValueError(
109 | "Unsupported robotics_api_connection:"
110 | f" {_ROBOTICS_API_CONNECTION.value}."
111 | )
112 |
113 | # 3. Call 20 times and print the average time
114 | num_calls = 20
115 | print(f"Calling model {num_calls} times...")
116 | times = []
117 | for _ in range(num_calls):
118 | start = time.time()
119 | response = client.models.generate_content(
120 | model=_SERVE_ID.value,
121 | contents=content,
122 | )
123 | match _ROBOTICS_API_CONNECTION.value:
124 | case _CONNECTION.CLOUD_GENAI:
125 | action_chunk = np.array(
126 | json.loads(
127 | response.candidates[0].content.parts[0].inline_data.data
128 | )["action_chunk"]
129 | )[0]
130 | print("action_chunk: ", action_chunk[0])
131 | case _:
132 | print(response.text)
133 | end = time.time()
134 | times.append(end - start)
135 | print("Inference time (s): ", end - start)
136 | del response
137 | print("times: ", times)
138 | print("Average time:", np.mean(times[10:])) # Skip the first 10 calls.
139 |
140 |
141 | if __name__ == "__main__":
142 | app.run(main)
143 |
--------------------------------------------------------------------------------
/examples/logging/sample_data_upload_script.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Sample data upload script.
16 |
17 | Google Deepmind Robotics team has set up an endpoint to upload data. Each log
18 | file is uploaded as a single HTTP POST request.
19 | """
20 |
21 | import datetime
22 | import json
23 | import os
24 | import time
25 |
26 | from absl import app
27 | from absl import flags
28 | import pytz
29 | import requests
30 |
31 |
32 | _API_ENDPOINT = flags.DEFINE_string(
33 | 'api_endpoint',
34 | 'https://roboticsdeveloper.googleapis.com/upload/v1/dataIngestion:uploadData',
35 | 'Data ingestion service endpoint.',
36 | )
37 | _AGENT_ID = flags.DEFINE_string(
38 | 'agent_id',
39 | None,
40 | 'Typically the identifier of the robot or human collector. Alphanumeric '
41 | 'and fewer than 60 characters.',
42 | required=True,
43 | )
44 | _API_KEY = flags.DEFINE_string(
45 | 'api_key',
46 | None,
47 | 'Api key to call the data ingestion service, please contact Google '
48 | 'Deepmind Robotics team for this',
49 | required=True,
50 | )
51 | _DATA_DIRECTORY = flags.DEFINE_string(
52 | 'data_dir',
53 | None,
54 | 'Directory where the data files are stored.',
55 | required=True,
56 | )
57 |
58 |
59 | def upload(
60 | *,
61 | api_endpoint,
62 | agent_id,
63 | filename,
64 | file_content_bytes,
65 | api_key,
66 | now,
67 | ):
68 | """Calls the data ingestion service to upload the file."""
69 |
70 | def to_multi_part(metadata, body, ct):
71 | """Returns a multi-part request for the metadata and body."""
72 | boundary_ = b'BOUNDARY'
73 | data_ct = b'Content-Type: application/octet-stream'
74 | payload = b''.join([
75 | b'--',
76 | boundary_,
77 | b'\r\n',
78 | data_ct,
79 | b'\r\n\r\n',
80 | metadata,
81 | b'\r\n--',
82 | boundary_,
83 | b'\r\n',
84 | data_ct,
85 | b'\r\n\r\n',
86 | body,
87 | b'\r\n--',
88 | boundary_,
89 | b'--\r\n',
90 | ])
91 | headers = {
92 | 'X-Goog-Upload-Protocol': 'multipart',
93 | 'X-Goog-Upload-Header-Content-Type': ct.decode('utf-8'),
94 | 'Content-Type': 'multipart/related; boundary=%s' % boundary_.decode(
95 | 'utf-8'
96 | ),
97 | }
98 | return headers, payload
99 |
100 | request_dict = {
101 | 'date': {'year': now.year, 'month': now.month, 'day': now.day},
102 | 'agentId': agent_id,
103 | 'filename': filename,
104 | }
105 | headers, body = to_multi_part(
106 | json.dumps(request_dict).encode(), file_content_bytes, b'text/plain'
107 | )
108 | r = requests.post(
109 | api_endpoint,
110 | params={'key': api_key},
111 | headers=headers,
112 | data=body,
113 | )
114 | return (r.status_code, r.reason)
115 |
116 |
117 | def main(_):
118 |
119 | def walk_and_upload(directory):
120 | for root, dirs, files in os.walk(directory):
121 | del dirs
122 | for file in files:
123 | if file.endswith('.mcap'):
124 | file_path = os.path.join(root, file)
125 |
126 | with open(file_path, 'rb') as f:
127 | file_content_bytes = f.read()
128 | file_size_mb = len(file_content_bytes) / (1024 * 1024)
129 |
130 | t_start = time.time()
131 | status_code, reason = upload(
132 | api_endpoint=_API_ENDPOINT.value,
133 | agent_id=_AGENT_ID.value,
134 | filename=file,
135 | file_content_bytes=file_content_bytes,
136 | api_key=_API_KEY.value,
137 | now=datetime.datetime.now(pytz.timezone('America/Los_Angeles')),
138 | )
139 | t_end = time.time()
140 |
141 | if status_code == 200:
142 | uploaded_file_path = os.path.splitext(file_path)[0] + '.uploaded'
143 | os.rename(file_path, uploaded_file_path)
144 |
145 | upload_speed_mb_s = file_size_mb / (t_end - t_start)
146 | print(
147 | f'Uploaded {file} ({file_size_mb:.2f} MB) and renamed to'
148 | f' {uploaded_file_path} in {t_end - t_start:.2f}s'
149 | f' ({upload_speed_mb_s:.2f} MB/s)'
150 | )
151 | else:
152 | print(f'Failed to upload {file} ({file_size_mb:.2f} MB): {reason}')
153 |
154 | walk_and_upload(_DATA_DIRECTORY.value)
155 |
156 |
157 | if __name__ == '__main__':
158 | app.run(main)
159 |
--------------------------------------------------------------------------------
/safari_sdk/orchestrator/client/libs/robot_job_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Unit tests for robot_job.py."""
16 |
17 | from unittest import mock
18 |
19 | from absl.testing import absltest
20 | from googleapiclient import errors
21 |
22 | from safari_sdk.orchestrator.client.libs import robot_job
23 |
24 |
25 | class RobotJobTest(absltest.TestCase):
26 |
27 | def test_get_current_robot_job_good(self):
28 | mock_connection = mock.MagicMock()
29 | robot_job_lib = robot_job.OrchestratorRobotJob(
30 | connection=mock_connection,
31 | robot_id="test_robot_id",
32 | job_type=robot_job.JobType.ALL,
33 | )
34 | robot_job_lib._current_robot_job = mock.MagicMock(
35 | spec=robot_job.robot_job.RobotJob
36 | )
37 |
38 | response = robot_job_lib.get_current_robot_job()
39 | self.assertTrue(response.success)
40 | self.assertIsInstance(response.robot_job, robot_job.robot_job.RobotJob)
41 |
42 | def test_get_current_robot_job_bad(self):
43 | mock_connection = mock.MagicMock()
44 | robot_job_lib = robot_job.OrchestratorRobotJob(
45 | connection=mock_connection,
46 | robot_id="test_robot_id",
47 | job_type=robot_job.JobType.COLLECTION,
48 | )
49 |
50 | response = robot_job_lib.get_current_robot_job()
51 | self.assertFalse(response.success)
52 | self.assertIn(
53 | robot_job._ERROR_ROBOT_JOB_NOT_ACQUIRED,
54 | response.error_message,
55 | )
56 |
57 | def test_request_robot_job_good(self):
58 |
59 | mock_connection = mock.MagicMock()
60 | mock_connection.orchestrator().allocateRobotJob().execute.return_value = {
61 | "robotJob": {"robotJobId": "test_robot_job_id"}
62 | }
63 |
64 | robot_job_lib = robot_job.OrchestratorRobotJob(
65 | connection=mock_connection,
66 | robot_id="test_robot_id",
67 | job_type=robot_job.JobType.EVALUATION,
68 | )
69 | response = robot_job_lib.request_robot_job()
70 | self.assertTrue(response.success)
71 | self.assertEqual(response.robot_id, "test_robot_id")
72 | self.assertEqual(response.robot_job_id, "test_robot_job_id")
73 |
74 | def test_request_robot_job_bad_server_call(self):
75 |
76 | class MockHttpError:
77 |
78 | def __init__(self):
79 | self.status = "Mock status"
80 | self.reason = "Mock reason"
81 | self.error_details = "Mock error details"
82 |
83 | def raise_error_side_effect():
84 | raise errors.HttpError(MockHttpError(), "Mock failed HTTP call.".encode())
85 |
86 | mock_connection = mock.MagicMock()
87 | mock_connection.orchestrator().allocateRobotJob().execute.side_effect = (
88 | raise_error_side_effect
89 | )
90 |
91 | robot_job_lib = robot_job.OrchestratorRobotJob(
92 | connection=mock_connection,
93 | robot_id="test_robot_id",
94 | job_type=robot_job.JobType.ALL,
95 | )
96 | response = robot_job_lib.request_robot_job()
97 |
98 | self.assertFalse(response.success)
99 | self.assertIn(robot_job._ERROR_GET_ROBOT_JOB, response.error_message)
100 |
101 | def test_request_robot_job_no_more_robot_job(self):
102 |
103 | mock_connection = mock.MagicMock()
104 | mock_connection.orchestrator().allocateRobotJob().execute.return_value = {
105 | "robotJob": {}
106 | }
107 |
108 | robot_job_lib = robot_job.OrchestratorRobotJob(
109 | connection=mock_connection,
110 | robot_id="test_robot_id",
111 | job_type=robot_job.JobType.ALL,
112 | )
113 | response = robot_job_lib.request_robot_job()
114 | self.assertTrue(response.success)
115 | self.assertTrue(response.no_more_robot_job)
116 | self.assertEqual(response.error_message, robot_job._ERROR_EMPTY_RESPONSE)
117 |
118 | def test_request_robot_job_bad_response_robot_job_id(self):
119 |
120 | mock_connection = mock.MagicMock()
121 | mock_connection.orchestrator().allocateRobotJob().execute.return_value = {
122 | "robotJob": {"robotJobId": ""}
123 | }
124 |
125 | robot_job_lib = robot_job.OrchestratorRobotJob(
126 | connection=mock_connection,
127 | robot_id="test_robot_id",
128 | job_type=robot_job.JobType.ALL,
129 | )
130 | response = robot_job_lib.request_robot_job()
131 | self.assertFalse(response.success)
132 | self.assertEqual(
133 | response.error_message, robot_job._ERROR_EMPTY_ROBOT_JOB_ID
134 | )
135 |
136 |
137 | if __name__ == "__main__":
138 | absltest.main()
139 |
--------------------------------------------------------------------------------
/safari_sdk/orchestrator/client/dataclass/visual_overlay_icon.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Visual overlay icon information for overlay rendering."""
16 |
17 | import dataclasses
18 |
19 |
20 | @dataclasses.dataclass
21 | class DrawCircleIcon:
22 | """Required information for visual overlay renderer to draw a circle.
23 |
24 | Attributes:
25 | object_id: The object id of the overlay object.
26 | overlay_text_label: The overlay text label of the overlay object.
27 | rgb_hex_color_value: The rgb hex color value of the overlay object.
28 | layer_order: The layer order of the overlay object.
29 | x: The x pixel coordinate of the circle's center.
30 | y: The y pixel coordinate of the cicrle's center.
31 | """
32 |
33 | object_id: str
34 | overlay_text_label: str
35 | rgb_hex_color_value: str
36 | layer_order: int
37 | x: int
38 | y: int
39 |
40 |
41 | @dataclasses.dataclass
42 | class DrawArrowIcon:
43 | """Required information for visual overlay renderer to draw an arrow.
44 |
45 | Attributes:
46 | object_id: The object id of the overlay object.
47 | overlay_text_label: The overlay text label of the overlay object.
48 | rgb_hex_color_value: The rgb hex color value of the overlay object.
49 | layer_order: The layer order of the overlay object.
50 | x: The x pixel coordinate of the arrow's head.
51 | y: The y pixel coordinate of the arrow's head.
52 | rad: The direction of the arrow in radians. Radian of 0 is right, pi/2 is
53 | up, pi or -pi is left, and -pi/2 is down.
54 | """
55 |
56 | object_id: str
57 | overlay_text_label: str
58 | rgb_hex_color_value: str
59 | layer_order: int
60 | x: int
61 | y: int
62 | rad: float
63 |
64 |
65 | @dataclasses.dataclass
66 | class DrawSquareIcon:
67 | """Required information for visual overlay renderer to draw a square.
68 |
69 | Attributes:
70 | object_id: The object id of the overlay object.
71 | overlay_text_label: The overlay text label of the overlay object.
72 | rgb_hex_color_value: The rgb hex color value of the overlay object.
73 | layer_order: The layer order of the overlay object.
74 | x: The x pixel coordinate of the square's center.
75 | y: The y pixel coordinate of the square's center.
76 | """
77 |
78 | object_id: str
79 | overlay_text_label: str
80 | rgb_hex_color_value: str
81 | layer_order: int
82 | x: int
83 | y: int
84 |
85 |
86 | @dataclasses.dataclass
87 | class DrawTriangleIcon:
88 | """Required information for visual overlay renderer to draw a triangle.
89 |
90 | Attributes:
91 | object_id: The object id of the overlay object.
92 | overlay_text_label: The overlay text label of the overlay object.
93 | rgb_hex_color_value: The rgb hex color value of the overlay object.
94 | layer_order: The layer order of the overlay object.
95 | x: The x pixel coordinate of the triangle's center.
96 | y: The y pixel coordinate of the triangle's center.
97 | """
98 |
99 | object_id: str
100 | overlay_text_label: str
101 | rgb_hex_color_value: str
102 | layer_order: int
103 | x: int
104 | y: int
105 |
106 |
107 | @dataclasses.dataclass
108 | class DrawContainer:
109 | """Required information for visual overlay renderer to draw a container.
110 |
111 | Attributes:
112 | object_id: The object id of the overlay object.
113 | overlay_text_label: The overlay text label of the overlay object.
114 | rgb_hex_color_value: The rgb hex color value of the overlay object.
115 | layer_order: The layer order of the overlay object.
116 | x: The x pixel coordinate of the container. If used with "w" and "h", this
117 | is the x pixel coordinate of the container's top left corner. If used with
118 | "radius", this is the x pixel coordinate of the container's center.
119 | y: The y pixel coordinate of the container. If used with "w" and "h", this
120 | is the y pixel coordinate of the container's top left corner. If used with
121 | "radius", this is the y pixel coordinate of the container's center.
122 | w: The width of the container. If used, then this indicates that this
123 | container is a rectangle and "h" must also be specified.
124 | h: The height of the container. If used, then this indicates that this
125 | container is a rectangle and "w" must also be specified.
126 | radius: The radius of the container. If used, then this indicates that this
127 | container is a circle.
128 | """
129 |
130 | object_id: str
131 | overlay_text_label: str
132 | rgb_hex_color_value: str
133 | layer_order: int
134 | x: int
135 | y: int
136 | w: int | None = None
137 | h: int | None = None
138 | radius: int | None = None
139 |
--------------------------------------------------------------------------------
/safari_sdk/orchestrator/example_client_sdk_robot_and_operator_info.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Example of using Orchestrator client SDK for robot and operator info.
16 |
17 | For more details on how to use the Orchestrator client SDK, please refer to the
18 | docstring of the main integration example at:
19 | orchestrator/example_client_sdk_integration.py
20 |
21 | For more details on each of the Orchestrator client SDK API methods, please
22 | refer to the docstring of the helper file itself:
23 | orchestrator/helpers/orchestrator_helper.py.
24 | """
25 |
26 | from collections.abc import Sequence
27 | import time
28 |
29 | from absl import app
30 | from absl import flags
31 |
32 | from safari_sdk.orchestrator.helpers import orchestrator_helper
33 |
34 | # Required flags.
35 | _ROBOT_ID = flags.DEFINE_string(
36 | name="robot_id",
37 | default=None,
38 | help="This robot's ID.",
39 | required=True,
40 | )
41 |
42 | _JOB_TYPE = flags.DEFINE_enum_class(
43 | name="job_type",
44 | default=orchestrator_helper.JOB_TYPE.ALL,
45 | enum_class=orchestrator_helper.JOB_TYPE,
46 | help="Type of job to run.",
47 | )
48 |
49 | # The flags below are optional.
50 | _RAISE_ERROR = flags.DEFINE_bool(
51 | "raise_error",
52 | default=False,
53 | help=(
54 | "Whether to raise the error as an exception or just show it as a"
55 | " messsage. Default = False."
56 | ),
57 | )
58 |
59 |
60 | def _print_robot_info_response(response: orchestrator_helper.RESPONSE) -> None:
61 | """Prints out details of the current robot information."""
62 | print("\n - Current robot information -")
63 |
64 | print(" ----------------------------------------------------------------\n")
65 | print(f" Robot ID: {response.robot_id}")
66 | print(f" Is operational: {response.is_operational}\n")
67 | print(f" Operator ID: {response.operator_id}\n")
68 | print(f" Robot job ID: {response.robot_job_id}")
69 | print(f" Work unit ID: {response.work_unit_id}")
70 | print(f" Work unit stage: {response.work_unit_stage}")
71 | print(" ----------------------------------------------------------------\n")
72 |
73 |
74 | def run_example(
75 | orchestrator_client: orchestrator_helper.OrchestratorHelper,
76 | ) -> None:
77 | """Runs mock eval loop."""
78 |
79 | print(" - Getting current robot info -\n")
80 | response = orchestrator_client.get_current_robot_info()
81 | if not response.success:
82 | print(f"\n - ERROR: {response.error_message} -\n")
83 | return
84 |
85 | _print_robot_info_response(response=response)
86 | time.sleep(1)
87 |
88 | print(" - Setting operator ID to: 'test_operator_id' -\n")
89 | response = orchestrator_client.set_current_robot_operator_id(
90 | operator_id="test_operator_id"
91 | )
92 | if not response.success:
93 | print(f"\n - ERROR: {response.error_message} -\n")
94 | return
95 |
96 | print(" - Getting current robot info again to verify operator ID -\n")
97 | response = orchestrator_client.get_current_robot_info()
98 | if not response.success:
99 | print(f"\n - ERROR: {response.error_message} -\n")
100 | return
101 |
102 | _print_robot_info_response(response=response)
103 | time.sleep(1)
104 |
105 | print(" - Clearing operator ID field in robot information -\n")
106 | response = orchestrator_client.set_current_robot_operator_id(operator_id="")
107 | if not response.success:
108 | print(f"\n - ERROR: {response.error_message} -\n")
109 | return
110 |
111 | print(" - Getting current robot info again to verify no operator ID -\n")
112 | response = orchestrator_client.get_current_robot_info()
113 | if not response.success:
114 | print(f"\n - ERROR: {response.error_message} -\n")
115 | return
116 |
117 | _print_robot_info_response(response=response)
118 |
119 |
120 | def main(argv: Sequence[str]) -> None:
121 | if len(argv) > 1:
122 | raise app.UsageError("Too many command-line arguments.")
123 |
124 | print(" - Initializing and connecting to orchestrator -\n")
125 | orchestrator_client = orchestrator_helper.OrchestratorHelper(
126 | robot_id=_ROBOT_ID.value,
127 | job_type=_JOB_TYPE.value,
128 | raise_error=_RAISE_ERROR.value,
129 | )
130 | response = orchestrator_client.connect()
131 | if not response.success:
132 | print(f"\n - ERROR: {response.error_message} -\n")
133 | return
134 |
135 | print(" - Running example of getting robot info and setting operator ID -\n")
136 | run_example(orchestrator_client=orchestrator_client)
137 |
138 | print(" - Disconnecting from orchestrator -\n")
139 | orchestrator_client.disconnect()
140 |
141 | print(" - Example run completed -\n")
142 |
143 |
144 | if __name__ == "__main__":
145 | app.run(main)
146 |
--------------------------------------------------------------------------------
/safari_sdk/model/genai_robotics_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import base64
16 | import json
17 | from unittest import mock
18 |
19 | from absl import flags
20 | from absl.testing import absltest
21 | from google.genai import types
22 | import numpy as np
23 | import tensorflow as tf
24 |
25 | from safari_sdk.model import genai_robotics
26 |
27 | FLAGS = flags.FLAGS
28 | FLAGS.mark_as_parsed()
29 |
30 |
31 | class GenaiRoboticsTest(absltest.TestCase):
32 |
33 | def test_robotics_api_create_client(self):
34 | with mock.patch("googleapiclient.discovery.build") as mock_build:
35 | mock_service = mock.Mock()
36 | mock_build.return_value = mock_service
37 | FLAGS.api_key = "test_api_key"
38 |
39 | client = genai_robotics.Client(
40 | use_robotics_api=True,
41 | )
42 | self.assertIsNotNone(client)
43 | mock_build.assert_called_once_with(
44 | serviceName=genai_robotics.auth._DEFAULT_SERVICE_NAME,
45 | version=genai_robotics.auth._DEFAULT_VERSION,
46 | discoveryServiceUrl=(
47 | genai_robotics.auth._DEFAULT_DISCOVERY_SERVICE_URL
48 | ),
49 | developerKey="test_api_key",
50 | http=mock.ANY,
51 | )
52 |
53 | def test_robotics_api_generate_content(self):
54 | with mock.patch("googleapiclient.discovery.build") as mock_build:
55 | mock_service = mock.Mock()
56 | mock_build.return_value = mock_service
57 | FLAGS.api_key = "test_api_key"
58 |
59 | client = genai_robotics.Client(
60 | use_robotics_api=True,
61 | )
62 | image = np.zeros((100, 100, 3), dtype=np.uint8)
63 | image_bytes = tf.io.encode_jpeg(image).numpy()
64 | expected_output = {"action_chunk": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6]}
65 |
66 | mock_cm_custom = mock_service.modelServing.return_value.cmCustom
67 | mock_cm_custom.return_value.execute.return_value = {
68 | "outputBytes": (
69 | base64.b64encode(
70 | json.dumps(expected_output).encode("utf-8")
71 | ).decode("utf-8")
72 | ),
73 | "someOtherKey": "some_other_value",
74 | }
75 |
76 | obs = {
77 | "images/overhead_cam": 0,
78 | "task_instruction": "test_task_instruction",
79 | "joints_pos": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
80 | }
81 |
82 | response = client.models.generate_content(
83 | model="test_model",
84 | contents=[
85 | types.Part.from_bytes(data=image_bytes, mime_type="image/jpeg"),
86 | json.dumps(obs),
87 | ],
88 | )
89 | self.assertEqual(response.text, json.dumps(expected_output))
90 | mock_cm_custom.assert_called_once()
91 | call_body = mock_cm_custom.call_args.kwargs["body"]
92 | self.assertEqual(call_body["modelId"], "test_model")
93 | self.assertEqual(call_body["methodName"], "sample_actions_json_flat")
94 | self.assertIsInstance(call_body["requestId"], int)
95 | query = json.loads(
96 | base64.b64decode(call_body["inputBytes"]).decode("utf-8")
97 | )
98 | self.assertEqual(
99 | query["images/overhead_cam"],
100 | base64.b64encode(image_bytes).decode("utf-8"),
101 | )
102 | self.assertEqual(query["task_instruction"], "test_task_instruction")
103 | self.assertEqual(query["joints_pos"], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
104 |
105 | def test_genai_create_client_via_auth_library(self):
106 | with mock.patch("google.genai.Client") as mock_genai_client:
107 | mock_client = mock.Mock()
108 | mock_genai_client.return_value = mock_client
109 | FLAGS.api_key = "test_api_key"
110 |
111 | client = genai_robotics.Client(
112 | robotics_api_connection=genai_robotics.constants.RoboticsApiConnectionType.CLOUD_GENAI,
113 | project="test_project"
114 | )
115 | self.assertIsNotNone(client)
116 | mock_genai_client.assert_called_once_with(
117 | api_key="test_api_key", project="test_project"
118 | )
119 |
120 | def test_genai_create_client_via_param(self):
121 | with mock.patch("google.genai.Client") as mock_genai_client:
122 | mock_client = mock.Mock()
123 | mock_genai_client.return_value = mock_client
124 | FLAGS.api_key = None
125 |
126 | client = genai_robotics.Client(
127 | robotics_api_connection=genai_robotics.constants.RoboticsApiConnectionType.CLOUD_GENAI,
128 | api_key="test_api_key",
129 | project="test_project",
130 | )
131 | self.assertIsNotNone(client)
132 | mock_genai_client.assert_called_once_with(
133 | api_key="test_api_key", project="test_project"
134 | )
135 |
136 |
137 | if __name__ == "__main__":
138 | absltest.main()
139 |
--------------------------------------------------------------------------------
/safari_sdk/orchestrator/client/libs/current_robot_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Unit tests for current_robot.py."""
16 |
17 | from unittest import mock
18 |
19 | from absl.testing import absltest
20 | from googleapiclient import errors
21 |
22 | from safari_sdk.orchestrator.client.libs import current_robot
23 |
24 |
25 | class CurrentRobotTest(absltest.TestCase):
26 |
27 | def test_get_current_robot_info_good(self):
28 |
29 | mock_connection = mock.MagicMock()
30 | mock_connection.orchestrator().currentRobotInfo().execute.return_value = {
31 | "robotId": "test_robot_id",
32 | "isOperational": True,
33 | "operatorId": "test_operator_id",
34 | "robotJobId": "test_robot_job_id",
35 | "workUnitId": "test_work_unit_id",
36 | "stage": "WORK_UNIT_STAGE_QUEUED_TO_ROBOT",
37 | }
38 |
39 | current_robot_lib = current_robot.OrchestratorCurrentRobotInfo(
40 | connection=mock_connection,
41 | robot_id="test_robot_id",
42 | )
43 | response = current_robot_lib.get_current_robot_info()
44 | self.assertTrue(response.success)
45 | self.assertEqual(response.robot_id, "test_robot_id")
46 | self.assertEqual(response.robot_job_id, "test_robot_job_id")
47 | self.assertEqual(response.work_unit_id, "test_work_unit_id")
48 | self.assertEqual(
49 | response.work_unit_stage,
50 | current_robot.current_robot_info.work_unit.WorkUnitStage.WORK_UNIT_STAGE_QUEUED_TO_ROBOT,
51 | )
52 | self.assertEqual(response.operator_id, "test_operator_id")
53 | self.assertTrue(response.is_operational)
54 |
55 | mock_connection.orchestrator().currentRobotInfo().execute.return_value = {
56 | "robotId": "test_robot_id",
57 | "isOperational": False,
58 | "operatorId": "test_operator_id",
59 | "robotJobId": "test_robot_job_id",
60 | "workUnitId": "test_work_unit_id",
61 | "stage": "WORK_UNIT_STAGE_QUEUED_TO_ROBOT",
62 | }
63 | response = current_robot_lib.get_current_robot_info()
64 | self.assertTrue(response.success)
65 | self.assertFalse(response.is_operational)
66 |
67 | def test_request_robot_job_bad_server_call(self):
68 |
69 | class MockHttpError:
70 |
71 | def __init__(self):
72 | self.status = "Mock status"
73 | self.reason = "Mock reason"
74 | self.error_details = "Mock error details"
75 |
76 | def raise_error_side_effect():
77 | raise errors.HttpError(MockHttpError(), "Mock failed HTTP call.".encode())
78 |
79 | mock_connection = mock.MagicMock()
80 | mock_connection.orchestrator().currentRobotInfo().execute.side_effect = (
81 | raise_error_side_effect
82 | )
83 |
84 | current_robot_lib = current_robot.OrchestratorCurrentRobotInfo(
85 | connection=mock_connection,
86 | robot_id="test_robot_id",
87 | )
88 | response = current_robot_lib.get_current_robot_info()
89 |
90 | self.assertFalse(response.success)
91 | self.assertIn(
92 | current_robot._ERROR_GET_CURRENT_ROBOT_INFO, response.error_message
93 | )
94 |
95 | def test_set_current_robot_operator_id_good(self):
96 |
97 | mock_connection = mock.MagicMock()
98 | mock_connection.orchestrator().currentRobotSetOperatorId().execute.return_value = {
99 | "robotId": "test_robot_id",
100 | "operatorId": "test_operator_id",
101 | }
102 |
103 | current_robot_lib = current_robot.OrchestratorCurrentRobotInfo(
104 | connection=mock_connection,
105 | robot_id="test_robot_id",
106 | )
107 | response = current_robot_lib.set_current_robot_operator_id(
108 | operator_id="test_operator_id"
109 | )
110 | self.assertTrue(response.success)
111 | self.assertEqual(response.robot_id, "test_robot_id")
112 | self.assertEqual(response.operator_id, "test_operator_id")
113 |
114 | def test_set_current_robot_operator_id_bad_server_call(self):
115 |
116 | class MockHttpError:
117 |
118 | def __init__(self):
119 | self.status = "Mock status"
120 | self.reason = "Mock reason"
121 | self.error_details = "Mock error details"
122 |
123 | def raise_error_side_effect():
124 | raise errors.HttpError(MockHttpError(), "Mock failed HTTP call.".encode())
125 |
126 | mock_connection = mock.MagicMock()
127 | mock_connection.orchestrator().currentRobotSetOperatorId().execute.side_effect = (
128 | raise_error_side_effect
129 | )
130 |
131 | current_robot_lib = current_robot.OrchestratorCurrentRobotInfo(
132 | connection=mock_connection,
133 | robot_id="test_robot_id",
134 | )
135 | response = current_robot_lib.set_current_robot_operator_id(
136 | operator_id="test_operator_id"
137 | )
138 |
139 | self.assertFalse(response.success)
140 | self.assertIn(
141 | current_robot._ERROR_SET_CURRENT_ROBOT_OPERATOR_ID,
142 | response.error_message,
143 | )
144 |
145 |
146 | if __name__ == "__main__":
147 | absltest.main()
148 |
--------------------------------------------------------------------------------
/safari_sdk/orchestrator/client/dataclass/artifact_dataclass_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from absl.testing import absltest
16 | from safari_sdk.orchestrator.client.dataclass import artifact
17 |
18 |
19 | class ArtifactTest(absltest.TestCase):
20 |
21 | def test_artifact_dataclass_parses_correctly(self):
22 | input_json = """
23 | {
24 | "artifact": {
25 | "uri": "gs://test-bucket/test-artifact.png",
26 | "artifactId": "test-artifact-id",
27 | "name": "test-artifact-name",
28 | "desc": "test-artifact-desc",
29 | "artifactObjectType": "ARTIFACT_OBJECT_TYPE_IMAGE",
30 | "commitTime": "2025-01-01T00:00:00Z",
31 | "tags": ["test-tag-1", "test-tag-2"],
32 | "version": "test-artifact-version",
33 | "isZipped": true
34 | }
35 | }
36 | """
37 | artifact_response = artifact.LoadArtifactResponse.from_json(input_json)
38 | self.assertEqual(
39 | artifact_response.artifact.uri, "gs://test-bucket/test-artifact.png"
40 | )
41 | self.assertEqual(artifact_response.artifact.artifactId, "test-artifact-id")
42 | self.assertEqual(artifact_response.artifact.name, "test-artifact-name")
43 | self.assertEqual(artifact_response.artifact.desc, "test-artifact-desc")
44 | self.assertEqual(
45 | artifact_response.artifact.artifactObjectType,
46 | artifact.ArtifactObjectType.ARTIFACT_OBJECT_TYPE_IMAGE,
47 | )
48 | self.assertEqual(
49 | artifact_response.artifact.commitTime, "2025-01-01T00:00:00Z"
50 | )
51 | self.assertEqual(
52 | artifact_response.artifact.tags, ["test-tag-1", "test-tag-2"]
53 | )
54 | self.assertEqual(
55 | artifact_response.artifact.version, "test-artifact-version"
56 | )
57 | self.assertEqual(artifact_response.artifact.isZipped, True)
58 |
59 | def test_artifact_dataclass_parses_correctly_with_defaults(self):
60 | input_json = """
61 | {
62 | "artifact": {
63 | "uri": "gs://test-bucket/test-artifact.png"
64 | }
65 | }
66 | """
67 | artifact_response = artifact.LoadArtifactResponse.from_json(input_json)
68 | self.assertEqual(
69 | artifact_response.artifact.uri, "gs://test-bucket/test-artifact.png"
70 | )
71 | self.assertEqual(
72 | artifact_response.artifact.artifactObjectType,
73 | artifact.ArtifactObjectType.ARTIFACT_OBJECT_TYPE_UNSPECIFIED,
74 | )
75 |
76 | def test_artifact_dataclass_parses_correctly_with_none_values(self):
77 | input_json = """
78 | {
79 | "artifact": {
80 | "uri": null,
81 | "artifactId": null,
82 | "name": null,
83 | "desc": null,
84 | "artifactObjectType": null,
85 | "commitTime": null,
86 | "tags": null,
87 | "version": null,
88 | "isZipped": null
89 | }
90 | }
91 | """
92 | artifact_response = artifact.LoadArtifactResponse.from_json(input_json)
93 | self.assertIsNone(artifact_response.artifact.uri)
94 | self.assertIsNone(artifact_response.artifact.artifactId)
95 | self.assertIsNone(artifact_response.artifact.name)
96 | self.assertIsNone(artifact_response.artifact.desc)
97 | self.assertIsNone(artifact_response.artifact.commitTime)
98 | self.assertIsNone(artifact_response.artifact.tags)
99 | self.assertIsNone(artifact_response.artifact.version)
100 | self.assertIsNone(artifact_response.artifact.isZipped)
101 | self.assertEqual(
102 | artifact_response.artifact.artifactObjectType,
103 | artifact.ArtifactObjectType.ARTIFACT_OBJECT_TYPE_UNSPECIFIED,
104 | )
105 |
106 | def test_artifact_dataclass_does_not_parse_with_empty_values(self):
107 | input_json = """
108 | {
109 | "artifact": {
110 | "uri": "",
111 | "artifactId": "",
112 | "name": "",
113 | "desc": "",
114 | "artifactObjectType": "",
115 | "commitTime": "",
116 | "tags": [],
117 | "version": "",
118 | "isZipped": false
119 | }
120 | }
121 | """
122 | with self.assertRaises(ValueError):
123 | artifact.LoadArtifactResponse.from_json(input_json)
124 |
125 | def test_artifact_dataclass_parses_correctly_with_no_values(self):
126 | input_json = """
127 | {
128 | "artifact": { }
129 | }
130 | """
131 | artifact_response = artifact.LoadArtifactResponse.from_json(input_json)
132 | self.assertIsNone(artifact_response.artifact.uri)
133 | self.assertIsNone(artifact_response.artifact.artifactId)
134 | self.assertIsNone(artifact_response.artifact.name)
135 | self.assertIsNone(artifact_response.artifact.desc)
136 | self.assertEqual(
137 | artifact_response.artifact.artifactObjectType,
138 | artifact.ArtifactObjectType.ARTIFACT_OBJECT_TYPE_UNSPECIFIED,
139 | )
140 | self.assertIsNone(artifact_response.artifact.commitTime)
141 | self.assertIsNone(artifact_response.artifact.tags)
142 | self.assertIsNone(artifact_response.artifact.version)
143 | self.assertIsNone(artifact_response.artifact.isZipped)
144 |
145 |
146 | if __name__ == "__main__":
147 | absltest.main()
148 |
--------------------------------------------------------------------------------
/examples/aloha/eval.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Run Robotics Policy Eval on Aloha Robot."""
16 |
17 | import socket
18 | from absl import app
19 | from absl import flags
20 | import env
21 | from gdm_robotics.runtime import runloop as runloop_lib
22 | import rclpy
23 | from safari_sdk.logging.python import episodic_logger
24 | from safari_sdk.model import constants
25 | from safari_sdk.model import gemini_robotics_policy
26 |
27 | # Robot constants
28 | ROBOT_CONFIG_NAME = 'aloha_stationary'
29 | CONFIG_BASE_PATH = '/home/juggler/interbotix_ws/src/aloha/config/'
30 |
31 | SERVE_ID = flags.DEFINE_string(
32 | 'serve_id',
33 | 'gemini_robotics_on_device',
34 | 'The serve id to use for the Gemini Robotics Policy.',
35 | )
36 | INFERENCE_MODE = flags.DEFINE_enum_class(
37 | 'inference_mode',
38 | constants.InferenceMode.SYNCHRONOUS,
39 | constants.InferenceMode,
40 | 'The inference mode to use for the Gemini Robotics Policy.',
41 | )
42 | ROBOTS_API_CONNECTION = flags.DEFINE_enum_class(
43 | 'robotics_api_connection',
44 | constants.RoboticsApiConnectionType.LOCAL,
45 | constants.RoboticsApiConnectionType,
46 | 'The robotics API connection type to use.',
47 | )
48 | INSTRUCTION = flags.DEFINE_string(
49 | 'instruction', None, 'Specify the instruction to give to the robot.'
50 | )
51 | MAX_NUM_STEPS = flags.DEFINE_integer(
52 | 'steps', 5000, 'Number of steps to run the episode for.'
53 | )
54 | AGENT_ID = flags.DEFINE_string(
55 | 'agent_id',
56 | socket.gethostname(),
57 | 'The agent id to use for the episodic logger.',
58 | )
59 |
60 |
61 | class UserInputRunloopOperations(runloop_lib.RunloopRuntimeOperations):
62 | """Runloop runtime operations that handle user input."""
63 |
64 | def __init__(self, default_instruction: str):
65 | self._instruction = default_instruction
66 | self._has_quit = False
67 |
68 | @property
69 | def instruction(self) -> str:
70 | return self._instruction
71 |
72 | @property
73 | def has_quit(self) -> bool:
74 | return self._has_quit
75 |
76 | def before_episode_reset(self) -> bool:
77 | # Reset the quit flag.
78 | self._has_quit = False
79 |
80 | new_input = instruction = input(
81 | "\nEnter a new instruction or 'quit' to cleanly exit: "
82 | ).lower()
83 |
84 | if new_input == 'quit':
85 | self._has_quit = True
86 | return False
87 |
88 | # It is an instruction. Save it.
89 | self._instruction = instruction
90 | return True
91 |
92 |
93 | def main(argv):
94 | del argv # Unused.
95 | if SERVE_ID.value is None:
96 | raise ValueError('serve_id must be specified.')
97 |
98 | if INSTRUCTION.value is None:
99 | print('Script started. Enter an instruction to begin.')
100 | else:
101 | print(f'Script started with instruction: {INSTRUCTION.value}')
102 |
103 | # Create environment and policy.
104 | environment = env.create_aloha_environment(
105 | robot_config_name=ROBOT_CONFIG_NAME,
106 | config_base_path=CONFIG_BASE_PATH,
107 | max_num_steps=MAX_NUM_STEPS.value,
108 | )
109 | # Uninstalls ros signal handlers (signal.SIGINT, signal.SIGTERM) to avoid
110 | # automatic ROS shutdown during keyboard interrupt.
111 | rclpy.signals.uninstall_signal_handlers()
112 | policy = gemini_robotics_policy.GeminiRoboticsPolicy(
113 | serve_id=SERVE_ID.value,
114 | task_instruction_key=env.INSTRUCTION_RESET_OPTION_KEY,
115 | image_observation_keys=(
116 | 'overhead_cam',
117 | 'worms_eye_cam',
118 | 'wrist_cam_left',
119 | 'wrist_cam_right',
120 | ),
121 | proprioceptive_observation_keys=('joints_pos',),
122 | inference_mode=INFERENCE_MODE.value,
123 | robotics_api_connection=ROBOTS_API_CONNECTION.value,
124 | )
125 | policy.step_spec(environment.timestep_spec())
126 |
127 | user_input_ops = UserInputRunloopOperations(INSTRUCTION.value)
128 |
129 | def _update_instruction_on_reset():
130 | return env.ResetOptions(
131 | options={env.INSTRUCTION_RESET_OPTION_KEY: user_input_ops.instruction}
132 | )
133 |
134 | logger = episodic_logger.EpisodicLogger.create(
135 | agent_id=AGENT_ID.value,
136 | task_id=user_input_ops.instruction,
137 | proprioceptive_observation_keys=['joints_pos'],
138 | output_directory='/tmp/eval_logs',
139 | action_spec=environment.action_spec(),
140 | timestep_spec=environment.timestep_spec(),
141 | image_observation_keys=[
142 | 'overhead_cam',
143 | 'worms_eye_cam',
144 | 'wrist_cam_left',
145 | 'wrist_cam_right',
146 | ],
147 | policy_extra_spec={},
148 | )
149 |
150 | runloop = runloop_lib.Runloop(
151 | environment=environment,
152 | policy=policy,
153 | loggers=[logger],
154 | runloop_runtime_operations=(user_input_ops,),
155 | reset_options_provider=_update_instruction_on_reset,
156 | )
157 |
158 | print('Script started. Enter an instruction to begin.')
159 |
160 | while True:
161 | try:
162 | runloop.reset()
163 | runloop.run_single_episode()
164 | if user_input_ops.has_quit:
165 | break
166 | except KeyboardInterrupt:
167 | runloop.stop()
168 |
169 | environment.close()
170 |
171 |
172 | if __name__ == '__main__':
173 | app.run(main)
174 |
--------------------------------------------------------------------------------
/safari_sdk/flywheel/upload_data_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import datetime
16 | import os
17 | from unittest import mock
18 |
19 | import pytz
20 |
21 | # Assuming the module is accessible like this for testing
22 | from absl.testing import absltest
23 | from absl.testing import parameterized
24 | from safari_sdk.flywheel import upload_data
25 |
26 |
27 | class UploadFileTest(absltest.TestCase):
28 |
29 | @mock.patch(
30 | 'safari_sdk.flywheel.upload_data.requests.post'
31 | )
32 | def test_upload_file_calls_requests_post_correctly(self, mock_post):
33 | mock_response = mock.Mock()
34 | mock_response.status_code = 200
35 | mock_response.reason = 'OK'
36 | mock_post.return_value = mock_response
37 |
38 | api_endpoint = 'https://example.com/upload'
39 | agent_id = 'test_agent_001'
40 | filename = 'data.mcap'
41 | file_content_bytes = b'dummy file content'
42 | api_key = 'test_api_key_123'
43 | # Provide a timezone-aware datetime object for 'now'
44 | now = datetime.datetime(2023, 10, 26, 10, 0, 0, tzinfo=pytz.utc) # pylint: disable=g-tzinfo-datetime
45 |
46 | status_code, reason = upload_data._upload_file(
47 | api_endpoint=api_endpoint,
48 | agent_id=agent_id,
49 | filename=filename,
50 | file_content_bytes=file_content_bytes,
51 | api_key=api_key,
52 | now=now,
53 | )
54 |
55 | self.assertEqual(status_code, 200)
56 | self.assertEqual(reason, 'OK')
57 |
58 | mock_post.assert_called_once()
59 |
60 |
61 | class UploadDataDirectoryTest(parameterized.TestCase):
62 |
63 | @mock.patch(
64 | 'safari_sdk.flywheel.upload_data._upload_file'
65 | )
66 | @mock.patch(
67 | 'safari_sdk.flywheel.upload_data.auth.get_api_key'
68 | )
69 | def test_upload_data_directory_success_and_rename(
70 | self,
71 | mock_get_api_key,
72 | mock_upload_file,
73 | ):
74 | upload_data_dir = self.create_tempdir()
75 | upload_data_dir.create_file('data1.mcap', content='dummy file content 1')
76 | upload_data_dir.create_file('data2.mcap', content='dummy file content 2')
77 |
78 | upload_sub_dir = upload_data_dir.mkdir()
79 | upload_sub_dir.create_file('data3.mcap', content='dummy file content 3')
80 |
81 | mock_upload_file.return_value = (200, 'OK')
82 | mock_get_api_key.return_value = 'test_api_key_123'
83 |
84 | upload_data.upload_data_directory(
85 | api_endpoint='https://example.com/upload',
86 | data_directory=upload_data_dir.full_path,
87 | robot_id='test_agent_001',
88 | )
89 | # check calls of upload_file,
90 | self.assertEqual(mock_upload_file.call_count, 3)
91 | # check calls of upload_file one by one
92 | mock_upload_file.assert_has_calls(
93 | any_order=True,
94 | calls=[
95 | mock.call(
96 | api_endpoint='https://example.com/upload',
97 | agent_id='test_agent_001',
98 | filename='data1.mcap',
99 | file_content_bytes=b'dummy file content 1',
100 | api_key='test_api_key_123',
101 | now=mock.ANY,
102 | ),
103 | mock.call(
104 | api_endpoint='https://example.com/upload',
105 | agent_id='test_agent_001',
106 | filename='data2.mcap',
107 | file_content_bytes=b'dummy file content 2',
108 | api_key='test_api_key_123',
109 | now=mock.ANY,
110 | ),
111 | mock.call(
112 | api_endpoint='https://example.com/upload',
113 | agent_id='test_agent_001',
114 | filename='data3.mcap',
115 | file_content_bytes=b'dummy file content 3',
116 | api_key='test_api_key_123',
117 | now=mock.ANY,
118 | ),
119 | ],
120 | )
121 |
122 | # check file name changed
123 | self.assertTrue(
124 | os.path.exists(
125 | os.path.join(upload_data_dir.full_path, 'data1.mcap.uploaded')
126 | )
127 | )
128 | self.assertFalse(
129 | os.path.exists(os.path.join(upload_data_dir.full_path, 'data1.mcap'))
130 | )
131 | self.assertTrue(
132 | os.path.exists(
133 | os.path.join(upload_data_dir.full_path, 'data2.mcap.uploaded')
134 | )
135 | )
136 | self.assertFalse(
137 | os.path.exists(os.path.join(upload_data_dir.full_path, 'data2.mcap'))
138 | )
139 | self.assertTrue(
140 | os.path.exists(
141 | os.path.join(upload_sub_dir.full_path, 'data3.mcap.uploaded')
142 | )
143 | )
144 | self.assertFalse(
145 | os.path.exists(os.path.join(upload_sub_dir.full_path, 'data3.mcap'))
146 | )
147 |
148 | @mock.patch(
149 | 'safari_sdk.flywheel.upload_data.auth.get_api_key'
150 | )
151 | def test_upload_data_directory_no_api_key_raises_error(
152 | self, mock_get_api_key
153 | ):
154 | mock_get_api_key.return_value = None
155 | with self.assertRaises(ValueError):
156 | upload_data.upload_data_directory(
157 | api_endpoint='https://example.com/upload',
158 | data_directory='test_data_dir',
159 | robot_id='test_agent_001',
160 | )
161 |
162 |
163 | if __name__ == '__main__':
164 | absltest.main()
165 |
--------------------------------------------------------------------------------
/safari_sdk/orchestrator/client/libs/artifact_lib_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Unit tests for artifact.py."""
16 |
17 | from unittest import mock
18 |
19 | from absl.testing import absltest
20 | from googleapiclient import errors
21 |
22 | from safari_sdk.orchestrator.client.libs import artifact
23 |
24 |
25 | class ArtifactTest(absltest.TestCase):
26 |
27 | def setUp(self):
28 | super().setUp()
29 | self.mock_connection = mock.MagicMock()
30 | self.artifact_lib = artifact.OrchestratorArtifact(
31 | connection=self.mock_connection
32 | )
33 | self.mock_connection.orchestrator().loadArtifact().execute.return_value = {
34 | "success": True,
35 | "artifact": {
36 | "uri": "test_artifact_uri",
37 | "artifactId": "test_artifact_id",
38 | "name": "test_name",
39 | "desc": "test_description",
40 | "artifactObjectType": "ARTIFACT_OBJECT_TYPE_IMAGE",
41 | "commitTime": "2025-01-01T00:00:00Z",
42 | "tags": ["tag1", "tag2"],
43 | "version": "1",
44 | "isZipped": False,
45 | },
46 | }
47 |
48 | def test_get_artifact_success(self):
49 |
50 | artifact_lib = artifact.OrchestratorArtifact(
51 | connection=self.mock_connection
52 | )
53 |
54 | response = artifact_lib.get_artifact(artifact_id="test_artifact_id")
55 | self.assertTrue(response.success)
56 | self.assertEqual(response.artifact.uri, "test_artifact_uri")
57 | self.assertEqual(response.artifact.artifactId, "test_artifact_id")
58 | self.assertEqual(response.artifact.name, "test_name")
59 | self.assertEqual(response.artifact.desc, "test_description")
60 | self.assertEqual(
61 | response.artifact.artifactObjectType.value,
62 | "ARTIFACT_OBJECT_TYPE_IMAGE",
63 | )
64 | self.assertEqual(response.artifact.commitTime, "2025-01-01T00:00:00Z")
65 | self.assertEqual(response.artifact.tags, ["tag1", "tag2"])
66 | self.assertEqual(response.artifact.version, "1")
67 | self.assertFalse(response.artifact.isZipped)
68 |
69 | def test_get_artifact_bad_server_call(self):
70 |
71 | class MockHttpError:
72 |
73 | def __init__(self):
74 | self.status = "Mock status"
75 | self.reason = "Mock reason"
76 | self.error_details = "Mock error details"
77 |
78 | def raise_error_side_effect():
79 | raise errors.HttpError(MockHttpError(), "Mock failed HTTP call.".encode())
80 |
81 | mock_connection = mock.MagicMock()
82 | mock_connection.orchestrator().loadArtifact().execute.side_effect = (
83 | raise_error_side_effect
84 | )
85 |
86 | artifact_lib = artifact.OrchestratorArtifact(connection=mock_connection)
87 |
88 | response = artifact_lib.get_artifact(artifact_id="test_artifact_id")
89 |
90 | self.assertFalse(response.success)
91 | self.assertIn(artifact._ERROR_GET_ARTIFACT, response.error_message)
92 |
93 | def test_get_artifact_empty_response(self):
94 |
95 | mock_connection = mock.MagicMock()
96 | mock_connection.orchestrator().loadArtifact().execute.return_value = {}
97 | artifact_lib = artifact.OrchestratorArtifact(connection=mock_connection)
98 |
99 | response = artifact_lib.get_artifact(artifact_id="test_artifact_id")
100 | self.assertFalse(response.success)
101 | self.assertEqual(artifact._ERROR_EMPTY_RESPONSE, response.error_message)
102 |
103 | def test_get_artifact_none_response(self):
104 |
105 | mock_connection = mock.MagicMock()
106 | mock_connection.orchestrator().loadArtifact().execute.return_value = None
107 | artifact_lib = artifact.OrchestratorArtifact(connection=mock_connection)
108 |
109 | response = artifact_lib.get_artifact(artifact_id="test_artifact_id")
110 | self.assertFalse(response.success)
111 | self.assertEqual(artifact._ERROR_EMPTY_RESPONSE, response.error_message)
112 |
113 | def test_get_artifact_bad_connection(self):
114 | artifact_lib = artifact.OrchestratorArtifact(connection=None)
115 | response = artifact_lib.get_artifact(artifact_id="test_artifact_id")
116 | self.assertFalse(response.success)
117 | self.assertEqual(
118 | artifact._ERROR_NO_ORCHESTRATOR_CONNECTION, response.error_message
119 | )
120 |
121 | def test_get_artifact_uri_success(self):
122 | artifact_lib = artifact.OrchestratorArtifact(
123 | connection=self.mock_connection
124 | )
125 |
126 | response = artifact_lib.get_artifact_uri(artifact_id="test_artifact_id")
127 | self.assertTrue(response.success)
128 | self.assertEqual(response.artifact_uri, "test_artifact_uri")
129 |
130 | def test_get_artifact_uri_bad_connection(self):
131 | artifact_lib = artifact.OrchestratorArtifact(connection=None)
132 | response = artifact_lib.get_artifact_uri(artifact_id="test_artifact_id")
133 | self.assertFalse(response.success)
134 | self.assertEqual(
135 | artifact._ERROR_NO_ORCHESTRATOR_CONNECTION, response.error_message
136 | )
137 |
138 | def test_get_artifact_uri_get_artifact_fails(self):
139 | mock_connection = mock.MagicMock()
140 | mock_connection.orchestrator().loadArtifact().execute.return_value = {
141 | "success": False,
142 | "error_message": "get artifact failed",
143 | "artifact": None,
144 | }
145 |
146 | artifact_lib = artifact.OrchestratorArtifact(connection=mock_connection)
147 |
148 | response = artifact_lib.get_artifact_uri(artifact_id="test_artifact_id")
149 |
150 | self.assertFalse(response.success)
151 | self.assertEqual(
152 | "OrchestratorArtifact: Received empty response for get artifact"
153 | " request.",
154 | response.error_message,
155 | )
156 |
157 | def test_invalid_artifact_id(self):
158 | mock_connection = mock.MagicMock()
159 | mock_connection.orchestrator().loadArtifact().execute.return_value = {}
160 | artifact_lib = artifact.OrchestratorArtifact(connection=mock_connection)
161 | response = artifact_lib.get_artifact_uri(artifact_id="")
162 | self.assertFalse(response.success)
163 | self.assertEqual("Artifact ID is empty.", response.error_message)
164 |
165 |
166 | if __name__ == "__main__":
167 | absltest.main()
168 |
--------------------------------------------------------------------------------
/safari_sdk/logging/python/stream_logger.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Safari robot Stream Logger class."""
16 |
17 | from collections.abc import Collection
18 | import copy
19 | import threading
20 |
21 | from safari_sdk.logging.python import base_logger
22 | from safari_sdk.logging.python import constants
23 | from safari_sdk.protos import image_pb2
24 | from safari_sdk.protos import joints_pb2
25 | from safari_sdk.protos import pose_pb2
26 | from safari_sdk.protos import sensor_calibration_pb2
27 | from safari_sdk.protos import transform_pb2
28 | from safari_sdk.protos import vector_pb2
29 | from safari_sdk.protos.logging import contact_surface_pb2
30 | from safari_sdk.protos.logging import metadata_pb2
31 | from safari_sdk.protos.logging import robot_base_pb2
32 | from safari_sdk.protos.logging import tracker_pb2
33 | from tensorflow.core.example import example_pb2
34 |
35 |
36 | _LOG_MESSAGE_TYPE = (
37 | contact_surface_pb2.ContactSurface
38 | | image_pb2.Image
39 | | joints_pb2.Joints
40 | | joints_pb2.JointsTrajectory
41 | | metadata_pb2.Session
42 | | metadata_pb2.FileMetadata
43 | | metadata_pb2.TimeSynchronization
44 | | pose_pb2.Poses
45 | | robot_base_pb2.RobotBase
46 | | sensor_calibration_pb2.SensorCalibration
47 | | tracker_pb2.Trackers
48 | | transform_pb2.Transforms
49 | | vector_pb2.NamedVectorDouble
50 | | vector_pb2.NamedVectorInt64
51 | | example_pb2.Example
52 | )
53 |
54 |
55 | class StreamLogger(base_logger.BaseLogger):
56 | """Safari robot Stream Logger class."""
57 |
58 | def __init__(
59 | self,
60 | agent_id: str,
61 | output_directory: str,
62 | required_topics: Collection[str],
63 | optional_topics: Collection[str] | None = None,
64 | file_shard_size_limit_bytes: int = constants.DEFAULT_FILE_SHARD_SIZE_LIMIT_BYTES,
65 | message_queue_size_limit: int = 0,
66 | ):
67 | super().__init__(
68 | agent_id=agent_id,
69 | output_directory=output_directory,
70 | required_topics=required_topics,
71 | optional_topics=optional_topics,
72 | internal_topics=set([constants.SYNC_TOPIC_NAME]),
73 | file_shard_size_limit_bytes=file_shard_size_limit_bytes,
74 | message_queue_size_limit=message_queue_size_limit,
75 | )
76 |
77 | # Tracks the time of the most recent message on each topic.
78 | # Protected by self._sync_message_lock.
79 | self._sync_message: metadata_pb2.TimeSynchronization = (
80 | metadata_pb2.TimeSynchronization()
81 | )
82 | self._sync_message_lock: threading.Lock = threading.Lock()
83 | self._have_all_required_topics: bool = False
84 |
85 | def has_received_all_required_topics(self) -> bool:
86 | """True if we have seen at least one message on each rwquired topic."""
87 | if not self._have_all_required_topics:
88 | with self._sync_message_lock:
89 | for topic in self._required_topics:
90 | if topic not in self._sync_message.last_timestamp_by_topic:
91 | # Have not received all required topics. Cannot start session
92 | # logging.
93 | return False
94 | # Once we have seen all required topics, we will always see all topics,
95 | # because the sync_message is never cleared.
96 | self._have_all_required_topics = True
97 | return True
98 |
99 | def start_session(
100 | self,
101 | *,
102 | start_nsec: int,
103 | task_id: str,
104 | output_file_prefix: str = '',
105 | ) -> bool:
106 |
107 | if not self.has_received_all_required_topics():
108 | return False
109 |
110 | if not super().start_session(
111 | task_id=task_id,
112 | start_nsec=start_nsec,
113 | output_file_prefix=output_file_prefix,
114 | ):
115 | return False
116 | return True
117 |
118 | def stop_session(self, stop_nsec: int) -> None:
119 | super().stop_session(stop_nsec=stop_nsec)
120 | self._session_started = False
121 |
122 | def write_sync_message(self, publish_time_nsec: int) -> None:
123 | """Writes the sync message.
124 |
125 | This must not be called unless we are recording (start_session or
126 | start_outside_session_logging has been called).
127 |
128 | This must not be called until we have seen at least one message on each
129 | topic.
130 |
131 | Args:
132 | publish_time_nsec: The publish time of the sync message.
133 | """
134 | if not self.has_received_all_required_topics():
135 | raise ValueError(
136 | 'write_sync_message is called before all required topics have been'
137 | ' received.'
138 | )
139 | if not self.is_recording():
140 | raise ValueError(
141 | 'write_sync_message was called, but no session is active and'
142 | ' start_outside_session_logging was not called..'
143 | )
144 | with self._sync_message_lock:
145 | sync_message: metadata_pb2.TimeSynchronization = copy.deepcopy(
146 | self._sync_message
147 | )
148 | super().write_proto_message(
149 | topic=constants.SYNC_TOPIC_NAME,
150 | message=sync_message,
151 | log_time_nsec=publish_time_nsec,
152 | publish_time_nsec=publish_time_nsec,
153 | )
154 |
155 | # Called within callback functions, maybe multi-threaded.
156 | def update_synchronization_and_maybe_write_message(
157 | self,
158 | topic: str,
159 | message: _LOG_MESSAGE_TYPE,
160 | publish_time_nsec: int,
161 | log_time_nsec: int = 0,
162 | ) -> None:
163 | """Updates the synchronization message and maybe writes the message.
164 |
165 | Args:
166 | topic: The safari_logging_topic of the message.
167 | message: The proto message to be written.
168 | publish_time_nsec: The timestamp of the message (this may be the time the
169 | message was published, or the time the data in the message was
170 | sampled).
171 | log_time_nsec: The time when the logger received the message. If 0, will
172 | be set to the system's current time.
173 | """
174 | if topic not in self._all_topics:
175 | raise ValueError(
176 | 'Unknown topic not present in during initialization: %s' % topic
177 | )
178 | with self._sync_message_lock:
179 | self._sync_message.last_timestamp_by_topic[topic] = publish_time_nsec
180 | if self.is_recording():
181 | super().write_proto_message(
182 | topic=topic,
183 | message=message,
184 | log_time_nsec=log_time_nsec,
185 | publish_time_nsec=publish_time_nsec,
186 | )
187 |
188 | def get_latest_sync_message(self) -> metadata_pb2.TimeSynchronization:
189 | with self._sync_message_lock:
190 | return copy.deepcopy(self._sync_message)
191 |
--------------------------------------------------------------------------------
/safari_sdk/logging/python/file_handler.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """A helper class for writing log entries to a mcap file, with file rotation."""
16 |
17 | from collections.abc import Collection
18 | import datetime
19 | import os
20 | import pathlib
21 | import stat
22 | import sys
23 | import threading
24 |
25 | from mcap_protobuf import writer as mcap_protobuf_writer
26 |
27 | from safari_sdk.logging.python import constants
28 | from safari_sdk.logging.python import message as message_lib
29 | from safari_sdk.protos import label_pb2
30 | from safari_sdk.protos.logging import metadata_pb2
31 |
32 |
33 | RESERVED_TOPICS = frozenset([
34 | constants.FILE_METADATA_TOPIC_NAME,
35 | constants.SESSION_TOPIC_NAME,
36 | constants.SYNC_TOPIC_NAME,
37 | constants.TIMESTEP_TOPIC_NAME,
38 | constants.ACTION_TOPIC_NAME,
39 | ])
40 |
41 |
42 | class FileHandler:
43 | """A helper class for writing log entries to a mcap file, with file rotation."""
44 |
45 | def __init__(
46 | self,
47 | agent_id: str,
48 | topics: Collection[str],
49 | output_directory: str,
50 | file_shard_size_limit_bytes: int = constants.DEFAULT_FILE_SHARD_SIZE_LIMIT_BYTES,
51 | ):
52 | # Invariant throughout the lifetime of the object.
53 | self._agent_id: str = agent_id
54 | self._topics: set[str] = set(topics)
55 | self._recognized_topics: set[str] = self._topics.union(RESERVED_TOPICS)
56 | self._output_directory: str = output_directory
57 | self._file_shard_size_limit_bytes: int = file_shard_size_limit_bytes
58 |
59 | # Invariant until reset_for_new_file call.
60 | self._output_file_prefix: str = ''
61 |
62 | # Invariant during operation on a particular file shard.
63 | self._shard: int = 0
64 | self._file_handle = None
65 | self._mcap_writer: mcap_protobuf_writer.Writer = None
66 |
67 | # Variable with each write_message call.
68 | self._file_shard_bytes: int = 0
69 |
70 | # start timestamp of current file shard.
71 | self._start_nsec: int = sys.maxsize
72 | # stop timestamp of current file shard.
73 | self._stop_nsec: int = 0
74 |
75 | self._lock: threading.Lock = threading.Lock()
76 |
77 | def reset_for_new_file(
78 | self, output_file_prefix: str, start_nsec: int
79 | ) -> None:
80 | with self._lock:
81 | self._output_file_prefix = output_file_prefix
82 | self._reset_for_new_shard(is_first_shard=True)
83 |
84 | self._start_nsec = start_nsec
85 | self._stop_nsec = start_nsec
86 |
87 | def _reset_for_new_shard(self, is_first_shard: bool) -> None:
88 | """Reset the file handler states for a new shard.
89 |
90 | Args:
91 | is_first_shard: If True, the shard number is set to 0.
92 | """
93 | if is_first_shard:
94 | self._shard = 0
95 | else:
96 | self._shard += 1
97 | tmp_dir = f'{self._output_directory}/tmp'
98 | if not os.path.exists(tmp_dir):
99 | os.makedirs(tmp_dir)
100 | self._file_handle = open(
101 | pathlib.Path(tmp_dir)
102 | / f'{self._output_file_prefix}-shard{self._shard}.mcap',
103 | 'wb',
104 | )
105 | self._mcap_writer = mcap_protobuf_writer.Writer(self._file_handle)
106 | self._file_shard_bytes = 0
107 |
108 | def write_message(
109 | self,
110 | message: message_lib.Message,
111 | ) -> None:
112 | """Write message with file rotation.
113 |
114 | Args:
115 | message: The Safari message object.
116 | """
117 | with self._lock:
118 | if message.topic not in self._recognized_topics:
119 | raise ValueError(
120 | 'Unknown topic not present in during initialization: %s'
121 | % message.topic
122 | )
123 | self._topics.add(message.topic)
124 | msg_size = message.message.ByteSize()
125 | if self._file_shard_bytes > 0 and (
126 | self._file_shard_bytes + msg_size > self._file_shard_size_limit_bytes
127 | ):
128 | self._finalize_and_close_file()
129 | self._reset_for_new_shard(is_first_shard=False)
130 | self._start_nsec = min(self._start_nsec, message.publish_time_nsec)
131 | self._stop_nsec = max(self._stop_nsec, message.publish_time_nsec + 1)
132 | self._mcap_writer.write_message(
133 | topic=message.topic,
134 | message=message.message,
135 | log_time=message.log_time_nsec,
136 | publish_time=message.publish_time_nsec,
137 | )
138 | self._file_shard_bytes += msg_size
139 |
140 | def finalize_and_close_file(self, stop_nsec: int) -> None:
141 | """Finalize the file metadata and the mcap writer and close the file handle.
142 |
143 | Args:
144 | stop_nsec: The stop time of data coverage in the this new file shard.
145 | """
146 | with self._lock:
147 | self._start_nsec = min(self._start_nsec, stop_nsec)
148 | self._stop_nsec = max(self._stop_nsec, stop_nsec)
149 | self._finalize_and_close_file()
150 |
151 | def _finalize_and_close_file(self) -> None:
152 | """Private method to finalize and close the mcap writer and file handle."""
153 | file_metadata = metadata_pb2.FileMetadata(
154 | agent_id=self._agent_id,
155 | )
156 | if self._stop_nsec > self._start_nsec:
157 | # there's actually data in the file, so we need to add the stream
158 | # coverages.
159 | for topic in self._topics:
160 | file_metadata.stream_coverages.append(
161 | metadata_pb2.KeyRange(
162 | topic=topic,
163 | interval=label_pb2.IntervalValue(
164 | start_nsec=self._start_nsec,
165 | stop_nsec=self._stop_nsec,
166 | ),
167 | )
168 | )
169 | # the start_nsec of the next file shard should <= the stop_nsec of
170 | # the current file shard so the backend can observe continous data
171 | # coverage after received both shards.
172 | self._start_nsec = self._stop_nsec
173 | self._mcap_writer.write_message(
174 | topic=constants.FILE_METADATA_TOPIC_NAME,
175 | message=file_metadata,
176 | log_time=self._stop_nsec,
177 | publish_time=self._stop_nsec,
178 | )
179 | self._mcap_writer.finish()
180 | if self._file_handle:
181 | tmp_file_path = pathlib.Path(self._file_handle.name)
182 | self._file_handle.close()
183 | file_name = tmp_file_path.name
184 | date_now = datetime.datetime.now()
185 | final_dir = (
186 | pathlib.Path(self._output_directory)
187 | / date_now.strftime('%Y')
188 | / date_now.strftime('%m')
189 | / date_now.strftime('%d')
190 | )
191 | final_file_path = final_dir / file_name
192 | if not os.path.exists(final_dir):
193 | os.makedirs(final_dir)
194 | os.rename(tmp_file_path, final_file_path)
195 | current_permissions = os.stat(final_file_path).st_mode
196 | # Remove write permissions for all users (owner, group, others)
197 | os.chmod(
198 | final_file_path,
199 | current_permissions & ~stat.S_IWUSR & ~stat.S_IWGRP & ~stat.S_IWOTH,
200 | )
201 |
--------------------------------------------------------------------------------
/examples/model/gemini_robotics_aloha_eval_example.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Run Gemini Robotics Policy Eval on Aloha Robot."""
16 |
17 | # This is not a google3-compatible python file. It is intended to be run on a
18 | # non-corp machine.
19 |
20 | import select
21 | import signal
22 | import sys
23 | import termios
24 | import threading
25 | import time
26 | import tty
27 |
28 | from absl import app
29 | from absl import flags
30 | # TODO: Remove dependency on hostbot.aloha and use interbotix instead.
31 | from hostbot.aloha import aloha_ros_robot_client
32 | import rclpy
33 | from safari_sdk.model import gemini_robotics_policy
34 |
35 |
36 | _TASK_INSTRUCTION = flags.DEFINE_string(
37 | 'task_instruction',
38 | 'pick up banana and hand over',
39 | 'Task instruction to use for the policy.',
40 | )
41 | _SERVE_ID = flags.DEFINE_string(
42 | 'serve_id',
43 | None,
44 | 'The serve ID to use.',
45 | required=True,
46 | )
47 |
48 | _DT = 0.02
49 | _IMAGE_SIZE = (480, 848)
50 | _ALOHA_CAMERAS = {
51 | 'overhead_cam': _IMAGE_SIZE,
52 | 'worms_eye_cam': _IMAGE_SIZE,
53 | 'wrist_cam_left': _IMAGE_SIZE,
54 | 'wrist_cam_right': _IMAGE_SIZE,
55 | }
56 | _ALOHA_JOINTS = {'joints_pos': 14}
57 |
58 |
59 | def main(_):
60 | rclpy.init()
61 | robot_client = aloha_ros_robot_client.AlohaROSRobotClient(
62 | include_leaders=False,
63 | subscribe_to_raw_images=False,
64 | )
65 | robot_client.prep_robots()
66 |
67 | def shutdown():
68 | print('Shutting down.')
69 | robot_client.move_to_rest_poses()
70 | robot_client.close()
71 | rclpy.try_shutdown()
72 | sys.exit(0)
73 |
74 | def handler_fn(sig, frame):
75 | del sig, frame
76 | shutdown()
77 |
78 | unused_sigint_handler = SigintHandler(handler_fn)
79 |
80 | task_instruction = TaskInstruction(_TASK_INSTRUCTION.value)
81 | model_client = gemini_robotics_policy.GeminiRoboticsPolicy(
82 | serve_id=_SERVE_ID.value,
83 | task_instruction=str(task_instruction),
84 | cameras=_ALOHA_CAMERAS,
85 | joints=_ALOHA_JOINTS,
86 | )
87 |
88 | def run_episode():
89 | print('Homing...')
90 | robot_client.move_to_home()
91 |
92 | print('Policy reset')
93 | model_client.reset()
94 |
95 | # Get new task instruction from user.
96 | task = task_instruction.get_user_input()
97 | print('Task instruction: ', task)
98 | model_client._task_instrution = task # pylint: disable=protected-access
99 |
100 | with KeyDetect() as detector:
101 | print('Running policy... Press "q" to terminate episode.')
102 | while True:
103 | frame_start_time = time.time()
104 | obs = {
105 | camera_name: bytes(robot_client.get_image_jpeg(camera_name))
106 | for camera_name, _ in _ALOHA_CAMERAS.items()
107 | } | {
108 | joint_name: robot_client.get_follower_joints_pos()
109 | for joint_name, _ in _ALOHA_JOINTS.items()
110 | }
111 | gemini_actions = model_client.step(obs)
112 | cmd = aloha_ros_robot_client.robot_client.RobotCommand(
113 | left_arm_joint_target=gemini_actions[:6],
114 | right_arm_joint_target=gemini_actions[7:13],
115 | left_gripper_joint_target=gemini_actions[6:7],
116 | right_gripper_joint_target=gemini_actions[13:],
117 | )
118 | robot_client.step(cmd)
119 |
120 | if detector.is_down('q'):
121 | print('Episode terminated.')
122 | detector.clear()
123 | break
124 |
125 | frame_time = time.time() - frame_start_time
126 | time.sleep(max(0, _DT - frame_time))
127 |
128 | while True:
129 | run_episode()
130 |
131 |
132 | class TaskInstruction:
133 | """Task instruction for the policy."""
134 |
135 | def __init__(self, task_instruction: str):
136 | self._task_instruction = task_instruction
137 |
138 | def get_user_input(self) -> str:
139 | new_instruction = input(
140 | f'Input task instruction [{self._task_instruction}]:'
141 | )
142 | if new_instruction:
143 | self._task_instruction = new_instruction
144 | return self._task_instruction
145 |
146 | def __str__(self):
147 | return self._task_instruction
148 |
149 |
150 | class SigintHandler:
151 | """Lightweight utility to call a function on SIGINT.
152 |
153 | The SIGINT handling will be removed once this object is deleted.
154 | """
155 |
156 | def __init__(self, handler_fn):
157 | self._prev_sigint_signal = signal.signal(
158 | signal.SIGINT,
159 | handler_fn,
160 | )
161 |
162 | def __del__(self):
163 | if self._prev_sigint_signal:
164 | signal.signal(signal.SIGINT, self._prev_sigint_signal)
165 |
166 |
167 | class KeyDetect:
168 | """A non-blocking key detection class."""
169 |
170 | def __init__(self):
171 | self._original_settings = None
172 | self._key_buffer = set()
173 | self._lock = threading.Lock()
174 | self._event = threading.Event()
175 | self._stop_flag = False
176 | self._thread = None
177 |
178 | def __enter__(self):
179 | """Enters the context, setting up non-blocking input."""
180 | self._original_settings = termios.tcgetattr(sys.stdin)
181 | tty.setraw(sys.stdin.fileno())
182 | self._start_listening()
183 | return self
184 |
185 | def __exit__(self, exc_type, exc_val, exc_tb):
186 | """Exits the context, restoring terminal settings and stopping listener."""
187 | self._stop_listening()
188 | termios.tcsetattr(sys.stdin, termios.TCSADRAIN, self._original_settings)
189 |
190 | def _start_listening(self):
191 | """Starts a background thread to listen for key presses."""
192 | self._stop_flag = False
193 | self._thread = threading.Thread(target=self._listen)
194 | self._thread.daemon = True
195 | self._thread.start()
196 |
197 | def _stop_listening(self):
198 | """Sets the stop flag and waits for the background thread to finish."""
199 | self._stop_flag = True
200 | self._event.set() # Wake up the thread if it's waiting
201 | if self._thread and self._thread.is_alive():
202 | self._thread.join()
203 |
204 | def _listen(self):
205 | """Listens for key presses and updates the key buffer."""
206 | while not self._stop_flag:
207 | if select.select([sys.stdin], [], [], 0.1)[
208 | 0
209 | ]: # Check for input with a timeout
210 | try:
211 | key = sys.stdin.read(1)
212 | with self._lock:
213 | self._key_buffer.add(key)
214 | except BlockingIOError:
215 | pass # No input available
216 | self._event.wait(0.01) # Small delay to reduce CPU usage
217 | self._event.clear()
218 |
219 | def is_down(self, key):
220 | """Checks if a specific key is currently pressed."""
221 | with self._lock:
222 | return key in self._key_buffer
223 |
224 | def get_pressed(self):
225 | """Returns a set of all currently pressed keys."""
226 | with self._lock:
227 | return set(self._key_buffer)
228 |
229 | def clear(self):
230 | """Clears the buffer of currently pressed keys."""
231 | with self._lock:
232 | self._key_buffer.clear()
233 |
234 |
235 | if __name__ == '__main__':
236 | app.run(main)
237 |
--------------------------------------------------------------------------------
/safari_sdk/protos/logging/machine_info.proto:
--------------------------------------------------------------------------------
1 | // Copyright 2025 Google LLC
2 | //
3 | // Licensed under the Apache License, Version 2.0 (the "License");
4 | // you may not use this file except in compliance with the License.
5 | // You may obtain a copy of the License at
6 | //
7 | // https://www.apache.org/licenses/LICENSE-2.0
8 | //
9 | // Unless required by applicable law or agreed to in writing, software
10 | // distributed under the License is distributed on an "AS IS" BASIS,
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | // See the License for the specific language governing permissions and
13 | // limitations under the License.
14 |
15 | syntax = "proto2";
16 |
17 | package safari_sdk.protos.logging;
18 |
19 | import "safari_sdk/protos/label.proto";
20 |
21 | message MachineInfo {
22 | message CPUInfo {
23 | optional uint32 physical_cores = 1; // Number of physical cores
24 | optional uint32 logical_cores = 2; // Number of logical cores
25 | optional string model_name = 3; // CPU model name
26 | // Total CPU usage percentage ([0, 100])
27 | optional float utilization_percent = 5;
28 | repeated float per_cpu_utilization_percent = 6 [packed = true];
29 |
30 | message CPUTimes { // Time spent by a CPU in seconds
31 | optional float user = 1;
32 | optional float system = 2;
33 | optional float idle = 3;
34 | optional float nice = 4;
35 | optional float iowait = 5;
36 | optional float irq = 6;
37 | optional float softirq = 7;
38 | optional float steal = 8;
39 | optional float guest = 9;
40 | optional float guest_nice = 10;
41 | }
42 |
43 | optional CPUTimes cpu_times = 7;
44 | repeated CPUTimes per_cpu_times = 8;
45 |
46 | message CPUStats {
47 | optional uint64 ctx_switches = 1;
48 | optional uint64 interrupts = 2;
49 | optional uint64 soft_interrupts = 3;
50 | optional uint64 syscalls = 4;
51 | }
52 | optional CPUStats cpu_stats = 9;
53 |
54 | message CPUFrequency {
55 | optional float current = 1; // Current CPU frequency in Hz
56 | optional float min = 2; // Minimum supported CPU frequency in Hz
57 | optional float max = 3; // Maximum supported CPU frequency in Hz
58 | }
59 |
60 | optional CPUFrequency cpu_frequency = 10;
61 | repeated CPUFrequency per_cpu_frequency = 11;
62 |
63 | message LoadStats {
64 | optional float load_avg1_min = 1;
65 | optional float load_avg5_min = 2;
66 | optional float load_avg15_min = 3;
67 | }
68 | optional LoadStats load_avg = 12; // Machine load statistics (like `top`)
69 | }
70 |
71 | message LinuxInfo {
72 | optional string platform = 1;
73 | optional string kernel = 2;
74 | optional string linux_version = 3;
75 | optional string hostname = 4;
76 | optional string architecture = 5;
77 | // Environment variables exposed to the process collecting the data.
78 | map env_variables = 6;
79 | }
80 |
81 | message VirtualMemoryInfo {
82 | optional uint64 total = 1; // Total physical memory (bytes)
83 | optional uint64 available = 2; // Available memory (bytes)
84 | optional float percent = 3; // Memory usage percentage
85 | optional uint64 used = 4; // Used memory (bytes)
86 | optional uint64 free = 5; // Free memory (bytes)
87 | optional uint64 active = 6; // Free memory (bytes)
88 | optional uint64 inactive = 7; // Free memory (bytes)
89 | optional uint64 buffers = 8; // Free memory (bytes)
90 | optional uint64 cached = 9; // Free memory (bytes)
91 | optional uint64 shared = 10; // Free memory (bytes)
92 | optional uint64 slab = 11; // Free memory (bytes)
93 | }
94 |
95 | message SwapMemoryInfo {
96 | optional uint64 total = 1; // Total physical memory (bytes)
97 | optional float percent = 3; // Memory usage percentage
98 | optional uint64 used = 4; // Used memory (bytes)
99 | optional uint64 free = 5; // Free memory (bytes)
100 | optional uint64 sin = 6; // Free memory (bytes)
101 | optional uint64 sout = 7; // Free memory (bytes)
102 | }
103 |
104 | message DiskInfo {
105 | message Partition {
106 | optional string device = 1; // Device identifier
107 | optional string mountpoint = 2;
108 | optional string fstype = 3; // File system type
109 | optional string opts = 4; // Mount options
110 | optional uint32 maxfile = 5;
111 | optional uint32 maxpath = 6;
112 | optional uint64 total = 7; // Total size (bytes)
113 | optional uint64 used = 8; // Used space (bytes)
114 | optional uint64 free = 9; // Free space (bytes)
115 | optional float percent = 10; // Usage percentage
116 | }
117 |
118 | repeated Partition partitions = 1;
119 |
120 | message DiskIOStats {
121 | optional uint64 read_count = 1;
122 | optional uint64 write_count = 2;
123 | optional uint64 read_bytes = 3;
124 | optional uint64 write_bytes = 4;
125 | optional uint64 read_time_milliseconds = 5;
126 | optional uint64 write_time_milliseconds = 6;
127 | optional uint64 busy_time_milliseconds = 7;
128 | optional uint64 read_merged_count = 8;
129 | optional uint64 write_merged_count = 9;
130 | }
131 | optional DiskIOStats io_stats = 2;
132 | }
133 |
134 | message UserInfo {
135 | optional string name = 1;
136 | optional string terminal = 2;
137 | optional string host = 3;
138 | optional DomainTimestamp started = 4;
139 | }
140 |
141 | message NetworkInfo {
142 | message InterfaceInfo {
143 | optional string name = 1; // Interface name (e.g., "eth0", "Wi-Fi")
144 | optional bool is_up = 2; // True if interface is up
145 | optional int32 mtu = 3; // Maximum transmission unit
146 | optional string mac_address = 4; // Hardware MAC address (if available)
147 |
148 | repeated string ip_addresses = 5; // List of IP addresses
149 | repeated string netmasks = 6; // Corresponding netmasks
150 |
151 | // Network statistics
152 | optional uint64 bytes_sent = 7;
153 | optional uint64 bytes_recv = 8;
154 | optional uint64 packets_sent = 9;
155 | optional uint64 packets_recv = 10;
156 | optional uint64 errin = 11;
157 | optional uint64 errout = 12;
158 | optional uint64 dropin = 13;
159 | optional uint64 dropout = 14;
160 | }
161 |
162 | repeated InterfaceInfo interfaces = 1;
163 | }
164 |
165 | message SensorInfo {
166 | message TemperatureSensor {
167 | optional string sensor_type = 1;
168 | optional string sensor_name = 2;
169 | optional float temperature = 3; // C
170 | optional float temperature_high = 4; // C
171 | optional float temperature_critical = 5; // C
172 | }
173 | repeated TemperatureSensor temperature_sensors = 1;
174 | }
175 |
176 | message ProcessInfo {
177 | optional string status = 1;
178 | optional uint32 cpu_num = 2;
179 | optional uint64 pid = 3;
180 | optional string cmdline = 4;
181 | optional DomainTimestamp create_time = 5;
182 | optional float cpu_percent = 6;
183 | optional string terminal = 7;
184 | optional uint64 ppid = 8;
185 | optional string cwd = 9;
186 | optional int32 nice = 10;
187 | optional string username = 11;
188 | optional float cpu_time_user_seconds = 12;
189 | optional float cpu_time_system_seconds = 13;
190 | optional float cpu_time_children_user_seconds = 14;
191 | optional float cpu_time_children_system_seconds = 15;
192 | optional float cpu_time_iowait_seconds = 16;
193 | optional uint64 num_ctx_switches_voluntary = 17;
194 | optional uint64 num_ctx_switches_involuntary = 18;
195 | optional string name = 19;
196 | optional uint64 num_threads = 20;
197 | optional float memory_percent = 21;
198 | }
199 |
200 | optional LinuxInfo linux = 1;
201 | optional CPUInfo cpu = 2;
202 | optional VirtualMemoryInfo virtual_memory = 3;
203 | optional SwapMemoryInfo swap_memory = 4;
204 | optional DiskInfo disk = 5;
205 | optional DomainTimestamp boot_time = 6;
206 | repeated UserInfo users = 7;
207 | optional NetworkInfo network = 8;
208 | optional SensorInfo sensors = 9;
209 | repeated ProcessInfo processes = 10;
210 | }
211 |
--------------------------------------------------------------------------------
/cmake/protobuf-generate.cmake:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # Functions and commands for generating protobuf code from proto files.
16 |
17 | # Ideally we would use the cmake protobuf package (https://cmake.org/cmake/help/latest/module/FindProtobuf.html)
18 | # and this file would not be needed, but a long chain of version and dependency
19 | # issues precludes that. Specifically, we need ROS Noetic, which in turn requires
20 | # Ubuntu 20.04. That has a really old versions of the protobuf compiler and cmake.
21 | # A newer version of the protobuf compiler must be used, but the version of cmake
22 | # that comes with Ubuntu 20.04 is only capable of using the system protobuf
23 | # compiler (the ability to specify the protoc executable was added in cmake 4.0,
24 | # while Ubuntu 20.04 ships with cmake 3.16). The root cause of all this is the
25 | # depenency on ROS 1, which is end of life. Once we switch to ROS 2, we can
26 | # upgrade ubuntu and everything that comes with that (including cmake and protoc),
27 | # and this file can probably be removed.
28 |
29 | # This was adapted from https://github.com/protocolbuffers/protobuf/blob/main/cmake/protobuf-generate.cmake
30 | # This file alone is released under the same license as that source.
31 |
32 | # Add commands to download and unzip the protoc binary. The timestamps on the
33 | # downloaded files are not set in a useful way (or at all), so touch them to
34 | # update the timestamps and prevent these commands from running every time.
35 | set(_DOWNLOADED_PROTOC_DIR ${CMAKE_BINARY_DIR}/protoc)
36 | set(_DOWNLOADED_PROTOC_EXE_PATH ${_DOWNLOADED_PROTOC_DIR}/bin/protoc)
37 |
38 | # Protobuf versions defined:
39 | # https://protobuf.dev/support/version-support/#python
40 |
41 | set(_DOWNLOADED_PROTOC_ZIP_FILE "protoc-32.0-linux-x86_64.zip")
42 | set(_DOWNLOADED_PROTOC_WEB_PATH "https://github.com/protocolbuffers/protobuf/releases/download/v32.0/${_DOWNLOADED_PROTOC_ZIP_FILE}")
43 |
44 | # Add a command to download and extract the protobuf compiler. The timestamps
45 | # are not set in a useful way (or at all), so touch the output to update the
46 | # timestamps and prevent this command from running every time.
47 | add_custom_command(
48 | OUTPUT ${_DOWNLOADED_PROTOC_EXE_PATH}
49 | COMMAND wget -P ${_DOWNLOADED_PROTOC_DIR} ${_DOWNLOADED_PROTOC_WEB_PATH}
50 | COMMAND unzip -o "${_DOWNLOADED_PROTOC_DIR}/${_DOWNLOADED_PROTOC_ZIP_FILE}" -d ${_DOWNLOADED_PROTOC_DIR}
51 | COMMAND touch ${_DOWNLOADED_PROTOC_EXE_PATH}
52 | COMMENT "Downloading ${_DOWNLOADED_PROTOC_WEB_PATH} and extracting to ${_DOWNLOADED_PROTOC_DIR}"
53 | VERBATIM )
54 |
55 | # Low-level function to add commands to invoke the protoc compiler. This should
56 | # be a drop-in replacement for the cmake protobuf package, except it defaults to
57 | # using the prebuilt protoc defined above rather than the system protoc.
58 | function(protobuf_generate)
59 | include(CMakeParseArguments)
60 |
61 | set(_options APPEND_PATH)
62 | set(_singleargs LANGUAGE OUT_VAR EXPORT_MACRO PROTOC_OUT_DIR PLUGIN PLUGIN_OPTIONS PROTOC_EXE)
63 | if(COMMAND target_sources)
64 | list(APPEND _singleargs TARGET)
65 | endif()
66 | set(_multiargs PROTOS IMPORT_DIRS GENERATE_EXTENSIONS PROTOC_OPTIONS DEPENDENCIES)
67 |
68 | cmake_parse_arguments(protobuf_generate "${_options}" "${_singleargs}" "${_multiargs}" "${ARGN}")
69 |
70 | if(NOT protobuf_generate_PROTOS AND NOT protobuf_generate_TARGET)
71 | message(SEND_ERROR "Error: protobuf_generate called without any targets or source files")
72 | return()
73 | endif()
74 |
75 | if(NOT protobuf_generate_OUT_VAR AND NOT protobuf_generate_TARGET)
76 | message(SEND_ERROR "Error: protobuf_generate called without a target or output variable")
77 | return()
78 | endif()
79 |
80 | if(NOT protobuf_generate_LANGUAGE)
81 | set(protobuf_generate_LANGUAGE cpp)
82 | endif()
83 | string(TOLOWER ${protobuf_generate_LANGUAGE} protobuf_generate_LANGUAGE)
84 |
85 | if(NOT protobuf_generate_PROTOC_OUT_DIR)
86 | set(protobuf_generate_PROTOC_OUT_DIR ${CMAKE_CURRENT_BINARY_DIR})
87 | endif()
88 |
89 | if(protobuf_generate_EXPORT_MACRO AND protobuf_generate_LANGUAGE STREQUAL cpp)
90 | set(_dll_export_decl "dllexport_decl=${protobuf_generate_EXPORT_MACRO}")
91 | endif()
92 |
93 | foreach(_option ${_dll_export_decl} ${protobuf_generate_PLUGIN_OPTIONS})
94 | # append comma - not using CMake lists and string replacement as users
95 | # might have semicolons in options
96 | if(_plugin_options)
97 | set( _plugin_options "${_plugin_options},")
98 | endif()
99 | set(_plugin_options "${_plugin_options}${_option}")
100 | endforeach()
101 |
102 | if(protobuf_generate_PLUGIN)
103 | set(_plugin "--plugin=${protobuf_generate_PLUGIN}")
104 | endif()
105 |
106 | if(NOT protobuf_generate_GENERATE_EXTENSIONS)
107 | if(protobuf_generate_LANGUAGE STREQUAL cpp)
108 | set(protobuf_generate_GENERATE_EXTENSIONS .pb.h .pb.cc)
109 | elseif(protobuf_generate_LANGUAGE STREQUAL python)
110 | set(protobuf_generate_GENERATE_EXTENSIONS _pb2.py)
111 | else()
112 | message(SEND_ERROR "Error: protobuf_generate given unknown Language ${LANGUAGE}, please provide a value for GENERATE_EXTENSIONS")
113 | return()
114 | endif()
115 | endif()
116 |
117 | if(protobuf_generate_TARGET)
118 | get_target_property(_source_list ${protobuf_generate_TARGET} SOURCES)
119 | foreach(_file ${_source_list})
120 | if(_file MATCHES "proto$")
121 | list(APPEND protobuf_generate_PROTOS ${_file})
122 | endif()
123 | endforeach()
124 | endif()
125 |
126 | if(NOT protobuf_generate_PROTOS)
127 | message(SEND_ERROR "Error: protobuf_generate could not find any .proto files")
128 | return()
129 | endif()
130 |
131 | if(protobuf_generate_APPEND_PATH)
132 | # Create an include path for each file specified
133 | foreach(_file ${protobuf_generate_PROTOS})
134 | get_filename_component(_abs_file ${_file} ABSOLUTE)
135 | get_filename_component(_abs_dir ${_abs_file} DIRECTORY)
136 | list(FIND _protobuf_include_path ${_abs_dir} _contains_already)
137 | if(${_contains_already} EQUAL -1)
138 | list(APPEND _protobuf_include_path -I ${_abs_dir})
139 | endif()
140 | endforeach()
141 | endif()
142 |
143 | if(NOT protobuf_generate_PROTOC_EXE)
144 | # Default to using the CMake executable
145 | set(protobuf_generate_PROTOC_EXE ${_DOWNLOADED_PROTOC_EXE_PATH})
146 | endif()
147 |
148 | foreach(DIR ${protobuf_generate_IMPORT_DIRS})
149 | get_filename_component(ABS_PATH ${DIR} ABSOLUTE)
150 | list(FIND _protobuf_include_path ${ABS_PATH} _contains_already)
151 | if(${_contains_already} EQUAL -1)
152 | list(APPEND _protobuf_include_path -I ${ABS_PATH})
153 | endif()
154 | endforeach()
155 |
156 | if(NOT _protobuf_include_path)
157 | set(_protobuf_include_path -I ${CMAKE_CURRENT_SOURCE_DIR})
158 | endif()
159 |
160 | set(_generated_srcs_all)
161 | foreach(_proto ${protobuf_generate_PROTOS})
162 | get_filename_component(_abs_file ${_proto} ABSOLUTE)
163 | get_filename_component(_abs_dir ${_abs_file} DIRECTORY)
164 |
165 | get_filename_component(_file_full_name ${_proto} NAME)
166 | string(FIND "${_file_full_name}" "." _file_last_ext_pos REVERSE)
167 | string(SUBSTRING "${_file_full_name}" 0 ${_file_last_ext_pos} _basename)
168 |
169 | set(_suitable_include_found FALSE)
170 | foreach(DIR ${_protobuf_include_path})
171 | if(NOT DIR STREQUAL "-I")
172 | file(RELATIVE_PATH _rel_dir ${DIR} ${_abs_dir})
173 | if(_rel_dir STREQUAL _abs_dir)
174 | # When there is no relative path from DIR to _abs_dir (e.g. due to
175 | # different drive letters on Windows), _rel_dir is equal to _abs_dir.
176 | # Therefore, DIR is not a suitable include path and must be skipped.
177 | continue()
178 | endif()
179 | string(FIND "${_rel_dir}" "../" _is_in_parent_folder)
180 | if (NOT ${_is_in_parent_folder} EQUAL 0)
181 | set(_suitable_include_found TRUE)
182 | break()
183 | endif()
184 | endif()
185 | endforeach()
186 |
187 | if(NOT _suitable_include_found)
188 | message(SEND_ERROR "Error: protobuf_generate could not find any correct proto include directory.")
189 | return()
190 | endif()
191 |
192 | set(_generated_srcs)
193 | foreach(_ext ${protobuf_generate_GENERATE_EXTENSIONS})
194 | list(APPEND _generated_srcs "${protobuf_generate_PROTOC_OUT_DIR}/${_rel_dir}/${_basename}${_ext}")
195 | endforeach()
196 | list(APPEND _generated_srcs_all ${_generated_srcs})
197 |
198 | set(_comment "Running ${protobuf_generate_LANGUAGE} protocol buffer compiler on ${_proto}")
199 | if(protobuf_generate_PROTOC_OPTIONS)
200 | set(_comment "${_comment}, protoc-options: ${protobuf_generate_PROTOC_OPTIONS}")
201 | endif()
202 | if(_plugin_options)
203 | set(_comment "${_comment}, plugin-options: ${_plugin_options}")
204 | endif()
205 |
206 | add_custom_command(
207 | OUTPUT ${_generated_srcs}
208 | COMMAND ${protobuf_generate_PROTOC_EXE} ${protobuf_generate_PROTOC_OPTIONS} --${protobuf_generate_LANGUAGE}_out ${_plugin_options}:${protobuf_generate_PROTOC_OUT_DIR} ${_plugin} ${_protobuf_include_path} ${_abs_file}
209 | DEPENDS ${_abs_file} ${protobuf_generate_PROTOC_EXE} ${protobuf_generate_DEPENDENCIES}
210 | COMMENT "${_comment}"
211 | VERBATIM
212 | )
213 | endforeach()
214 |
215 | set_source_files_properties(${_generated_srcs_all} PROPERTIES GENERATED TRUE)
216 | if(protobuf_generate_OUT_VAR)
217 | set(${protobuf_generate_OUT_VAR} ${_generated_srcs_all} PARENT_SCOPE)
218 | endif()
219 | if(protobuf_generate_TARGET)
220 | target_sources(${protobuf_generate_TARGET} PRIVATE ${_generated_srcs_all})
221 | endif()
222 |
223 | endfunction()
224 |
--------------------------------------------------------------------------------
/safari_sdk/logging/python/mcap_lerobot_logger.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Logger for LeRobot data."""
16 |
17 | import concurrent.futures
18 |
19 | from absl import logging
20 | import dm_env
21 | from dm_env import specs
22 | from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
23 | import numpy as np
24 |
25 | from safari_sdk.logging.python import mcap_episodic_logger
26 |
27 | # LeRobot step keys.
28 | _ACTION_KEY = "action"
29 | _FRAME_INDEX_KEY = "frame_index"
30 | _NEXT_DONE_KEY = "next.done"
31 | _OBSERVATION_KEY_PREFIX = "observation."
32 | # _TASK_KEY maps to _INSTRUCTION_KEY and eventually will be used to populate the
33 | # instruction label in the SSOT Session.
34 | _TASK_KEY = "task"
35 |
36 | # LeRobot feature keys.
37 | _DTYPE_VIDEO = "video"
38 | _DTYPE_FLOAT32 = "float32"
39 | _DTYPE_INT64 = "int64"
40 | _DTYPE_BOOL = "bool"
41 | _DTYPE_KEY = "dtype"
42 |
43 | # MCAP Logger spec keys.
44 | _REWARD_KEY = "reward"
45 | _DISCOUNT_KEY = "discount"
46 | _STEP_TYPE_KEY = "step_type"
47 | _SHAPE_KEY = "shape"
48 | _INSTRUCTION_KEY = "instruction"
49 |
50 | # Others.
51 | _AGENT_ID_PREFIX = "robot_episode_"
52 |
53 |
54 | class LeRobotEpisodicLogger:
55 | """An episodic logger that writes LeRobot episodes to MCAP files."""
56 |
57 | def __init__(
58 | self,
59 | task_id: str,
60 | output_directory: str,
61 | camera_names: list[str] | None = None,
62 | proprio_key: str | None = None,
63 | features: dict | None = None,
64 | ):
65 | """Initializes the logger.
66 |
67 | Args:
68 | task_id: The task ID.
69 | output_directory: The output directory for MCAP files.
70 | camera_names: A list of camera keys for image encoding.
71 | proprio_key: The key for the proprioceptive data.
72 | features: A dictionary of dataset features, used to generate specs for
73 | validation.
74 | """
75 | self._task_id = task_id
76 | self._output_directory = output_directory
77 | self._camera_names = camera_names or []
78 | self._proprio_key = proprio_key
79 | self._timestep_spec = None
80 | self._action_spec = None
81 |
82 | if features:
83 | self._timestep_spec, self._action_spec = self._parse_features_to_specs(
84 | features
85 | )
86 |
87 | self._mcap_episodic_logger: (
88 | mcap_episodic_logger.McapEpisodicLogger | None
89 | ) = None
90 | self._current_episode_id: int = -1
91 | self._previous_action: np.ndarray | None = None
92 |
93 | def _parse_features_to_specs(
94 | self,
95 | features: dict,
96 | ) -> tuple[mcap_episodic_logger.TimeStepSpec, specs.BoundedArray]:
97 | """Converts dataset features to dm_env specs."""
98 | action_spec = None
99 | observation_spec = {}
100 |
101 | # Mapping from LeRobot dtype to numpy dtype.
102 | def _dtype_map(dtype: str) -> str:
103 | if dtype == _DTYPE_VIDEO:
104 | return "uint8"
105 | else:
106 | return dtype
107 |
108 | for key, feature_info in features.items():
109 | dtype = _dtype_map(feature_info[_DTYPE_KEY])
110 |
111 | if key == _ACTION_KEY:
112 | shape = tuple(feature_info[_SHAPE_KEY])
113 | action_spec = specs.BoundedArray(
114 | shape=shape,
115 | dtype=dtype,
116 | minimum=-2.0,
117 | maximum=3.0,
118 | name=key,
119 | )
120 | elif key.startswith(_OBSERVATION_KEY_PREFIX):
121 | obs_key = key.replace(_OBSERVATION_KEY_PREFIX, "", 1)
122 | observation_spec[obs_key] = specs.Array(
123 | shape=tuple(feature_info[_SHAPE_KEY]),
124 | dtype=dtype,
125 | name=obs_key,
126 | )
127 |
128 | if action_spec is None:
129 | raise ValueError("Action spec not found in features.")
130 | if not observation_spec:
131 | raise ValueError("Observation spec not found in features.")
132 | observation_spec[_INSTRUCTION_KEY] = specs.Array(
133 | shape=(), dtype=object, name=_INSTRUCTION_KEY
134 | )
135 | # Create timestep spec.
136 | timestep_spec = mcap_episodic_logger.TimeStepSpec(
137 | observation=observation_spec,
138 | reward=specs.Array(shape=(), dtype=np.float32, name=_REWARD_KEY),
139 | discount=specs.Array(shape=(), dtype=np.float32, name=_DISCOUNT_KEY),
140 | step_type=specs.BoundedArray(
141 | shape=(),
142 | dtype=int,
143 | minimum=min(dm_env.StepType),
144 | maximum=max(dm_env.StepType),
145 | name=_STEP_TYPE_KEY,
146 | ),
147 | )
148 |
149 | return timestep_spec, action_spec
150 |
151 | def start_episode(self, episode_id: int) -> None:
152 | """Starts a new episode session."""
153 | if self._mcap_episodic_logger is not None:
154 | raise ValueError(
155 | "Cannot start a new episode, the previous one has not been finished."
156 | )
157 |
158 | self._current_episode_id = episode_id
159 | agent_id = f"{_AGENT_ID_PREFIX}{episode_id}"
160 | logging.info("Starting episode %s with agent id %s", episode_id, agent_id)
161 | self._mcap_episodic_logger = mcap_episodic_logger.McapEpisodicLogger(
162 | agent_id=agent_id,
163 | task_id=self._task_id,
164 | output_directory=self._output_directory,
165 | proprio_key=self._proprio_key,
166 | camera_names=self._camera_names,
167 | timestep_spec=self._timestep_spec,
168 | action_spec=self._action_spec,
169 | policy_extra_spec={},
170 | validate_data_with_spec=True,
171 | )
172 | self._previous_action = None
173 |
174 | def finish_episode(self) -> None:
175 | """Finishes the current episode and writes the data to a file."""
176 | if self._mcap_episodic_logger is None:
177 | return
178 |
179 | self._mcap_episodic_logger.write()
180 | self._mcap_episodic_logger = None
181 |
182 | def record_step(self, step_data: dict[str, np.ndarray]) -> None:
183 | """Records a single step."""
184 | assert (
185 | self._mcap_episodic_logger is not None
186 | ), "Cannot record step, episode not started. Call start_episode() first."
187 |
188 | observation = {}
189 | for k, v in step_data.items():
190 | if k.startswith(_OBSERVATION_KEY_PREFIX):
191 | obs_key = k.replace(_OBSERVATION_KEY_PREFIX, "", 1)
192 | # Transpose image data from (C, H, W) (PyTorch) to (H, W, C).
193 | if obs_key in self._camera_names and v.ndim == 3:
194 | v = np.transpose(v, (1, 2, 0))
195 | v = (v * 255).astype(np.uint8)
196 | observation[obs_key] = v
197 | observation[_INSTRUCTION_KEY] = np.array(step_data[_TASK_KEY], dtype=object)
198 |
199 | action = step_data[_ACTION_KEY]
200 | frame_index = int(step_data[_FRAME_INDEX_KEY])
201 |
202 | if frame_index == 0:
203 | step_type = dm_env.StepType.FIRST
204 | elif bool(step_data[_NEXT_DONE_KEY]):
205 | step_type = dm_env.StepType.LAST
206 | else:
207 | step_type = dm_env.StepType.MID
208 |
209 | timestep = dm_env.TimeStep(
210 | step_type=step_type,
211 | reward=np.float32(0.0),
212 | discount=np.float32(1.0),
213 | observation=observation,
214 | )
215 | if step_type == dm_env.StepType.FIRST:
216 | self._mcap_episodic_logger.reset(
217 | timestep, episode_uuid=str(self._current_episode_id)
218 | )
219 | else:
220 | self._mcap_episodic_logger.record_action_and_next_timestep(
221 | action=self._previous_action,
222 | next_timestep=timestep,
223 | policy_extra={},
224 | )
225 | self._previous_action = action
226 |
227 |
228 | def convert_lerobot_data_to_mcap(
229 | *,
230 | dataset: LeRobotDataset,
231 | task_id: str,
232 | output_directory: str,
233 | proprio_key: str,
234 | episodes_limit: int,
235 | max_workers: int,
236 | ) -> None:
237 | """Converts LeRobot data to MCAP files, processing episodes in parallel."""
238 | episode_indices = dataset.episode_data_index
239 | num_episodes = len(episode_indices["from"])
240 | if episodes_limit <= 0:
241 | num_episodes_to_process = num_episodes
242 | else:
243 | num_episodes_to_process = min(episodes_limit, num_episodes)
244 |
245 | if max_workers <= 0:
246 | raise ValueError("max_workers must be greater than 0.")
247 |
248 | max_workers = min(max_workers, num_episodes_to_process)
249 | logging.info(
250 | "Will process the first %d episodes with %d workers.",
251 | num_episodes_to_process,
252 | max_workers,
253 | )
254 |
255 | camera_names = [
256 | key.replace(f"{_OBSERVATION_KEY_PREFIX}", "", 1)
257 | for key in dataset.meta.camera_keys
258 | ]
259 |
260 | def _process_episode(episode_id: int):
261 | """Processes a single episode."""
262 | thread_logger = LeRobotEpisodicLogger(
263 | task_id=task_id,
264 | output_directory=output_directory,
265 | camera_names=camera_names,
266 | proprio_key=proprio_key,
267 | features=dataset.features,
268 | )
269 | start_index = episode_indices["from"][episode_id]
270 | end_index = episode_indices["to"][episode_id]
271 | logging.info(
272 | "Processing episode %d from index %d to %d",
273 | episode_id,
274 | start_index,
275 | end_index,
276 | )
277 |
278 | thread_logger.start_episode(episode_id=episode_id)
279 |
280 | for step_index in range(start_index, end_index):
281 | step = dataset[step_index]
282 | step_np = {k: np.array(v) for k, v in step.items()}
283 | thread_logger.record_step(step_np)
284 |
285 | thread_logger.finish_episode()
286 |
287 | with concurrent.futures.ThreadPoolExecutor(
288 | max_workers=max_workers
289 | ) as executor:
290 | futures = [
291 | executor.submit(_process_episode, i)
292 | for i in range(num_episodes_to_process)
293 | ]
294 | for future in concurrent.futures.as_completed(futures):
295 | try:
296 | future.result()
297 | except Exception as e:
298 | logging.exception("Error processing episode: %s", e)
299 |
--------------------------------------------------------------------------------