├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── __init__.py ├── bin ├── run_collect_eval.py └── run_t2r_trainer.py ├── export_generators ├── __init__.py ├── abstract_export_generator.py ├── abstract_export_generator_test.py ├── default_export_generator.py └── default_export_generator_test.py ├── hooks ├── __init__.py ├── async_export_hook_builder.py ├── async_export_hook_builder_test.py ├── async_export_hook_builder_tpu_test.py ├── checkpoint_hooks.py ├── checkpoint_hooks_test.py ├── gin_config_hook_builder.py ├── golden_values_hook_builder.py ├── hook_builder.py ├── td3.py ├── td3_test.py └── variable_logger_hook.py ├── input_generators ├── __init__.py ├── abstract_input_generator.py ├── abstract_input_generator_test.py ├── default_input_generator.py └── default_input_generator_test.py ├── layers ├── __init__.py ├── bcz_networks.py ├── film_resnet_model.py ├── mdn.py ├── mdn_test.py ├── resnet.py ├── resnet_test.py ├── snail.py ├── snail_test.py ├── spatial_softmax.py ├── spatial_softmax_test.py ├── tec.py ├── tec_test.py └── vision_layers.py ├── meta_learning ├── __init__.py ├── maml_inner_loop.py ├── maml_inner_loop_test.py ├── maml_model.py ├── maml_model_test.py ├── meta_example.py ├── meta_policies.py ├── meta_tf_models.py ├── meta_tf_models_test.py ├── meta_tfdata.py ├── preprocessors.py ├── preprocessors_test.py └── run_meta_env.py ├── models ├── __init__.py ├── abstract_model.py ├── classification_model.py ├── critic_model.py ├── model_interface.py ├── optimizers.py ├── regression_model.py └── tpu_model_wrapper.py ├── policies ├── __init__.py └── policies.py ├── predictors ├── __init__.py ├── abstract_predictor.py ├── checkpoint_predictor.py ├── checkpoint_predictor_test.py ├── ensemble_exported_savedmodel_predictor.py ├── ensemble_exported_savedmodel_predictor_test.py ├── exported_savedmodel_predictor.py ├── exported_savedmodel_predictor_test.py ├── saved_model_v2_predictor.py └── saved_model_v2_predictor_test.py ├── preprocessors ├── __init__.py ├── abstract_preprocessor.py ├── abstract_preprocessor_test.py ├── distortion.py ├── image_transformations.py ├── image_transformations_test.py ├── noop_preprocessor.py ├── noop_preprocessor_test.py ├── spec_transformation_preprocessor.py ├── tpu_preprocessor_wrapper.py └── tpu_preprocessor_wrapper_test.py ├── proto └── t2r.proto ├── requirements.txt ├── research ├── bcz │ ├── README.md │ ├── configs │ │ ├── common_imagedistortions.gin │ │ ├── common_imports.gin │ │ ├── run_train_bc_gtcond_trajectory.gin │ │ └── run_train_bc_langcond_trajectory.gin │ ├── model.py │ ├── model_test.py │ └── pose_components_lib.py ├── dql_grasping_lib │ ├── __init__.py │ ├── run_env.py │ └── tf_modules.py ├── grasp2vec │ ├── README.md │ ├── __init__.py │ ├── configs │ │ ├── common_imports.gin │ │ └── train_grasp2vec.gin │ ├── grasp2vec_model.py │ ├── losses.py │ ├── losses_test.py │ ├── networks.py │ ├── resnet.py │ └── visualization.py ├── pose_env │ ├── __init__.py │ ├── configs │ │ ├── common_imports.gin │ │ ├── run_random_collect.gin │ │ ├── run_train_reg.gin │ │ └── run_train_reg_maml.gin │ ├── episode_to_transitions.py │ ├── pose_env.py │ ├── pose_env_maml_models.py │ ├── pose_env_models.py │ ├── pose_env_models_test.py │ └── pose_env_test.py ├── qtopt │ ├── README.md │ ├── __init__.py │ ├── networks.py │ ├── optimizer_builder.py │ ├── pcgrad.py │ ├── pcgrad_test.py │ ├── pcgrad_tpu_test.py │ ├── t2r_models.py │ └── t2r_models_test.py └── vrgripper │ ├── README.md │ ├── __init__.py │ ├── configs │ ├── common_imports.gin │ ├── run_train_wtl_statespace_retrial.gin │ ├── run_train_wtl_statespace_trial.gin │ ├── run_train_wtl_vision_retrial.gin │ └── run_train_wtl_vision_trial.gin │ ├── discrete.py │ ├── episode_to_transitions.py │ ├── episode_to_transitions_test.py │ ├── maf.py │ ├── mse_decoder.py │ ├── vrgripper_env_meta_models.py │ ├── vrgripper_env_models.py │ └── vrgripper_env_wtl_models.py ├── test_data ├── mock_exported_savedmodel │ ├── assets.extra │ │ └── t2r_assets.pbtxt │ ├── saved_model.pb │ └── variables │ │ ├── variables.data-00000-of-00001 │ │ └── variables.index └── pose_env_test_data.tfrecord └── utils ├── __init__.py ├── continuous_collect_eval.py ├── continuous_collect_eval_test.py ├── convert_pkl_assets_to_proto_assets.py ├── cross_entropy.py ├── global_step_functions.py ├── global_step_functions_test.py ├── image.py ├── mocks.py ├── subsample.py ├── subsample_test.py ├── t2r_test_fixture.py ├── tensorspec_utils.py ├── tensorspec_utils_test.py ├── tfdata.py ├── tfdata_test.py ├── train_eval.py ├── train_eval_test.py ├── train_eval_test_utils.py └── writer.py /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | Disclaimer: T2R is not an official Google product. External support not 4 | guaranteed, and Robotics @ Google infra needs are prioritized over those of 5 | third-party. The purpose of developing this codebase in the open is to provide 6 | an avenue for open-sourcing our team's code. If you have questions about the 7 | codebase, or would like to submit a pull request, file a GitHub issue first. 8 | 9 | # Issues 10 | 11 | * Please tag your issue with `bug`, `feature request`, or `question` to help us 12 | effectively respond. 13 | * Please include the versions of TensorFlow and Tensor2Robot you are running. 14 | * Please provide the command line you ran as well as the log output. 15 | 16 | # Pull Requests 17 | 18 | We'd love to accept your patches and contributions to this project, provided 19 | that they are aligned with our internal uses of this codebase. There are 20 | just a few guidelines you need to follow. 21 | 22 | ## Contributor License Agreement 23 | 24 | Contributions to this project must be accompanied by a Contributor License 25 | Agreement. You (or your employer) retain the copyright to your contribution, 26 | this simply gives us permission to use and redistribute your contributions as 27 | part of the project. Head over to to see 28 | your current agreements on file or to sign a new one. 29 | 30 | You generally only need to submit a CLA once, so if you've already submitted one 31 | (even if it was for a different project), you probably don't need to do it 32 | again. 33 | 34 | ## Code reviews 35 | 36 | All submissions, including submissions by project members, require review. We 37 | use GitHub pull requests for this purpose. Consult 38 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 39 | information on using pull requests. 40 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Tensor2Robot Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /bin/run_collect_eval.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Tensor2Robot Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Runs data collection and policy evaluation for RL experiments. 17 | """ 18 | 19 | from absl import app 20 | from absl import flags 21 | import gin 22 | from tensor2robot.utils import continuous_collect_eval 23 | # Contains the gin_configs and gin_bindings flag definitions. 24 | FLAGS = flags.FLAGS 25 | 26 | try: 27 | flags.DEFINE_list( 28 | 'gin_configs', None, 29 | 'A comma-separated list of paths to Gin configuration files.') 30 | flags.DEFINE_multi_string( 31 | 'gin_bindings', [], 'A newline separated list of Gin parameter bindings.') 32 | except flags.DuplicateFlagError: 33 | pass 34 | 35 | flags.DEFINE_string('root_dir', '', 36 | 'Root directory of experiment.') 37 | flags.mark_flag_as_required('gin_configs') 38 | 39 | 40 | def main(unused_argv): 41 | del unused_argv 42 | gin.parse_config_files_and_bindings(FLAGS.gin_configs, FLAGS.gin_bindings) 43 | continuous_collect_eval.collect_eval_loop(root_dir=FLAGS.root_dir) 44 | 45 | 46 | if __name__ == '__main__': 47 | app.run(main) 48 | -------------------------------------------------------------------------------- /bin/run_t2r_trainer.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Tensor2Robot Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Binary for training TFModels with Estimator API.""" 17 | 18 | from absl import app 19 | from absl import flags 20 | import gin 21 | from tensor2robot.utils import train_eval 22 | import tensorflow.compat.v1 as tf 23 | 24 | 25 | FLAGS = flags.FLAGS 26 | 27 | 28 | def main(unused_argv): 29 | gin.parse_config_files_and_bindings( 30 | FLAGS.gin_configs, FLAGS.gin_bindings, print_includes_and_imports=True) 31 | train_eval.train_eval_model() 32 | 33 | 34 | if __name__ == '__main__': 35 | tf.logging.set_verbosity(tf.logging.INFO) 36 | app.run(main) 37 | -------------------------------------------------------------------------------- /export_generators/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Tensor2Robot Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /export_generators/abstract_export_generator.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Tensor2Robot Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Utilities for exporting savedmodels.""" 17 | 18 | import abc 19 | import functools 20 | import os 21 | from typing import Any, Dict, List, Optional, Text 22 | 23 | import gin 24 | import six 25 | from tensor2robot.models import abstract_model 26 | from tensor2robot.utils import tensorspec_utils 27 | import tensorflow.compat.v1 as tf 28 | from tensorflow.compat.v1 import estimator as tf_estimator 29 | from tensorflow.contrib import util as contrib_util 30 | 31 | from tensorflow_serving.apis import predict_pb2 32 | from tensorflow_serving.apis import prediction_log_pb2 33 | 34 | MODE = tf_estimator.ModeKeys.PREDICT 35 | 36 | 37 | @gin.configurable 38 | class AbstractExportGenerator(six.with_metaclass(abc.ABCMeta, object)): 39 | """Class to manage assets related to exporting a model. 40 | 41 | Attributes: 42 | export_raw_receivers: Whether to export receiver_fns which do not have 43 | preprocessing enabled. This is useful for serving using Servo, in 44 | conjunction with client-preprocessing. 45 | """ 46 | 47 | def __init__(self, export_raw_receivers = False): 48 | self._export_raw_receivers = export_raw_receivers 49 | self._feature_spec = None 50 | self._out_feature_spec = None 51 | self._preprocess_fn = None 52 | self._model_name = None 53 | 54 | def set_specification_from_model(self, 55 | t2r_model): 56 | """Set the feature specifications and preprocess function from the model. 57 | 58 | Args: 59 | t2r_model: A T2R model instance. 60 | """ 61 | preprocessor = t2r_model.preprocessor 62 | self._feature_spec = preprocessor.get_in_feature_specification(MODE) 63 | tensorspec_utils.assert_valid_spec_structure(self._feature_spec) 64 | self._out_feature_spec = (preprocessor.get_out_feature_specification(MODE)) 65 | tensorspec_utils.assert_valid_spec_structure(self._out_feature_spec) 66 | self._preprocess_fn = functools.partial(preprocessor.preprocess, mode=MODE) 67 | self._model_name = type(t2r_model).__name__ 68 | 69 | def _get_input_features_for_receiver_fn(self): 70 | """Helper function to return a input featurespec for reciver fns. 71 | 72 | Returns: 73 | The appropriate feature specification to use 74 | """ 75 | if self._export_raw_receivers: 76 | return self._out_feature_spec 77 | else: 78 | return self._feature_spec 79 | 80 | @abc.abstractmethod 81 | def create_serving_input_receiver_numpy_fn(self, params=None): 82 | """Create a serving input receiver for numpy. 83 | 84 | Args: 85 | params: An optional dict of hyper parameters that will be passed into 86 | input_fn and model_fn. Keys are names of parameters, values are basic 87 | python types. There are reserved keys for TPUEstimator, including 88 | 'batch_size'. 89 | 90 | Returns: 91 | serving_input_receiver_fn: A callable which creates the serving inputs. 92 | """ 93 | 94 | @abc.abstractmethod 95 | def create_serving_input_receiver_tf_example_fn( 96 | self, params = None): 97 | """Create a serving input receiver for tf_examples. 98 | 99 | Args: 100 | params: An optional dict of hyper parameters that will be passed into 101 | input_fn and model_fn. Keys are names of parameters, values are basic 102 | python types. There are reserved keys for TPUEstimator, including 103 | 'batch_size'. 104 | 105 | Returns: 106 | serving_input_receiver_fn: A callable which creates the serving inputs. 107 | """ 108 | 109 | def create_warmup_requests_numpy(self, batch_sizes, 110 | export_dir): 111 | """Creates warm-up requests for a given feature specification. 112 | 113 | This writes an output file in 114 | `export_dir/assets.extra/tf_serving_warmup_requests` for use with Servo. 115 | 116 | Args: 117 | batch_sizes: Batch sizes of warm-up requests to write. 118 | export_dir: Base directory for the export. 119 | 120 | Returns: 121 | The filename written. 122 | """ 123 | feature_spec = self._get_input_features_for_receiver_fn() 124 | 125 | flat_feature_spec = tensorspec_utils.flatten_spec_structure(feature_spec) 126 | tf.io.gfile.makedirs(export_dir) 127 | request_filename = os.path.join(export_dir, 'tf_serving_warmup_requests') 128 | with tf.python_io.TFRecordWriter(request_filename) as writer: 129 | for batch_size in batch_sizes: 130 | request = predict_pb2.PredictRequest() 131 | request.model_spec.name = self._model_name 132 | numpy_feature_specs = tensorspec_utils.make_constant_numpy( 133 | flat_feature_spec, constant_value=0, batch_size=batch_size) 134 | 135 | for key, numpy_spec in numpy_feature_specs.items(): 136 | request.inputs[key].CopyFrom( 137 | contrib_util.make_tensor_proto(numpy_spec)) 138 | 139 | log = prediction_log_pb2.PredictionLog( 140 | predict_log=prediction_log_pb2.PredictLog(request=request)) 141 | writer.write(log.SerializeToString()) 142 | return request_filename 143 | -------------------------------------------------------------------------------- /export_generators/abstract_export_generator_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Tensor2Robot Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for tensor2robot.export_generator.sabstract_export_generator.""" 17 | 18 | from six.moves import zip 19 | from tensor2robot.export_generators import abstract_export_generator 20 | from tensor2robot.preprocessors import noop_preprocessor 21 | from tensor2robot.utils import mocks 22 | import tensorflow.compat.v1 as tf 23 | from tensorflow_serving.apis import prediction_log_pb2 24 | 25 | 26 | class AbstractExportGeneratorTest(tf.test.TestCase): 27 | 28 | def test_init_abstract(self): 29 | with self.assertRaises(TypeError): 30 | abstract_export_generator.AbstractExportGenerator() 31 | 32 | def test_create_warmup_requests_numpy(self): 33 | mock_t2r_model = mocks.MockT2RModel( 34 | preprocessor_cls=noop_preprocessor.NoOpPreprocessor) 35 | exporter = mocks.MockExportGenerator() 36 | exporter.set_specification_from_model(mock_t2r_model) 37 | 38 | export_dir = self.create_tempdir() 39 | batch_sizes = [2, 4] 40 | request_filename = exporter.create_warmup_requests_numpy( 41 | batch_sizes=batch_sizes, export_dir=export_dir.full_path) 42 | 43 | for expected_batch_size, record in zip( 44 | batch_sizes, tf.compat.v1.io.tf_record_iterator(request_filename)): 45 | record_proto = prediction_log_pb2.PredictionLog() 46 | record_proto.ParseFromString(record) 47 | request = record_proto.predict_log.request 48 | self.assertEqual(request.model_spec.name, 'MockT2RModel') 49 | for _, in_tensor in request.inputs.items(): 50 | self.assertEqual(in_tensor.tensor_shape.dim[0].size, 51 | expected_batch_size) 52 | 53 | 54 | if __name__ == '__main__': 55 | tf.test.main() 56 | -------------------------------------------------------------------------------- /export_generators/default_export_generator.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Tensor2Robot Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Utilities for generating SavedModel exports based on AbstractT2RModels.""" 17 | 18 | import copy 19 | from typing import Any, Dict, Optional, Text 20 | 21 | import gin 22 | import six 23 | from tensor2robot.export_generators import abstract_export_generator 24 | from tensor2robot.utils import tensorspec_utils 25 | from tensor2robot.utils import tfdata 26 | import tensorflow.compat.v1 as tf 27 | from tensorflow.compat.v1 import estimator as tf_estimator 28 | 29 | MODE = tf_estimator.ModeKeys.PREDICT 30 | 31 | 32 | @gin.configurable 33 | class DefaultExportGenerator(abstract_export_generator.AbstractExportGenerator): 34 | """Class to manage assets related to exporting a model. 35 | 36 | Attributes: 37 | export_raw_receivers: Whether to export receiver_fns which do not have 38 | preprocessing enabled. This is useful for serving using Servo, in 39 | conjunction with client-preprocessing. 40 | """ 41 | 42 | def create_serving_input_receiver_numpy_fn(self, params=None): 43 | """Create a serving input receiver for numpy. 44 | 45 | Args: 46 | params: An optional dict of hyper parameters that will be passed into 47 | input_fn and model_fn. Keys are names of parameters, values are basic 48 | python types. There are reserved keys for TPUEstimator, including 49 | 'batch_size'. 50 | 51 | Returns: 52 | serving_input_receiver_fn: A callable which creates the serving inputs. 53 | """ 54 | del params 55 | 56 | def serving_input_receiver_fn(): 57 | """Create the ServingInputReceiver to export a saved model. 58 | 59 | Returns: 60 | An instance of ServingInputReceiver. 61 | """ 62 | # We have to filter our specs since only required tensors are 63 | # used for inference time. 64 | flat_feature_spec = tensorspec_utils.flatten_spec_structure( 65 | self._get_input_features_for_receiver_fn()) 66 | required_feature_spec = ( 67 | tensorspec_utils.filter_required_flat_tensor_spec(flat_feature_spec)) 68 | receiver_tensors = tensorspec_utils.make_placeholders( 69 | required_feature_spec) 70 | 71 | # We want to ensure that our feature processing pipeline operates on a 72 | # copy of the features and does not alter the receiver_tensors. 73 | features = tensorspec_utils.flatten_spec_structure( 74 | copy.copy(receiver_tensors)) 75 | 76 | if (not self._export_raw_receivers and self._preprocess_fn is not None): 77 | features, _ = self._preprocess_fn(features=features, labels=None) 78 | 79 | return tf_estimator.export.ServingInputReceiver(features, 80 | receiver_tensors) 81 | 82 | return serving_input_receiver_fn 83 | 84 | def create_serving_input_receiver_tf_example_fn( 85 | self, params = None): 86 | """Create a serving input receiver for tf_examples. 87 | 88 | Args: 89 | params: An optional dict of hyper parameters that will be passed into 90 | input_fn and model_fn. Keys are names of parameters, values are basic 91 | python types. There are reserved keys for TPUEstimator, including 92 | 'batch_size'. 93 | 94 | Returns: 95 | serving_input_receiver_fn: A callable which creates the serving inputs. 96 | """ 97 | del params 98 | 99 | def serving_input_receiver_fn(): 100 | """Create the ServingInputReceiver to export a saved model. 101 | 102 | Returns: 103 | An instance of ServingInputReceiver. 104 | """ 105 | # We assume one input (a string which containes the serialized proto) per 106 | # dataset_key. 107 | feature_spec = self._get_input_features_for_receiver_fn() 108 | # We have to filter our specs since only required tensors are 109 | # used for inference time. 110 | flat_feature_spec = tensorspec_utils.flatten_spec_structure(feature_spec) 111 | required_feature_spec = ( 112 | tensorspec_utils.filter_required_flat_tensor_spec(flat_feature_spec)) 113 | dataset_keys = set( 114 | [t.dataset_key for t in required_feature_spec.values()]) 115 | receiver_tensors = {} 116 | parse_tensors = {} 117 | for dataset_key in dataset_keys: 118 | receiver_name = 'input_example_' + six.ensure_str( 119 | (dataset_key or 'tensor')) 120 | parse_tensors[dataset_key] = tf.placeholder( 121 | dtype=tf.string, shape=[None], name=receiver_name) 122 | receiver_tensors[receiver_name] = parse_tensors[dataset_key] 123 | parse_tf_example_fn = tfdata.create_parse_tf_example_fn( 124 | feature_tspec=required_feature_spec) 125 | features = parse_tf_example_fn(parse_tensors) 126 | 127 | if (not self._export_raw_receivers and self._preprocess_fn is not None): 128 | features, _ = self._preprocess_fn(features=features, labels=None) 129 | 130 | return tf_estimator.export.ServingInputReceiver(features, 131 | receiver_tensors) 132 | 133 | return serving_input_receiver_fn 134 | -------------------------------------------------------------------------------- /hooks/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Tensor2Robot Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /hooks/async_export_hook_builder.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Tensor2Robot Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Hook builders for TD3 distributed training with SavedModels.""" 17 | 18 | import os 19 | import tempfile 20 | from typing import Text, List, Callable, Optional 21 | 22 | import gin 23 | from tensor2robot.export_generators import abstract_export_generator 24 | from tensor2robot.export_generators import default_export_generator 25 | from tensor2robot.hooks import checkpoint_hooks 26 | from tensor2robot.hooks import hook_builder 27 | from tensor2robot.models import model_interface 28 | from tensor2robot.proto import t2r_pb2 29 | from tensor2robot.utils import tensorspec_utils 30 | from tensorflow.compat.v1 import estimator as tf_estimator 31 | import tensorflow.compat.v1 as tf # tf 32 | 33 | from tensorflow.contrib import tpu as contrib_tpu 34 | 35 | CreateExportFnType = Callable[[ 36 | model_interface.ModelInterface, 37 | tf_estimator.Estimator, 38 | abstract_export_generator.AbstractExportGenerator, 39 | ], Callable[[Text, int], Text]] 40 | 41 | 42 | def default_create_export_fn( 43 | t2r_model, 44 | estimator, 45 | export_generator 46 | ): 47 | """Create an export function for a device type. 48 | 49 | Args: 50 | t2r_model: A T2RModel instance. 51 | estimator: The estimator used for training. 52 | export_generator: An export generator. 53 | 54 | Returns: 55 | A callable function which exports a saved model and returns the path. 56 | """ 57 | 58 | in_feature_spec = t2r_model.get_feature_specification_for_packing( 59 | mode=tf_estimator.ModeKeys.PREDICT) 60 | in_label_spec = t2r_model.get_label_specification_for_packing( 61 | mode=tf_estimator.ModeKeys.PREDICT) 62 | t2r_assets = t2r_pb2.T2RAssets() 63 | t2r_assets.feature_spec.CopyFrom(in_feature_spec.to_proto()) 64 | t2r_assets.label_spec.CopyFrom(in_label_spec.to_proto()) 65 | 66 | def _export_fn(export_dir, global_step): 67 | """The actual closure function creating the exported model and assets.""" 68 | # Create additional assets for the exported models 69 | t2r_assets.global_step = global_step 70 | tmpdir = tempfile.mkdtemp() 71 | t2r_assets_filename = os.path.join(tmpdir, 72 | tensorspec_utils.T2R_ASSETS_FILENAME) 73 | tensorspec_utils.write_t2r_assets_to_file(t2r_assets, t2r_assets_filename) 74 | assets = { 75 | tensorspec_utils.T2R_ASSETS_FILENAME: t2r_assets_filename, 76 | } 77 | return estimator.export_saved_model( 78 | export_dir_base=export_dir, 79 | serving_input_receiver_fn=export_generator 80 | .create_serving_input_receiver_numpy_fn(), 81 | assets_extra=assets) 82 | 83 | return _export_fn 84 | 85 | 86 | @gin.configurable 87 | class AsyncExportHookBuilder(hook_builder.HookBuilder): 88 | """Creates hooks for exporting for cpu and tpu for serving. 89 | 90 | Attributes: 91 | export_dir: Directory to output the latest models. 92 | save_secs: Interval to save models, and copy the latest model from 93 | `export_dir` to `lagged_export_dir`. 94 | num_versions: Number of model versions to save in each directory 95 | export_generator: The export generator used to generate the 96 | serving_input_receiver_fn. 97 | """ 98 | 99 | def __init__( 100 | self, 101 | export_dir, 102 | save_secs = 90, 103 | num_versions = 3, 104 | create_export_fn = default_create_export_fn, 105 | export_generator = None, 106 | ): 107 | super(AsyncExportHookBuilder, self).__init__() 108 | self._save_secs = save_secs 109 | self._num_versions = num_versions 110 | self._export_dir = export_dir 111 | self._create_export_fn = create_export_fn 112 | if export_generator is None: 113 | self._export_generator = default_export_generator.DefaultExportGenerator() 114 | else: 115 | self._export_generator = export_generator 116 | 117 | def create_hooks( 118 | self, 119 | t2r_model, 120 | estimator, 121 | ): 122 | self._export_generator.set_specification_from_model(t2r_model) 123 | return [ 124 | contrib_tpu.AsyncCheckpointSaverHook( 125 | save_secs=self._save_secs, 126 | checkpoint_dir=estimator.model_dir, 127 | listeners=[ 128 | checkpoint_hooks.CheckpointExportListener( 129 | export_fn=self._create_export_fn(t2r_model, estimator, 130 | self._export_generator), 131 | num_versions=self._num_versions, 132 | export_dir=self._export_dir) 133 | ]) 134 | ] 135 | -------------------------------------------------------------------------------- /hooks/async_export_hook_builder_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Tensor2Robot Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for TD3 Hooks.""" 17 | 18 | import os 19 | from tensor2robot.hooks import async_export_hook_builder 20 | from tensor2robot.predictors import exported_savedmodel_predictor 21 | from tensor2robot.preprocessors import noop_preprocessor 22 | from tensor2robot.utils import mocks 23 | from tensor2robot.utils import train_eval 24 | import tensorflow.compat.v1 as tf # tf 25 | 26 | _EXPORT_DIR = 'export_dir' 27 | _MAX_STEPS = 4 28 | _BATCH_SIZE = 4 29 | 30 | 31 | class AsyncExportHookBuilderTest(tf.test.TestCase): 32 | 33 | def test_with_mock_training(self): 34 | model_dir = self.create_tempdir().full_path 35 | mock_t2r_model = mocks.MockT2RModel( 36 | preprocessor_cls=noop_preprocessor.NoOpPreprocessor, device_type='cpu') 37 | 38 | mock_input_generator = mocks.MockInputGenerator(batch_size=_BATCH_SIZE) 39 | export_dir = os.path.join(model_dir, _EXPORT_DIR) 40 | 41 | hook_builder = async_export_hook_builder.AsyncExportHookBuilder( 42 | export_dir=export_dir, 43 | create_export_fn=async_export_hook_builder.default_create_export_fn) 44 | 45 | # We optimize our network. 46 | train_eval.train_eval_model( 47 | t2r_model=mock_t2r_model, 48 | input_generator_train=mock_input_generator, 49 | train_hook_builders=[hook_builder], 50 | model_dir=model_dir, 51 | max_train_steps=_MAX_STEPS) 52 | self.assertNotEmpty(tf.io.gfile.listdir(model_dir)) 53 | self.assertNotEmpty(tf.io.gfile.listdir(export_dir)) 54 | for exported_model_dir in tf.io.gfile.listdir(export_dir): 55 | self.assertNotEmpty( 56 | tf.io.gfile.listdir(os.path.join(export_dir, exported_model_dir))) 57 | predictor = exported_savedmodel_predictor.ExportedSavedModelPredictor( 58 | export_dir=export_dir) 59 | self.assertTrue(predictor.restore()) 60 | 61 | 62 | if __name__ == '__main__': 63 | tf.test.main() 64 | -------------------------------------------------------------------------------- /hooks/async_export_hook_builder_tpu_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Tensor2Robot Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for TD3 Hooks.""" 17 | 18 | import os 19 | import gin 20 | from tensor2robot.hooks import async_export_hook_builder 21 | from tensor2robot.predictors import exported_savedmodel_predictor 22 | from tensor2robot.preprocessors import noop_preprocessor 23 | from tensor2robot.utils import mocks 24 | from tensor2robot.utils import train_eval 25 | import tensorflow.compat.v1 as tf # tf 26 | 27 | _EXPORT_DIR = 'export_dir' 28 | _BATCH_SIZES_FOR_EXPORT = [128] 29 | _MAX_STEPS = 4 30 | _BATCH_SIZE = 4 31 | 32 | 33 | class AsyncExportHookBuilderTest(tf.test.TestCase): 34 | 35 | def test_with_mock_training(self): 36 | model_dir = self.create_tempdir().full_path 37 | mock_t2r_model = mocks.MockT2RModel( 38 | preprocessor_cls=noop_preprocessor.NoOpPreprocessor, 39 | device_type='tpu', 40 | use_avg_model_params=True) 41 | 42 | mock_input_generator = mocks.MockInputGenerator(batch_size=_BATCH_SIZE) 43 | export_dir = os.path.join(model_dir, _EXPORT_DIR) 44 | hook_builder = async_export_hook_builder.AsyncExportHookBuilder( 45 | export_dir=export_dir, 46 | create_export_fn=async_export_hook_builder.default_create_export_fn) 47 | 48 | gin.parse_config('tf.contrib.tpu.TPUConfig.iterations_per_loop=1') 49 | gin.parse_config('tf.estimator.RunConfig.save_checkpoints_steps=1') 50 | 51 | # We optimize our network. 52 | train_eval.train_eval_model( 53 | t2r_model=mock_t2r_model, 54 | input_generator_train=mock_input_generator, 55 | train_hook_builders=[hook_builder], 56 | model_dir=model_dir, 57 | max_train_steps=_MAX_STEPS) 58 | self.assertNotEmpty(tf.io.gfile.listdir(model_dir)) 59 | self.assertNotEmpty(tf.io.gfile.listdir(export_dir)) 60 | for exported_model_dir in tf.io.gfile.listdir(export_dir): 61 | self.assertNotEmpty( 62 | tf.io.gfile.listdir(os.path.join(export_dir, exported_model_dir))) 63 | predictor = exported_savedmodel_predictor.ExportedSavedModelPredictor( 64 | export_dir=export_dir) 65 | self.assertTrue(predictor.restore()) 66 | 67 | 68 | if __name__ == '__main__': 69 | tf.test.main() 70 | -------------------------------------------------------------------------------- /hooks/gin_config_hook_builder.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Tensor2Robot Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Builds hooks that write out the operative gin configuration. 17 | """ 18 | 19 | from typing import List 20 | 21 | from absl import logging 22 | import gin 23 | from tensor2robot.hooks import hook_builder 24 | from tensor2robot.models import model_interface 25 | from tensorflow import estimator as tf_estimator 26 | 27 | 28 | @gin.configurable 29 | class GinConfigLoggerHook(tf_estimator.SessionRunHook): 30 | """A SessionRunHook that logs the operative config to stdout.""" 31 | 32 | def __init__(self, only_once=True): 33 | self._only_once = only_once 34 | self._written_at_least_once = False 35 | 36 | def after_create_session(self, session=None, coord=None): 37 | """Logs Gin's operative config.""" 38 | if self._only_once and self._written_at_least_once: 39 | return 40 | 41 | logging.info('Gin operative configuration:') 42 | for gin_config_line in gin.operative_config_str().splitlines(): 43 | logging.info(gin_config_line) 44 | self._written_at_least_once = True 45 | 46 | 47 | @gin.configurable 48 | class OperativeGinConfigLoggerHookBuilder(hook_builder.HookBuilder): 49 | 50 | def create_hooks( 51 | self, 52 | t2r_model, 53 | estimator, 54 | ): 55 | return [GinConfigLoggerHook()] 56 | -------------------------------------------------------------------------------- /hooks/golden_values_hook_builder.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Tensor2Robot Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Hook that logs golden values to be used in unit tests. 17 | 18 | In the Data -> Checkpoint -> Inference -> Eval flow, this verifies no regression 19 | occurred in Data -> Checkpoint. 20 | """ 21 | 22 | import os 23 | from typing import List 24 | from absl import logging 25 | import gin 26 | import numpy as np 27 | from tensor2robot.hooks import hook_builder 28 | from tensor2robot.models import model_interface 29 | import tensorflow.compat.v1 as tf 30 | from tensorflow.compat.v1 import estimator as tf_estimator 31 | 32 | ModeKeys = tf_estimator.ModeKeys 33 | COLLECTION = 'golden' 34 | PREFIX = 'golden_' 35 | 36 | 37 | def add_golden_tensor(tensor, name): 38 | """Adds tensor to be tracked.""" 39 | tf.add_to_collection(COLLECTION, tf.identity(tensor, name=PREFIX + name)) 40 | 41 | 42 | class GoldenValuesHook(tf.train.SessionRunHook): 43 | """SessionRunHook that saves loss metrics to file.""" 44 | 45 | def __init__(self, 46 | log_directory): 47 | self._log_directory = log_directory 48 | 49 | def begin(self): 50 | self._measurements = [] 51 | 52 | def end(self, session): 53 | # Record measurements. 54 | del session 55 | np.save(os.path.join(self._log_directory, 'golden_values.npy'), 56 | self._measurements) 57 | 58 | def before_run(self, run_context): 59 | return tf.train.SessionRunArgs( 60 | fetches=tf.get_collection_ref(COLLECTION)) 61 | 62 | def after_run(self, run_context, run_values): 63 | # Strip the 'golden_' prefix before saving the data. 64 | golden_values = {t.name.split(PREFIX)[1]: v for t, v in 65 | zip(tf.get_collection_ref(COLLECTION), run_values.results)} 66 | logging.info('Recorded golden values for %s', golden_values.keys()) 67 | self._measurements.append(golden_values) 68 | 69 | 70 | @gin.configurable 71 | class GoldenValuesHookBuilder(hook_builder.HookBuilder): 72 | """Hook builder for generating golden values.""" 73 | 74 | def create_hooks( 75 | self, 76 | t2r_model, 77 | estimator, 78 | ): 79 | return [GoldenValuesHook(estimator.model_dir)] 80 | -------------------------------------------------------------------------------- /hooks/hook_builder.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Tensor2Robot Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Interface to manage building hooks.""" 17 | 18 | import abc 19 | from typing import List 20 | 21 | import six 22 | from tensor2robot.models import model_interface 23 | from tensorflow.compat.v1 import estimator as tf_estimator 24 | import tensorflow.compat.v1 as tf # tf 25 | 26 | 27 | class HookBuilder(six.with_metaclass(abc.ABCMeta, object)): 28 | 29 | @abc.abstractmethod 30 | def create_hooks( 31 | self, t2r_model, 32 | estimator, 33 | ): 34 | """Create hooks for the trainer. 35 | 36 | Subclasses can add arguments here. 37 | 38 | Arguments: 39 | t2r_model: Provided model 40 | estimator: Provided estimator instance 41 | Returns: 42 | A list of tf.train.SessionRunHooks to add to the trainer. 43 | """ 44 | -------------------------------------------------------------------------------- /hooks/td3.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Tensor2Robot Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Hook builders for TD3 distributed training with SavedModels.""" 17 | 18 | import os 19 | import tempfile 20 | from typing import Text, List, Optional 21 | 22 | import gin 23 | from tensor2robot.export_generators import abstract_export_generator 24 | from tensor2robot.export_generators import default_export_generator 25 | from tensor2robot.hooks import checkpoint_hooks 26 | from tensor2robot.hooks import hook_builder 27 | from tensor2robot.models import model_interface 28 | from tensor2robot.proto import t2r_pb2 29 | from tensor2robot.utils import tensorspec_utils 30 | from tensorflow.compat.v1 import estimator as tf_estimator 31 | import tensorflow.compat.v1 as tf # tf 32 | 33 | from tensorflow.contrib import tpu as contrib_tpu 34 | 35 | 36 | @gin.configurable 37 | class TD3Hooks(hook_builder.HookBuilder): 38 | """Creates hooks for exporting models for serving in TD3 distributed training. 39 | 40 | See: 41 | "Addressing Function Approximation Error in Actor-Critic Methods" 42 | by Fujimoto et al. 43 | 44 | https://arxiv.org/abs/1802.09477 45 | 46 | These hooks manage exporting of SavedModels to two different directories: 47 | `export_dir` contains the latest version of the model, `lagged_export_dir` 48 | contains a lagged version, delayed by one interval of `save_secs`. 49 | 50 | Attributes: 51 | export_dir: Directory to output the latest models. 52 | lagged_export_dir: Directory containing a lagged version of SavedModels 53 | save_secs: Interval to save models, and copy the latest model from 54 | `export_dir` to `lagged_export_dir`. 55 | num_versions: Number of model versions to save in each directory 56 | use_preprocessed_features: Whether to export SavedModels which do *not* 57 | incldue preprocessing. This is useful for offloading the preprocessing 58 | graph to the client. 59 | export_generator: The export generator used to generate the 60 | serving_input_receiver_fn. 61 | """ 62 | 63 | def __init__( 64 | self, 65 | export_dir, 66 | lagged_export_dir, 67 | batch_sizes_for_export, 68 | save_secs = 90, 69 | num_versions = 3, 70 | use_preprocessed_features=False, 71 | export_generator = None, 72 | ): 73 | super(TD3Hooks, self).__init__() 74 | self._save_secs = save_secs 75 | self._num_versions = num_versions 76 | self._export_dir = export_dir 77 | self._lagged_export_dir = lagged_export_dir 78 | self._batch_sizes_for_export = batch_sizes_for_export 79 | if export_generator is None: 80 | self._export_generator = default_export_generator.DefaultExportGenerator() 81 | else: 82 | self._export_generator = export_generator 83 | 84 | def create_hooks( 85 | self, 86 | t2r_model, 87 | estimator, 88 | ): 89 | if not self._export_dir and not self._lagged_export_dir: 90 | return [] 91 | self._export_generator.set_specification_from_model(t2r_model) 92 | warmup_requests_file = self._export_generator.create_warmup_requests_numpy( 93 | batch_sizes=self._batch_sizes_for_export, 94 | export_dir=estimator.model_dir) 95 | 96 | in_feature_spec = t2r_model.get_feature_specification_for_packing( 97 | mode=tf_estimator.ModeKeys.PREDICT) 98 | in_label_spec = t2r_model.get_label_specification_for_packing( 99 | mode=tf_estimator.ModeKeys.PREDICT) 100 | t2r_assets = t2r_pb2.T2RAssets() 101 | t2r_assets.feature_spec.CopyFrom(in_feature_spec.to_proto()) 102 | t2r_assets.label_spec.CopyFrom(in_label_spec.to_proto()) 103 | 104 | def _export_fn(export_dir, global_step): 105 | """The actual closure function creating the exported model and assets.""" 106 | t2r_assets.global_step = global_step 107 | tmpdir = tempfile.mkdtemp() 108 | t2r_assets_filename = os.path.join(tmpdir, 109 | tensorspec_utils.T2R_ASSETS_FILENAME) 110 | tensorspec_utils.write_t2r_assets_to_file(t2r_assets, t2r_assets_filename) 111 | res = estimator.export_saved_model( 112 | export_dir_base=export_dir, 113 | serving_input_receiver_fn=self._export_generator 114 | .create_serving_input_receiver_numpy_fn(), 115 | assets_extra={ 116 | 'tf_serving_warmup_requests': warmup_requests_file, 117 | tensorspec_utils.T2R_ASSETS_FILENAME: t2r_assets_filename 118 | }) 119 | return res 120 | 121 | return [ 122 | contrib_tpu.AsyncCheckpointSaverHook( 123 | save_secs=self._save_secs, 124 | checkpoint_dir=estimator.model_dir, 125 | listeners=[ 126 | checkpoint_hooks.LaggedCheckpointListener( 127 | export_fn=_export_fn, 128 | num_versions=self._num_versions, 129 | export_dir=self._export_dir, 130 | lagged_export_dir=self._lagged_export_dir) 131 | ]) 132 | ] 133 | -------------------------------------------------------------------------------- /hooks/td3_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Tensor2Robot Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for TD3 Hooks.""" 17 | 18 | import mock 19 | from tensor2robot.export_generators import abstract_export_generator 20 | from tensor2robot.hooks import checkpoint_hooks 21 | from tensor2robot.hooks import td3 22 | from tensor2robot.utils import mocks 23 | from tensor2robot.utils import tensorspec_utils 24 | from tensorflow.compat.v1 import estimator as tf_estimator 25 | import tensorflow.compat.v1 as tf # tf 26 | 27 | _BATCH_SIZES_FOR_EXPORT = [128] 28 | _MODEL_DIR = "model_dir" 29 | _NUMPY_WARMUP_REQUESTS = "warmup_requests" 30 | _EXPORT_DIR = "export_dir" 31 | _LAGGED_EXPORT_DIR = "lagged_export_dir" 32 | 33 | 34 | class MockEstimator(tf_estimator.Estimator): 35 | 36 | def __init__(self): 37 | pass 38 | 39 | @property 40 | def model_dir(self): 41 | return _MODEL_DIR 42 | 43 | 44 | class Td3Test(tf.test.TestCase): 45 | 46 | @mock.patch.object(MockEstimator, "export_saved_model") 47 | @mock.patch.object(checkpoint_hooks.LaggedCheckpointListener, "__init__") 48 | @mock.patch.object(mocks.MockExportGenerator, 49 | "create_serving_input_receiver_numpy_fn") 50 | @mock.patch.object(abstract_export_generator.AbstractExportGenerator, 51 | "create_warmup_requests_numpy") 52 | def test_hooks(self, mock_create_warmup_requests_numpy, 53 | mock_create_serving_input_receiver_numpy_fn, 54 | mock_checkpoint_init, mock_export_saved_model): 55 | 56 | def _checkpoint_init(export_fn, export_dir, **kwargs): 57 | del kwargs 58 | export_fn(export_dir, global_step=1) 59 | return None 60 | 61 | mock_checkpoint_init.side_effect = _checkpoint_init 62 | 63 | export_generator = mocks.MockExportGenerator() 64 | 65 | hook_builder = td3.TD3Hooks( 66 | export_dir=_EXPORT_DIR, 67 | lagged_export_dir=_LAGGED_EXPORT_DIR, 68 | batch_sizes_for_export=_BATCH_SIZES_FOR_EXPORT, 69 | export_generator=export_generator) 70 | 71 | model = mocks.MockT2RModel() 72 | estimator = MockEstimator() 73 | 74 | mock_create_warmup_requests_numpy.return_value = _NUMPY_WARMUP_REQUESTS 75 | 76 | hooks = hook_builder.create_hooks(t2r_model=model, estimator=estimator) 77 | self.assertLen(hooks, 1) 78 | 79 | mock_create_warmup_requests_numpy.assert_called_with( 80 | batch_sizes=_BATCH_SIZES_FOR_EXPORT, 81 | export_dir=_MODEL_DIR) 82 | 83 | mock_export_saved_model.assert_called_with( 84 | serving_input_receiver_fn=mock.ANY, 85 | export_dir_base=_EXPORT_DIR, 86 | assets_extra={ 87 | "tf_serving_warmup_requests": _NUMPY_WARMUP_REQUESTS, 88 | tensorspec_utils.T2R_ASSETS_FILENAME: mock.ANY 89 | }) 90 | 91 | mock_create_serving_input_receiver_numpy_fn.assert_called() 92 | 93 | 94 | if __name__ == "__main__": 95 | tf.test.main() 96 | -------------------------------------------------------------------------------- /hooks/variable_logger_hook.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Tensor2Robot Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """A hook to log all variables.""" 17 | 18 | from typing import Optional 19 | 20 | from absl import logging 21 | 22 | import numpy as np 23 | import tensorflow.compat.v1 as tf 24 | from tensorflow.contrib import framework as contrib_framework 25 | 26 | 27 | class VariableLoggerHook(tf.train.SessionRunHook): 28 | """A hook to log variables via a session run hook.""" 29 | 30 | def __init__(self, max_num_variable_values = None): 31 | """Initializes a VariableLoggerHook. 32 | 33 | Args: 34 | max_num_variable_values: If not None, at most max_num_variable_values will 35 | be logged per variable. 36 | """ 37 | super(VariableLoggerHook, self).__init__() 38 | self._max_num_variable_values = max_num_variable_values 39 | 40 | def begin(self): 41 | """Captures all variables to be read out during the session run.""" 42 | self._variables_to_log = contrib_framework.get_variables() 43 | 44 | def before_run(self, run_context): 45 | """Adds the variables to the run args.""" 46 | return tf.train.SessionRunArgs(self._variables_to_log) 47 | 48 | def after_run(self, run_context, run_values): 49 | del run_context 50 | original = np.get_printoptions() 51 | np.set_printoptions(suppress=True) 52 | for variable, variable_value in zip(self._variables_to_log, 53 | run_values.results): 54 | if not isinstance(variable_value, np.ndarray): 55 | continue 56 | variable_value = variable_value.ravel() 57 | logging.info('%s.mean = %s', variable.op.name, np.mean(variable_value)) 58 | logging.info('%s.std = %s', variable.op.name, np.std(variable_value)) 59 | if self._max_num_variable_values: 60 | variable_value = variable_value[:self._max_num_variable_values] 61 | logging.info('%s = %s', variable.op.name, variable_value) 62 | np.set_printoptions(**original) 63 | -------------------------------------------------------------------------------- /input_generators/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Tensor2Robot Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """InputGenerator implementations for estimator models.""" 17 | from tensor2robot.input_generators import abstract_input_generator 18 | -------------------------------------------------------------------------------- /input_generators/abstract_input_generator_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Tensor2Robot Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for estimator_models.input_generators.abstract_input_generator.""" 17 | 18 | import functools 19 | from absl import flags 20 | from tensor2robot.input_generators import abstract_input_generator 21 | from tensor2robot.preprocessors import noop_preprocessor 22 | from tensor2robot.utils import mocks 23 | import tensorflow.compat.v1 as tf 24 | 25 | 26 | FLAGS = flags.FLAGS 27 | 28 | BATCH_SIZE = 32 29 | 30 | 31 | class AbstractInputGeneratorTest(tf.test.TestCase): 32 | 33 | def test_init_abstract(self): 34 | with self.assertRaises(TypeError): 35 | abstract_input_generator.AbstractInputGenerator() 36 | 37 | def test_set_preprocess_fn(self): 38 | mock_input_generator = mocks.MockInputGenerator(batch_size=BATCH_SIZE) 39 | preprocessor = noop_preprocessor.NoOpPreprocessor() 40 | with self.assertRaises(ValueError): 41 | # This should raise since we pass a function with `mode` not already 42 | # filled in either by a closure or functools.partial. 43 | mock_input_generator.set_preprocess_fn(preprocessor.preprocess) 44 | 45 | preprocess_fn = functools.partial(preprocessor.preprocess, labels=None) 46 | with self.assertRaises(ValueError): 47 | # This should raise since we pass a partial function but `mode` 48 | # is not abstracted away. 49 | mock_input_generator.set_preprocess_fn(preprocess_fn) 50 | 51 | 52 | if __name__ == '__main__': 53 | tf.test.main() 54 | -------------------------------------------------------------------------------- /layers/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Tensor2Robot Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /layers/mdn_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Tensor2Robot Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Test for tensor2robot.layers.mdn.""" 17 | 18 | from absl.testing import parameterized 19 | import numpy as np 20 | from tensor2robot.layers import mdn 21 | import tensorflow.compat.v1 as tf 22 | import tensorflow_probability as tfp 23 | 24 | 25 | class MDNTest(tf.test.TestCase, parameterized.TestCase): 26 | 27 | def test_get_mixture_distribution(self): 28 | sample_size = 10 29 | num_alphas = 5 30 | batch_shape = (4, 2) 31 | alphas = tf.random.normal(batch_shape + (num_alphas,)) 32 | mus = tf.random.normal(batch_shape + (sample_size * num_alphas,)) 33 | sigmas = tf.random.normal(batch_shape + (sample_size * num_alphas,)) 34 | params = tf.concat([alphas, mus, sigmas], -1) 35 | output_mean_np = np.random.normal(size=(sample_size,)) 36 | gm = mdn.get_mixture_distribution( 37 | params, num_alphas, sample_size, output_mean=output_mean_np) 38 | self.assertEqual(gm.batch_shape, batch_shape) 39 | self.assertEqual(gm.event_shape, sample_size) 40 | 41 | # Check that the component means were translated by output_mean_np. 42 | component_means = gm.components_distribution.mean() 43 | with self.test_session() as sess: 44 | # Note: must get values from the same session run, since params will be 45 | # randomized across separate session runs. 46 | component_means_np, mus_np = sess.run([component_means, mus]) 47 | mus_np = np.reshape(mus_np, component_means_np.shape) 48 | self.assertAllClose(component_means_np, mus_np + output_mean_np) 49 | 50 | @parameterized.parameters((True,), (False,)) 51 | def test_predict_mdn_params(self, condition_sigmas): 52 | sample_size = 10 53 | num_alphas = 5 54 | inputs = tf.random.normal((2, 16)) 55 | with tf.variable_scope('test_scope'): 56 | dist_params = mdn.predict_mdn_params( 57 | inputs, num_alphas, sample_size, condition_sigmas=condition_sigmas) 58 | expected_num_params = num_alphas * (1 + 2 * sample_size) 59 | self.assertEqual(dist_params.shape.as_list(), [2, expected_num_params]) 60 | 61 | gm = mdn.get_mixture_distribution(dist_params, num_alphas, sample_size) 62 | stddev = gm.components_distribution.stddev() 63 | with self.test_session() as sess: 64 | sess.run(tf.global_variables_initializer()) 65 | stddev_np = sess.run(stddev) 66 | if condition_sigmas: 67 | # Standard deviations should vary with input. 68 | self.assertNotAllClose(stddev_np[0], stddev_np[1]) 69 | else: 70 | # Standard deviations should *not* vary with input. 71 | self.assertAllClose(stddev_np[0], stddev_np[1]) 72 | 73 | def test_gaussian_mixture_approximate_mode(self): 74 | sample_size = 10 75 | num_alphas = 5 76 | # Manually set alphas to 1 in zero-th column and 0 elsewhere, making the 77 | # first component the most likely. 78 | alphas = tf.one_hot(2 * [0], num_alphas) 79 | mus = tf.random.normal((2, num_alphas, sample_size)) 80 | sigmas = tf.ones_like(mus) 81 | mix_dist = tfp.distributions.Categorical(logits=alphas) 82 | comp_dist = tfp.distributions.MultivariateNormalDiag( 83 | loc=mus, scale_diag=sigmas) 84 | gm = tfp.distributions.MixtureSameFamily( 85 | mixture_distribution=mix_dist, components_distribution=comp_dist) 86 | approximate_mode = mdn.gaussian_mixture_approximate_mode(gm) 87 | with self.test_session() as sess: 88 | approximate_mode_np, mus_np = sess.run([approximate_mode, mus]) 89 | # The approximate mode should be the mean of the zero-th (most likely) 90 | # component. 91 | self.assertAllClose(approximate_mode_np, mus_np[:, 0, :]) 92 | 93 | 94 | if __name__ == '__main__': 95 | tf.test.main() 96 | -------------------------------------------------------------------------------- /layers/resnet_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Tensor2Robot Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for tensor2robot.layers.resnet.""" 17 | 18 | import functools 19 | from absl.testing import parameterized 20 | from six.moves import range 21 | from tensor2robot.layers import resnet 22 | import tensorflow.compat.v1 as tf 23 | 24 | 25 | class ResnetTest(tf.test.TestCase, parameterized.TestCase): 26 | 27 | @parameterized.parameters(('',), ('fubar',), ('dummy/scope')) 28 | def test_intermediate_values(self, scope): 29 | with tf.variable_scope(scope): 30 | image = tf.zeros((2, 224, 224, 3), dtype=tf.float32) 31 | end_points = resnet.resnet_model(image, 32 | is_training=True, 33 | num_classes=1001, 34 | return_intermediate_values=True) 35 | tensors = ['initial_conv', 'initial_max_pool', 'pre_final_pool', 36 | 'final_reduce_mean', 'final_dense'] 37 | tensors += [ 38 | 'block_layer{}'.format(i + 1) for i in range(4)] 39 | self.assertEqual(set(tensors), set(end_points.keys())) 40 | 41 | @parameterized.parameters( 42 | (18, [True, True, True, True]), 43 | (50, [True, False, True, False])) 44 | def test_film(self, resnet_size, enabled_blocks): 45 | image = tf.zeros((2, 224, 224, 3), dtype=tf.float32) 46 | embedding = tf.zeros((2, 100), dtype=tf.float32) 47 | film_generator_fn = functools.partial( 48 | resnet.linear_film_generator, enabled_block_layers=enabled_blocks) 49 | _ = resnet.resnet_model(image, 50 | is_training=True, 51 | num_classes=1001, 52 | resnet_size=resnet_size, 53 | return_intermediate_values=True, 54 | film_generator_fn=film_generator_fn, 55 | film_generator_input=embedding) 56 | 57 | def test_malformed_film_raises(self): 58 | image = tf.zeros((2, 224, 224, 3), dtype=tf.float32) 59 | embedding = tf.zeros((2, 100), dtype=tf.float32) 60 | film_generator_fn = functools.partial( 61 | resnet.linear_film_generator, enabled_block_layers=[True]*5) 62 | with self.assertRaises(ValueError): 63 | _ = resnet.resnet_model(image, 64 | is_training=True, 65 | num_classes=1001, 66 | resnet_size=18, 67 | return_intermediate_values=True, 68 | film_generator_fn=film_generator_fn, 69 | film_generator_input=embedding) 70 | 71 | if __name__ == '__main__': 72 | tf.test.main() 73 | -------------------------------------------------------------------------------- /layers/snail.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Tensor2Robot Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Implementation of building blocks from https://arxiv.org/abs/1707.03141. 17 | 18 | Implementation here is designed to match pseudocode in the paper. 19 | """ 20 | 21 | from typing import Text 22 | 23 | import numpy as np 24 | from six.moves import range 25 | import tensorflow.compat.v1 as tf 26 | from tensorflow.contrib import layers 27 | 28 | 29 | def CausalConv(x, dilation_rate, filters, kernel_size=2, scope = ""): 30 | """Performs causal dilated 1D convolutions. 31 | 32 | Args: 33 | x : Tensor of shape (batch_size, steps, input_dim). 34 | dilation_rate: Dilation rate of convolution. 35 | filters: Number of convolution filters. 36 | kernel_size: Width of convolution kernel. SNAIL paper uses 2 for all 37 | experiments. 38 | scope: Variable scope for this layer. 39 | Returns: 40 | y: Tensor of shape (batch_size, new_steps, D). 41 | """ 42 | with tf.variable_scope(scope): 43 | causal_pad_size = (kernel_size - 1) * dilation_rate 44 | # Pad sequence dimension. 45 | x = tf.pad(x, [[0, 0], [causal_pad_size, 0], [0, 0]]) 46 | return layers.conv1d( 47 | x, 48 | filters, 49 | kernel_size=kernel_size, 50 | padding="VALID", 51 | rate=dilation_rate) 52 | 53 | 54 | def DenseBlock(x, dilation_rate, filters, scope = ""): 55 | r"""SNAIL \'dense block\' with gated activation and concatenation. 56 | 57 | Args: 58 | x : Tensor of shape [batch, time, channels]. 59 | dilation_rate: Dilation rate of convolution. 60 | filters: Number of convolution filters. 61 | scope: Variable scope for this layer. 62 | Returns: 63 | y: Tensor of shape [batch, time, channels + filters]. 64 | """ 65 | with tf.variable_scope(scope): 66 | xf = CausalConv(x, dilation_rate, filters, scope="xf") 67 | xg = CausalConv(x, dilation_rate, filters, scope="xg") 68 | activations = tf.nn.tanh(xf) * tf.nn.sigmoid(xg) 69 | return tf.concat([x, activations], axis=2) 70 | 71 | 72 | def TCBlock(x, sequence_length, filters, scope = ""): 73 | """A stack of DenseBlocks with exponentially increasing dilations. 74 | 75 | Args: 76 | x : Tensor of shape [batch, sequence_length, channels]. 77 | sequence_length: Sequence length of x. 78 | filters: Number of convolution filters. 79 | scope: Variable scope for this layer. 80 | Returns: 81 | y: Tensor of shape [batch, sequence_length, channels + filters]. 82 | """ 83 | with tf.variable_scope(scope): 84 | for i in range(1, int(np.ceil(np.log2(sequence_length)))+1): 85 | x = DenseBlock(x, 2**i, filters, scope="DenseBlock_%d" % i) 86 | return x 87 | 88 | 89 | def CausallyMaskedSoftmax(x): 90 | """Causally masked Softmax. Zero out probabilities before and after norm. 91 | 92 | pre-softmax logits are masked by setting upper diagonal to -inf: 93 | 94 | |a 0, 0| |0, -inf, -inf| 95 | |b, d, 0| + |0, 0, -inf| 96 | |c, e, f| |0, 0, 0 | 97 | 98 | Args: 99 | x: Batched tensor of shape [batch_size, T, T]. 100 | Returns: 101 | Softmax where each row corresponds to softmax vector for each query. 102 | """ 103 | lower_diag = tf.linalg.band_part(x, -1, 0) 104 | upper_diag = -np.inf * tf.ones_like(x) 105 | upper_diag = tf.linalg.band_part(upper_diag, 0, -1) 106 | upper_diag = tf.linalg.set_diag( 107 | upper_diag, tf.zeros_like(tf.linalg.diag_part(x))) 108 | x = lower_diag + upper_diag 109 | softmax = tf.nn.softmax(x) 110 | return tf.linalg.band_part(softmax, -1, 0) 111 | 112 | 113 | def AttentionBlock(x, key_size, value_size, scope = ""): 114 | """Self-attention key-value lookup, styled after Vaswani et al. '17. 115 | 116 | query and key are of shape [T, K]. query * transpose(key) yields logits of 117 | shape [T, T]. logits[i, j] corresponds to unnormalized attention vector over 118 | values [T, V] for each timestep i. Because this attention is over a set of 119 | temporal values, we causally mask the pre-softmax logits[i, j] := 0, for all 120 | j > i. 121 | 122 | Citations: 123 | Vaswani et al. '17: Attention is All you need 124 | https://arxiv.org/abs/1706.03762. 125 | 126 | Args: 127 | x: Input tensor of shape [batch, sequence_length, channels]. 128 | key_size: Integer key dimensionality. 129 | value_size: Integer value dimensionality. 130 | scope: Variable scope for this layer. 131 | Returns: 132 | result: Tensor of shape [batch, sequence_length, channels + value_size] 133 | end_points: Dictionary of intermediate values (e.g. debugging). 134 | """ 135 | end_points = {} 136 | with tf.variable_scope(scope): 137 | key = layers.fully_connected(x, key_size, activation_fn=None) # [T, K] 138 | query = layers.fully_connected(x, key_size, activation_fn=None) # [T, K] 139 | logits = tf.matmul(query, key, transpose_b=True) # [T, T] 140 | # Useful for visualizing attention alignment matrices. 141 | probs = CausallyMaskedSoftmax(logits/np.sqrt(key_size)) # [T, T] 142 | end_points["attn_prob"] = probs 143 | values = layers.fully_connected(x, value_size, activation_fn=None) # [T, V] 144 | read = tf.matmul(probs, values) # [T, V] 145 | result = tf.concat([x, read], axis=2) # [T, K + V] 146 | return result, end_points 147 | -------------------------------------------------------------------------------- /layers/snail_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Tensor2Robot Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for SNAIL.""" 17 | 18 | import numpy as np 19 | from six.moves import range 20 | from tensor2robot.layers import snail 21 | import tensorflow.compat.v1 as tf 22 | 23 | 24 | class SNAILTest(tf.test.TestCase): 25 | 26 | def test_CausalConv(self): 27 | x = tf.random.normal((4, 10, 8)) 28 | y = snail.CausalConv(x, 1, 5) 29 | self.assertEqual(y.shape, (4, 10, 5)) 30 | 31 | def test_DenseBlock(self): 32 | x = tf.random.normal((4, 10, 8)) 33 | y = snail.DenseBlock(x, 1, 5) 34 | self.assertEqual(y.shape, (4, 10, 13)) 35 | 36 | def test_TCBlock(self): 37 | sequence_length = 10 38 | x = tf.random.normal((4, sequence_length, 8)) 39 | y = snail.TCBlock(x, sequence_length, 5) 40 | self.assertEqual(y.shape, (4, 10, 8 + 4*5)) 41 | 42 | def test_CausallyMaskedSoftmax(self): 43 | num_rows = 5 44 | x = tf.random.normal((num_rows, 3)) 45 | logits = tf.matmul(x, tf.linalg.transpose(x)) 46 | y = snail.CausallyMaskedSoftmax(logits) 47 | with self.test_session() as sess: 48 | y_ = sess.run(y) 49 | idx = np.triu_indices(num_rows, 1) 50 | np.testing.assert_array_equal(y_[idx], 0.) 51 | # Testing that each row sums to 1. 52 | for i in range(num_rows): 53 | np.testing.assert_almost_equal(np.sum(y_[i, :]), 1.0) 54 | 55 | def test_AttentionBlock(self): 56 | x = tf.random.normal((4, 10, 8)) 57 | y, end_points = snail.AttentionBlock(x, 3, 5) 58 | self.assertEqual(y.shape, (4, 10, 5+8)) 59 | self.assertEqual(end_points['attn_prob'].shape, (4, 10, 10)) 60 | 61 | if __name__ == '__main__': 62 | tf.test.main() 63 | -------------------------------------------------------------------------------- /layers/spatial_softmax.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Tensor2Robot Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """TensorFlow impl of Spatial Softmax layers. (spatial soft arg-max). 17 | 18 | TODO(T2R_CONTRIBUTORS) - consider replacing with contrib version. 19 | """ 20 | 21 | import gin 22 | import numpy as np 23 | from six.moves import range 24 | import tensorflow.compat.v1 as tf 25 | import tensorflow_probability as tfp 26 | 27 | 28 | @gin.configurable 29 | def BuildSpatialSoftmax(features, spatial_gumbel_softmax=False): 30 | """Computes the spatial softmax of the input features. 31 | 32 | Args: 33 | features: A tensor of size [batch_size, num_rows, num_cols, num_features] 34 | spatial_gumbel_softmax: If set to True, samples locations stochastically 35 | rather than computing expected coordinates with respect to heatmap. 36 | Returns: 37 | A tuple of (expected_feature_points, softmax). 38 | expected_feature_points: A tensor of size 39 | [batch_size, num_features * 2]. These are the expected feature 40 | locations, i.e., the spatial softmax of feature_maps. The inner 41 | dimension is arranged as [x1, x2, x3 ... xN, y1, y2, y3, ... yN]. 42 | softmax: A Tensor which is the softmax of the features. 43 | [batch_size, num_rows, num_cols, num_features]. 44 | """ 45 | _, num_rows, num_cols, num_features = features.get_shape().as_list() 46 | 47 | with tf.name_scope('SpatialSoftmax'): 48 | # Create tensors for x and y positions, respectively 49 | x_pos = np.empty([num_rows, num_cols], np.float32) 50 | y_pos = np.empty([num_rows, num_cols], np.float32) 51 | 52 | # Assign values to positions 53 | for i in range(num_rows): 54 | for j in range(num_cols): 55 | x_pos[i, j] = 2.0 * j / (num_cols - 1.0) - 1.0 56 | y_pos[i, j] = 2.0 * i / (num_rows - 1.0) - 1.0 57 | 58 | x_pos = tf.reshape(x_pos, [num_rows * num_cols]) 59 | y_pos = tf.reshape(y_pos, [num_rows * num_cols]) 60 | 61 | # We reorder the features (norm3) into the following order: 62 | # [batch_size, NUM_FEATURES, num_rows, num_cols] 63 | # This lets us merge the batch_size and num_features dimensions, in order 64 | # to compute spatial softmax as a single batch operation. 65 | features = tf.reshape( 66 | tf.transpose(features, [0, 3, 1, 2]), [-1, num_rows * num_cols]) 67 | 68 | if spatial_gumbel_softmax: 69 | # Temperature is hard-coded for now, make this more flexible if results 70 | # are promising. 71 | dist = tfp.distributions.RelaxedOneHotCategorical( 72 | temperature=1.0, logits=features) 73 | softmax = dist.sample() 74 | else: 75 | softmax = tf.nn.softmax(features) 76 | # Element-wise multiplication 77 | x_output = tf.multiply(x_pos, softmax) 78 | y_output = tf.multiply(y_pos, softmax) 79 | # Sum per out_size x out_size 80 | x_output = tf.reduce_sum(x_output, [1], keep_dims=True) 81 | y_output = tf.reduce_sum(y_output, [1], keep_dims=True) 82 | # Concatenate x and y, and reshape. 83 | expected_feature_points = tf.reshape( 84 | tf.concat([x_output, y_output], 1), [-1, num_features*2]) 85 | softmax = tf.transpose( 86 | tf.reshape(softmax, [-1, num_features, num_rows, 87 | num_cols]), [0, 2, 3, 1]) 88 | return expected_feature_points, softmax 89 | -------------------------------------------------------------------------------- /layers/spatial_softmax_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Tensor2Robot Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for Spatial Softmax Layer.""" 17 | 18 | import numpy as np 19 | from tensor2robot.layers import spatial_softmax 20 | import tensorflow.compat.v1 as tf 21 | 22 | 23 | class SpatialSoftmaxTest(tf.test.TestCase): 24 | 25 | def test_SpatialGumbelSoftmax(self): 26 | 27 | features = tf.convert_to_tensor( 28 | np.random.normal(size=(32, 16, 16, 64)).astype(np.float32)) 29 | with tf.variable_scope('mean_pool'): 30 | expected_feature_points, softmax = spatial_softmax.BuildSpatialSoftmax( 31 | features, spatial_gumbel_softmax=False) 32 | with tf.variable_scope('gumbel_pool'): 33 | gumbel_feature_points, gumbel_softmax = ( 34 | spatial_softmax.BuildSpatialSoftmax( 35 | features, spatial_gumbel_softmax=True)) 36 | self.assertEqual(expected_feature_points.shape, gumbel_feature_points.shape) 37 | self.assertEqual(softmax.shape, gumbel_softmax.shape) 38 | 39 | if __name__ == '__main__': 40 | tf.test.main() 41 | -------------------------------------------------------------------------------- /layers/tec_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Tensor2Robot Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for tensor2robot.layers.tec.""" 17 | 18 | from tensor2robot.layers import tec 19 | import tensorflow.compat.v1 as tf 20 | 21 | 22 | class TECTest(tf.test.TestCase): 23 | 24 | def test_embed_condition_images(self): 25 | images = tf.random.normal((4, 100, 100, 3)) 26 | embedding = tec.embed_condition_images( 27 | images, 'test_embed', fc_layers=(100, 20)) 28 | self.assertEqual(embedding.shape.as_list(), [4, 20]) 29 | 30 | def test_doubly_batched_embed_condition_images(self): 31 | doubly_batched_images = tf.random.normal((3, 4, 10, 12, 3)) 32 | with self.assertRaises(ValueError): 33 | tec.embed_condition_images(doubly_batched_images, 'test_embed') 34 | 35 | def test_reduce_temporal_embeddings(self): 36 | temporal_embeddings = tf.random.normal((4, 20, 16)) 37 | embedding = tec.reduce_temporal_embeddings( 38 | temporal_embeddings, 10, 'test_reduce') 39 | self.assertEqual(embedding.shape.as_list(), [4, 10]) 40 | 41 | def test_doubly_batched_reduce_temporal_embeddings(self): 42 | temporal_embeddings = tf.random.normal((2, 4, 20, 16)) 43 | with self.assertRaises(ValueError): 44 | tec.reduce_temporal_embeddings(temporal_embeddings, 10, 'test_reduce') 45 | 46 | def test_contrastive_loss(self): 47 | inf_embeddings = tf.nn.l2_normalize( 48 | tf.ones((5, 1, 10), dtype=tf.float32), axis=-1) 49 | con_embeddings = tf.nn.l2_normalize( 50 | tf.ones((5, 1, 10), dtype=tf.float32), axis=-1) 51 | tec.compute_embedding_contrastive_loss(inf_embeddings, con_embeddings) 52 | 53 | 54 | if __name__ == '__main__': 55 | tf.test.main() 56 | -------------------------------------------------------------------------------- /meta_learning/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Tensor2Robot Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /meta_learning/meta_example.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Tensor2Robot Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Utility function for dealing with meta-examples. 17 | """ 18 | 19 | from typing import List, Union 20 | 21 | import six 22 | import tensorflow.compat.v1 as tf # tf 23 | 24 | Example = Union[tf.train.Example, tf.train.SequenceExample] 25 | 26 | 27 | def make_meta_example( 28 | condition_examples, 29 | inference_examples, 30 | ): 31 | """Creates a single MetaExample from train_examples and val_examples.""" 32 | if isinstance(condition_examples[0], tf.train.Example): 33 | meta_example = tf.train.Example() 34 | append_fn = append_example 35 | else: 36 | meta_example = tf.train.SequenceExample() 37 | append_fn = append_sequence_example 38 | for i, train_example in enumerate(condition_examples): 39 | append_fn(meta_example, train_example, 'condition_ep{:d}'.format(i)) 40 | 41 | for i, val_example in enumerate(inference_examples): 42 | append_fn(meta_example, val_example, 'inference_ep{:d}'.format(i)) 43 | return meta_example 44 | 45 | 46 | def append_example(example, ep_example, prefix): 47 | """Add episode Example to Meta TFExample with a prefix.""" 48 | context_feature_map = example.features.feature 49 | for key, feature in six.iteritems(ep_example.features.feature): 50 | context_feature_map[six.ensure_str(prefix) + '/' + 51 | six.ensure_str(key)].CopyFrom(feature) 52 | 53 | 54 | def append_sequence_example(meta_example, ep_example, prefix): 55 | """Add episode SequenceExample to the Meta SequenceExample with a prefix.""" 56 | context_feature_map = meta_example.context.feature 57 | # Append context features. 58 | for key, feature in six.iteritems(ep_example.context.feature): 59 | context_feature_map[six.ensure_str(prefix) + '/' + 60 | six.ensure_str(key)].CopyFrom(feature) 61 | # Append Sequential features. 62 | sequential_feature_map = meta_example.feature_lists.feature_list 63 | for key, feature_list in six.iteritems(ep_example.feature_lists.feature_list): 64 | sequential_feature_map[six.ensure_str(prefix) + '/' + 65 | six.ensure_str(key)].CopyFrom(feature_list) 66 | -------------------------------------------------------------------------------- /meta_learning/meta_tf_models_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Tensor2Robot Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for learning.estimator_models.meta_learning.meta_tf_models.""" 17 | 18 | from tensor2robot.meta_learning import meta_tf_models 19 | from tensor2robot.preprocessors import abstract_preprocessor 20 | from tensor2robot.utils import tensorspec_utils 21 | import tensorflow.compat.v1 as tf 22 | 23 | 24 | class MockBasePreprocessor(abstract_preprocessor.AbstractPreprocessor): 25 | 26 | def _get_feature_specification(self): 27 | spec = tensorspec_utils.TensorSpecStruct() 28 | spec.action = tensorspec_utils.ExtendedTensorSpec( 29 | name='action', shape=(1,), dtype=tf.float32) 30 | spec.velocity = tensorspec_utils.ExtendedTensorSpec( 31 | name='velocity', shape=(1,), dtype=tf.float32, is_optional=True) 32 | return spec 33 | 34 | def get_in_feature_specification(self): 35 | return self._get_feature_specification() 36 | 37 | def get_out_feature_specification(self): 38 | return self._get_feature_specification() 39 | 40 | def _get_label_specification(self): 41 | spec = tensorspec_utils.TensorSpecStruct() 42 | spec.target = tensorspec_utils.ExtendedTensorSpec( 43 | name='target', shape=(1,), dtype=tf.float32) 44 | spec.proxy = tensorspec_utils.ExtendedTensorSpec( 45 | name='proxy', shape=(1,), dtype=tf.float32, is_optional=True) 46 | return spec 47 | 48 | def get_in_label_specification(self): 49 | return self._get_label_specification() 50 | 51 | def get_out_label_specification(self): 52 | return self._get_label_specification() 53 | 54 | def _preprocess_fn(self, features, labels, unused_mode): 55 | return features, labels 56 | 57 | 58 | class MetaTfModelsTest(tf.test.TestCase): 59 | 60 | def test_meta_preprocessor_required_specs(self): 61 | meta_preprocessor = meta_tf_models.MetaPreprocessor( 62 | base_preprocessor=MockBasePreprocessor(), 63 | num_train_samples_per_task=1, 64 | num_val_samples_per_task=1) 65 | ref_feature_spec = meta_preprocessor.get_in_feature_specification() 66 | filtered_feature_spec = tensorspec_utils.filter_required_flat_tensor_spec( 67 | meta_preprocessor.get_in_feature_specification()) 68 | self.assertDictEqual(ref_feature_spec, filtered_feature_spec) 69 | 70 | ref_label_spec = meta_preprocessor.get_in_label_specification() 71 | filtered_label_spec = tensorspec_utils.filter_required_flat_tensor_spec( 72 | meta_preprocessor.get_in_label_specification()) 73 | self.assertDictEqual(ref_label_spec, filtered_label_spec) 74 | 75 | 76 | if __name__ == '__main__': 77 | tf.test.main() 78 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Tensor2Robot Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /policies/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Tensor2Robot Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /predictors/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Tensor2Robot Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /predictors/abstract_predictor.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Tensor2Robot Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """An abstract predictor to load tf models and expose a predict function.""" 17 | 18 | import abc 19 | from typing import Dict, Optional, Text 20 | 21 | import numpy as np 22 | import six 23 | from tensor2robot.utils import tensorspec_utils 24 | 25 | 26 | class AbstractPredictor(six.with_metaclass(abc.ABCMeta, object)): 27 | """A predictor responsible to load a T2RModel and expose a predict function. 28 | 29 | The purpose of the predictor is to abstract model loading and running, e.g. 30 | using a raw session interface, a tensorflow predictor created from saved 31 | models or tensorflow 2.0 models. 32 | """ 33 | 34 | @abc.abstractmethod 35 | def predict(self, features): 36 | """Predicts based on feature input using the loaded model. 37 | 38 | Args: 39 | features: A dict containing the features used for predictions. 40 | Returns: 41 | The result of the queried model predictions. 42 | """ 43 | 44 | @abc.abstractmethod 45 | def get_feature_specification(self): 46 | """Exposes the required input features for evaluation of the model.""" 47 | 48 | def get_label_specification(self 49 | ): 50 | """Exposes the optional labels for evaluation of the model.""" 51 | return None 52 | 53 | @abc.abstractmethod 54 | def restore(self): 55 | """Restores the model parameters from the latest available data.""" 56 | 57 | def init_randomly(self): 58 | """Initializes model parameters from with random values.""" 59 | 60 | @abc.abstractmethod 61 | def close(self): 62 | """Closes all open handles used throughout model evaluation.""" 63 | 64 | @abc.abstractmethod 65 | def assert_is_loaded(self): 66 | """Raises a ValueError if the predictor has not been restored yet.""" 67 | 68 | @property 69 | def model_version(self): 70 | """The version of the model currently in use.""" 71 | return 0 72 | 73 | @property 74 | def global_step(self): 75 | """The global step of the model currently in use.""" 76 | return 0 77 | 78 | @property 79 | def model_path(self): 80 | """The path of the model currently in use.""" 81 | return '' 82 | -------------------------------------------------------------------------------- /predictors/checkpoint_predictor_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Tensor2Robot Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for py.tensor2robot.predictors.checkpoint_predictor_test.""" 17 | 18 | from absl import flags 19 | import gin 20 | import numpy as np 21 | from tensor2robot.input_generators import default_input_generator 22 | from tensor2robot.predictors import checkpoint_predictor 23 | from tensor2robot.utils import mocks 24 | from tensor2robot.utils import tensorspec_utils 25 | from tensor2robot.utils import train_eval 26 | import tensorflow.compat.v1 as tf 27 | from tensorflow.compat.v1 import estimator as tf_estimator 28 | 29 | FLAGS = flags.FLAGS 30 | 31 | _BATCH_SIZE = 2 32 | _MAX_TRAIN_STEPS = 3 33 | 34 | 35 | class CheckpointPredictorTest(tf.test.TestCase): 36 | 37 | def setUp(self): 38 | super(CheckpointPredictorTest, self).setUp() 39 | gin.clear_config() 40 | gin.parse_config('tf.estimator.RunConfig.save_checkpoints_steps=1') 41 | 42 | def test_predictor(self): 43 | input_generator = default_input_generator.DefaultRandomInputGenerator( 44 | batch_size=_BATCH_SIZE) 45 | model_dir = self.create_tempdir().full_path 46 | mock_model = mocks.MockT2RModel() 47 | train_eval.train_eval_model( 48 | t2r_model=mock_model, 49 | input_generator_train=input_generator, 50 | max_train_steps=_MAX_TRAIN_STEPS, 51 | model_dir=model_dir) 52 | 53 | predictor = checkpoint_predictor.CheckpointPredictor( 54 | t2r_model=mock_model, checkpoint_dir=model_dir, use_gpu=False) 55 | with self.assertRaises(ValueError): 56 | predictor.predict({'does_not_matter': np.zeros(1)}) 57 | self.assertEqual(predictor.model_version, -1) 58 | self.assertEqual(predictor.global_step, -1) 59 | self.assertTrue(predictor.restore()) 60 | self.assertGreater(predictor.model_version, 0) 61 | self.assertEqual(predictor.global_step, 3) 62 | ref_feature_spec = mock_model.preprocessor.get_in_feature_specification( 63 | tf_estimator.ModeKeys.PREDICT) 64 | tensorspec_utils.assert_equal(predictor.get_feature_specification(), 65 | ref_feature_spec) 66 | features = tensorspec_utils.make_random_numpy( 67 | ref_feature_spec, batch_size=_BATCH_SIZE) 68 | predictions = predictor.predict(features) 69 | self.assertLen(predictions, 1) 70 | self.assertCountEqual(sorted(predictions.keys()), ['logit']) 71 | self.assertEqual(predictions['logit'].shape, (2, 1)) 72 | 73 | def test_predictor_timeout(self): 74 | mock_model = mocks.MockT2RModel() 75 | predictor = checkpoint_predictor.CheckpointPredictor( 76 | t2r_model=mock_model, 77 | checkpoint_dir='/random/path/which/does/not/exist', 78 | timeout=1) 79 | self.assertFalse(predictor.restore()) 80 | 81 | def test_predictor_raises(self): 82 | mock_model = mocks.MockT2RModel() 83 | # Raises because no checkpoint_dir and has been set and restore is called. 84 | predictor = checkpoint_predictor.CheckpointPredictor(t2r_model=mock_model) 85 | with self.assertRaises(ValueError): 86 | predictor.restore() 87 | 88 | 89 | if __name__ == '__main__': 90 | tf.test.main() 91 | -------------------------------------------------------------------------------- /predictors/ensemble_exported_savedmodel_predictor_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Tensor2Robot Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for tensor2robot.predictors.ensemble_exported_savedmodel_predictor.""" 17 | 18 | import os 19 | 20 | from absl import flags 21 | from absl.testing import parameterized 22 | import gin 23 | import numpy as np 24 | from tensor2robot.input_generators import default_input_generator 25 | from tensor2robot.predictors import ensemble_exported_savedmodel_predictor 26 | from tensor2robot.utils import mocks 27 | from tensor2robot.utils import tensorspec_utils 28 | from tensor2robot.utils import train_eval 29 | import tensorflow.compat.v1 as tf 30 | from tensorflow.compat.v1 import estimator as tf_estimator 31 | 32 | FLAGS = flags.FLAGS 33 | 34 | _EXPORT_DIR = 'asyn_export' 35 | 36 | _BATCH_SIZE = 2 37 | _MAX_TRAIN_STEPS = 3 38 | _MAX_EVAL_STEPS = 2 39 | 40 | 41 | class ExportedSavedmodelPredictorTest(tf.test.TestCase, parameterized.TestCase): 42 | 43 | def setUp(self): 44 | super(ExportedSavedmodelPredictorTest, self).setUp() 45 | gin.clear_config() 46 | gin.parse_config('tf.estimator.RunConfig.save_checkpoints_steps=1') 47 | 48 | def test_predictor_with_default_exporter(self): 49 | input_generator = default_input_generator.DefaultRandomInputGenerator( 50 | batch_size=_BATCH_SIZE) 51 | model_dir = self.create_tempdir().full_path 52 | mock_model = mocks.MockT2RModel() 53 | train_eval.train_eval_model( 54 | t2r_model=mock_model, 55 | input_generator_train=input_generator, 56 | input_generator_eval=input_generator, 57 | max_train_steps=_MAX_TRAIN_STEPS, 58 | eval_steps=_MAX_EVAL_STEPS, 59 | model_dir=model_dir, 60 | create_exporters_fn=train_eval.create_default_exporters) 61 | # Create ensemble by duplicating the same directory multiple times. 62 | export_dirs = ','.join( 63 | [os.path.join(model_dir, 'export', 'latest_exporter_numpy')] * 2) 64 | predictor = ensemble_exported_savedmodel_predictor.EnsembleExportedSavedModelPredictor( 65 | export_dirs=export_dirs, local_export_root=None, ensemble_size=2) 66 | predictor.resample_ensemble() 67 | with self.assertRaises(ValueError): 68 | predictor.get_feature_specification() 69 | with self.assertRaises(ValueError): 70 | predictor.predict({'does_not_matter': np.zeros(1)}) 71 | with self.assertRaises(ValueError): 72 | _ = predictor.model_version 73 | self.assertEqual(predictor.global_step, -1) 74 | self.assertTrue(predictor.restore(is_async=False)) 75 | self.assertGreater(predictor.model_version, 0) 76 | self.assertEqual(predictor.global_step, -1) 77 | ref_feature_spec = mock_model.preprocessor.get_in_feature_specification( 78 | tf_estimator.ModeKeys.PREDICT) 79 | tensorspec_utils.assert_equal(predictor.get_feature_specification(), 80 | ref_feature_spec) 81 | features = tensorspec_utils.make_random_numpy( 82 | ref_feature_spec, batch_size=_BATCH_SIZE) 83 | predictions = predictor.predict(features) 84 | self.assertLen(predictions, 1) 85 | self.assertCountEqual(predictions.keys(), ['logit']) 86 | self.assertEqual(predictions['logit'].shape, (2, 1)) 87 | 88 | 89 | if __name__ == '__main__': 90 | tf.test.main() 91 | -------------------------------------------------------------------------------- /predictors/saved_model_v2_predictor_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Tensor2Robot Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for tensor2robot.predictors.saved_model_v2_predictor.""" 17 | 18 | import os 19 | import numpy as np 20 | 21 | from tensor2robot.predictors import saved_model_v2_predictor 22 | from tensor2robot.proto import t2r_pb2 23 | from tensor2robot.utils import mocks 24 | from tensor2robot.utils import tensorspec_utils 25 | 26 | from tensorflow import estimator as tf_estimator 27 | from tensorflow.compat.v1 import estimator as tf_compat_v1_estimator 28 | import tensorflow.compat.v2 as tf 29 | 30 | _BATCH_SIZE = 2 31 | 32 | 33 | def setUpModule(): 34 | tf.enable_v2_behavior() 35 | 36 | 37 | def _generate_assets(model, export_dir): 38 | in_feature_spec = model.get_feature_specification_for_packing( 39 | mode=tf_estimator.ModeKeys.PREDICT) 40 | in_label_spec = model.get_label_specification_for_packing( 41 | mode=tf_compat_v1_estimator.ModeKeys.PREDICT) 42 | 43 | in_feature_spec = tensorspec_utils.filter_required_flat_tensor_spec( 44 | in_feature_spec) 45 | in_label_spec = tensorspec_utils.filter_required_flat_tensor_spec( 46 | in_label_spec) 47 | 48 | t2r_assets = t2r_pb2.T2RAssets() 49 | t2r_assets.feature_spec.CopyFrom(in_feature_spec.to_proto()) 50 | t2r_assets.label_spec.CopyFrom(in_label_spec.to_proto()) 51 | t2r_assets_dir = os.path.join(export_dir, 52 | tensorspec_utils.EXTRA_ASSETS_DIRECTORY) 53 | 54 | tf.io.gfile.makedirs(t2r_assets_dir) 55 | t2r_assets_filename = os.path.join(t2r_assets_dir, 56 | tensorspec_utils.T2R_ASSETS_FILENAME) 57 | tensorspec_utils.write_t2r_assets_to_file(t2r_assets, t2r_assets_filename) 58 | 59 | 60 | class SavedModelV2PredictorTest(tf.test.TestCase): 61 | 62 | def __init__(self, *args, **kwargs): 63 | super(SavedModelV2PredictorTest, self).__init__(*args, **kwargs) 64 | self._saved_model_path = None 65 | 66 | def _save_model(self, model, sample_features): 67 | if self._saved_model_path: 68 | return self._saved_model_path 69 | 70 | # Save inference_network_fn as the predict method for the saved_model. 71 | @tf.function(autograph=False) 72 | def predict(features): 73 | return model.inference_network_fn(features, None, 74 | tf_compat_v1_estimator.ModeKeys.PREDICT) 75 | 76 | # Call the model for the tf.function tracing side effects. 77 | predict(sample_features) 78 | model.predict = predict 79 | 80 | self._saved_model_path = self.create_tempdir().full_path 81 | tf.saved_model.save(model, self._saved_model_path) 82 | _generate_assets(model, self._saved_model_path) 83 | return self._saved_model_path 84 | 85 | def _test_predictor(self, predictor_cls, multi_dataset): 86 | mock_model = mocks.MockTF2T2RModel(multi_dataset=multi_dataset) 87 | 88 | # Generate a sample to evaluate 89 | feature_spec = mock_model.preprocessor.get_in_feature_specification( 90 | tf_compat_v1_estimator.ModeKeys.PREDICT) 91 | sample_features = tensorspec_utils.make_random_numpy( 92 | feature_spec, batch_size=_BATCH_SIZE) 93 | 94 | # Generate a saved model and load it. 95 | path = self._save_model(mock_model, sample_features) 96 | saved_model_predictor = predictor_cls(path) 97 | 98 | # Not restored yet. 99 | with self.assertRaises(ValueError): 100 | saved_model_predictor.predict(sample_features) 101 | 102 | saved_model_predictor.restore() 103 | 104 | # Validate evaluations are the same afterwards. 105 | original_model_out = mock_model.inference_network_fn( 106 | sample_features, None, tf_compat_v1_estimator.ModeKeys.PREDICT) 107 | 108 | predictor_out = saved_model_predictor.predict(sample_features) 109 | 110 | np.testing.assert_almost_equal(original_model_out['logits'], 111 | predictor_out['logits']) 112 | 113 | def testTF1PredictorSingleDataset(self): 114 | self._test_predictor(saved_model_v2_predictor.SavedModelTF1Predictor, False) 115 | 116 | def testTF1PredictorMultiDataset(self): 117 | self._test_predictor(saved_model_v2_predictor.SavedModelTF1Predictor, True) 118 | 119 | def testTF2PredictorSingleDataset(self): 120 | self._test_predictor(saved_model_v2_predictor.SavedModelTF2Predictor, False) 121 | 122 | def testTF2PredictorMultiDataset(self): 123 | self._test_predictor(saved_model_v2_predictor.SavedModelTF2Predictor, True) 124 | 125 | 126 | if __name__ == '__main__': 127 | tf.test.main() 128 | -------------------------------------------------------------------------------- /preprocessors/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Tensor2Robot Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Preprocessor implementations for estimator models.""" 17 | from tensor2robot.preprocessors import abstract_preprocessor 18 | -------------------------------------------------------------------------------- /preprocessors/abstract_preprocessor_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Tensor2Robot Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for tensor2robot.preprocessors.abstract_preprocessor.""" 17 | 18 | from tensor2robot.preprocessors import abstract_preprocessor 19 | import tensorflow.compat.v1 as tf 20 | 21 | 22 | class AbstractPreprocessorTest(tf.test.TestCase): 23 | 24 | def test_init_abstract(self): 25 | with self.assertRaises(TypeError): 26 | abstract_preprocessor.AbstractPreprocessor() 27 | 28 | 29 | if __name__ == '__main__': 30 | tf.test.main() 31 | -------------------------------------------------------------------------------- /preprocessors/noop_preprocessor.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Tensor2Robot Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """A simple no operation preprocessor.""" 17 | 18 | from typing import Optional, Tuple 19 | 20 | import gin 21 | from tensor2robot.preprocessors import abstract_preprocessor 22 | from tensor2robot.utils import tensorspec_utils 23 | from tensorflow.compat.v1 import estimator as tf_estimator 24 | 25 | 26 | @gin.configurable 27 | class NoOpPreprocessor(abstract_preprocessor.AbstractPreprocessor): 28 | """A convenience preprocessor which does not perform any preprocessing. 29 | 30 | This prerpocessor provides convenience functionality in case we simply want 31 | to ensure that already single examples contain the right information for our 32 | model. This preprocessor does not perform any preprocessing, but allows 33 | existing models to initialize a preprocessor without any additional runtime 34 | overhead. 35 | """ 36 | 37 | def get_in_feature_specification( 38 | self, mode): 39 | """The specification for the input features for the preprocess_fn. 40 | 41 | Arguments: 42 | mode: mode key for this feature specification 43 | Returns: 44 | A TensorSpecStruct describing the required and optional tensors. 45 | """ 46 | return tensorspec_utils.flatten_spec_structure( 47 | self._model_feature_specification_fn(mode)) 48 | 49 | def get_in_label_specification( 50 | self, mode): 51 | """The specification for the input labels for the preprocess_fn. 52 | 53 | Arguments: 54 | mode: mode key for this feature specification 55 | Returns: 56 | A TensorSpecStruct describing the required and optional tensors. 57 | """ 58 | return tensorspec_utils.flatten_spec_structure( 59 | self._model_label_specification_fn(mode)) 60 | 61 | def get_out_feature_specification( 62 | self, mode): 63 | """The specification for the output features after executing preprocess_fn. 64 | 65 | Arguments: 66 | mode: mode key for this feature specification 67 | Returns: 68 | A TensorSpecStruct describing the required and optional tensors. 69 | """ 70 | return tensorspec_utils.flatten_spec_structure( 71 | self._model_feature_specification_fn(mode)) 72 | 73 | def get_out_label_specification( 74 | self, mode): 75 | """The specification for the output labels after executing preprocess_fn. 76 | 77 | Arguments: 78 | mode: mode key for this feature specification 79 | Returns: 80 | A TensorSpecStruct describing the required and optional tensors. 81 | """ 82 | return tensorspec_utils.flatten_spec_structure( 83 | self._model_label_specification_fn(mode)) 84 | 85 | def _preprocess_fn( 86 | self, features, 87 | labels, 88 | mode 89 | ): 90 | """The preprocessing function which will be executed prior to the fn. 91 | 92 | As the name NoOpPreprocessor suggests, we do not perform any prerprocessing. 93 | 94 | Args: 95 | features: The input features extracted from a single example in our 96 | in_features_specification format. 97 | labels: (Optional None) The input labels extracted from a single example 98 | in our in_features_specification format. 99 | mode: (ModeKeys) Specifies if this is training, evaluation or prediction. 100 | 101 | Returns: 102 | features: The preprocessed features, potentially adding 103 | additional tensors derived from the input features. 104 | labels: (Optional) The preprocessed labels, potentially 105 | adding additional tensors derived from the input features and labels. 106 | """ 107 | return features, labels 108 | -------------------------------------------------------------------------------- /proto/t2r.proto: -------------------------------------------------------------------------------- 1 | // Copyright 2024 The Tensor2Robot Authors. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | syntax = "proto2"; 16 | 17 | package third_party.py.tensor2robot; 18 | 19 | message ExtendedTensorSpec { 20 | // This message allows to (de)serialize tensorspec_utils.ExtendedTensorSpec. 21 | // Each field has a one to one mapping to the class constructor. 22 | repeated int32 shape = 1; 23 | optional int32 dtype = 2; 24 | optional string name = 3; 25 | optional bool is_optional = 4; 26 | optional bool is_extracted = 5; 27 | optional string data_format = 6; 28 | optional string dataset_key = 7; 29 | optional float varlen_default_value = 8; 30 | } 31 | 32 | message TensorSpecStruct { 33 | // This message allows to (de)serialize tensorspec_utils.TensorSpecStruct. 34 | // This structure is essentially an OrderedDict which is therefore 35 | // serializable through a key: value map. 36 | map key_value = 1; 37 | } 38 | 39 | message T2RAssets { 40 | optional TensorSpecStruct feature_spec = 1; 41 | optional TensorSpecStruct label_spec = 2; 42 | optional int32 global_step = 3; 43 | } 44 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py>=0.5.0 2 | numpy>=1.13.3 3 | tensorflow>=1.13.0 4 | tensorflow-serving-api>=1.13.0 5 | gin-config>=0.1.4 6 | pybullet==2.5.0 7 | Pillow==5.3.0 8 | gym>=0.10.9 9 | tensorflow-probability>=0.6.0 10 | tf-slim>=1.0 11 | -------------------------------------------------------------------------------- /research/bcz/README.md: -------------------------------------------------------------------------------- 1 | # BC-Z 2 | 3 | Source codes for reproducing "BC-Z: Zero-Shot Task Generalization with Robotic 4 | Imitation Learning". 5 | 6 | Links 7 | 8 | - [Project Website](https://sites.google.com/view/bc-z/home/) 9 | - [Paper](https://arxiv.org/abs/2202.02005) 10 | - [Google AI Blog Post](https://ai.googleblog.com/2022/02/can-robots-follow-instructions-for-new.html) 11 | 12 | ## Training the Model 13 | 14 | Download the data in the TFRecords format is open-sourced here: 15 | https://www.kaggle.com/google/bc-z-robot 16 | 17 | ``` 18 | TRAIN_DATA="/path/to/bcz-21task_v9.0.1.tfrecord/train*,/path/to/bcz-79task_v16.0.0.tfrecord/train*" 19 | EVAL_DATA="/path/to/bcz-21task_v9.0.1.tfrecord/val*,/path/to/bcz-79task_v16.0.0.tfrecord/val*" 20 | python3 -m tensor2robot.bin.run_t2r_trainer --logtostderr \ 21 | --gin_configs="tensor2robot/research/bcz/configs/run_train_bc_langcond_trajectory.gin" \ 22 | --gin_bindings="train_eval_model.model_dir='/tmp/bcz/'" \ 23 | --gin_bindings="TRAIN_DATA='${TRAIN_DATA}' \ 24 | --gin_bindings="EVAL_DATA='${EVAL_DATA}'" 25 | ``` 26 | -------------------------------------------------------------------------------- /research/bcz/configs/common_imagedistortions.gin: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Tensor2Robot Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | ApplyPhotometricImageDistortions.random_brightness = True 17 | ApplyPhotometricImageDistortions.random_saturation = True 18 | ApplyPhotometricImageDistortions.random_hue = True 19 | ApplyPhotometricImageDistortions.random_contrast = True 20 | -------------------------------------------------------------------------------- /research/bcz/configs/common_imports.gin: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Tensor2Robot Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import tensor2robot.input_generators.default_input_generator 17 | import tensor2robot.utils.train_eval 18 | import tensor2robot.models.abstract_model 19 | import tensor2robot.layers.bcz_networks 20 | import tensor2robot.research.bcz.model 21 | import tensor2robot.hooks.golden_values_hook_builder 22 | import tensor2robot.hooks.gin_config_hook_builder 23 | import tensor2robot.predictors.exported_savedmodel_predictor 24 | -------------------------------------------------------------------------------- /research/bcz/configs/run_train_bc_gtcond_trajectory.gin: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Tensor2Robot Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | include 'tensor2robot/research/bcz/configs/common_imports.gin' 17 | include 'tensor2robot/research/bcz/configs/common_imagedistortions.gin' 18 | 19 | ####################################### 20 | # INPUT GENERATION 21 | ####################################### 22 | 23 | TRAIN_DATA="" 24 | EVAL_DATA="" 25 | 26 | STATE_COMPONENTS = [] 27 | ACTION_COMPONENTS = [ 28 | ('xyz', 3, True, 100.), 29 | ('axis_angle', 3, True, 10.), 30 | ('target_close', 1, False, 0.5), # best m104 model used this param. 31 | ] 32 | 33 | NUM_WAYPOINTS = 10 34 | 35 | TRAIN_BATCH_SIZE = 32 36 | EVAL_BATCH_SIZE = 64 37 | TRAIN_FRACTION = 1.0 38 | 39 | TRAIN_INPUT_GENERATOR = @train_input_generator/FractionalRecordInputGenerator() 40 | train_input_generator/FractionalRecordInputGenerator.file_patterns = %TRAIN_DATA 41 | train_input_generator/FractionalRecordInputGenerator.batch_size = %TRAIN_BATCH_SIZE 42 | train_input_generator/FractionalRecordInputGenerator.file_fraction = %TRAIN_FRACTION 43 | 44 | EVAL_INPUT_GENERATOR = @eval_input_generator/DefaultRecordInputGenerator() 45 | eval_input_generator/DefaultRecordInputGenerator.file_patterns = %EVAL_DATA 46 | eval_input_generator/DefaultRecordInputGenerator.batch_size = %EVAL_BATCH_SIZE 47 | 48 | ####################################### 49 | # PREPROCESSOR & DATA AUGMENTATION 50 | ####################################### 51 | CROP_SIZE = 450 52 | IMAGE_SIZE = 150 53 | BCZPreprocessor.binarize_gripper = True 54 | BCZPreprocessor.crop_size = (%CROP_SIZE, %CROP_SIZE) 55 | BCZPreprocessor.image_size = (%IMAGE_SIZE, %IMAGE_SIZE) 56 | 57 | ####################################### 58 | # MODEL 59 | ####################################### 60 | 61 | BCZModel.image_size = (%IMAGE_SIZE, %IMAGE_SIZE) 62 | BCZModel.network_fn = @resnet_film_network 63 | BCZModel.predict_stop = False 64 | MultiHeadMLP.stop_gradient_future_waypoints = False 65 | 66 | resnet_film_network.film_generator_fn = @linear_film_generator 67 | resnet_model.resnet_size = 18 68 | BCZModel.ignore_task_embedding = False 69 | BCZModel.task_embedding_noise_std = 0.1 70 | linear_film_generator.enabled_block_layers = [True, True, True, True] 71 | MultiHeadMLP.stop_gradient_future_waypoints = False 72 | 73 | BCZPreprocessor.cutout_size = 0 # Was 20 in paper, but not implemented in OSS. 74 | resnet_film_network.fc_layers = (256, 256) 75 | 76 | compute_stop_state_loss.class_weights = [[1.0309278350515465, 0, 33.333333333333336]] 77 | BCZModel.num_past = 0 78 | BCZModel.num_waypoints = %NUM_WAYPOINTS 79 | BCZModel.summarize_gradients = False 80 | BCZModel.state_components = %STATE_COMPONENTS 81 | BCZModel.action_components = %ACTION_COMPONENTS 82 | resnet_model.resnet_size = 18 83 | train_eval_model.t2r_model = @BCZModel() 84 | 85 | ##################################### 86 | # TRAINING 87 | ###################################### 88 | default_create_optimizer_fn.learning_rate = %LEARNING_RATE 89 | LEARNING_RATE = 2.5e-4 90 | 91 | train_eval_model.max_train_steps = 50000 92 | train_eval_model.eval_steps = 1000 93 | train_eval_model.eval_throttle_secs = 300 # Export model every 5 min. 94 | train_eval_model.input_generator_train = %TRAIN_INPUT_GENERATOR 95 | train_eval_model.input_generator_eval = %EVAL_INPUT_GENERATOR 96 | train_eval_model.create_exporters_fn = @create_default_exporters 97 | train_eval_model.chief_train_hook_builders = [@OperativeGinConfigLoggerHookBuilder()] 98 | create_default_exporters.exports_to_keep = None # Keep all ckpts to support evaluation. 99 | 100 | # Export best numpy models based on arm joint loss. 101 | create_valid_result_smaller.result_key = 'mean_first_xyz_error' 102 | 103 | # Save checkpoints frequently for evaluation. 104 | tf.estimator.RunConfig.save_summary_steps = 1000 # save summary every n global steps 105 | 106 | -------------------------------------------------------------------------------- /research/bcz/configs/run_train_bc_langcond_trajectory.gin: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Tensor2Robot Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | include 'tensor2robot/research/bcz/configs/run_train_bc_gtcond_trajectory.gin' 17 | 18 | TRAIN_INPUT_GENERATOR = @train_input_generator/WeightedRecordInputGenerator() 19 | train_input_generator/WeightedRecordInputGenerator.file_patterns = %TRAIN_DATA 20 | train_input_generator/WeightedRecordInputGenerator.batch_size = %TRAIN_BATCH_SIZE 21 | train_input_generator/WeightedRecordInputGenerator.weights = [0.1, 0.9] # 10% 21-task data, 90% 83-task data. 22 | train_input_generator/WeightedRecordInputGenerator.seed = 0 23 | 24 | BCZModel.cond_modality = %ConditionMode.LANGUAGE_EMBEDDING 25 | BCZModel.task_embedding_noise_std = 0.1 26 | BCZPreprocessor.binarize_gripper = False 27 | BCZPreprocessor.rescale_gripper = True 28 | 29 | IMAGE_SIZE = 200 30 | -------------------------------------------------------------------------------- /research/bcz/model_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Tensor2Robot Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests BCZ models with placeholder data.""" 17 | 18 | import itertools 19 | import os 20 | from absl import flags 21 | from absl.testing import parameterized 22 | import gin 23 | from tensor2robot.research.bcz import model 24 | from tensor2robot.utils.t2r_test_fixture import T2RModelFixture 25 | import tensorflow.compat.v1 as tf 26 | from tensorflow.compat.v1 import estimator as tf_estimator 27 | 28 | FLAGS = flags.FLAGS 29 | TRAIN = tf_estimator.ModeKeys.TRAIN 30 | _POSE_COMPONENTS_LIST = list(itertools.product(*[ 31 | [True, False], [True, False], ['axis_angle', 'quaternion'], [False]])) 32 | 33 | 34 | class BCZModelTest(tf.test.TestCase, parameterized.TestCase): 35 | 36 | def setUp(self): 37 | gin.clear_config() 38 | super(BCZModelTest, self).setUp() 39 | self._fixture = T2RModelFixture(test_case=self, use_tpu=False) 40 | 41 | @parameterized.parameters( 42 | (model.spatial_softmax_network), 43 | (model.resnet_film_network), 44 | ) 45 | def test_network_fn(self, network_fn): 46 | model_name = 'BCZModel' 47 | gin.bind_parameter( 48 | 'BCZModel.network_fn', network_fn) 49 | gin.parse_config('BCZPreprocessor.mock_subtask = True') 50 | gin.parse_config( 51 | 'resnet_film_network.film_generator_fn = @linear_film_generator') 52 | self._fixture.random_train(model, model_name) 53 | 54 | def test_all_components(self): 55 | """Train with all pose components.""" 56 | model_name = 'BCZModel' 57 | pose_components = [ 58 | ('xyz', 3, True, 100.), 59 | ('quaternion', 4, False, 10.), 60 | ('axis_angle', 3, True, 10.), 61 | ('arm_joints', 7, True, 1.), 62 | ('target_close', 1, False, 1.), 63 | ] 64 | gin.bind_parameter( 65 | 'BCZModel.action_components', pose_components) 66 | gin.parse_config('BCZPreprocessor.mock_subtask = True') 67 | gin.parse_config( 68 | 'resnet_film_network.film_generator_fn = @linear_film_generator') 69 | self._fixture.random_train(model, model_name) 70 | 71 | @parameterized.parameters(*_POSE_COMPONENTS_LIST) 72 | def test_pose_components(self, 73 | residual_xyz, 74 | residual_angle, 75 | angle_format, 76 | residual_gripper): 77 | """Tests with different action configurations.""" 78 | model_name = 'BCZModel' 79 | if angle_format == 'axis_angle': 80 | angle_size = 3 81 | elif angle_format == 'quaternion': 82 | angle_size = 4 83 | action_components = [ 84 | ('xyz', 3, residual_xyz, 100.), 85 | (angle_format, angle_size, residual_angle, 10.), 86 | ('target_close', 1, residual_gripper, 1.), 87 | ] 88 | gin.bind_parameter( 89 | 'BCZModel.action_components', action_components) 90 | gin.bind_parameter( 91 | 'BCZModel.state_components', []) 92 | gin.parse_config('BCZPreprocessor.mock_subtask = True') 93 | gin.parse_config( 94 | 'resnet_film_network.film_generator_fn = @linear_film_generator') 95 | self._fixture.random_train(model, model_name) 96 | 97 | def test_random_train(self): 98 | base_dir = 'tensor2robot' 99 | 100 | gin_config = os.path.join( 101 | FLAGS.test_srcdir, base_dir, 'research/bcz/configs', 102 | 'run_train_bc_langcond_trajectory.gin') 103 | model_name = 'BCZModel' 104 | gin_bindings = [ 105 | 'train_eval_model.eval_steps = 1', 106 | 'EVAL_INPUT_GENERATOR=None', 107 | ] 108 | gin.parse_config_files_and_bindings( 109 | [gin_config], gin_bindings, finalize_config=False) 110 | self._fixture.random_train(model, model_name) 111 | 112 | def tearDown(self): 113 | gin.clear_config() 114 | super(BCZModelTest, self).tearDown() 115 | 116 | 117 | if __name__ == '__main__': 118 | tf.test.main() 119 | -------------------------------------------------------------------------------- /research/bcz/pose_components_lib.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Tensor2Robot Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Action space definitions for MetaTidy models. 17 | """ 18 | 19 | from typing import Text, Tuple 20 | 21 | # Name, size, whether it is residual or not, and loss weight. 22 | # This is used to parameterize action labels. 23 | ActionComponent = Tuple[Text, int, bool, float] 24 | 25 | # Name, size, whether residual or not. 26 | # This is used to parameterize proprioceptive state inputs. 27 | StateComponent = Tuple[Text, int, bool] 28 | 29 | DEFAULT_STATE_COMPONENTS = [] 30 | DEFAULT_ACTION_COMPONENTS = [ 31 | ('xyz', 3, True, 100.), 32 | ('quaternion', 4, False, 10.), 33 | ('target_close', 1, False, 1.), 34 | ] 35 | -------------------------------------------------------------------------------- /research/dql_grasping_lib/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Tensor2Robot Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /research/dql_grasping_lib/tf_modules.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Tensor2Robot Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Reused modules for building actors/critics for grasping task. 17 | """ 18 | 19 | import gin 20 | import tensorflow.compat.v1 as tf 21 | from tensorflow.contrib import slim 22 | 23 | 24 | @gin.configurable 25 | def argscope(is_training=None, normalizer_fn=slim.layer_norm): 26 | """Default TF argscope used for convnet-based grasping models. 27 | 28 | Args: 29 | is_training: Whether this argscope is for training or inference. 30 | normalizer_fn: Which conv/fc normalizer to use. 31 | Returns: 32 | Dictionary of argument overrides. 33 | """ 34 | with slim.arg_scope([slim.batch_norm, slim.dropout], is_training=is_training): 35 | with slim.arg_scope( 36 | [slim.conv2d, slim.fully_connected], 37 | weights_initializer=tf.truncated_normal_initializer(stddev=0.01), 38 | activation_fn=tf.nn.relu, 39 | normalizer_fn=normalizer_fn): 40 | with slim.arg_scope( 41 | [slim.conv2d, slim.max_pool2d], stride=2, padding='VALID') as scope: 42 | return scope 43 | 44 | 45 | def tile_to_match_context(net, context): 46 | """Tiles net along a new axis=1 to match context. 47 | 48 | Repeats minibatch elements of `net` tensor to match multiple corresponding 49 | minibatch elements from `context`. 50 | Args: 51 | net: Tensor of shape [num_batch_net, ....]. 52 | context: Tensor of shape [num_batch_net, num_examples, context_size]. 53 | Returns: 54 | Tensor of shape [num_batch_net, num_examples, ...], where each minibatch 55 | element of net has been tiled M times where M = num_batch_context / 56 | num_batch_net. 57 | """ 58 | with tf.name_scope('tile_to_context'): 59 | num_samples = tf.shape(context)[1] 60 | net_examples = tf.expand_dims(net, 1) # [batch_size, 1, ...] 61 | 62 | net_ndim = len(net_examples.get_shape().as_list()) 63 | # Tile net by num_samples in axis=1. 64 | multiples = [1]*net_ndim 65 | multiples[1] = num_samples 66 | net_examples = tf.tile(net_examples, multiples) 67 | return net_examples 68 | 69 | 70 | def add_context(net, context): 71 | """Merges visual perception with context using elementwise addition. 72 | 73 | Actions are reshaped to match net dimension depth-wise, and are added to 74 | the conv layers by broadcasting element-wise across H, W extent. 75 | 76 | Args: 77 | net: Tensor of shape [batch_size, H, W, C]. 78 | context: Tensor of shape [batch_size * num_examples, C]. 79 | Returns: 80 | Tensor with shape [batch_size * num_examples, H, W, C] 81 | """ 82 | num_batch_net = tf.shape(net)[0] 83 | _, h, w, d1 = net.get_shape().as_list() 84 | _, d2 = context.get_shape().as_list() 85 | assert d1 == d2 86 | context = tf.reshape(context, [num_batch_net, -1, d2]) 87 | net_examples = tile_to_match_context(net, context) 88 | # Flatten first two dimensions. 89 | net = tf.reshape(net_examples, [-1, h, w, d1]) 90 | context = tf.reshape(context, [-1, 1, 1, d2]) 91 | context = tf.tile(context, [1, h, w, 1]) 92 | net = tf.add_n([net, context]) 93 | return net 94 | -------------------------------------------------------------------------------- /research/grasp2vec/README.md: -------------------------------------------------------------------------------- 1 | # Grasp2Vec 2 | 3 | Source codes for reproducing "Grasp2Vec: Learning Object Representations from 4 | Self-Supervised Grasping". 5 | 6 | Links 7 | 8 | - [Project Website](https://sites.google.com/site/grasp2vec/) 9 | - [Paper](https://arxiv.org/abs/1811.06964) 10 | - [Google AI Blog Post](https://ai.googleblog.com/2018/12/grasp2vec-learning-object.html) 11 | 12 | ## Authors 13 | 14 | Eric Jang*1, Coline Devin*2, Vincent 15 | Vanhoucke1, Sergey Levine12 16 | 17 | *Equal Contribution, 1 Google Brain, 2UC 18 | Berkeley 19 | 20 | ## Training the Model 21 | 22 | Data is not included in this repository, so you will have to provide your own 23 | training/eval datasets of TFRecords. The Grasp2Vec T2R model attempts to parse 24 | the following Feature spec from the data, before cropping and resizing the 25 | parsed images: 26 | 27 | ``` 28 | tspec.pregrasp_image = TensorSpec(shape=(512, 640, 3), 29 | dtype=tf.uint8, name='image', data_format='jpeg') 30 | tspec.postgrasp_image = TensorSpec( 31 | shape=(512, 640, 3), dtype=tf.uint8, name='postgrasp_image', 32 | data_format='jpeg') 33 | tspec.goal_image = TensorSpec( 34 | shape=(512, 640, 3), dtype=tf.uint8, name='present_image', 35 | data_format='jpeg') 36 | ``` 37 | 38 | Note that `image`, `postgrasp_image`, `present_image` are the names of features 39 | stored in the TFExample feature map. 40 | 41 | ``` 42 | python3 -m tensor2robot.bin.run_t2r_trainer --logtostderr \ 43 | --gin_configs="tensor2robot/research/grasp2vec/configs/train_grasp2vec.gin" \ 44 | --gin_bindings="train_eval_model.model_dir='/tmp/grasp2vec/'" \ 45 | --gin_bindings="TRAIN_DATA='/path/to/your/data/train*' \ 46 | --gin_bindings="EVAL_DATA='/path/to/your/data/val*'" 47 | ``` 48 | 49 | Tensorboard will show heatmap localization visualization summaries as shown in 50 | the paper. 51 | -------------------------------------------------------------------------------- /research/grasp2vec/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Tensor2Robot Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /research/grasp2vec/configs/common_imports.gin: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Tensor2Robot Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import tensor2robot.input_generators.default_input_generator 17 | import tensor2robot.utils.train_eval 18 | import tensor2robot.models.abstract_model 19 | import tensor2robot.research.grasp2vec.grasp2vec_model 20 | -------------------------------------------------------------------------------- /research/grasp2vec/configs/train_grasp2vec.gin: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Tensor2Robot Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | include 'tensor2robot/research/grasp2vec/configs/common_imports.gin' 17 | 18 | ######## INPUT GENERATION 19 | 20 | TRAIN_DATA="/path/to/your/data/train*" 21 | EVAL_DATA="/path/to/your/data/val*" 22 | 23 | TRAIN_BATCH_SIZE = 8 24 | EVAL_BATCH_SIZE = 8 25 | 26 | TRAIN_INPUT_GENERATOR = @train_input_generator/DefaultRecordInputGenerator() 27 | train_input_generator/DefaultRecordInputGenerator.file_patterns = %TRAIN_DATA 28 | train_input_generator/DefaultRecordInputGenerator.batch_size = %TRAIN_BATCH_SIZE 29 | 30 | EVAL_INPUT_GENERATOR = @eval_input_generator/DefaultRecordInputGenerator() 31 | eval_input_generator/DefaultRecordInputGenerator.file_patterns = %EVAL_DATA 32 | eval_input_generator/DefaultRecordInputGenerator.batch_size = %EVAL_BATCH_SIZE 33 | 34 | ####################################### 35 | # MODEL 36 | ####################################### 37 | 38 | train_eval_model.t2r_model = @Grasp2VecModel() 39 | 40 | default_create_optimizer_fn.learning_rate = 0.0001 41 | 42 | ##################################### 43 | # TRAINING 44 | ###################################### 45 | 46 | train_eval_model.max_train_steps = 50000 47 | train_eval_model.eval_steps = 200 48 | train_eval_model.input_generator_train = %TRAIN_INPUT_GENERATOR 49 | train_eval_model.input_generator_eval = %EVAL_INPUT_GENERATOR 50 | train_eval_model.create_exporters_fn = @create_default_exporters 51 | -------------------------------------------------------------------------------- /research/grasp2vec/networks.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Tensor2Robot Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Implements forward pass for embeddings. 17 | """ 18 | 19 | from tensor2robot.research.grasp2vec import resnet 20 | import tensorflow.compat.v1 as tf 21 | from tensorflow.compat.v1 import estimator as tf_estimator 22 | 23 | 24 | def Embedding(image, mode, params, reuse=tf.AUTO_REUSE, scope='scene'): 25 | """Implements scene or goal embedding. 26 | 27 | Args: 28 | image: Batch of images corresponding to scene or goal. 29 | mode: Mode is tf.estimator.ModeKeys.EVAL, TRAIN, or PREDICT (unused). 30 | params: Hyperparameters for the network. 31 | reuse: Reuse parameter for variable scope. 32 | scope: The variable_scope to use for the variables. 33 | Returns: 34 | A tuple (batch of summed embeddings, batch of embedding maps). 35 | """ 36 | del params 37 | is_training = mode == tf_estimator.ModeKeys.TRAIN 38 | with tf.variable_scope(scope, reuse=reuse): 39 | scene = resnet.get_resnet50_spatial(image, is_training) 40 | scene = tf.nn.relu(scene) 41 | summed_scene = tf.reduce_mean(scene, axis=[1, 2]) 42 | return summed_scene, scene 43 | -------------------------------------------------------------------------------- /research/pose_env/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Tensor2Robot Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /research/pose_env/configs/common_imports.gin: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Tensor2Robot Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import tensor2robot.input_generators.default_input_generator 17 | import tensor2robot.meta_learning.meta_policies 18 | import tensor2robot.meta_learning.run_meta_env 19 | import tensor2robot.utils.train_eval 20 | import tensor2robot.models.abstract_model 21 | import tensor2robot.utils.continuous_collect_eval 22 | import tensor2robot.research.pose_env.episode_to_transitions 23 | import tensor2robot.research.pose_env.pose_env 24 | import tensor2robot.research.pose_env.pose_env_maml_models 25 | import tensor2robot.research.pose_env.pose_env_models 26 | import tensor2robot.utils.writer 27 | import tensor2robot.research.dql_grasping_lib.run_env 28 | -------------------------------------------------------------------------------- /research/pose_env/configs/run_random_collect.gin: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Tensor2Robot Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | include 'tensor2robot/research/pose_env/configs/common_imports.gin' 17 | 18 | collect_eval_loop.collect_env = @train/PoseToyEnv() 19 | collect_eval_loop.eval_env = None 20 | collect_eval_loop.run_agent_fn = @run_meta_env 21 | collect_eval_loop.policy_class = @PoseEnvRandomPolicy 22 | collect_eval_loop.num_collect = None # this is ignored 23 | 24 | train/PoseToyEnv.render_mode = 'DIRECT' 25 | train/PoseToyEnv.hidden_drift = True 26 | 27 | run_meta_env.num_adaptations_per_task = 1 28 | run_meta_env.num_tasks = 5000 29 | run_meta_env.num_episodes_per_adaptation = 4 30 | 31 | # Save out for visualization. 32 | run_meta_env.episode_to_transitions_fn = @episode_to_transitions_pose_toy 33 | run_meta_env.replay_writer = @TFRecordReplayWriter() 34 | -------------------------------------------------------------------------------- /research/pose_env/configs/run_train_reg.gin: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Tensor2Robot Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | include 'tensor2robot/research/pose_env/configs/common_imports.gin' 17 | 18 | TRAIN_INPUT_GENERATOR = @train_input_generator/DefaultConstantInputGenerator() 19 | train_input_generator/DefaultConstantInputGenerator.constant_value = 1.0 20 | train_input_generator/DefaultRecordInputGenerator.batch_size = 64 21 | 22 | EVAL_INPUT_GENERATOR = @eval_input_generator/DefaultConstantInputGenerator() 23 | eval_input_generator/DefaultConstantInputGenerator.constant_value = 1.0 24 | eval_input_generator/DefaultRecordInputGenerator.batch_size = 64 25 | 26 | # Training - input generator and preprocessor numbers need to match up. 27 | train_eval_model.t2r_model = @PoseEnvRegressionModel() 28 | train_eval_model.max_train_steps = 5000 29 | train_eval_model.eval_steps = 1000 30 | train_eval_model.input_generator_train = %TRAIN_INPUT_GENERATOR 31 | train_eval_model.input_generator_eval = %EVAL_INPUT_GENERATOR 32 | 33 | # Collection & Evaluation. 34 | collect_eval_loop.collect_env = None 35 | collect_eval_loop.eval_env = @train/PoseToyEnv() 36 | collect_eval_loop.run_agent_fn = @run_meta_env 37 | collect_eval_loop.num_eval = 100 38 | 39 | train/PoseToyEnv.hidden_drift = True 40 | 41 | train/PoseToyEnv.render_mode = 'DIRECT' 42 | 43 | run_meta_env.num_episodes = 100 44 | run_meta_env.num_episodes_per_adaptation = 2 45 | 46 | collect_eval_loop.policy_class = @RegressionPolicy 47 | RegressionPolicy.t2r_model = @PoseEnvRegressionModel() 48 | 49 | # Save data out for visualization. 50 | run_meta_env.episode_to_transitions_fn = @episode_to_transitions_pose_toy 51 | run_meta_env.replay_writer = @TFRecordReplayWriter() 52 | -------------------------------------------------------------------------------- /research/pose_env/configs/run_train_reg_maml.gin: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Tensor2Robot Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | include 'tensor2robot/research/pose_env/configs/common_imports.gin' 17 | 18 | # Input Pipeline 19 | 20 | TRAIN_INPUT_GENERATOR = @train_input_generator/DefaultConstantInputGenerator() 21 | train_input_generator/DefaultConstantInputGenerator.constant_value = 1.0 22 | train_input_generator/DefaultRecordInputGenerator.batch_size = 64 23 | 24 | EVAL_INPUT_GENERATOR = @eval_input_generator/DefaultConstantInputGenerator() 25 | eval_input_generator/DefaultConstantInputGenerator.constant_value = 1.0 26 | eval_input_generator/DefaultRecordInputGenerator.batch_size = 64 27 | 28 | # Model 29 | 30 | train_eval_model.t2r_model = @train/PoseEnvRegressionModelMAML() 31 | train_eval_model.max_train_steps = 5000 32 | train_eval_model.eval_steps = 100 33 | train_eval_model.input_generator_train = %TRAIN_INPUT_GENERATOR 34 | train_eval_model.input_generator_eval = %EVAL_INPUT_GENERATOR 35 | 36 | PoseEnvRegressionModelMAML.base_model = @PoseEnvRegressionModel() 37 | PoseEnvRegressionModelMAML.preprocessor_cls = @FixedLenMetaExamplePreprocessor 38 | train_eval_model.t2r_model = @train/PoseEnvRegressionModelMAML() 39 | 40 | FixedLenMetaExamplePreprocessor.num_condition_samples_per_task = 1 41 | FixedLenMetaExamplePreprocessor.num_inference_samples_per_task = 1 42 | 43 | # MAMLInnerLoopGradientDescent.learning_rate = 0.001 44 | # MAMLModel.num_inner_loop_steps = 4 45 | 46 | # Collection & Evaluation. 47 | collect_eval_loop.collect_env = None 48 | collect_eval_loop.eval_env = @train/PoseToyEnv() 49 | collect_eval_loop.run_agent_fn = @run_meta_env 50 | train/PoseToyEnv.hidden_drift = True 51 | train/PoseToyEnv.render_mode = 'DIRECT' 52 | collect_eval_loop.num_eval = 100 53 | run_meta_env.num_adaptations_per_task = 4 54 | run_meta_env.num_episodes_per_adaptation = 1 55 | 56 | collect_eval_loop.policy_class = @MAMLRegressionPolicy 57 | MAMLRegressionPolicy.t2r_model = @eval/PoseEnvRegressionModelMAML() 58 | 59 | # Save data out for visualization. 60 | run_meta_env.episode_to_transitions_fn = @episode_to_transitions_pose_toy 61 | run_meta_env.replay_writer = @TFRecordReplayWriter() 62 | -------------------------------------------------------------------------------- /research/pose_env/episode_to_transitions.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Tensor2Robot Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Functions for converting env episode data to tfrecords of transitions.""" 17 | 18 | import gin 19 | from PIL import Image 20 | from tensor2robot.utils import image 21 | import tensorflow.compat.v1 as tf 22 | 23 | _bytes_feature = ( 24 | lambda v: tf.train.Feature(bytes_list=tf.train.BytesList(value=v))) 25 | _int64_feature = ( 26 | lambda v: tf.train.Feature(int64_list=tf.train.Int64List(value=v))) 27 | _float_feature = ( 28 | lambda v: tf.train.Feature(float_list=tf.train.FloatList(value=v))) 29 | 30 | 31 | @gin.configurable 32 | def episode_to_transitions_pose_toy(episode_data): 33 | """Converts pose toy env episode data to transition Examples.""" 34 | # This is just saving data for a supervised regression problem, so obs_tp1 35 | # can be discarded. 36 | transitions = [] 37 | for transition in episode_data: 38 | (obs_t, action, reward, obs_tp1, done, debug) = transition 39 | del obs_tp1 40 | del done 41 | features = {} 42 | obs_t = Image.fromarray(obs_t) 43 | features['state/image'] = _bytes_feature([image.jpeg_string(obs_t)]) 44 | features['pose'] = _float_feature(action.flatten().tolist()) 45 | features['reward'] = _float_feature([reward]) 46 | features['target_pose'] = _float_feature(debug['target_pose'].tolist()) 47 | transitions.append( 48 | tf.train.Example(features=tf.train.Features(feature=features))) 49 | return transitions 50 | 51 | 52 | -------------------------------------------------------------------------------- /research/pose_env/pose_env_maml_models.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Tensor2Robot Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """MAML-based meta-learning models for the duck task.""" 17 | 18 | import gin 19 | import numpy as np 20 | from tensor2robot.meta_learning import maml_model 21 | from tensor2robot.utils import tensorspec_utils 22 | from tensorflow.compat.v1 import estimator as tf_estimator 23 | from tensorflow.contrib import framework as contrib_framework 24 | nest = contrib_framework.nest 25 | 26 | 27 | @gin.configurable 28 | class PoseEnvRegressionModelMAML(maml_model.MAMLModel): 29 | """MAML Regression environment for duck task.""" 30 | 31 | def _make_dummy_labels(self): 32 | """Helper function to make dummy labels for pack_labels.""" 33 | label_spec = self._base_model.get_label_specification( 34 | tf_estimator.ModeKeys.TRAIN) 35 | reward_shape = tuple(label_spec.reward.shape) 36 | pose_shape = tuple(label_spec.target_pose.shape) 37 | dummy_reward = np.zeros(reward_shape).astype(np.float32) 38 | dummy_pose = np.zeros(pose_shape).astype(np.float32) 39 | return tensorspec_utils.TensorSpecStruct( 40 | reward=dummy_reward, target_pose=dummy_pose) 41 | 42 | def _select_inference_output(self, predictions): 43 | """Inference output selection for regression models.""" 44 | # We select our output for inference. 45 | predictions.condition_output = ( 46 | predictions.full_condition_output.inference_output) 47 | predictions.inference_output = ( 48 | predictions.full_inference_output.inference_output) 49 | return predictions 50 | 51 | def pack_features(self, state, prev_episode_data, timestep): 52 | """Combines current state and conditioning data into MetaExample spec. 53 | 54 | See create_metaexample_spec for an example of the spec layout. 55 | 56 | If prev_episode_data does not contain enough episodes to fill 57 | num_condition_samples_per_task, we stuff dummy episodes with reward=0.5 58 | so that no inner gradients are applied. 59 | 60 | Args: 61 | state: VRGripperObservation containing image and pose. 62 | prev_episode_data: A list of episode data, each of which is a list of 63 | tuples containing transition data. Each transition tuple takes the form 64 | (obs, action, rew, new_obs, done, debug). 65 | timestep: Current episode timestep. 66 | Returns: 67 | TensorSpecStruct containing conditioning (features, labels) 68 | and inference (features) keys. 69 | Raises: 70 | ValueError: If no demonstration is provided. 71 | """ 72 | meta_features = tensorspec_utils.TensorSpecStruct() 73 | meta_features['inference/features/state/0'] = state 74 | def pack_condition_features(episode_data, idx, dummy_values=False): 75 | """Pack previous episode data into condition_ep* features/labels. 76 | 77 | Args: 78 | episode_data: List of (obs, action, rew, new_obs, done, debug) tuples. 79 | idx: Index of the conditioning episode. 0 for demo, 1 for first trial, 80 | etc. 81 | dummy_values: If an episode is not available yet, set the loss_mask 82 | to 0. 83 | 84 | """ 85 | transition = episode_data[0] 86 | meta_features['condition/features/state/%d' % idx] = transition[0] 87 | reward = np.array([transition[2]]) 88 | reward = 2 * reward - 1 89 | if dummy_values: 90 | # success_weight of 0. = no gradients in inner loop for this batch. 91 | reward = np.array([0.]) 92 | meta_features['condition/labels/target_pose/%d' % idx] = transition[1] 93 | meta_features['condition/labels/reward/%d' % idx] = reward.astype( 94 | np.float32) 95 | 96 | if prev_episode_data: 97 | pack_condition_features(prev_episode_data[0], 0) 98 | else: 99 | dummy_labels = self._make_dummy_labels() 100 | dummy_episode = [(state, dummy_labels.target_pose, dummy_labels.reward)] 101 | pack_condition_features(dummy_episode, 0, dummy_values=True) 102 | return nest.map_structure(lambda x: np.expand_dims(x, 0), meta_features) 103 | -------------------------------------------------------------------------------- /research/pose_env/pose_env_models_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Tensor2Robot Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Integration tests for training pose_env models.""" 17 | 18 | import os 19 | 20 | from absl.testing import absltest 21 | from absl.testing import parameterized 22 | import gin 23 | from tensor2robot.input_generators import default_input_generator 24 | from tensor2robot.meta_learning import meta_policies 25 | from tensor2robot.meta_learning import preprocessors 26 | from tensor2robot.predictors import checkpoint_predictor 27 | from tensor2robot.research.pose_env import pose_env 28 | from tensor2robot.research.pose_env import pose_env_maml_models 29 | from tensor2robot.research.pose_env import pose_env_models 30 | from tensor2robot.utils import train_eval 31 | from tensor2robot.utils import train_eval_test_utils 32 | import tensorflow.compat.v1 as tf # tf 33 | 34 | 35 | BATCH_SIZE = 1 36 | MAX_TRAIN_STEPS = 1 37 | EVAL_STEPS = 1 38 | 39 | NUM_TRAIN_SAMPLES_PER_TASK = 1 40 | NUM_VAL_SAMPLES_PER_TASK = 1 41 | 42 | FLAGS = tf.app.flags.FLAGS 43 | 44 | 45 | class PoseEnvModelsTest(parameterized.TestCase): 46 | 47 | def setUp(self): 48 | super(PoseEnvModelsTest, self).setUp() 49 | base_dir = 'tensor2robot' 50 | test_data = os.path.join(FLAGS.test_srcdir, 51 | base_dir, 52 | 'test_data/pose_env_test_data.tfrecord') 53 | self._train_log_dir = FLAGS.test_tmpdir 54 | if tf.io.gfile.exists(self._train_log_dir): 55 | tf.io.gfile.rmtree(self._train_log_dir) 56 | gin.bind_parameter('train_eval_model.max_train_steps', 3) 57 | gin.bind_parameter('train_eval_model.eval_steps', 2) 58 | 59 | self._record_input_generator = ( 60 | default_input_generator.DefaultRecordInputGenerator( 61 | batch_size=BATCH_SIZE, file_patterns=test_data)) 62 | 63 | self._meta_record_input_generator_train = ( 64 | default_input_generator.DefaultRandomInputGenerator( 65 | batch_size=BATCH_SIZE)) 66 | self._meta_record_input_generator_eval = ( 67 | default_input_generator.DefaultRandomInputGenerator( 68 | batch_size=BATCH_SIZE)) 69 | 70 | def test_mc(self): 71 | train_eval.train_eval_model( 72 | t2r_model=pose_env_models.PoseEnvContinuousMCModel(), 73 | input_generator_train=self._record_input_generator, 74 | input_generator_eval=self._record_input_generator, 75 | create_exporters_fn=None) 76 | 77 | def test_regression(self): 78 | train_eval.train_eval_model( 79 | t2r_model=pose_env_models.PoseEnvRegressionModel(), 80 | input_generator_train=self._record_input_generator, 81 | input_generator_eval=self._record_input_generator, 82 | create_exporters_fn=None) 83 | 84 | def test_regression_maml(self): 85 | maml_model = pose_env_maml_models.PoseEnvRegressionModelMAML( 86 | base_model=pose_env_models.PoseEnvRegressionModel()) 87 | train_eval.train_eval_model( 88 | t2r_model=maml_model, 89 | input_generator_train=self._meta_record_input_generator_train, 90 | input_generator_eval=self._meta_record_input_generator_eval, 91 | create_exporters_fn=None) 92 | 93 | def _test_policy_interface(self, policy, restore=True): 94 | urdf_root = pose_env.get_pybullet_urdf_root() 95 | self.assertTrue(os.path.exists(urdf_root)) 96 | env = pose_env.PoseToyEnv( 97 | urdf_root=urdf_root, render_mode='DIRECT') 98 | env.reset_task() 99 | obs = env.reset() 100 | if restore: 101 | policy.restore() 102 | policy.reset_task() 103 | action = policy.SelectAction(obs, None, 0) 104 | 105 | new_obs, rew, done, env_debug = env.step(action) 106 | episode_data = [[(obs, action, rew, new_obs, done, env_debug)]] 107 | policy.adapt(episode_data) 108 | 109 | policy.SelectAction(new_obs, None, 1) 110 | 111 | def test_regression_maml_policy_interface(self): 112 | t2r_model = pose_env_maml_models.PoseEnvRegressionModelMAML( 113 | base_model=pose_env_models.PoseEnvRegressionModel(), 114 | preprocessor_cls=preprocessors.FixedLenMetaExamplePreprocessor) 115 | predictor = checkpoint_predictor.CheckpointPredictor(t2r_model=t2r_model) 116 | predictor.init_randomly() 117 | policy = meta_policies.MAMLRegressionPolicy(t2r_model, predictor=predictor) 118 | self._test_policy_interface(policy, restore=False) 119 | 120 | @parameterized.parameters( 121 | ('run_train_reg_maml.gin',), 122 | ('run_train_reg.gin',)) 123 | def test_train_eval_gin(self, gin_file): 124 | base_dir = 'tensor2robot' 125 | full_gin_path = os.path.join( 126 | FLAGS.test_srcdir, base_dir, 'research/pose_env/configs', gin_file) 127 | model_dir = os.path.join(FLAGS.test_tmpdir, 'test_train_eval_gin', gin_file) 128 | train_eval_test_utils.test_train_eval_gin( 129 | test_case=self, 130 | model_dir=model_dir, 131 | full_gin_path=full_gin_path, 132 | max_train_steps=MAX_TRAIN_STEPS, 133 | eval_steps=EVAL_STEPS) 134 | 135 | 136 | if __name__ == '__main__': 137 | absltest.main() 138 | -------------------------------------------------------------------------------- /research/pose_env/pose_env_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Tensor2Robot Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for tensor2robot.research.pose_env.pose_env.""" 17 | 18 | import os 19 | from absl.testing import absltest 20 | from six.moves import range 21 | from tensor2robot.research.pose_env import pose_env 22 | 23 | 24 | class PoseEnvTest(absltest.TestCase): 25 | 26 | def test_PoseEnv(self): 27 | urdf_root = pose_env.get_pybullet_urdf_root() 28 | self.assertTrue(os.path.exists(urdf_root)) 29 | env = pose_env.PoseToyEnv(urdf_root=urdf_root) 30 | obs = env.reset() 31 | policy = pose_env.PoseEnvRandomPolicy() 32 | action, _ = policy.sample_action(obs, 0) 33 | for _ in range(3): 34 | obs, _, done, _ = env.step(action) 35 | if done: 36 | obs = env.reset() 37 | 38 | if __name__ == '__main__': 39 | absltest.main() 40 | -------------------------------------------------------------------------------- /research/qtopt/README.md: -------------------------------------------------------------------------------- 1 | # QT-Opt 2 | 3 | This directory contains network architecture definitions for the Grasping critic 4 | architecture described in [QT-Opt: Scalable Deep Reinforcement Learning for 5 | Vision-Based Robotic Manipulation](https://arxiv.org/abs/1806.10293). 6 | 7 | ## Running the code 8 | 9 | The following command trains the QT-Opt critic architecture for a few gradient 10 | steps with mock data (real data is not included in this repo). The learning 11 | obective resembles supervised learning, since Bellman targets in QT-Opt are 12 | computed in a separate process (not open-sourced). 13 | 14 | ``` 15 | git clone https://github.com/google/tensor2robot 16 | # Optional: Create a virtualenv 17 | python3 -m venv ~/venv 18 | source ~/venv/bin/activate 19 | pip install -r tensor2robot/requirements.txt 20 | python -m tensor2robot.research.qtopt.t2r_models_test 21 | ``` 22 | 23 | ## PCGrad 24 | 25 | This directory also contains a multi-task optimization method 26 | [PCGrad](https://arxiv.org/pdf/2001.06782.pdf) that is implemented in the form 27 | of a optimization wrapper. This is based on the open-source implementation 28 | [here](https://github.com/tianheyu927/PCGrad/blob/master/PCGrad_tf.py). 29 | -------------------------------------------------------------------------------- /research/qtopt/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Tensor2Robot Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /research/qtopt/optimizer_builder.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Tensor2Robot Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Build optimizer with the given hyperparamaters. 17 | """ 18 | 19 | from absl import logging 20 | import tensorflow.compat.v1 as tf 21 | from tensorflow.contrib import opt as contrib_opt 22 | from tensorflow.contrib.tpu.python.tpu import tpu_function 23 | 24 | 25 | def BuildOpt(hparams): 26 | """Constructs the optimizer. 27 | 28 | Args: 29 | hparams: An instance of tf.HParams, with these parameters: 30 | - batch_size 31 | - examples_per_epoch 32 | - learning_rate 33 | - learning_rate_decay_factor 34 | - model_weights_averaging 35 | - momentum 36 | - num_epochs_per_decay 37 | - optimizer 38 | - rmsprop_decay 39 | - use_avg_model_params 40 | 41 | Returns: 42 | opt: The optimizer. 43 | """ 44 | logging.info('Hyperparameters: %s', hparams) 45 | batch_size = hparams.batch_size 46 | examples_per_epoch = hparams.examples_per_epoch 47 | learning_rate_decay_factor = hparams.learning_rate_decay_factor 48 | learning_rate = hparams.learning_rate 49 | model_weights_averaging = hparams.model_weights_averaging 50 | momentum = hparams.momentum 51 | num_epochs_per_decay = hparams.num_epochs_per_decay 52 | optimizer = hparams.optimizer 53 | rmsprop_decay = hparams.rmsprop_decay 54 | rmsprop_epsilon = hparams.rmsprop_epsilon 55 | adam_beta2 = hparams.get('adam_beta2', 0.999) 56 | adam_epsilon = hparams.get('adam_epsilon', 1e-8) 57 | use_avg_model_params = hparams.use_avg_model_params 58 | 59 | global_step = tf.train.get_or_create_global_step() 60 | 61 | # Configure the learning rate using an exponetial decay. 62 | decay_steps = int(examples_per_epoch / batch_size * 63 | num_epochs_per_decay) 64 | 65 | learning_rate = tf.train.exponential_decay( 66 | learning_rate, 67 | global_step, 68 | decay_steps, 69 | learning_rate_decay_factor, 70 | staircase=True) 71 | if not tpu_function.get_tpu_context(): 72 | tf.summary.scalar('Learning Rate', learning_rate) 73 | 74 | if optimizer == 'momentum': 75 | opt = tf.train.MomentumOptimizer(learning_rate, momentum) 76 | elif optimizer == 'rmsprop': 77 | opt = tf.train.RMSPropOptimizer( 78 | learning_rate, 79 | decay=rmsprop_decay, 80 | momentum=momentum, 81 | epsilon=rmsprop_epsilon) 82 | else: 83 | opt = tf.train.AdamOptimizer( 84 | learning_rate, 85 | beta1=momentum, 86 | beta2=adam_beta2, 87 | epsilon=adam_epsilon) 88 | 89 | if use_avg_model_params: 90 | # Callers of BuildOpt() with use_avg_model_params=True expect the 91 | # MovingAverageOptimizer to be the last optimizer returned by this function 92 | # so that the swapping_saver can be constructed from it. 93 | return contrib_opt.MovingAverageOptimizer( 94 | opt, average_decay=model_weights_averaging) 95 | 96 | return opt 97 | -------------------------------------------------------------------------------- /research/qtopt/pcgrad_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Tensor2Robot Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for tensor2robot.research.qtopt.pcgrad.""" 17 | 18 | from absl.testing import parameterized 19 | import numpy as np 20 | from tensor2robot.research.qtopt import pcgrad 21 | import tensorflow.compat.v1 as tf 22 | 23 | 24 | class PcgradTest(tf.test.TestCase, parameterized.TestCase): 25 | 26 | @parameterized.parameters( 27 | (None, None, [0, 1]), 28 | (None, ['*var*'], [0, 1]), 29 | (['second*'], None, [0]), 30 | (None, ['first*'], [0]), 31 | (None, ['*0'], [0]), 32 | (['first*'], None, [1]), 33 | (['*var*'], None, []), 34 | ) 35 | def testPCgradBasic(self, 36 | denylist, 37 | allowlist, 38 | pcgrad_var_idx): 39 | tf.disable_eager_execution() 40 | for dtype in [tf.dtypes.float32, tf.dtypes.float64]: 41 | with self.session(graph=tf.Graph()): 42 | var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) 43 | var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) 44 | const0_np = np.array([1., 0.], dtype=dtype.as_numpy_dtype) 45 | const1_np = np.array([-1., -1.], dtype=dtype.as_numpy_dtype) 46 | const2_np = np.array([-1., 1.], dtype=dtype.as_numpy_dtype) 47 | 48 | var0 = tf.Variable(var0_np, dtype=dtype, name='first_var/var0') 49 | var1 = tf.Variable(var1_np, dtype=dtype, name='second_var/var1') 50 | const0 = tf.constant(const0_np) 51 | const1 = tf.constant(const1_np) 52 | const2 = tf.constant(const2_np) 53 | loss0 = tf.tensordot(var0, const0, 1) + tf.tensordot(var1, const2, 1) 54 | loss1 = tf.tensordot(var0, const1, 1) + tf.tensordot(var1, const0, 1) 55 | 56 | learning_rate = lambda: 0.001 57 | opt = tf.train.GradientDescentOptimizer(learning_rate) 58 | losses = loss0 + loss1 59 | opt_grads = opt.compute_gradients(losses, var_list=[var0, var1]) 60 | 61 | pcgrad_opt = pcgrad.PCGrad( 62 | tf.train.GradientDescentOptimizer(learning_rate), 63 | denylist=denylist, 64 | allowlist=allowlist) 65 | pcgrad_col_opt = pcgrad.PCGrad( 66 | tf.train.GradientDescentOptimizer(learning_rate), 67 | use_collection_losses=True, 68 | denylist=denylist, 69 | allowlist=allowlist) 70 | losses = [loss0, loss1] 71 | pcgrad_grads = pcgrad_opt.compute_gradients( 72 | losses, var_list=[var0, var1]) 73 | tf.add_to_collection(pcgrad.PCGRAD_LOSSES_COLLECTION, loss0) 74 | tf.add_to_collection(pcgrad.PCGRAD_LOSSES_COLLECTION, loss1) 75 | pcgrad_grads_collection = pcgrad_col_opt.compute_gradients( 76 | None, var_list=[var0, var1]) 77 | 78 | with tf.Graph().as_default(): 79 | # Shouldn't return non-slot variables from other graphs. 80 | self.assertEmpty(opt.variables()) 81 | 82 | self.evaluate(tf.global_variables_initializer()) 83 | grad_vec, pcgrad_vec, pcgrad_col_vec = self.evaluate( 84 | [opt_grads, pcgrad_grads, pcgrad_grads_collection]) 85 | # Make sure that both methods take grads of the same vars. 86 | self.assertAllCloseAccordingToType(pcgrad_vec, pcgrad_col_vec) 87 | 88 | results = [{ 89 | 'var': var0, 90 | 'pcgrad_vec': [0.5, -1.5], 91 | 'result': [0.9995, 2.0015] 92 | }, { 93 | 'var': var1, 94 | 'pcgrad_vec': [0.5, 1.5], 95 | 'result': [2.9995, 3.9985] 96 | }] 97 | grad_var_idx = {0, 1}.difference(pcgrad_var_idx) 98 | 99 | self.assertAllCloseAccordingToType( 100 | grad_vec[0][0], [0.0, -1.0], atol=1e-5) 101 | self.assertAllCloseAccordingToType( 102 | grad_vec[1][0], [0.0, 1.0], atol=1e-5) 103 | pcgrad_vec_idx = 0 104 | for var_idx in pcgrad_var_idx: 105 | self.assertAllCloseAccordingToType( 106 | pcgrad_vec[pcgrad_vec_idx][0], 107 | results[var_idx]['pcgrad_vec'], 108 | atol=1e-5) 109 | pcgrad_vec_idx += 1 110 | 111 | for var_idx in grad_var_idx: 112 | self.assertAllCloseAccordingToType( 113 | pcgrad_vec[pcgrad_vec_idx][0], grad_vec[var_idx][0], atol=1e-5) 114 | pcgrad_vec_idx += 1 115 | 116 | self.evaluate(opt.apply_gradients(pcgrad_grads)) 117 | self.assertAllCloseAccordingToType( 118 | self.evaluate([results[idx]['var'] for idx in pcgrad_var_idx]), 119 | [results[idx]['result'] for idx in pcgrad_var_idx]) 120 | 121 | 122 | if __name__ == '__main__': 123 | tf.test.main() 124 | -------------------------------------------------------------------------------- /research/qtopt/pcgrad_tpu_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Tensor2Robot Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for tensor2robot.research.qtopt.pcgrad.""" 17 | 18 | from tensor2robot.research.qtopt import pcgrad 19 | import tensorflow.compat.v1 as tf 20 | 21 | 22 | class PcgradTest(tf.test.TestCase): 23 | 24 | def testPCgradNetworkTPU(self): 25 | tf.reset_default_graph() 26 | tf.disable_eager_execution() 27 | learning_rate = lambda: 0.001 28 | def pcgrad_computation(): 29 | x = tf.constant(1., shape=[64, 472, 472, 3]) 30 | layers = [ 31 | tf.keras.layers.Conv2D(filters=64, kernel_size=3), 32 | tf.keras.layers.Conv2D(filters=32, kernel_size=3, strides=(2, 2)), 33 | tf.keras.layers.Conv2D(filters=32, kernel_size=3, strides=(2, 2)), 34 | tf.keras.layers.Conv2D(filters=32, kernel_size=3, strides=(2, 2)), 35 | tf.keras.layers.Conv2D(filters=32, kernel_size=3, strides=(2, 2)), 36 | ] 37 | y = x 38 | for layer in layers: 39 | y = layer(y) 40 | n_tasks = 10 41 | task_loss_0 = tf.reduce_sum(y) 42 | task_losses = [task_loss_0 * (1. + (n / 10.)) for n in range(n_tasks)] 43 | 44 | pcgrad_opt = pcgrad.PCGrad( 45 | tf.train.GradientDescentOptimizer(learning_rate)) 46 | pcgrad_grads_and_vars = pcgrad_opt.compute_gradients( 47 | task_losses, var_list=tf.trainable_variables()) 48 | return pcgrad_opt.apply_gradients(pcgrad_grads_and_vars) 49 | 50 | tpu_computation = tf.compat.v1.tpu.batch_parallel(pcgrad_computation, 51 | num_shards=2) 52 | self.evaluate(tf.compat.v1.tpu.initialize_system()) 53 | self.evaluate(tf.compat.v1.global_variables_initializer()) 54 | self.evaluate(tpu_computation) 55 | self.evaluate(tf.compat.v1.tpu.shutdown_system()) 56 | 57 | if __name__ == "__main__": 58 | tf.test.main() 59 | -------------------------------------------------------------------------------- /research/qtopt/t2r_models_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Tensor2Robot Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests that we can run a few training steps (mock data) on T2R model.""" 17 | 18 | from absl import flags 19 | from absl.testing import parameterized 20 | from tensor2robot.research.qtopt import t2r_models 21 | from tensor2robot.utils.t2r_test_fixture import T2RModelFixture 22 | import tensorflow.compat.v1 as tf 23 | from tensorflow.compat.v1 import estimator as tf_estimator 24 | 25 | FLAGS = flags.FLAGS 26 | TRAIN = tf_estimator.ModeKeys.TRAIN 27 | MODEL_NAME = 'Grasping44E2EOpenCloseTerminateGripperStatusHeightToBottom' 28 | 29 | 30 | class GraspPredictT2RTest(parameterized.TestCase): 31 | 32 | def setUp(self): 33 | super(GraspPredictT2RTest, self).setUp() 34 | self._fixture = T2RModelFixture( 35 | test_case=self, 36 | use_tpu=False, 37 | ) 38 | 39 | @parameterized.parameters( 40 | (MODEL_NAME,)) 41 | def test_random_train(self, model_name): 42 | self._fixture.random_train( 43 | module_name=t2r_models, model_name=model_name) 44 | 45 | @parameterized.parameters( 46 | (MODEL_NAME,)) 47 | def test_inference(self, model_name): 48 | result = self._fixture.random_predict( 49 | t2r_models, model_name, action_batch_size=64) 50 | self.assertIsNotNone(result) 51 | self.assertDictContainsSubset({'global_step': 0}, result) 52 | 53 | 54 | if __name__ == '__main__': 55 | tf.test.main() 56 | -------------------------------------------------------------------------------- /research/vrgripper/README.md: -------------------------------------------------------------------------------- 1 | # VRGripper Environment Models 2 | 3 | Contains code for training models in the VRGripper environment from 4 | "Watch, Try, Learn: Meta-Learning from Demonstrations and Rewards." 5 | 6 | Includes the models used in the Watch, Try, Learn (WTL) gripping experiments. 7 | 8 | Links 9 | 10 | - [Project Website](https://sites.google.com/corp/view/watch-try-learn-project) 11 | - [Paper Preprint](https://arxiv.org/abs/1906.03352) 12 | 13 | ## Authors 14 | 15 | Allan Zhou1, Eric Jang1, Daniel Kappler2, 16 | Alex Herzog2, Mohi Khansari2, 17 | Paul Wohlhart2, Yunfei Bai2, 18 | Mrinal Kalakrishnan2, Sergey Levine1,3, 19 | Chelsea Finn1 20 | 21 | 1 Google Brain, 2X, 3UC Berkeley 22 | 23 | ## Training the WTL gripping experiment models. 24 | 25 | WTL experiment models are located in `vrgripper_env_wtl_models.py`. 26 | Data is not included in this repository, so you will have to provide your own 27 | training/eval datasets. Training is configured by the following gin configs: 28 | 29 | * `configs/run_train_wtl_statespace_trial.gin`: Train a trial policy on 30 | state-space observations. 31 | * `configs/run_train_wtl_statespace_retrial.gin`: Train a retrial policy 32 | on state-space observations. 33 | * `configs/run_train_wtl_vision_trial.gin`: Train a trial policy on image 34 | observations. 35 | * `configs/run_train_wtl_vision_retrial.gin`: Train a retrial policy on 36 | image observations. 37 | -------------------------------------------------------------------------------- /research/vrgripper/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Tensor2Robot Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /research/vrgripper/configs/common_imports.gin: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Tensor2Robot Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import tensor2robot.utils.continuous_collect_eval 17 | import tensor2robot.input_generators.default_input_generator 18 | import tensor2robot.meta_learning.meta_policies 19 | import tensor2robot.meta_learning.run_meta_env 20 | import tensor2robot.models.abstract_model 21 | import tensor2robot.predictors.checkpoint_predictor 22 | import tensor2robot.research.vrgripper.episode_to_transitions 23 | import tensor2robot.research.vrgripper.vrgripper_env_models 24 | import tensor2robot.research.vrgripper.maf 25 | import tensor2robot.research.vrgripper.mse_decoder 26 | import tensor2robot.research.vrgripper.discrete 27 | import tensor2robot.research.vrgripper.vrgripper_env_meta_models 28 | import tensor2robot.research.vrgripper.vrgripper_env_wtl_models 29 | import tensor2robot.utils.train_eval 30 | -------------------------------------------------------------------------------- /research/vrgripper/configs/run_train_wtl_statespace_retrial.gin: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Tensor2Robot Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Train a state-space-based WTL retrial policy. 17 | 18 | import tensor2robot.input_generators.default_input_generator 19 | 20 | include 'tensor2robot/research/vrgripper/configs/common_imports.gin' 21 | 22 | # Input Generation. 23 | TRAIN_DATA = '' 24 | EVAL_DATA = '' 25 | 26 | TRAIN_BATCH_SIZE = 8 27 | EVAL_BATCH_SIZE = 8 28 | 29 | TRAIN_INPUT_GENERATOR = @train_input_generator/DefaultRecordInputGenerator() 30 | train_input_generator/DefaultRecordInputGenerator.file_patterns = %TRAIN_DATA 31 | train_input_generator/DefaultRecordInputGenerator.batch_size = %TRAIN_BATCH_SIZE 32 | 33 | EVAL_INPUT_GENERATOR = @eval_input_generator/DefaultRecordInputGenerator() 34 | eval_input_generator/DefaultRecordInputGenerator.file_patterns = %EVAL_DATA 35 | eval_input_generator/DefaultRecordInputGenerator.batch_size = %EVAL_BATCH_SIZE 36 | 37 | ####################################### 38 | # MODEL 39 | ####################################### 40 | 41 | train_eval_model.t2r_model = @retrial/VRGripperEnvSimpleTrialModel() 42 | retrial/VRGripperEnvSimpleTrialModel.use_sync_replicas_optimizer = True 43 | retrial/VRGripperEnvSimpleTrialModel.retrial = True 44 | retrial/VRGripperEnvSimpleTrialModel.num_condition_samples_per_task = 2 45 | BuildImageFeaturesToPoseModel.bias_transform_size = 0 46 | 47 | reduce_temporal_embeddings.conv1d_layers = (64, 32) 48 | reduce_temporal_embeddings.fc_hidden_layers = (100,) 49 | default_create_optimizer_fn.learning_rate = 1e-3 50 | 51 | train_eval_model.max_train_steps = 100 52 | train_eval_model.eval_steps = 1000 53 | train_eval_model.input_generator_train = %TRAIN_INPUT_GENERATOR 54 | train_eval_model.input_generator_eval = %EVAL_INPUT_GENERATOR 55 | train_eval_model.create_exporters_fn = @create_default_exporters 56 | 57 | -------------------------------------------------------------------------------- /research/vrgripper/configs/run_train_wtl_statespace_trial.gin: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Tensor2Robot Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Train a state-space-based WTL trial policy. 17 | 18 | import tensor2robot.input_generators.default_input_generator 19 | 20 | include 'tensor2robot/research/vrgripper/configs/common_imports.gin' 21 | 22 | # Input Generation. 23 | TRAIN_DATA = '' 24 | EVAL_DATA = '' 25 | 26 | TRAIN_BATCH_SIZE = 8 27 | EVAL_BATCH_SIZE = 8 28 | 29 | TRAIN_INPUT_GENERATOR = @train_input_generator/DefaultRecordInputGenerator() 30 | train_input_generator/DefaultRecordInputGenerator.file_patterns = %TRAIN_DATA 31 | train_input_generator/DefaultRecordInputGenerator.batch_size = %TRAIN_BATCH_SIZE 32 | 33 | EVAL_INPUT_GENERATOR = @eval_input_generator/DefaultRecordInputGenerator() 34 | eval_input_generator/DefaultRecordInputGenerator.file_patterns = %EVAL_DATA 35 | eval_input_generator/DefaultRecordInputGenerator.batch_size = %EVAL_BATCH_SIZE 36 | 37 | ####################################### 38 | # MODEL 39 | ####################################### 40 | 41 | train_eval_model.t2r_model = @VRGripperEnvSimpleTrialModel() 42 | VRGripperEnvSimpleTrialModel.num_mixture_components = 10 43 | VRGripperEnvSimpleTrialModel.use_sync_replicas_optimizer = True 44 | VRGripperEnvSimpleTrialModel.embed_type = 'temporal' 45 | BuildImageFeaturesToPoseModel.bias_transform_size = 0 46 | 47 | reduce_temporal_embeddings.conv1d_layers = (64, 32) 48 | reduce_temporal_embeddings.fc_hidden_layers = (100,) 49 | default_create_optimizer_fn.learning_rate = 1e-3 50 | 51 | train_eval_model.max_train_steps = 100 52 | train_eval_model.eval_steps = 1000 53 | train_eval_model.input_generator_train = %TRAIN_INPUT_GENERATOR 54 | train_eval_model.input_generator_eval = %EVAL_INPUT_GENERATOR 55 | train_eval_model.create_exporters_fn = @create_default_exporters 56 | 57 | -------------------------------------------------------------------------------- /research/vrgripper/configs/run_train_wtl_vision_retrial.gin: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Tensor2Robot Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Train a vision-based WTL retrial policy. 17 | 18 | import tensor2robot.input_generators.default_input_generator 19 | 20 | include 'tensor2robot/research/vrgripper/configs/common_imports.gin' 21 | 22 | # Input Generation. 23 | TRAIN_DATA = '' 24 | EVAL_DATA = '' 25 | 26 | TRAIN_BATCH_SIZE = 8 27 | EVAL_BATCH_SIZE = 8 28 | 29 | TRAIN_INPUT_GENERATOR = @train_input_generator/DefaultRecordInputGenerator() 30 | train_input_generator/DefaultRecordInputGenerator.file_patterns = %TRAIN_DATA 31 | train_input_generator/DefaultRecordInputGenerator.batch_size = %TRAIN_BATCH_SIZE 32 | 33 | EVAL_INPUT_GENERATOR = @eval_input_generator/DefaultRecordInputGenerator() 34 | eval_input_generator/DefaultRecordInputGenerator.file_patterns = %EVAL_DATA 35 | eval_input_generator/DefaultRecordInputGenerator.batch_size = %EVAL_BATCH_SIZE 36 | 37 | ####################################### 38 | # MODEL 39 | ####################################### 40 | 41 | train_eval_model.t2r_model = @retrial/VRGripperEnvVisionTrialModel() 42 | retrial/VRGripperEnvVisionTrialModel.use_sync_replicas_optimizer = True 43 | retrial/VRGripperEnvVisionTrialModel.num_condition_samples_per_task = 2 44 | BuildImageFeaturesToPoseModel.bias_transform_size = 0 45 | 46 | reduce_temporal_embeddings.conv1d_layers = (64, 32) 47 | reduce_temporal_embeddings.fc_hidden_layers = (100,) 48 | default_create_optimizer_fn.learning_rate = 1e-3 49 | 50 | train_eval_model.max_train_steps = 100 51 | train_eval_model.eval_steps = 1000 52 | train_eval_model.input_generator_train = %TRAIN_INPUT_GENERATOR 53 | train_eval_model.input_generator_eval = %EVAL_INPUT_GENERATOR 54 | train_eval_model.create_exporters_fn = @create_default_exporters 55 | 56 | -------------------------------------------------------------------------------- /research/vrgripper/configs/run_train_wtl_vision_trial.gin: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Tensor2Robot Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Train a vision-based WTL trial policy. 17 | 18 | import tensor2robot.input_generators.default_input_generator 19 | 20 | include 'tensor2robot/research/vrgripper/configs/common_imports.gin' 21 | 22 | # Input Generation. 23 | TRAIN_DATA = '' 24 | EVAL_DATA = '' 25 | 26 | TRAIN_BATCH_SIZE = 8 27 | EVAL_BATCH_SIZE = 8 28 | 29 | TRAIN_INPUT_GENERATOR = @train_input_generator/DefaultRecordInputGenerator() 30 | train_input_generator/DefaultRecordInputGenerator.file_patterns = %TRAIN_DATA 31 | train_input_generator/DefaultRecordInputGenerator.batch_size = %TRAIN_BATCH_SIZE 32 | 33 | EVAL_INPUT_GENERATOR = @eval_input_generator/DefaultRecordInputGenerator() 34 | eval_input_generator/DefaultRecordInputGenerator.file_patterns = %EVAL_DATA 35 | eval_input_generator/DefaultRecordInputGenerator.batch_size = %EVAL_BATCH_SIZE 36 | 37 | ####################################### 38 | # MODEL 39 | ####################################### 40 | 41 | train_eval_model.t2r_model = @VRGripperEnvVisionTrialModel() 42 | VRGripperEnvVisionTrialModel.num_mixture_components = 20 43 | VRGripperEnvVisionTrialModel.use_sync_replicas_optimizer = True 44 | BuildImageFeaturesToPoseModel.bias_transform_size = 0 45 | 46 | embed_condition_images.fc_layers = (100, 64) 47 | reduce_temporal_embeddings.conv1d_layers = (32,) 48 | reduce_temporal_embeddings.fc_hidden_layers = (100,) 49 | default_create_optimizer_fn.learning_rate = 5e-4 50 | 51 | train_eval_model.max_train_steps = 5000 52 | train_eval_model.eval_steps = 1000 53 | train_eval_model.input_generator_train = %TRAIN_INPUT_GENERATOR 54 | train_eval_model.input_generator_eval = %EVAL_INPUT_GENERATOR 55 | train_eval_model.create_exporters_fn = @create_default_exporters 56 | 57 | -------------------------------------------------------------------------------- /research/vrgripper/episode_to_transitions.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Tensor2Robot Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Functions for converting env episode data to tfrecords of transitions.""" 17 | 18 | import collections 19 | 20 | import gin 21 | import numpy as np 22 | from PIL import Image 23 | import six 24 | from six.moves import range 25 | import tensorflow.compat.v1 as tf 26 | 27 | 28 | 29 | _bytes_feature = ( 30 | lambda v: tf.train.Feature(bytes_list=tf.train.BytesList(value=v))) 31 | _int64_feature = ( 32 | lambda v: tf.train.Feature(int64_list=tf.train.Int64List(value=v))) 33 | _float_feature = ( 34 | lambda v: tf.train.Feature(float_list=tf.train.FloatList(value=v))) 35 | 36 | _IMAGE_KEY_PREFIX = 'image' 37 | 38 | 39 | @gin.configurable 40 | def make_fixed_length( 41 | input_list, 42 | fixed_length, 43 | always_include_endpoints=True, 44 | randomized=True): 45 | """Create a fixed length list by sampling entries from input_list. 46 | 47 | Args: 48 | input_list: The original list we sample entries from. 49 | fixed_length: An integer: the desired length of the output list. 50 | always_include_endpoints: If True, always include the first and last entries 51 | of input_list in the output. 52 | randomized: If True, select entries from input_list by random sampling with 53 | replacement. If False, select entries from input_list deterministically. 54 | Returns: 55 | A list of length fixed_length containing sampled entries of input_list. 56 | """ 57 | original_length = len(input_list) 58 | if original_length <= 2: 59 | return None 60 | if not randomized: 61 | indices = np.sort(np.mod(np.arange(fixed_length), original_length)) 62 | return [input_list[i] for i in indices] 63 | if always_include_endpoints: 64 | # Always include entries 0 and N-1. 65 | endpoint_indices = np.array([0, original_length - 1]) 66 | # The remaining (fixed_length-2) frames are sampled with replacement 67 | # from entries [1, N-1) of input_list. 68 | other_indices = 1 + np.random.choice( 69 | original_length - 2, fixed_length-2, replace=True) 70 | indices = np.concatenate( 71 | (endpoint_indices, other_indices), 72 | axis=0) 73 | else: 74 | indices = np.random.choice( 75 | original_length, fixed_length, replace=True) 76 | indices = np.sort(indices) 77 | return [input_list[i] for i in indices] 78 | 79 | 80 | 81 | 82 | @gin.configurable 83 | def episode_to_transitions_reacher(episode_data, is_demo=False): 84 | """Converts reacher env data to transition examples.""" 85 | transitions = [] 86 | for i, transition in enumerate(episode_data): 87 | del i 88 | feature_dict = {} 89 | (obs_t, action, reward, obs_tp1, done, debug) = transition 90 | del debug 91 | feature_dict['pose_t'] = _float_feature(obs_t) 92 | feature_dict['pose_tp1'] = _float_feature(obs_tp1) 93 | feature_dict['action'] = _float_feature(action) 94 | feature_dict['reward'] = _float_feature([reward]) 95 | feature_dict['done'] = _int64_feature([int(done)]) 96 | feature_dict['is_demo'] = _int64_feature([int(is_demo)]) 97 | example = tf.train.Example(features=tf.train.Features(feature=feature_dict)) 98 | transitions.append(example) 99 | return transitions 100 | 101 | 102 | @gin.configurable 103 | def episode_to_transitions_metareacher(episode_data): 104 | """Converts metareacher env data to transition examples.""" 105 | context_features = {} 106 | feature_lists = collections.defaultdict(list) 107 | 108 | context_features['is_demo'] = _int64_feature( 109 | [int(episode_data[0][-1]['is_demo'])]) 110 | context_features['target_idx'] = _int64_feature( 111 | [episode_data[0][-1]['target_idx']]) 112 | 113 | for i, transition in enumerate(episode_data): 114 | del i 115 | (obs_t, action, reward, obs_tp1, done, debug) = transition 116 | del debug 117 | feature_lists['pose_t'].append(_float_feature(obs_t)) 118 | feature_lists['pose_tp1'].append(_float_feature(obs_tp1)) 119 | feature_lists['action'].append(_float_feature(action)) 120 | feature_lists['reward'].append(_float_feature([reward])) 121 | feature_lists['done'].append(_int64_feature([int(done)])) 122 | 123 | tf_feature_lists = {} 124 | for key in feature_lists: 125 | tf_feature_lists[key] = tf.train.FeatureList(feature=feature_lists[key]) 126 | 127 | return [tf.train.SequenceExample( 128 | context=tf.train.Features(feature=context_features), 129 | feature_lists=tf.train.FeatureLists(feature_list=tf_feature_lists))] 130 | 131 | 132 | -------------------------------------------------------------------------------- /research/vrgripper/episode_to_transitions_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Tensor2Robot Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for episodes_to_transitions.""" 17 | 18 | from six.moves import range 19 | from tensor2robot.research.vrgripper import episode_to_transitions 20 | import tensorflow.compat.v1 as tf 21 | 22 | 23 | class EpisodeToTransitionsTest(tf.test.TestCase): 24 | 25 | def test_make_fixed_length(self): 26 | fixed_length = 10 27 | dummy_feature_dict_lists = [ 28 | [{'dummy_feature': i} for i in range(5)], 29 | [{'dummy_feature': i} for i in range(20)], 30 | ] 31 | 32 | for feature_dict_list in dummy_feature_dict_lists: 33 | filtered_feature_dict_list = episode_to_transitions.make_fixed_length( 34 | feature_dict_list, 35 | fixed_length=fixed_length, 36 | always_include_endpoints=True) 37 | self.assertLen(filtered_feature_dict_list, fixed_length) 38 | 39 | # The first and last entries of the original list should be present in 40 | # the filtered list. 41 | self.assertEqual(feature_dict_list[0], filtered_feature_dict_list[0]) 42 | self.assertEqual(feature_dict_list[-1], filtered_feature_dict_list[-1]) 43 | -------------------------------------------------------------------------------- /research/vrgripper/maf.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Tensor2Robot Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Conditional density estimation with masked autoregressive flow. 17 | """ 18 | 19 | import gin 20 | import numpy as np 21 | from six.moves import range 22 | import tensorflow.compat.v1 as tf 23 | import tensorflow_probability as tfp 24 | from tensorflow.contrib import slim 25 | tfd = tfp.distributions 26 | tfb = tfp.bijectors 27 | 28 | 29 | def init_once(x, name): 30 | """Return a variable initialized with a constant value. 31 | 32 | This is used to initialize Permutation bijectors. See [1] for more information 33 | on why returning Permute(np.random.permutation(event_size)) is unsafe. 34 | 35 | Args: 36 | x: A TF Variables initializer or constant-valued tensor. 37 | name: String name for the returned variable. 38 | 39 | Returns: 40 | Variable copy of the tensor. 41 | 42 | References: 43 | 44 | [1] https://www.tensorflow.org/probability/api_docs/python/ 45 | tfp/bijectors/Permute 46 | """ 47 | return tf.get_variable(name, initializer=x, trainable=False) 48 | 49 | 50 | def maf_bijector(event_size, num_flows, hidden_layers): 51 | """Construct a chain of MAF flows into a single bijector.""" 52 | bijectors = [] 53 | for i in range(num_flows): 54 | bijectors.append(tfb.MaskedAutoregressiveFlow( 55 | shift_and_log_scale_fn=tfb.masked_autoregressive_default_template( 56 | hidden_layers=hidden_layers))) 57 | bijectors.append( 58 | tfb.Permute( 59 | permutation=init_once( 60 | np.random.permutation(event_size).astype('int32'), 61 | name='permute_%d' % i))) 62 | # Chain the bijectors, leaving out the last permutation bijector. 63 | return tfb.Chain(list(reversed(bijectors[:-1]))) 64 | 65 | 66 | @gin.configurable 67 | class MAFDecoder(object): 68 | """Decoder using a Masked Autoregressive Flow. 69 | 70 | Conditioning is specified by warping the centers of the base isotropic normal 71 | distributions, e.g. MAF(N(mu, 1)), where mu is the incoming conditioning 72 | parameters. This allows us to avoid having to incorporate conditioning into 73 | the actual bijector. 74 | """ 75 | 76 | def __init__(self, num_flows=1, hidden_layers=None): 77 | self._num_flows = num_flows 78 | self._hidden_layers = hidden_layers or [512, 512] 79 | 80 | def __call__(self, params, output_size): 81 | mus = slim.fully_connected( 82 | params, output_size, activation_fn=None, scope='maf_mus') 83 | base_dist = tfd.MultivariateNormalDiag( 84 | loc=mus, scale_diag=tf.ones_like(mus)) 85 | event_shape = base_dist.event_shape.as_list() 86 | if np.any([event_shape[0] > l for l in self._hidden_layers]): 87 | raise ValueError( 88 | 'MAF hidden layers have to be at least as wide as event size.') 89 | self._maf = tfd.TransformedDistribution( 90 | distribution=base_dist, 91 | bijector=maf_bijector( 92 | event_shape[0], self._num_flows, self._hidden_layers)) 93 | return self._maf.sample() 94 | 95 | def loss(self, labels): 96 | nll_local = -self._maf.log_prob(labels.action) 97 | # Average across batch, sequence. 98 | return tf.reduce_mean(nll_local) 99 | -------------------------------------------------------------------------------- /research/vrgripper/mse_decoder.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Tensor2Robot Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Abstract decoder and MSE decoder. 17 | """ 18 | 19 | import gin 20 | 21 | import tensorflow.compat.v1 as tf 22 | from tensorflow.contrib import slim 23 | 24 | 25 | @gin.configurable 26 | class MSEDecoder(object): 27 | """Default MSE decoder.""" 28 | 29 | def __call__(self, params, output_size): 30 | self._predictions = slim.fully_connected( 31 | params, output_size, activation_fn=None, scope='pose') 32 | return self._predictions 33 | 34 | def loss(self, labels): 35 | return tf.losses.mean_squared_error(labels=labels.action, 36 | predictions=self._predictions) 37 | -------------------------------------------------------------------------------- /test_data/mock_exported_savedmodel/assets.extra/t2r_assets.pbtxt: -------------------------------------------------------------------------------- 1 | feature_spec { 2 | key_value { 3 | key: "x" 4 | value { 5 | shape: 3 6 | dtype: 1 7 | name: "measured_position" 8 | is_optional: false 9 | is_extracted: false 10 | } 11 | } 12 | } 13 | label_spec { 14 | key_value { 15 | key: "y" 16 | value { 17 | shape: 1 18 | dtype: 1 19 | name: "valid_position" 20 | is_optional: false 21 | is_extracted: false 22 | } 23 | } 24 | } 25 | -------------------------------------------------------------------------------- /test_data/mock_exported_savedmodel/saved_model.pb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-research/tensor2robot/034f49c435168151930c6a07abea48b0df58b74b/test_data/mock_exported_savedmodel/saved_model.pb -------------------------------------------------------------------------------- /test_data/mock_exported_savedmodel/variables/variables.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-research/tensor2robot/034f49c435168151930c6a07abea48b0df58b74b/test_data/mock_exported_savedmodel/variables/variables.data-00000-of-00001 -------------------------------------------------------------------------------- /test_data/mock_exported_savedmodel/variables/variables.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-research/tensor2robot/034f49c435168151930c6a07abea48b0df58b74b/test_data/mock_exported_savedmodel/variables/variables.index -------------------------------------------------------------------------------- /test_data/pose_env_test_data.tfrecord: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-research/tensor2robot/034f49c435168151930c6a07abea48b0df58b74b/test_data/pose_env_test_data.tfrecord -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Tensor2Robot Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /utils/continuous_collect_eval.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Tensor2Robot Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Collect/Eval a policy on the live environment.""" 17 | 18 | import os 19 | import time 20 | from typing import Text 21 | 22 | import gin 23 | import gym 24 | import tensorflow.compat.v1 as tf 25 | 26 | 27 | @gin.configurable 28 | def collect_eval_loop( 29 | collect_env, 30 | eval_env, 31 | policy_class, 32 | num_collect = 2000, 33 | num_eval = 100, 34 | run_agent_fn=None, 35 | root_dir = '', 36 | continuous = False, 37 | min_collect_eval_step = 0, 38 | max_steps = 1, 39 | pre_collect_eval_fn=None, 40 | record_eval_env_video = False, 41 | init_with_random_variables = False): 42 | """Like dql_grasping.collect_eval, but can run continuously. 43 | 44 | Args: 45 | collect_env: (gym.Env) Gym environment to collect data from (and train the 46 | policy on). 47 | eval_env: (gym.Env) Gym environment to evaluate the policy on. Can be 48 | another instance of collect_env, or a different environment if one 49 | wishes to evaluate generalization capability. The only constraint is 50 | that the action and observation spaces have to be equivalent. If None, 51 | eval_env is not evaluated. 52 | policy_class: Policy class that we want to train. 53 | num_collect: (int) Number of episodes to collect from collect_env. 54 | num_eval: (int) Number of episodes to evaluate from eval_env. 55 | run_agent_fn: (Optional) Python function that executes the interaction of 56 | the policy with the environment. Defaults to run_env.run_env. 57 | root_dir: Base directory where collect data and eval data are 58 | stored. 59 | continuous: If True, loop and wait for new ckpt to load a policy from 60 | (up until the ckpt number exceeds max_steps). 61 | min_collect_eval_step: An integer which specifies the lowest ckpt step 62 | number that we will collect/evaluate. 63 | max_steps: (Ignored unless continuous=True). An integer controlling when 64 | to stop looping: once we see a policy with global_step > max_steps, we 65 | stop. 66 | pre_collect_eval_fn: This callable will be run prior to the start of this 67 | collect/eval loop. Example use: pushing a record dataset into a replay 68 | buffer at the start of training. 69 | record_eval_env_video: Whether to enable video recording in our eval env. 70 | init_with_random_variables: If True, initializes policy model with random 71 | variables instead (useful for unit testing). 72 | """ 73 | if pre_collect_eval_fn: 74 | pre_collect_eval_fn() 75 | 76 | collect_dir = os.path.join(root_dir, 'policy_collect') 77 | eval_dir = os.path.join(root_dir, 'eval') 78 | 79 | policy = policy_class() 80 | prev_global_step = -1 81 | while True: 82 | global_step = None 83 | if hasattr(policy, 'restore'): 84 | if init_with_random_variables: 85 | policy.init_randomly() 86 | else: 87 | policy.restore() 88 | global_step = policy.global_step 89 | 90 | if global_step is None or global_step < min_collect_eval_step \ 91 | or global_step <= prev_global_step: 92 | time.sleep(10) 93 | continue 94 | 95 | if collect_env: 96 | run_agent_fn(collect_env, policy=policy, num_episodes=num_collect, 97 | root_dir=collect_dir, global_step=global_step, tag='collect') 98 | if eval_env: 99 | if record_eval_env_video and hasattr(eval_env, 'set_video_output_dir'): 100 | eval_env.set_video_output_dir( 101 | os.path.join(root_dir, 'videos', str(global_step))) 102 | run_agent_fn(eval_env, policy=policy, num_episodes=num_eval, 103 | root_dir=eval_dir, global_step=global_step, tag='eval') 104 | if not continuous or global_step >= max_steps: 105 | tf.logging.info('Completed collect/eval on final ckpt.') 106 | break 107 | 108 | prev_global_step = global_step 109 | -------------------------------------------------------------------------------- /utils/continuous_collect_eval_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Tensor2Robot Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for run_collect_eval.""" 17 | 18 | import os 19 | from absl import flags 20 | from absl.testing import absltest 21 | from absl.testing import parameterized 22 | import gin 23 | from tensor2robot.research.pose_env import pose_env 24 | from tensor2robot.utils import continuous_collect_eval 25 | import tensorflow.compat.v1 as tf 26 | FLAGS = flags.FLAGS 27 | 28 | 29 | class PoseEnvModelsTest(parameterized.TestCase): 30 | 31 | @parameterized.parameters( 32 | (pose_env.PoseEnvRandomPolicy,), 33 | ) 34 | def test_run_pose_env_collect(self, demo_policy_cls): 35 | urdf_root = pose_env.get_pybullet_urdf_root() 36 | 37 | config_dir = 'research/pose_env/configs' 38 | gin_config = os.path.join( 39 | FLAGS.test_srcdir, config_dir, 'run_random_collect.gin') 40 | gin.parse_config_file(gin_config) 41 | tmp_dir = absltest.get_default_test_tmpdir() 42 | root_dir = os.path.join(tmp_dir, str(demo_policy_cls)) 43 | gin.bind_parameter('PoseToyEnv.urdf_root', urdf_root) 44 | gin.bind_parameter( 45 | 'collect_eval_loop.root_dir', root_dir) 46 | gin.bind_parameter('run_meta_env.num_tasks', 2) 47 | gin.bind_parameter('run_meta_env.num_episodes_per_adaptation', 1) 48 | gin.bind_parameter( 49 | 'collect_eval_loop.policy_class', demo_policy_cls) 50 | continuous_collect_eval.collect_eval_loop() 51 | output_files = tf.io.gfile.glob(os.path.join( 52 | root_dir, 'policy_collect', '*.tfrecord')) 53 | self.assertLen(output_files, 2) 54 | 55 | 56 | if __name__ == '__main__': 57 | absltest.main() 58 | -------------------------------------------------------------------------------- /utils/convert_pkl_assets_to_proto_assets.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Tensor2Robot Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Convert existing pickle based assets to t2r_pb2 based assets.""" 17 | 18 | import os 19 | from typing import Text 20 | 21 | from absl import app 22 | from absl import flags 23 | 24 | from tensor2robot.proto import t2r_pb2 25 | from tensor2robot.utils import tensorspec_utils 26 | import tensorflow.compat.v1 as tf 27 | 28 | 29 | FLAGS = flags.FLAGS 30 | 31 | flags.DEFINE_string('assets_filepath', None, 32 | 'The path to the exported savedmodel assets directory.') 33 | 34 | 35 | def convert(assets_filepath): 36 | """Converts existing asset pickle based files to t2r proto based assets.""" 37 | 38 | t2r_assets = t2r_pb2.T2RAssets() 39 | input_spec_filepath = os.path.join(assets_filepath, 'input_specs.pkl') 40 | if not tf.io.gfile.exists(input_spec_filepath): 41 | raise ValueError('No file exists for {}.'.format(input_spec_filepath)) 42 | feature_spec, label_spec = tensorspec_utils.load_input_spec_from_file( 43 | input_spec_filepath) 44 | 45 | t2r_assets.feature_spec.CopyFrom(feature_spec.to_proto()) 46 | t2r_assets.label_spec.CopyFrom(label_spec.to_proto()) 47 | 48 | global_step_filepath = os.path.join(assets_filepath, 'global_step.pkl') 49 | if tf.io.gfile.exists(global_step_filepath): 50 | global_step = tensorspec_utils.load_input_spec_from_file( 51 | global_step_filepath) 52 | t2r_assets.global_step = global_step 53 | 54 | t2r_assets_filepath = os.path.join(assets_filepath, 55 | tensorspec_utils.T2R_ASSETS_FILENAME) 56 | tensorspec_utils.write_t2r_assets_to_file(t2r_assets, t2r_assets_filepath) 57 | 58 | 59 | def main(unused_argv): 60 | flags.mark_flag_as_required('assets_filepath') 61 | convert(FLAGS.assets_filepath) 62 | 63 | 64 | if __name__ == '__main__': 65 | app.run(main) 66 | -------------------------------------------------------------------------------- /utils/global_step_functions.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Tensor2Robot Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Gin configurable functions returning tf.Tensors based on the global_step. 17 | """ 18 | 19 | from typing import Optional, Sequence, Text, Union 20 | 21 | import gin 22 | import numpy as np 23 | import tensorflow.compat.v1 as tf 24 | 25 | 26 | @gin.configurable 27 | def piecewise_linear(boundaries, 28 | values, 29 | name = None): 30 | """Piecewise linear function assuming given values at given boundaries. 31 | 32 | Args: 33 | boundaries: A list of `Tensor`s or `int`s or `float`s with strictly 34 | increasing entries. The first entry must be 0. 35 | values: A list of `Tensor`s or float`s or `int`s that specifies the values 36 | at the `boundaries`. It must have the same number of elements as 37 | `boundaries`, and all elements should have the same type. 38 | name: A string. Optional name of the operation. Defaults to 39 | 'PiecewiseConstant'. 40 | 41 | Returns: 42 | A 0-D Tensor. Its value is `values[0]` if `x < boundaries[0]` and 43 | `values[-1]` if `x >= boundaries[-1]. If `boundaries[i] <= x < 44 | boundaries[i+1]` it is the linear interpolation between `values[i]` and 45 | `values[i+1]`: `values[i] + (values[i+1]-values[i]) * (x-boundaries[i]) / 46 | (boundaries[i+1]-boundaries[i])`. 47 | 48 | Raises: 49 | AssertionError: if values or boundaries is empty, or not the same size. 50 | """ 51 | global_step = tf.train.get_or_create_global_step() 52 | with tf.name_scope(name, 'PiecewiseLinear', [global_step, boundaries, values, 53 | name]) as name: 54 | values = tf.convert_to_tensor(values) 55 | x = tf.cast(tf.convert_to_tensor(global_step), values.dtype) 56 | boundaries = tf.cast(tf.convert_to_tensor(boundaries), values.dtype) 57 | 58 | num_boundaries = np.prod(boundaries.shape.as_list()) 59 | num_values = np.prod(values.shape.as_list()) 60 | assert num_boundaries > 0, 'Need more than 0 boundaries' 61 | assert num_values > 0, 'Need more than 0 values' 62 | assert num_values == num_boundaries, ('boundaries and values must be of ' 63 | 'same size') 64 | 65 | # Make sure there is an unmet last boundary with the same value as the 66 | # last one that was passed in, and at least one boundary was met. 67 | values = tf.concat([values, tf.reshape(values[-1], [1])], 0) 68 | boundaries = tf.concat( 69 | [boundaries, 70 | tf.reshape(tf.maximum(x + 1, boundaries[-1]), [1])], 0) 71 | 72 | # Make sure there is at least one boundary that was already met, with the 73 | # same value as the first one that was passed in. 74 | values = tf.concat([tf.reshape(values[0], [1]), values], 0) 75 | boundaries = tf.concat( 76 | [tf.reshape(tf.minimum(x - 1, boundaries[0]), [1]), boundaries], 0) 77 | 78 | # Identify index of the last boundary that was passed. 79 | unreached_boundaries = tf.reshape( 80 | tf.where(tf.greater(boundaries, x)), [-1]) 81 | unreached_boundaries = tf.concat( 82 | [unreached_boundaries, [tf.cast(tf.size(boundaries), tf.int64)]], 0) 83 | index = tf.reshape(tf.reduce_min(unreached_boundaries), [1]) 84 | 85 | # Get values at last and next boundaries. 86 | value_left = tf.reshape(tf.slice(values, index - 1, [1]), []) 87 | left_boundary = tf.reshape(tf.slice(boundaries, index - 1, [1]), []) 88 | value_right = tf.reshape(tf.slice(values, index, [1]), []) 89 | right_boundary = tf.reshape(tf.slice(boundaries, index, [1]), []) 90 | 91 | # Calculate linear interpolation. 92 | a = (value_right - value_left) / (right_boundary - left_boundary) 93 | b = value_left - a * left_boundary 94 | return a * x + b 95 | 96 | 97 | @gin.configurable 98 | def exponential_decay(initial_value = 0.0001, 99 | decay_steps = 10000, 100 | decay_rate = 0.9, 101 | staircase = True): 102 | """Create a value that decays exponentially with global_step. 103 | 104 | Args: 105 | initial_value: A scalar float32 or float64 Tensor or a Python 106 | number. The initial value returned for global_step == 0. 107 | decay_steps: A scalar int32 or int64 Tensor or a Python number. Must be 108 | positive. See the decay computation in `tf.exponential_decay`. 109 | decay_rate: A scalar float32 or float64 Tensor or a Python number. The decay 110 | rate. 111 | staircase: Boolean. If True, decay the value at discrete intervals. 112 | 113 | Returns: 114 | value: Scalar tf.Tensor with the value decaying based on the global_step. 115 | """ 116 | global_step = tf.train.get_or_create_global_step() 117 | value = tf.compat.v1.train.exponential_decay( 118 | learning_rate=initial_value, 119 | global_step=global_step, 120 | decay_steps=decay_steps, 121 | decay_rate=decay_rate, 122 | staircase=staircase) 123 | return value 124 | -------------------------------------------------------------------------------- /utils/global_step_functions_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Tensor2Robot Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for tensor2robot.utils.global_step_functions.""" 17 | 18 | from absl.testing import parameterized 19 | from tensor2robot.utils import global_step_functions 20 | import tensorflow.compat.v1 as tf 21 | 22 | 23 | class GlobalStepFunctionsTest(parameterized.TestCase, tf.test.TestCase): 24 | 25 | @parameterized.named_parameters({ 26 | 'testcase_name': 'constant', 27 | 'boundaries': [1], 28 | 'values': [5.0], 29 | 'test_inputs': [0, 1, 10], 30 | 'expected_outputs': [5.0, 5.0, 5.0] 31 | }, { 32 | 'testcase_name': 'ramp_up', 33 | 'boundaries': [10, 20], 34 | 'values': [1.0, 11.0], 35 | 'test_inputs': [0, 10, 13, 15, 18, 20, 25], 36 | 'expected_outputs': [1.0, 1.0, 4.0, 6.0, 9.0, 11.0, 11.0] 37 | }) 38 | def test_piecewise_linear(self, boundaries, values, test_inputs, 39 | expected_outputs): 40 | global_step = tf.train.get_or_create_global_step() 41 | global_step_value = tf.placeholder(tf.int64, []) 42 | set_global_step = tf.assign(global_step, global_step_value) 43 | 44 | test_function = global_step_functions.piecewise_linear(boundaries, values) 45 | with tf.Session() as sess: 46 | for x, y_expected in zip(test_inputs, expected_outputs): 47 | sess.run(set_global_step, {global_step_value: x}) 48 | y = sess.run(test_function) 49 | self.assertEqual(y, y_expected) 50 | 51 | # Test the same with tensors as inputs 52 | test_function = global_step_functions.piecewise_linear( 53 | tf.convert_to_tensor(boundaries), tf.convert_to_tensor(values)) 54 | with tf.Session() as sess: 55 | for x, y_expected in zip(test_inputs, expected_outputs): 56 | sess.run(set_global_step, {global_step_value: x}) 57 | y = sess.run(test_function) 58 | self.assertEqual(y, y_expected) 59 | 60 | 61 | if __name__ == '__main__': 62 | tf.test.main() 63 | -------------------------------------------------------------------------------- /utils/image.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Tensor2Robot Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Utilities for encoding images.""" 17 | 18 | import io 19 | import numpy as np 20 | from PIL import Image 21 | from PIL import ImageFile 22 | 23 | 24 | def jpeg_string(image, jpeg_quality = 90): 25 | """Returns given PIL.Image instance as jpeg string. 26 | 27 | Args: 28 | image: A PIL image. 29 | jpeg_quality: The image quality, on a scale from 1 (worst) to 95 (best). 30 | 31 | Returns: 32 | a jpeg_string. 33 | """ 34 | # This fix to PIL makes sure that we don't get an error when saving large 35 | # jpeg files. This is a workaround for a bug in PIL. The value should be 36 | # substantially larger than the size of the image being saved. 37 | ImageFile.MAXBLOCK = 640 * 512 * 64 38 | 39 | output_jpeg = io.BytesIO() 40 | image.save(output_jpeg, 'jpeg', quality=jpeg_quality, optimize=True) 41 | return output_jpeg.getvalue() 42 | 43 | 44 | def numpy_to_image_string(image_array, image_format='jpeg', 45 | data_type=np.uint8): 46 | image = Image.fromarray(image_array.astype(data_type)) 47 | output_jpeg = io.BytesIO() 48 | image.save(output_jpeg, image_format) 49 | return output_jpeg.getvalue() 50 | -------------------------------------------------------------------------------- /utils/writer.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Tensor2Robot Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Write episode transitions to Recordio-backed replay buffer. 17 | 18 | TODO(T2R_CONTRIBUTORS) - re-base class using tf.python_io.TFRecordWriter. 19 | """ 20 | 21 | import os 22 | import gin 23 | import tensorflow.compat.v1 as tf 24 | 25 | 26 | @gin.configurable 27 | class TFRecordReplayWriter(object): 28 | """Saves transitions to a TFRecord-backed replay buffer.""" 29 | 30 | def __init__(self): 31 | self.writer = None 32 | 33 | def open(self, path): 34 | if self.writer is not None: 35 | raise ValueError('Writer is already open!') 36 | 37 | path_dirname = os.path.dirname(path) 38 | if not tf.gfile.IsDirectory(path_dirname): 39 | tf.gfile.MakeDirs(path_dirname) 40 | 41 | self.writer = tf.python_io.TFRecordWriter(path + '.tfrecord') 42 | 43 | def close(self): 44 | if self.writer is None: 45 | raise ValueError('Writer is not open!') 46 | self.writer.close() 47 | self.writer = None 48 | 49 | def write(self, transitions): 50 | """Writes entire episode to a TFRecord file. 51 | 52 | Args: 53 | transitions: List of tf.Examples. 54 | 55 | Raises: 56 | ValueError: If writer has not been opened. 57 | """ 58 | if self.writer is None: 59 | raise ValueError('Writer is not open!') 60 | for transition in transitions: 61 | self.writer.write(transition.SerializeToString()) 62 | --------------------------------------------------------------------------------