├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── __init__.py
├── bin
├── run_collect_eval.py
└── run_t2r_trainer.py
├── export_generators
├── __init__.py
├── abstract_export_generator.py
├── abstract_export_generator_test.py
├── default_export_generator.py
└── default_export_generator_test.py
├── hooks
├── __init__.py
├── async_export_hook_builder.py
├── async_export_hook_builder_test.py
├── async_export_hook_builder_tpu_test.py
├── checkpoint_hooks.py
├── checkpoint_hooks_test.py
├── gin_config_hook_builder.py
├── golden_values_hook_builder.py
├── hook_builder.py
├── td3.py
├── td3_test.py
└── variable_logger_hook.py
├── input_generators
├── __init__.py
├── abstract_input_generator.py
├── abstract_input_generator_test.py
├── default_input_generator.py
└── default_input_generator_test.py
├── layers
├── __init__.py
├── bcz_networks.py
├── film_resnet_model.py
├── mdn.py
├── mdn_test.py
├── resnet.py
├── resnet_test.py
├── snail.py
├── snail_test.py
├── spatial_softmax.py
├── spatial_softmax_test.py
├── tec.py
├── tec_test.py
└── vision_layers.py
├── meta_learning
├── __init__.py
├── maml_inner_loop.py
├── maml_inner_loop_test.py
├── maml_model.py
├── maml_model_test.py
├── meta_example.py
├── meta_policies.py
├── meta_tf_models.py
├── meta_tf_models_test.py
├── meta_tfdata.py
├── preprocessors.py
├── preprocessors_test.py
└── run_meta_env.py
├── models
├── __init__.py
├── abstract_model.py
├── classification_model.py
├── critic_model.py
├── model_interface.py
├── optimizers.py
├── regression_model.py
└── tpu_model_wrapper.py
├── policies
├── __init__.py
└── policies.py
├── predictors
├── __init__.py
├── abstract_predictor.py
├── checkpoint_predictor.py
├── checkpoint_predictor_test.py
├── ensemble_exported_savedmodel_predictor.py
├── ensemble_exported_savedmodel_predictor_test.py
├── exported_savedmodel_predictor.py
├── exported_savedmodel_predictor_test.py
├── saved_model_v2_predictor.py
└── saved_model_v2_predictor_test.py
├── preprocessors
├── __init__.py
├── abstract_preprocessor.py
├── abstract_preprocessor_test.py
├── distortion.py
├── image_transformations.py
├── image_transformations_test.py
├── noop_preprocessor.py
├── noop_preprocessor_test.py
├── spec_transformation_preprocessor.py
├── tpu_preprocessor_wrapper.py
└── tpu_preprocessor_wrapper_test.py
├── proto
└── t2r.proto
├── requirements.txt
├── research
├── bcz
│ ├── README.md
│ ├── configs
│ │ ├── common_imagedistortions.gin
│ │ ├── common_imports.gin
│ │ ├── run_train_bc_gtcond_trajectory.gin
│ │ └── run_train_bc_langcond_trajectory.gin
│ ├── model.py
│ ├── model_test.py
│ └── pose_components_lib.py
├── dql_grasping_lib
│ ├── __init__.py
│ ├── run_env.py
│ └── tf_modules.py
├── grasp2vec
│ ├── README.md
│ ├── __init__.py
│ ├── configs
│ │ ├── common_imports.gin
│ │ └── train_grasp2vec.gin
│ ├── grasp2vec_model.py
│ ├── losses.py
│ ├── losses_test.py
│ ├── networks.py
│ ├── resnet.py
│ └── visualization.py
├── pose_env
│ ├── __init__.py
│ ├── configs
│ │ ├── common_imports.gin
│ │ ├── run_random_collect.gin
│ │ ├── run_train_reg.gin
│ │ └── run_train_reg_maml.gin
│ ├── episode_to_transitions.py
│ ├── pose_env.py
│ ├── pose_env_maml_models.py
│ ├── pose_env_models.py
│ ├── pose_env_models_test.py
│ └── pose_env_test.py
├── qtopt
│ ├── README.md
│ ├── __init__.py
│ ├── networks.py
│ ├── optimizer_builder.py
│ ├── pcgrad.py
│ ├── pcgrad_test.py
│ ├── pcgrad_tpu_test.py
│ ├── t2r_models.py
│ └── t2r_models_test.py
└── vrgripper
│ ├── README.md
│ ├── __init__.py
│ ├── configs
│ ├── common_imports.gin
│ ├── run_train_wtl_statespace_retrial.gin
│ ├── run_train_wtl_statespace_trial.gin
│ ├── run_train_wtl_vision_retrial.gin
│ └── run_train_wtl_vision_trial.gin
│ ├── discrete.py
│ ├── episode_to_transitions.py
│ ├── episode_to_transitions_test.py
│ ├── maf.py
│ ├── mse_decoder.py
│ ├── vrgripper_env_meta_models.py
│ ├── vrgripper_env_models.py
│ └── vrgripper_env_wtl_models.py
├── test_data
├── mock_exported_savedmodel
│ ├── assets.extra
│ │ └── t2r_assets.pbtxt
│ ├── saved_model.pb
│ └── variables
│ │ ├── variables.data-00000-of-00001
│ │ └── variables.index
└── pose_env_test_data.tfrecord
└── utils
├── __init__.py
├── continuous_collect_eval.py
├── continuous_collect_eval_test.py
├── convert_pkl_assets_to_proto_assets.py
├── cross_entropy.py
├── global_step_functions.py
├── global_step_functions_test.py
├── image.py
├── mocks.py
├── subsample.py
├── subsample_test.py
├── t2r_test_fixture.py
├── tensorspec_utils.py
├── tensorspec_utils_test.py
├── tfdata.py
├── tfdata_test.py
├── train_eval.py
├── train_eval_test.py
├── train_eval_test_utils.py
└── writer.py
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # How to Contribute
2 |
3 | Disclaimer: T2R is not an official Google product. External support not
4 | guaranteed, and Robotics @ Google infra needs are prioritized over those of
5 | third-party. The purpose of developing this codebase in the open is to provide
6 | an avenue for open-sourcing our team's code. If you have questions about the
7 | codebase, or would like to submit a pull request, file a GitHub issue first.
8 |
9 | # Issues
10 |
11 | * Please tag your issue with `bug`, `feature request`, or `question` to help us
12 | effectively respond.
13 | * Please include the versions of TensorFlow and Tensor2Robot you are running.
14 | * Please provide the command line you ran as well as the log output.
15 |
16 | # Pull Requests
17 |
18 | We'd love to accept your patches and contributions to this project, provided
19 | that they are aligned with our internal uses of this codebase. There are
20 | just a few guidelines you need to follow.
21 |
22 | ## Contributor License Agreement
23 |
24 | Contributions to this project must be accompanied by a Contributor License
25 | Agreement. You (or your employer) retain the copyright to your contribution,
26 | this simply gives us permission to use and redistribute your contributions as
27 | part of the project. Head over to to see
28 | your current agreements on file or to sign a new one.
29 |
30 | You generally only need to submit a CLA once, so if you've already submitted one
31 | (even if it was for a different project), you probably don't need to do it
32 | again.
33 |
34 | ## Code reviews
35 |
36 | All submissions, including submissions by project members, require review. We
37 | use GitHub pull requests for this purpose. Consult
38 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
39 | information on using pull requests.
40 |
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Tensor2Robot Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
--------------------------------------------------------------------------------
/bin/run_collect_eval.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Tensor2Robot Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Runs data collection and policy evaluation for RL experiments.
17 | """
18 |
19 | from absl import app
20 | from absl import flags
21 | import gin
22 | from tensor2robot.utils import continuous_collect_eval
23 | # Contains the gin_configs and gin_bindings flag definitions.
24 | FLAGS = flags.FLAGS
25 |
26 | try:
27 | flags.DEFINE_list(
28 | 'gin_configs', None,
29 | 'A comma-separated list of paths to Gin configuration files.')
30 | flags.DEFINE_multi_string(
31 | 'gin_bindings', [], 'A newline separated list of Gin parameter bindings.')
32 | except flags.DuplicateFlagError:
33 | pass
34 |
35 | flags.DEFINE_string('root_dir', '',
36 | 'Root directory of experiment.')
37 | flags.mark_flag_as_required('gin_configs')
38 |
39 |
40 | def main(unused_argv):
41 | del unused_argv
42 | gin.parse_config_files_and_bindings(FLAGS.gin_configs, FLAGS.gin_bindings)
43 | continuous_collect_eval.collect_eval_loop(root_dir=FLAGS.root_dir)
44 |
45 |
46 | if __name__ == '__main__':
47 | app.run(main)
48 |
--------------------------------------------------------------------------------
/bin/run_t2r_trainer.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Tensor2Robot Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Binary for training TFModels with Estimator API."""
17 |
18 | from absl import app
19 | from absl import flags
20 | import gin
21 | from tensor2robot.utils import train_eval
22 | import tensorflow.compat.v1 as tf
23 |
24 |
25 | FLAGS = flags.FLAGS
26 |
27 |
28 | def main(unused_argv):
29 | gin.parse_config_files_and_bindings(
30 | FLAGS.gin_configs, FLAGS.gin_bindings, print_includes_and_imports=True)
31 | train_eval.train_eval_model()
32 |
33 |
34 | if __name__ == '__main__':
35 | tf.logging.set_verbosity(tf.logging.INFO)
36 | app.run(main)
37 |
--------------------------------------------------------------------------------
/export_generators/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Tensor2Robot Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
--------------------------------------------------------------------------------
/export_generators/abstract_export_generator.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Tensor2Robot Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Utilities for exporting savedmodels."""
17 |
18 | import abc
19 | import functools
20 | import os
21 | from typing import Any, Dict, List, Optional, Text
22 |
23 | import gin
24 | import six
25 | from tensor2robot.models import abstract_model
26 | from tensor2robot.utils import tensorspec_utils
27 | import tensorflow.compat.v1 as tf
28 | from tensorflow.compat.v1 import estimator as tf_estimator
29 | from tensorflow.contrib import util as contrib_util
30 |
31 | from tensorflow_serving.apis import predict_pb2
32 | from tensorflow_serving.apis import prediction_log_pb2
33 |
34 | MODE = tf_estimator.ModeKeys.PREDICT
35 |
36 |
37 | @gin.configurable
38 | class AbstractExportGenerator(six.with_metaclass(abc.ABCMeta, object)):
39 | """Class to manage assets related to exporting a model.
40 |
41 | Attributes:
42 | export_raw_receivers: Whether to export receiver_fns which do not have
43 | preprocessing enabled. This is useful for serving using Servo, in
44 | conjunction with client-preprocessing.
45 | """
46 |
47 | def __init__(self, export_raw_receivers = False):
48 | self._export_raw_receivers = export_raw_receivers
49 | self._feature_spec = None
50 | self._out_feature_spec = None
51 | self._preprocess_fn = None
52 | self._model_name = None
53 |
54 | def set_specification_from_model(self,
55 | t2r_model):
56 | """Set the feature specifications and preprocess function from the model.
57 |
58 | Args:
59 | t2r_model: A T2R model instance.
60 | """
61 | preprocessor = t2r_model.preprocessor
62 | self._feature_spec = preprocessor.get_in_feature_specification(MODE)
63 | tensorspec_utils.assert_valid_spec_structure(self._feature_spec)
64 | self._out_feature_spec = (preprocessor.get_out_feature_specification(MODE))
65 | tensorspec_utils.assert_valid_spec_structure(self._out_feature_spec)
66 | self._preprocess_fn = functools.partial(preprocessor.preprocess, mode=MODE)
67 | self._model_name = type(t2r_model).__name__
68 |
69 | def _get_input_features_for_receiver_fn(self):
70 | """Helper function to return a input featurespec for reciver fns.
71 |
72 | Returns:
73 | The appropriate feature specification to use
74 | """
75 | if self._export_raw_receivers:
76 | return self._out_feature_spec
77 | else:
78 | return self._feature_spec
79 |
80 | @abc.abstractmethod
81 | def create_serving_input_receiver_numpy_fn(self, params=None):
82 | """Create a serving input receiver for numpy.
83 |
84 | Args:
85 | params: An optional dict of hyper parameters that will be passed into
86 | input_fn and model_fn. Keys are names of parameters, values are basic
87 | python types. There are reserved keys for TPUEstimator, including
88 | 'batch_size'.
89 |
90 | Returns:
91 | serving_input_receiver_fn: A callable which creates the serving inputs.
92 | """
93 |
94 | @abc.abstractmethod
95 | def create_serving_input_receiver_tf_example_fn(
96 | self, params = None):
97 | """Create a serving input receiver for tf_examples.
98 |
99 | Args:
100 | params: An optional dict of hyper parameters that will be passed into
101 | input_fn and model_fn. Keys are names of parameters, values are basic
102 | python types. There are reserved keys for TPUEstimator, including
103 | 'batch_size'.
104 |
105 | Returns:
106 | serving_input_receiver_fn: A callable which creates the serving inputs.
107 | """
108 |
109 | def create_warmup_requests_numpy(self, batch_sizes,
110 | export_dir):
111 | """Creates warm-up requests for a given feature specification.
112 |
113 | This writes an output file in
114 | `export_dir/assets.extra/tf_serving_warmup_requests` for use with Servo.
115 |
116 | Args:
117 | batch_sizes: Batch sizes of warm-up requests to write.
118 | export_dir: Base directory for the export.
119 |
120 | Returns:
121 | The filename written.
122 | """
123 | feature_spec = self._get_input_features_for_receiver_fn()
124 |
125 | flat_feature_spec = tensorspec_utils.flatten_spec_structure(feature_spec)
126 | tf.io.gfile.makedirs(export_dir)
127 | request_filename = os.path.join(export_dir, 'tf_serving_warmup_requests')
128 | with tf.python_io.TFRecordWriter(request_filename) as writer:
129 | for batch_size in batch_sizes:
130 | request = predict_pb2.PredictRequest()
131 | request.model_spec.name = self._model_name
132 | numpy_feature_specs = tensorspec_utils.make_constant_numpy(
133 | flat_feature_spec, constant_value=0, batch_size=batch_size)
134 |
135 | for key, numpy_spec in numpy_feature_specs.items():
136 | request.inputs[key].CopyFrom(
137 | contrib_util.make_tensor_proto(numpy_spec))
138 |
139 | log = prediction_log_pb2.PredictionLog(
140 | predict_log=prediction_log_pb2.PredictLog(request=request))
141 | writer.write(log.SerializeToString())
142 | return request_filename
143 |
--------------------------------------------------------------------------------
/export_generators/abstract_export_generator_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Tensor2Robot Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for tensor2robot.export_generator.sabstract_export_generator."""
17 |
18 | from six.moves import zip
19 | from tensor2robot.export_generators import abstract_export_generator
20 | from tensor2robot.preprocessors import noop_preprocessor
21 | from tensor2robot.utils import mocks
22 | import tensorflow.compat.v1 as tf
23 | from tensorflow_serving.apis import prediction_log_pb2
24 |
25 |
26 | class AbstractExportGeneratorTest(tf.test.TestCase):
27 |
28 | def test_init_abstract(self):
29 | with self.assertRaises(TypeError):
30 | abstract_export_generator.AbstractExportGenerator()
31 |
32 | def test_create_warmup_requests_numpy(self):
33 | mock_t2r_model = mocks.MockT2RModel(
34 | preprocessor_cls=noop_preprocessor.NoOpPreprocessor)
35 | exporter = mocks.MockExportGenerator()
36 | exporter.set_specification_from_model(mock_t2r_model)
37 |
38 | export_dir = self.create_tempdir()
39 | batch_sizes = [2, 4]
40 | request_filename = exporter.create_warmup_requests_numpy(
41 | batch_sizes=batch_sizes, export_dir=export_dir.full_path)
42 |
43 | for expected_batch_size, record in zip(
44 | batch_sizes, tf.compat.v1.io.tf_record_iterator(request_filename)):
45 | record_proto = prediction_log_pb2.PredictionLog()
46 | record_proto.ParseFromString(record)
47 | request = record_proto.predict_log.request
48 | self.assertEqual(request.model_spec.name, 'MockT2RModel')
49 | for _, in_tensor in request.inputs.items():
50 | self.assertEqual(in_tensor.tensor_shape.dim[0].size,
51 | expected_batch_size)
52 |
53 |
54 | if __name__ == '__main__':
55 | tf.test.main()
56 |
--------------------------------------------------------------------------------
/export_generators/default_export_generator.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Tensor2Robot Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Utilities for generating SavedModel exports based on AbstractT2RModels."""
17 |
18 | import copy
19 | from typing import Any, Dict, Optional, Text
20 |
21 | import gin
22 | import six
23 | from tensor2robot.export_generators import abstract_export_generator
24 | from tensor2robot.utils import tensorspec_utils
25 | from tensor2robot.utils import tfdata
26 | import tensorflow.compat.v1 as tf
27 | from tensorflow.compat.v1 import estimator as tf_estimator
28 |
29 | MODE = tf_estimator.ModeKeys.PREDICT
30 |
31 |
32 | @gin.configurable
33 | class DefaultExportGenerator(abstract_export_generator.AbstractExportGenerator):
34 | """Class to manage assets related to exporting a model.
35 |
36 | Attributes:
37 | export_raw_receivers: Whether to export receiver_fns which do not have
38 | preprocessing enabled. This is useful for serving using Servo, in
39 | conjunction with client-preprocessing.
40 | """
41 |
42 | def create_serving_input_receiver_numpy_fn(self, params=None):
43 | """Create a serving input receiver for numpy.
44 |
45 | Args:
46 | params: An optional dict of hyper parameters that will be passed into
47 | input_fn and model_fn. Keys are names of parameters, values are basic
48 | python types. There are reserved keys for TPUEstimator, including
49 | 'batch_size'.
50 |
51 | Returns:
52 | serving_input_receiver_fn: A callable which creates the serving inputs.
53 | """
54 | del params
55 |
56 | def serving_input_receiver_fn():
57 | """Create the ServingInputReceiver to export a saved model.
58 |
59 | Returns:
60 | An instance of ServingInputReceiver.
61 | """
62 | # We have to filter our specs since only required tensors are
63 | # used for inference time.
64 | flat_feature_spec = tensorspec_utils.flatten_spec_structure(
65 | self._get_input_features_for_receiver_fn())
66 | required_feature_spec = (
67 | tensorspec_utils.filter_required_flat_tensor_spec(flat_feature_spec))
68 | receiver_tensors = tensorspec_utils.make_placeholders(
69 | required_feature_spec)
70 |
71 | # We want to ensure that our feature processing pipeline operates on a
72 | # copy of the features and does not alter the receiver_tensors.
73 | features = tensorspec_utils.flatten_spec_structure(
74 | copy.copy(receiver_tensors))
75 |
76 | if (not self._export_raw_receivers and self._preprocess_fn is not None):
77 | features, _ = self._preprocess_fn(features=features, labels=None)
78 |
79 | return tf_estimator.export.ServingInputReceiver(features,
80 | receiver_tensors)
81 |
82 | return serving_input_receiver_fn
83 |
84 | def create_serving_input_receiver_tf_example_fn(
85 | self, params = None):
86 | """Create a serving input receiver for tf_examples.
87 |
88 | Args:
89 | params: An optional dict of hyper parameters that will be passed into
90 | input_fn and model_fn. Keys are names of parameters, values are basic
91 | python types. There are reserved keys for TPUEstimator, including
92 | 'batch_size'.
93 |
94 | Returns:
95 | serving_input_receiver_fn: A callable which creates the serving inputs.
96 | """
97 | del params
98 |
99 | def serving_input_receiver_fn():
100 | """Create the ServingInputReceiver to export a saved model.
101 |
102 | Returns:
103 | An instance of ServingInputReceiver.
104 | """
105 | # We assume one input (a string which containes the serialized proto) per
106 | # dataset_key.
107 | feature_spec = self._get_input_features_for_receiver_fn()
108 | # We have to filter our specs since only required tensors are
109 | # used for inference time.
110 | flat_feature_spec = tensorspec_utils.flatten_spec_structure(feature_spec)
111 | required_feature_spec = (
112 | tensorspec_utils.filter_required_flat_tensor_spec(flat_feature_spec))
113 | dataset_keys = set(
114 | [t.dataset_key for t in required_feature_spec.values()])
115 | receiver_tensors = {}
116 | parse_tensors = {}
117 | for dataset_key in dataset_keys:
118 | receiver_name = 'input_example_' + six.ensure_str(
119 | (dataset_key or 'tensor'))
120 | parse_tensors[dataset_key] = tf.placeholder(
121 | dtype=tf.string, shape=[None], name=receiver_name)
122 | receiver_tensors[receiver_name] = parse_tensors[dataset_key]
123 | parse_tf_example_fn = tfdata.create_parse_tf_example_fn(
124 | feature_tspec=required_feature_spec)
125 | features = parse_tf_example_fn(parse_tensors)
126 |
127 | if (not self._export_raw_receivers and self._preprocess_fn is not None):
128 | features, _ = self._preprocess_fn(features=features, labels=None)
129 |
130 | return tf_estimator.export.ServingInputReceiver(features,
131 | receiver_tensors)
132 |
133 | return serving_input_receiver_fn
134 |
--------------------------------------------------------------------------------
/hooks/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Tensor2Robot Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
--------------------------------------------------------------------------------
/hooks/async_export_hook_builder.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Tensor2Robot Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Hook builders for TD3 distributed training with SavedModels."""
17 |
18 | import os
19 | import tempfile
20 | from typing import Text, List, Callable, Optional
21 |
22 | import gin
23 | from tensor2robot.export_generators import abstract_export_generator
24 | from tensor2robot.export_generators import default_export_generator
25 | from tensor2robot.hooks import checkpoint_hooks
26 | from tensor2robot.hooks import hook_builder
27 | from tensor2robot.models import model_interface
28 | from tensor2robot.proto import t2r_pb2
29 | from tensor2robot.utils import tensorspec_utils
30 | from tensorflow.compat.v1 import estimator as tf_estimator
31 | import tensorflow.compat.v1 as tf # tf
32 |
33 | from tensorflow.contrib import tpu as contrib_tpu
34 |
35 | CreateExportFnType = Callable[[
36 | model_interface.ModelInterface,
37 | tf_estimator.Estimator,
38 | abstract_export_generator.AbstractExportGenerator,
39 | ], Callable[[Text, int], Text]]
40 |
41 |
42 | def default_create_export_fn(
43 | t2r_model,
44 | estimator,
45 | export_generator
46 | ):
47 | """Create an export function for a device type.
48 |
49 | Args:
50 | t2r_model: A T2RModel instance.
51 | estimator: The estimator used for training.
52 | export_generator: An export generator.
53 |
54 | Returns:
55 | A callable function which exports a saved model and returns the path.
56 | """
57 |
58 | in_feature_spec = t2r_model.get_feature_specification_for_packing(
59 | mode=tf_estimator.ModeKeys.PREDICT)
60 | in_label_spec = t2r_model.get_label_specification_for_packing(
61 | mode=tf_estimator.ModeKeys.PREDICT)
62 | t2r_assets = t2r_pb2.T2RAssets()
63 | t2r_assets.feature_spec.CopyFrom(in_feature_spec.to_proto())
64 | t2r_assets.label_spec.CopyFrom(in_label_spec.to_proto())
65 |
66 | def _export_fn(export_dir, global_step):
67 | """The actual closure function creating the exported model and assets."""
68 | # Create additional assets for the exported models
69 | t2r_assets.global_step = global_step
70 | tmpdir = tempfile.mkdtemp()
71 | t2r_assets_filename = os.path.join(tmpdir,
72 | tensorspec_utils.T2R_ASSETS_FILENAME)
73 | tensorspec_utils.write_t2r_assets_to_file(t2r_assets, t2r_assets_filename)
74 | assets = {
75 | tensorspec_utils.T2R_ASSETS_FILENAME: t2r_assets_filename,
76 | }
77 | return estimator.export_saved_model(
78 | export_dir_base=export_dir,
79 | serving_input_receiver_fn=export_generator
80 | .create_serving_input_receiver_numpy_fn(),
81 | assets_extra=assets)
82 |
83 | return _export_fn
84 |
85 |
86 | @gin.configurable
87 | class AsyncExportHookBuilder(hook_builder.HookBuilder):
88 | """Creates hooks for exporting for cpu and tpu for serving.
89 |
90 | Attributes:
91 | export_dir: Directory to output the latest models.
92 | save_secs: Interval to save models, and copy the latest model from
93 | `export_dir` to `lagged_export_dir`.
94 | num_versions: Number of model versions to save in each directory
95 | export_generator: The export generator used to generate the
96 | serving_input_receiver_fn.
97 | """
98 |
99 | def __init__(
100 | self,
101 | export_dir,
102 | save_secs = 90,
103 | num_versions = 3,
104 | create_export_fn = default_create_export_fn,
105 | export_generator = None,
106 | ):
107 | super(AsyncExportHookBuilder, self).__init__()
108 | self._save_secs = save_secs
109 | self._num_versions = num_versions
110 | self._export_dir = export_dir
111 | self._create_export_fn = create_export_fn
112 | if export_generator is None:
113 | self._export_generator = default_export_generator.DefaultExportGenerator()
114 | else:
115 | self._export_generator = export_generator
116 |
117 | def create_hooks(
118 | self,
119 | t2r_model,
120 | estimator,
121 | ):
122 | self._export_generator.set_specification_from_model(t2r_model)
123 | return [
124 | contrib_tpu.AsyncCheckpointSaverHook(
125 | save_secs=self._save_secs,
126 | checkpoint_dir=estimator.model_dir,
127 | listeners=[
128 | checkpoint_hooks.CheckpointExportListener(
129 | export_fn=self._create_export_fn(t2r_model, estimator,
130 | self._export_generator),
131 | num_versions=self._num_versions,
132 | export_dir=self._export_dir)
133 | ])
134 | ]
135 |
--------------------------------------------------------------------------------
/hooks/async_export_hook_builder_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Tensor2Robot Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for TD3 Hooks."""
17 |
18 | import os
19 | from tensor2robot.hooks import async_export_hook_builder
20 | from tensor2robot.predictors import exported_savedmodel_predictor
21 | from tensor2robot.preprocessors import noop_preprocessor
22 | from tensor2robot.utils import mocks
23 | from tensor2robot.utils import train_eval
24 | import tensorflow.compat.v1 as tf # tf
25 |
26 | _EXPORT_DIR = 'export_dir'
27 | _MAX_STEPS = 4
28 | _BATCH_SIZE = 4
29 |
30 |
31 | class AsyncExportHookBuilderTest(tf.test.TestCase):
32 |
33 | def test_with_mock_training(self):
34 | model_dir = self.create_tempdir().full_path
35 | mock_t2r_model = mocks.MockT2RModel(
36 | preprocessor_cls=noop_preprocessor.NoOpPreprocessor, device_type='cpu')
37 |
38 | mock_input_generator = mocks.MockInputGenerator(batch_size=_BATCH_SIZE)
39 | export_dir = os.path.join(model_dir, _EXPORT_DIR)
40 |
41 | hook_builder = async_export_hook_builder.AsyncExportHookBuilder(
42 | export_dir=export_dir,
43 | create_export_fn=async_export_hook_builder.default_create_export_fn)
44 |
45 | # We optimize our network.
46 | train_eval.train_eval_model(
47 | t2r_model=mock_t2r_model,
48 | input_generator_train=mock_input_generator,
49 | train_hook_builders=[hook_builder],
50 | model_dir=model_dir,
51 | max_train_steps=_MAX_STEPS)
52 | self.assertNotEmpty(tf.io.gfile.listdir(model_dir))
53 | self.assertNotEmpty(tf.io.gfile.listdir(export_dir))
54 | for exported_model_dir in tf.io.gfile.listdir(export_dir):
55 | self.assertNotEmpty(
56 | tf.io.gfile.listdir(os.path.join(export_dir, exported_model_dir)))
57 | predictor = exported_savedmodel_predictor.ExportedSavedModelPredictor(
58 | export_dir=export_dir)
59 | self.assertTrue(predictor.restore())
60 |
61 |
62 | if __name__ == '__main__':
63 | tf.test.main()
64 |
--------------------------------------------------------------------------------
/hooks/async_export_hook_builder_tpu_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Tensor2Robot Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for TD3 Hooks."""
17 |
18 | import os
19 | import gin
20 | from tensor2robot.hooks import async_export_hook_builder
21 | from tensor2robot.predictors import exported_savedmodel_predictor
22 | from tensor2robot.preprocessors import noop_preprocessor
23 | from tensor2robot.utils import mocks
24 | from tensor2robot.utils import train_eval
25 | import tensorflow.compat.v1 as tf # tf
26 |
27 | _EXPORT_DIR = 'export_dir'
28 | _BATCH_SIZES_FOR_EXPORT = [128]
29 | _MAX_STEPS = 4
30 | _BATCH_SIZE = 4
31 |
32 |
33 | class AsyncExportHookBuilderTest(tf.test.TestCase):
34 |
35 | def test_with_mock_training(self):
36 | model_dir = self.create_tempdir().full_path
37 | mock_t2r_model = mocks.MockT2RModel(
38 | preprocessor_cls=noop_preprocessor.NoOpPreprocessor,
39 | device_type='tpu',
40 | use_avg_model_params=True)
41 |
42 | mock_input_generator = mocks.MockInputGenerator(batch_size=_BATCH_SIZE)
43 | export_dir = os.path.join(model_dir, _EXPORT_DIR)
44 | hook_builder = async_export_hook_builder.AsyncExportHookBuilder(
45 | export_dir=export_dir,
46 | create_export_fn=async_export_hook_builder.default_create_export_fn)
47 |
48 | gin.parse_config('tf.contrib.tpu.TPUConfig.iterations_per_loop=1')
49 | gin.parse_config('tf.estimator.RunConfig.save_checkpoints_steps=1')
50 |
51 | # We optimize our network.
52 | train_eval.train_eval_model(
53 | t2r_model=mock_t2r_model,
54 | input_generator_train=mock_input_generator,
55 | train_hook_builders=[hook_builder],
56 | model_dir=model_dir,
57 | max_train_steps=_MAX_STEPS)
58 | self.assertNotEmpty(tf.io.gfile.listdir(model_dir))
59 | self.assertNotEmpty(tf.io.gfile.listdir(export_dir))
60 | for exported_model_dir in tf.io.gfile.listdir(export_dir):
61 | self.assertNotEmpty(
62 | tf.io.gfile.listdir(os.path.join(export_dir, exported_model_dir)))
63 | predictor = exported_savedmodel_predictor.ExportedSavedModelPredictor(
64 | export_dir=export_dir)
65 | self.assertTrue(predictor.restore())
66 |
67 |
68 | if __name__ == '__main__':
69 | tf.test.main()
70 |
--------------------------------------------------------------------------------
/hooks/gin_config_hook_builder.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Tensor2Robot Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Builds hooks that write out the operative gin configuration.
17 | """
18 |
19 | from typing import List
20 |
21 | from absl import logging
22 | import gin
23 | from tensor2robot.hooks import hook_builder
24 | from tensor2robot.models import model_interface
25 | from tensorflow import estimator as tf_estimator
26 |
27 |
28 | @gin.configurable
29 | class GinConfigLoggerHook(tf_estimator.SessionRunHook):
30 | """A SessionRunHook that logs the operative config to stdout."""
31 |
32 | def __init__(self, only_once=True):
33 | self._only_once = only_once
34 | self._written_at_least_once = False
35 |
36 | def after_create_session(self, session=None, coord=None):
37 | """Logs Gin's operative config."""
38 | if self._only_once and self._written_at_least_once:
39 | return
40 |
41 | logging.info('Gin operative configuration:')
42 | for gin_config_line in gin.operative_config_str().splitlines():
43 | logging.info(gin_config_line)
44 | self._written_at_least_once = True
45 |
46 |
47 | @gin.configurable
48 | class OperativeGinConfigLoggerHookBuilder(hook_builder.HookBuilder):
49 |
50 | def create_hooks(
51 | self,
52 | t2r_model,
53 | estimator,
54 | ):
55 | return [GinConfigLoggerHook()]
56 |
--------------------------------------------------------------------------------
/hooks/golden_values_hook_builder.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Tensor2Robot Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Hook that logs golden values to be used in unit tests.
17 |
18 | In the Data -> Checkpoint -> Inference -> Eval flow, this verifies no regression
19 | occurred in Data -> Checkpoint.
20 | """
21 |
22 | import os
23 | from typing import List
24 | from absl import logging
25 | import gin
26 | import numpy as np
27 | from tensor2robot.hooks import hook_builder
28 | from tensor2robot.models import model_interface
29 | import tensorflow.compat.v1 as tf
30 | from tensorflow.compat.v1 import estimator as tf_estimator
31 |
32 | ModeKeys = tf_estimator.ModeKeys
33 | COLLECTION = 'golden'
34 | PREFIX = 'golden_'
35 |
36 |
37 | def add_golden_tensor(tensor, name):
38 | """Adds tensor to be tracked."""
39 | tf.add_to_collection(COLLECTION, tf.identity(tensor, name=PREFIX + name))
40 |
41 |
42 | class GoldenValuesHook(tf.train.SessionRunHook):
43 | """SessionRunHook that saves loss metrics to file."""
44 |
45 | def __init__(self,
46 | log_directory):
47 | self._log_directory = log_directory
48 |
49 | def begin(self):
50 | self._measurements = []
51 |
52 | def end(self, session):
53 | # Record measurements.
54 | del session
55 | np.save(os.path.join(self._log_directory, 'golden_values.npy'),
56 | self._measurements)
57 |
58 | def before_run(self, run_context):
59 | return tf.train.SessionRunArgs(
60 | fetches=tf.get_collection_ref(COLLECTION))
61 |
62 | def after_run(self, run_context, run_values):
63 | # Strip the 'golden_' prefix before saving the data.
64 | golden_values = {t.name.split(PREFIX)[1]: v for t, v in
65 | zip(tf.get_collection_ref(COLLECTION), run_values.results)}
66 | logging.info('Recorded golden values for %s', golden_values.keys())
67 | self._measurements.append(golden_values)
68 |
69 |
70 | @gin.configurable
71 | class GoldenValuesHookBuilder(hook_builder.HookBuilder):
72 | """Hook builder for generating golden values."""
73 |
74 | def create_hooks(
75 | self,
76 | t2r_model,
77 | estimator,
78 | ):
79 | return [GoldenValuesHook(estimator.model_dir)]
80 |
--------------------------------------------------------------------------------
/hooks/hook_builder.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Tensor2Robot Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Interface to manage building hooks."""
17 |
18 | import abc
19 | from typing import List
20 |
21 | import six
22 | from tensor2robot.models import model_interface
23 | from tensorflow.compat.v1 import estimator as tf_estimator
24 | import tensorflow.compat.v1 as tf # tf
25 |
26 |
27 | class HookBuilder(six.with_metaclass(abc.ABCMeta, object)):
28 |
29 | @abc.abstractmethod
30 | def create_hooks(
31 | self, t2r_model,
32 | estimator,
33 | ):
34 | """Create hooks for the trainer.
35 |
36 | Subclasses can add arguments here.
37 |
38 | Arguments:
39 | t2r_model: Provided model
40 | estimator: Provided estimator instance
41 | Returns:
42 | A list of tf.train.SessionRunHooks to add to the trainer.
43 | """
44 |
--------------------------------------------------------------------------------
/hooks/td3.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Tensor2Robot Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Hook builders for TD3 distributed training with SavedModels."""
17 |
18 | import os
19 | import tempfile
20 | from typing import Text, List, Optional
21 |
22 | import gin
23 | from tensor2robot.export_generators import abstract_export_generator
24 | from tensor2robot.export_generators import default_export_generator
25 | from tensor2robot.hooks import checkpoint_hooks
26 | from tensor2robot.hooks import hook_builder
27 | from tensor2robot.models import model_interface
28 | from tensor2robot.proto import t2r_pb2
29 | from tensor2robot.utils import tensorspec_utils
30 | from tensorflow.compat.v1 import estimator as tf_estimator
31 | import tensorflow.compat.v1 as tf # tf
32 |
33 | from tensorflow.contrib import tpu as contrib_tpu
34 |
35 |
36 | @gin.configurable
37 | class TD3Hooks(hook_builder.HookBuilder):
38 | """Creates hooks for exporting models for serving in TD3 distributed training.
39 |
40 | See:
41 | "Addressing Function Approximation Error in Actor-Critic Methods"
42 | by Fujimoto et al.
43 |
44 | https://arxiv.org/abs/1802.09477
45 |
46 | These hooks manage exporting of SavedModels to two different directories:
47 | `export_dir` contains the latest version of the model, `lagged_export_dir`
48 | contains a lagged version, delayed by one interval of `save_secs`.
49 |
50 | Attributes:
51 | export_dir: Directory to output the latest models.
52 | lagged_export_dir: Directory containing a lagged version of SavedModels
53 | save_secs: Interval to save models, and copy the latest model from
54 | `export_dir` to `lagged_export_dir`.
55 | num_versions: Number of model versions to save in each directory
56 | use_preprocessed_features: Whether to export SavedModels which do *not*
57 | incldue preprocessing. This is useful for offloading the preprocessing
58 | graph to the client.
59 | export_generator: The export generator used to generate the
60 | serving_input_receiver_fn.
61 | """
62 |
63 | def __init__(
64 | self,
65 | export_dir,
66 | lagged_export_dir,
67 | batch_sizes_for_export,
68 | save_secs = 90,
69 | num_versions = 3,
70 | use_preprocessed_features=False,
71 | export_generator = None,
72 | ):
73 | super(TD3Hooks, self).__init__()
74 | self._save_secs = save_secs
75 | self._num_versions = num_versions
76 | self._export_dir = export_dir
77 | self._lagged_export_dir = lagged_export_dir
78 | self._batch_sizes_for_export = batch_sizes_for_export
79 | if export_generator is None:
80 | self._export_generator = default_export_generator.DefaultExportGenerator()
81 | else:
82 | self._export_generator = export_generator
83 |
84 | def create_hooks(
85 | self,
86 | t2r_model,
87 | estimator,
88 | ):
89 | if not self._export_dir and not self._lagged_export_dir:
90 | return []
91 | self._export_generator.set_specification_from_model(t2r_model)
92 | warmup_requests_file = self._export_generator.create_warmup_requests_numpy(
93 | batch_sizes=self._batch_sizes_for_export,
94 | export_dir=estimator.model_dir)
95 |
96 | in_feature_spec = t2r_model.get_feature_specification_for_packing(
97 | mode=tf_estimator.ModeKeys.PREDICT)
98 | in_label_spec = t2r_model.get_label_specification_for_packing(
99 | mode=tf_estimator.ModeKeys.PREDICT)
100 | t2r_assets = t2r_pb2.T2RAssets()
101 | t2r_assets.feature_spec.CopyFrom(in_feature_spec.to_proto())
102 | t2r_assets.label_spec.CopyFrom(in_label_spec.to_proto())
103 |
104 | def _export_fn(export_dir, global_step):
105 | """The actual closure function creating the exported model and assets."""
106 | t2r_assets.global_step = global_step
107 | tmpdir = tempfile.mkdtemp()
108 | t2r_assets_filename = os.path.join(tmpdir,
109 | tensorspec_utils.T2R_ASSETS_FILENAME)
110 | tensorspec_utils.write_t2r_assets_to_file(t2r_assets, t2r_assets_filename)
111 | res = estimator.export_saved_model(
112 | export_dir_base=export_dir,
113 | serving_input_receiver_fn=self._export_generator
114 | .create_serving_input_receiver_numpy_fn(),
115 | assets_extra={
116 | 'tf_serving_warmup_requests': warmup_requests_file,
117 | tensorspec_utils.T2R_ASSETS_FILENAME: t2r_assets_filename
118 | })
119 | return res
120 |
121 | return [
122 | contrib_tpu.AsyncCheckpointSaverHook(
123 | save_secs=self._save_secs,
124 | checkpoint_dir=estimator.model_dir,
125 | listeners=[
126 | checkpoint_hooks.LaggedCheckpointListener(
127 | export_fn=_export_fn,
128 | num_versions=self._num_versions,
129 | export_dir=self._export_dir,
130 | lagged_export_dir=self._lagged_export_dir)
131 | ])
132 | ]
133 |
--------------------------------------------------------------------------------
/hooks/td3_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Tensor2Robot Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for TD3 Hooks."""
17 |
18 | import mock
19 | from tensor2robot.export_generators import abstract_export_generator
20 | from tensor2robot.hooks import checkpoint_hooks
21 | from tensor2robot.hooks import td3
22 | from tensor2robot.utils import mocks
23 | from tensor2robot.utils import tensorspec_utils
24 | from tensorflow.compat.v1 import estimator as tf_estimator
25 | import tensorflow.compat.v1 as tf # tf
26 |
27 | _BATCH_SIZES_FOR_EXPORT = [128]
28 | _MODEL_DIR = "model_dir"
29 | _NUMPY_WARMUP_REQUESTS = "warmup_requests"
30 | _EXPORT_DIR = "export_dir"
31 | _LAGGED_EXPORT_DIR = "lagged_export_dir"
32 |
33 |
34 | class MockEstimator(tf_estimator.Estimator):
35 |
36 | def __init__(self):
37 | pass
38 |
39 | @property
40 | def model_dir(self):
41 | return _MODEL_DIR
42 |
43 |
44 | class Td3Test(tf.test.TestCase):
45 |
46 | @mock.patch.object(MockEstimator, "export_saved_model")
47 | @mock.patch.object(checkpoint_hooks.LaggedCheckpointListener, "__init__")
48 | @mock.patch.object(mocks.MockExportGenerator,
49 | "create_serving_input_receiver_numpy_fn")
50 | @mock.patch.object(abstract_export_generator.AbstractExportGenerator,
51 | "create_warmup_requests_numpy")
52 | def test_hooks(self, mock_create_warmup_requests_numpy,
53 | mock_create_serving_input_receiver_numpy_fn,
54 | mock_checkpoint_init, mock_export_saved_model):
55 |
56 | def _checkpoint_init(export_fn, export_dir, **kwargs):
57 | del kwargs
58 | export_fn(export_dir, global_step=1)
59 | return None
60 |
61 | mock_checkpoint_init.side_effect = _checkpoint_init
62 |
63 | export_generator = mocks.MockExportGenerator()
64 |
65 | hook_builder = td3.TD3Hooks(
66 | export_dir=_EXPORT_DIR,
67 | lagged_export_dir=_LAGGED_EXPORT_DIR,
68 | batch_sizes_for_export=_BATCH_SIZES_FOR_EXPORT,
69 | export_generator=export_generator)
70 |
71 | model = mocks.MockT2RModel()
72 | estimator = MockEstimator()
73 |
74 | mock_create_warmup_requests_numpy.return_value = _NUMPY_WARMUP_REQUESTS
75 |
76 | hooks = hook_builder.create_hooks(t2r_model=model, estimator=estimator)
77 | self.assertLen(hooks, 1)
78 |
79 | mock_create_warmup_requests_numpy.assert_called_with(
80 | batch_sizes=_BATCH_SIZES_FOR_EXPORT,
81 | export_dir=_MODEL_DIR)
82 |
83 | mock_export_saved_model.assert_called_with(
84 | serving_input_receiver_fn=mock.ANY,
85 | export_dir_base=_EXPORT_DIR,
86 | assets_extra={
87 | "tf_serving_warmup_requests": _NUMPY_WARMUP_REQUESTS,
88 | tensorspec_utils.T2R_ASSETS_FILENAME: mock.ANY
89 | })
90 |
91 | mock_create_serving_input_receiver_numpy_fn.assert_called()
92 |
93 |
94 | if __name__ == "__main__":
95 | tf.test.main()
96 |
--------------------------------------------------------------------------------
/hooks/variable_logger_hook.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Tensor2Robot Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """A hook to log all variables."""
17 |
18 | from typing import Optional
19 |
20 | from absl import logging
21 |
22 | import numpy as np
23 | import tensorflow.compat.v1 as tf
24 | from tensorflow.contrib import framework as contrib_framework
25 |
26 |
27 | class VariableLoggerHook(tf.train.SessionRunHook):
28 | """A hook to log variables via a session run hook."""
29 |
30 | def __init__(self, max_num_variable_values = None):
31 | """Initializes a VariableLoggerHook.
32 |
33 | Args:
34 | max_num_variable_values: If not None, at most max_num_variable_values will
35 | be logged per variable.
36 | """
37 | super(VariableLoggerHook, self).__init__()
38 | self._max_num_variable_values = max_num_variable_values
39 |
40 | def begin(self):
41 | """Captures all variables to be read out during the session run."""
42 | self._variables_to_log = contrib_framework.get_variables()
43 |
44 | def before_run(self, run_context):
45 | """Adds the variables to the run args."""
46 | return tf.train.SessionRunArgs(self._variables_to_log)
47 |
48 | def after_run(self, run_context, run_values):
49 | del run_context
50 | original = np.get_printoptions()
51 | np.set_printoptions(suppress=True)
52 | for variable, variable_value in zip(self._variables_to_log,
53 | run_values.results):
54 | if not isinstance(variable_value, np.ndarray):
55 | continue
56 | variable_value = variable_value.ravel()
57 | logging.info('%s.mean = %s', variable.op.name, np.mean(variable_value))
58 | logging.info('%s.std = %s', variable.op.name, np.std(variable_value))
59 | if self._max_num_variable_values:
60 | variable_value = variable_value[:self._max_num_variable_values]
61 | logging.info('%s = %s', variable.op.name, variable_value)
62 | np.set_printoptions(**original)
63 |
--------------------------------------------------------------------------------
/input_generators/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Tensor2Robot Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """InputGenerator implementations for estimator models."""
17 | from tensor2robot.input_generators import abstract_input_generator
18 |
--------------------------------------------------------------------------------
/input_generators/abstract_input_generator_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Tensor2Robot Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for estimator_models.input_generators.abstract_input_generator."""
17 |
18 | import functools
19 | from absl import flags
20 | from tensor2robot.input_generators import abstract_input_generator
21 | from tensor2robot.preprocessors import noop_preprocessor
22 | from tensor2robot.utils import mocks
23 | import tensorflow.compat.v1 as tf
24 |
25 |
26 | FLAGS = flags.FLAGS
27 |
28 | BATCH_SIZE = 32
29 |
30 |
31 | class AbstractInputGeneratorTest(tf.test.TestCase):
32 |
33 | def test_init_abstract(self):
34 | with self.assertRaises(TypeError):
35 | abstract_input_generator.AbstractInputGenerator()
36 |
37 | def test_set_preprocess_fn(self):
38 | mock_input_generator = mocks.MockInputGenerator(batch_size=BATCH_SIZE)
39 | preprocessor = noop_preprocessor.NoOpPreprocessor()
40 | with self.assertRaises(ValueError):
41 | # This should raise since we pass a function with `mode` not already
42 | # filled in either by a closure or functools.partial.
43 | mock_input_generator.set_preprocess_fn(preprocessor.preprocess)
44 |
45 | preprocess_fn = functools.partial(preprocessor.preprocess, labels=None)
46 | with self.assertRaises(ValueError):
47 | # This should raise since we pass a partial function but `mode`
48 | # is not abstracted away.
49 | mock_input_generator.set_preprocess_fn(preprocess_fn)
50 |
51 |
52 | if __name__ == '__main__':
53 | tf.test.main()
54 |
--------------------------------------------------------------------------------
/layers/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Tensor2Robot Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
--------------------------------------------------------------------------------
/layers/mdn_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Tensor2Robot Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Test for tensor2robot.layers.mdn."""
17 |
18 | from absl.testing import parameterized
19 | import numpy as np
20 | from tensor2robot.layers import mdn
21 | import tensorflow.compat.v1 as tf
22 | import tensorflow_probability as tfp
23 |
24 |
25 | class MDNTest(tf.test.TestCase, parameterized.TestCase):
26 |
27 | def test_get_mixture_distribution(self):
28 | sample_size = 10
29 | num_alphas = 5
30 | batch_shape = (4, 2)
31 | alphas = tf.random.normal(batch_shape + (num_alphas,))
32 | mus = tf.random.normal(batch_shape + (sample_size * num_alphas,))
33 | sigmas = tf.random.normal(batch_shape + (sample_size * num_alphas,))
34 | params = tf.concat([alphas, mus, sigmas], -1)
35 | output_mean_np = np.random.normal(size=(sample_size,))
36 | gm = mdn.get_mixture_distribution(
37 | params, num_alphas, sample_size, output_mean=output_mean_np)
38 | self.assertEqual(gm.batch_shape, batch_shape)
39 | self.assertEqual(gm.event_shape, sample_size)
40 |
41 | # Check that the component means were translated by output_mean_np.
42 | component_means = gm.components_distribution.mean()
43 | with self.test_session() as sess:
44 | # Note: must get values from the same session run, since params will be
45 | # randomized across separate session runs.
46 | component_means_np, mus_np = sess.run([component_means, mus])
47 | mus_np = np.reshape(mus_np, component_means_np.shape)
48 | self.assertAllClose(component_means_np, mus_np + output_mean_np)
49 |
50 | @parameterized.parameters((True,), (False,))
51 | def test_predict_mdn_params(self, condition_sigmas):
52 | sample_size = 10
53 | num_alphas = 5
54 | inputs = tf.random.normal((2, 16))
55 | with tf.variable_scope('test_scope'):
56 | dist_params = mdn.predict_mdn_params(
57 | inputs, num_alphas, sample_size, condition_sigmas=condition_sigmas)
58 | expected_num_params = num_alphas * (1 + 2 * sample_size)
59 | self.assertEqual(dist_params.shape.as_list(), [2, expected_num_params])
60 |
61 | gm = mdn.get_mixture_distribution(dist_params, num_alphas, sample_size)
62 | stddev = gm.components_distribution.stddev()
63 | with self.test_session() as sess:
64 | sess.run(tf.global_variables_initializer())
65 | stddev_np = sess.run(stddev)
66 | if condition_sigmas:
67 | # Standard deviations should vary with input.
68 | self.assertNotAllClose(stddev_np[0], stddev_np[1])
69 | else:
70 | # Standard deviations should *not* vary with input.
71 | self.assertAllClose(stddev_np[0], stddev_np[1])
72 |
73 | def test_gaussian_mixture_approximate_mode(self):
74 | sample_size = 10
75 | num_alphas = 5
76 | # Manually set alphas to 1 in zero-th column and 0 elsewhere, making the
77 | # first component the most likely.
78 | alphas = tf.one_hot(2 * [0], num_alphas)
79 | mus = tf.random.normal((2, num_alphas, sample_size))
80 | sigmas = tf.ones_like(mus)
81 | mix_dist = tfp.distributions.Categorical(logits=alphas)
82 | comp_dist = tfp.distributions.MultivariateNormalDiag(
83 | loc=mus, scale_diag=sigmas)
84 | gm = tfp.distributions.MixtureSameFamily(
85 | mixture_distribution=mix_dist, components_distribution=comp_dist)
86 | approximate_mode = mdn.gaussian_mixture_approximate_mode(gm)
87 | with self.test_session() as sess:
88 | approximate_mode_np, mus_np = sess.run([approximate_mode, mus])
89 | # The approximate mode should be the mean of the zero-th (most likely)
90 | # component.
91 | self.assertAllClose(approximate_mode_np, mus_np[:, 0, :])
92 |
93 |
94 | if __name__ == '__main__':
95 | tf.test.main()
96 |
--------------------------------------------------------------------------------
/layers/resnet_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Tensor2Robot Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for tensor2robot.layers.resnet."""
17 |
18 | import functools
19 | from absl.testing import parameterized
20 | from six.moves import range
21 | from tensor2robot.layers import resnet
22 | import tensorflow.compat.v1 as tf
23 |
24 |
25 | class ResnetTest(tf.test.TestCase, parameterized.TestCase):
26 |
27 | @parameterized.parameters(('',), ('fubar',), ('dummy/scope'))
28 | def test_intermediate_values(self, scope):
29 | with tf.variable_scope(scope):
30 | image = tf.zeros((2, 224, 224, 3), dtype=tf.float32)
31 | end_points = resnet.resnet_model(image,
32 | is_training=True,
33 | num_classes=1001,
34 | return_intermediate_values=True)
35 | tensors = ['initial_conv', 'initial_max_pool', 'pre_final_pool',
36 | 'final_reduce_mean', 'final_dense']
37 | tensors += [
38 | 'block_layer{}'.format(i + 1) for i in range(4)]
39 | self.assertEqual(set(tensors), set(end_points.keys()))
40 |
41 | @parameterized.parameters(
42 | (18, [True, True, True, True]),
43 | (50, [True, False, True, False]))
44 | def test_film(self, resnet_size, enabled_blocks):
45 | image = tf.zeros((2, 224, 224, 3), dtype=tf.float32)
46 | embedding = tf.zeros((2, 100), dtype=tf.float32)
47 | film_generator_fn = functools.partial(
48 | resnet.linear_film_generator, enabled_block_layers=enabled_blocks)
49 | _ = resnet.resnet_model(image,
50 | is_training=True,
51 | num_classes=1001,
52 | resnet_size=resnet_size,
53 | return_intermediate_values=True,
54 | film_generator_fn=film_generator_fn,
55 | film_generator_input=embedding)
56 |
57 | def test_malformed_film_raises(self):
58 | image = tf.zeros((2, 224, 224, 3), dtype=tf.float32)
59 | embedding = tf.zeros((2, 100), dtype=tf.float32)
60 | film_generator_fn = functools.partial(
61 | resnet.linear_film_generator, enabled_block_layers=[True]*5)
62 | with self.assertRaises(ValueError):
63 | _ = resnet.resnet_model(image,
64 | is_training=True,
65 | num_classes=1001,
66 | resnet_size=18,
67 | return_intermediate_values=True,
68 | film_generator_fn=film_generator_fn,
69 | film_generator_input=embedding)
70 |
71 | if __name__ == '__main__':
72 | tf.test.main()
73 |
--------------------------------------------------------------------------------
/layers/snail.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Tensor2Robot Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Implementation of building blocks from https://arxiv.org/abs/1707.03141.
17 |
18 | Implementation here is designed to match pseudocode in the paper.
19 | """
20 |
21 | from typing import Text
22 |
23 | import numpy as np
24 | from six.moves import range
25 | import tensorflow.compat.v1 as tf
26 | from tensorflow.contrib import layers
27 |
28 |
29 | def CausalConv(x, dilation_rate, filters, kernel_size=2, scope = ""):
30 | """Performs causal dilated 1D convolutions.
31 |
32 | Args:
33 | x : Tensor of shape (batch_size, steps, input_dim).
34 | dilation_rate: Dilation rate of convolution.
35 | filters: Number of convolution filters.
36 | kernel_size: Width of convolution kernel. SNAIL paper uses 2 for all
37 | experiments.
38 | scope: Variable scope for this layer.
39 | Returns:
40 | y: Tensor of shape (batch_size, new_steps, D).
41 | """
42 | with tf.variable_scope(scope):
43 | causal_pad_size = (kernel_size - 1) * dilation_rate
44 | # Pad sequence dimension.
45 | x = tf.pad(x, [[0, 0], [causal_pad_size, 0], [0, 0]])
46 | return layers.conv1d(
47 | x,
48 | filters,
49 | kernel_size=kernel_size,
50 | padding="VALID",
51 | rate=dilation_rate)
52 |
53 |
54 | def DenseBlock(x, dilation_rate, filters, scope = ""):
55 | r"""SNAIL \'dense block\' with gated activation and concatenation.
56 |
57 | Args:
58 | x : Tensor of shape [batch, time, channels].
59 | dilation_rate: Dilation rate of convolution.
60 | filters: Number of convolution filters.
61 | scope: Variable scope for this layer.
62 | Returns:
63 | y: Tensor of shape [batch, time, channels + filters].
64 | """
65 | with tf.variable_scope(scope):
66 | xf = CausalConv(x, dilation_rate, filters, scope="xf")
67 | xg = CausalConv(x, dilation_rate, filters, scope="xg")
68 | activations = tf.nn.tanh(xf) * tf.nn.sigmoid(xg)
69 | return tf.concat([x, activations], axis=2)
70 |
71 |
72 | def TCBlock(x, sequence_length, filters, scope = ""):
73 | """A stack of DenseBlocks with exponentially increasing dilations.
74 |
75 | Args:
76 | x : Tensor of shape [batch, sequence_length, channels].
77 | sequence_length: Sequence length of x.
78 | filters: Number of convolution filters.
79 | scope: Variable scope for this layer.
80 | Returns:
81 | y: Tensor of shape [batch, sequence_length, channels + filters].
82 | """
83 | with tf.variable_scope(scope):
84 | for i in range(1, int(np.ceil(np.log2(sequence_length)))+1):
85 | x = DenseBlock(x, 2**i, filters, scope="DenseBlock_%d" % i)
86 | return x
87 |
88 |
89 | def CausallyMaskedSoftmax(x):
90 | """Causally masked Softmax. Zero out probabilities before and after norm.
91 |
92 | pre-softmax logits are masked by setting upper diagonal to -inf:
93 |
94 | |a 0, 0| |0, -inf, -inf|
95 | |b, d, 0| + |0, 0, -inf|
96 | |c, e, f| |0, 0, 0 |
97 |
98 | Args:
99 | x: Batched tensor of shape [batch_size, T, T].
100 | Returns:
101 | Softmax where each row corresponds to softmax vector for each query.
102 | """
103 | lower_diag = tf.linalg.band_part(x, -1, 0)
104 | upper_diag = -np.inf * tf.ones_like(x)
105 | upper_diag = tf.linalg.band_part(upper_diag, 0, -1)
106 | upper_diag = tf.linalg.set_diag(
107 | upper_diag, tf.zeros_like(tf.linalg.diag_part(x)))
108 | x = lower_diag + upper_diag
109 | softmax = tf.nn.softmax(x)
110 | return tf.linalg.band_part(softmax, -1, 0)
111 |
112 |
113 | def AttentionBlock(x, key_size, value_size, scope = ""):
114 | """Self-attention key-value lookup, styled after Vaswani et al. '17.
115 |
116 | query and key are of shape [T, K]. query * transpose(key) yields logits of
117 | shape [T, T]. logits[i, j] corresponds to unnormalized attention vector over
118 | values [T, V] for each timestep i. Because this attention is over a set of
119 | temporal values, we causally mask the pre-softmax logits[i, j] := 0, for all
120 | j > i.
121 |
122 | Citations:
123 | Vaswani et al. '17: Attention is All you need
124 | https://arxiv.org/abs/1706.03762.
125 |
126 | Args:
127 | x: Input tensor of shape [batch, sequence_length, channels].
128 | key_size: Integer key dimensionality.
129 | value_size: Integer value dimensionality.
130 | scope: Variable scope for this layer.
131 | Returns:
132 | result: Tensor of shape [batch, sequence_length, channels + value_size]
133 | end_points: Dictionary of intermediate values (e.g. debugging).
134 | """
135 | end_points = {}
136 | with tf.variable_scope(scope):
137 | key = layers.fully_connected(x, key_size, activation_fn=None) # [T, K]
138 | query = layers.fully_connected(x, key_size, activation_fn=None) # [T, K]
139 | logits = tf.matmul(query, key, transpose_b=True) # [T, T]
140 | # Useful for visualizing attention alignment matrices.
141 | probs = CausallyMaskedSoftmax(logits/np.sqrt(key_size)) # [T, T]
142 | end_points["attn_prob"] = probs
143 | values = layers.fully_connected(x, value_size, activation_fn=None) # [T, V]
144 | read = tf.matmul(probs, values) # [T, V]
145 | result = tf.concat([x, read], axis=2) # [T, K + V]
146 | return result, end_points
147 |
--------------------------------------------------------------------------------
/layers/snail_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Tensor2Robot Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for SNAIL."""
17 |
18 | import numpy as np
19 | from six.moves import range
20 | from tensor2robot.layers import snail
21 | import tensorflow.compat.v1 as tf
22 |
23 |
24 | class SNAILTest(tf.test.TestCase):
25 |
26 | def test_CausalConv(self):
27 | x = tf.random.normal((4, 10, 8))
28 | y = snail.CausalConv(x, 1, 5)
29 | self.assertEqual(y.shape, (4, 10, 5))
30 |
31 | def test_DenseBlock(self):
32 | x = tf.random.normal((4, 10, 8))
33 | y = snail.DenseBlock(x, 1, 5)
34 | self.assertEqual(y.shape, (4, 10, 13))
35 |
36 | def test_TCBlock(self):
37 | sequence_length = 10
38 | x = tf.random.normal((4, sequence_length, 8))
39 | y = snail.TCBlock(x, sequence_length, 5)
40 | self.assertEqual(y.shape, (4, 10, 8 + 4*5))
41 |
42 | def test_CausallyMaskedSoftmax(self):
43 | num_rows = 5
44 | x = tf.random.normal((num_rows, 3))
45 | logits = tf.matmul(x, tf.linalg.transpose(x))
46 | y = snail.CausallyMaskedSoftmax(logits)
47 | with self.test_session() as sess:
48 | y_ = sess.run(y)
49 | idx = np.triu_indices(num_rows, 1)
50 | np.testing.assert_array_equal(y_[idx], 0.)
51 | # Testing that each row sums to 1.
52 | for i in range(num_rows):
53 | np.testing.assert_almost_equal(np.sum(y_[i, :]), 1.0)
54 |
55 | def test_AttentionBlock(self):
56 | x = tf.random.normal((4, 10, 8))
57 | y, end_points = snail.AttentionBlock(x, 3, 5)
58 | self.assertEqual(y.shape, (4, 10, 5+8))
59 | self.assertEqual(end_points['attn_prob'].shape, (4, 10, 10))
60 |
61 | if __name__ == '__main__':
62 | tf.test.main()
63 |
--------------------------------------------------------------------------------
/layers/spatial_softmax.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Tensor2Robot Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """TensorFlow impl of Spatial Softmax layers. (spatial soft arg-max).
17 |
18 | TODO(T2R_CONTRIBUTORS) - consider replacing with contrib version.
19 | """
20 |
21 | import gin
22 | import numpy as np
23 | from six.moves import range
24 | import tensorflow.compat.v1 as tf
25 | import tensorflow_probability as tfp
26 |
27 |
28 | @gin.configurable
29 | def BuildSpatialSoftmax(features, spatial_gumbel_softmax=False):
30 | """Computes the spatial softmax of the input features.
31 |
32 | Args:
33 | features: A tensor of size [batch_size, num_rows, num_cols, num_features]
34 | spatial_gumbel_softmax: If set to True, samples locations stochastically
35 | rather than computing expected coordinates with respect to heatmap.
36 | Returns:
37 | A tuple of (expected_feature_points, softmax).
38 | expected_feature_points: A tensor of size
39 | [batch_size, num_features * 2]. These are the expected feature
40 | locations, i.e., the spatial softmax of feature_maps. The inner
41 | dimension is arranged as [x1, x2, x3 ... xN, y1, y2, y3, ... yN].
42 | softmax: A Tensor which is the softmax of the features.
43 | [batch_size, num_rows, num_cols, num_features].
44 | """
45 | _, num_rows, num_cols, num_features = features.get_shape().as_list()
46 |
47 | with tf.name_scope('SpatialSoftmax'):
48 | # Create tensors for x and y positions, respectively
49 | x_pos = np.empty([num_rows, num_cols], np.float32)
50 | y_pos = np.empty([num_rows, num_cols], np.float32)
51 |
52 | # Assign values to positions
53 | for i in range(num_rows):
54 | for j in range(num_cols):
55 | x_pos[i, j] = 2.0 * j / (num_cols - 1.0) - 1.0
56 | y_pos[i, j] = 2.0 * i / (num_rows - 1.0) - 1.0
57 |
58 | x_pos = tf.reshape(x_pos, [num_rows * num_cols])
59 | y_pos = tf.reshape(y_pos, [num_rows * num_cols])
60 |
61 | # We reorder the features (norm3) into the following order:
62 | # [batch_size, NUM_FEATURES, num_rows, num_cols]
63 | # This lets us merge the batch_size and num_features dimensions, in order
64 | # to compute spatial softmax as a single batch operation.
65 | features = tf.reshape(
66 | tf.transpose(features, [0, 3, 1, 2]), [-1, num_rows * num_cols])
67 |
68 | if spatial_gumbel_softmax:
69 | # Temperature is hard-coded for now, make this more flexible if results
70 | # are promising.
71 | dist = tfp.distributions.RelaxedOneHotCategorical(
72 | temperature=1.0, logits=features)
73 | softmax = dist.sample()
74 | else:
75 | softmax = tf.nn.softmax(features)
76 | # Element-wise multiplication
77 | x_output = tf.multiply(x_pos, softmax)
78 | y_output = tf.multiply(y_pos, softmax)
79 | # Sum per out_size x out_size
80 | x_output = tf.reduce_sum(x_output, [1], keep_dims=True)
81 | y_output = tf.reduce_sum(y_output, [1], keep_dims=True)
82 | # Concatenate x and y, and reshape.
83 | expected_feature_points = tf.reshape(
84 | tf.concat([x_output, y_output], 1), [-1, num_features*2])
85 | softmax = tf.transpose(
86 | tf.reshape(softmax, [-1, num_features, num_rows,
87 | num_cols]), [0, 2, 3, 1])
88 | return expected_feature_points, softmax
89 |
--------------------------------------------------------------------------------
/layers/spatial_softmax_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Tensor2Robot Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for Spatial Softmax Layer."""
17 |
18 | import numpy as np
19 | from tensor2robot.layers import spatial_softmax
20 | import tensorflow.compat.v1 as tf
21 |
22 |
23 | class SpatialSoftmaxTest(tf.test.TestCase):
24 |
25 | def test_SpatialGumbelSoftmax(self):
26 |
27 | features = tf.convert_to_tensor(
28 | np.random.normal(size=(32, 16, 16, 64)).astype(np.float32))
29 | with tf.variable_scope('mean_pool'):
30 | expected_feature_points, softmax = spatial_softmax.BuildSpatialSoftmax(
31 | features, spatial_gumbel_softmax=False)
32 | with tf.variable_scope('gumbel_pool'):
33 | gumbel_feature_points, gumbel_softmax = (
34 | spatial_softmax.BuildSpatialSoftmax(
35 | features, spatial_gumbel_softmax=True))
36 | self.assertEqual(expected_feature_points.shape, gumbel_feature_points.shape)
37 | self.assertEqual(softmax.shape, gumbel_softmax.shape)
38 |
39 | if __name__ == '__main__':
40 | tf.test.main()
41 |
--------------------------------------------------------------------------------
/layers/tec_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Tensor2Robot Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for tensor2robot.layers.tec."""
17 |
18 | from tensor2robot.layers import tec
19 | import tensorflow.compat.v1 as tf
20 |
21 |
22 | class TECTest(tf.test.TestCase):
23 |
24 | def test_embed_condition_images(self):
25 | images = tf.random.normal((4, 100, 100, 3))
26 | embedding = tec.embed_condition_images(
27 | images, 'test_embed', fc_layers=(100, 20))
28 | self.assertEqual(embedding.shape.as_list(), [4, 20])
29 |
30 | def test_doubly_batched_embed_condition_images(self):
31 | doubly_batched_images = tf.random.normal((3, 4, 10, 12, 3))
32 | with self.assertRaises(ValueError):
33 | tec.embed_condition_images(doubly_batched_images, 'test_embed')
34 |
35 | def test_reduce_temporal_embeddings(self):
36 | temporal_embeddings = tf.random.normal((4, 20, 16))
37 | embedding = tec.reduce_temporal_embeddings(
38 | temporal_embeddings, 10, 'test_reduce')
39 | self.assertEqual(embedding.shape.as_list(), [4, 10])
40 |
41 | def test_doubly_batched_reduce_temporal_embeddings(self):
42 | temporal_embeddings = tf.random.normal((2, 4, 20, 16))
43 | with self.assertRaises(ValueError):
44 | tec.reduce_temporal_embeddings(temporal_embeddings, 10, 'test_reduce')
45 |
46 | def test_contrastive_loss(self):
47 | inf_embeddings = tf.nn.l2_normalize(
48 | tf.ones((5, 1, 10), dtype=tf.float32), axis=-1)
49 | con_embeddings = tf.nn.l2_normalize(
50 | tf.ones((5, 1, 10), dtype=tf.float32), axis=-1)
51 | tec.compute_embedding_contrastive_loss(inf_embeddings, con_embeddings)
52 |
53 |
54 | if __name__ == '__main__':
55 | tf.test.main()
56 |
--------------------------------------------------------------------------------
/meta_learning/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Tensor2Robot Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
--------------------------------------------------------------------------------
/meta_learning/meta_example.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Tensor2Robot Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Utility function for dealing with meta-examples.
17 | """
18 |
19 | from typing import List, Union
20 |
21 | import six
22 | import tensorflow.compat.v1 as tf # tf
23 |
24 | Example = Union[tf.train.Example, tf.train.SequenceExample]
25 |
26 |
27 | def make_meta_example(
28 | condition_examples,
29 | inference_examples,
30 | ):
31 | """Creates a single MetaExample from train_examples and val_examples."""
32 | if isinstance(condition_examples[0], tf.train.Example):
33 | meta_example = tf.train.Example()
34 | append_fn = append_example
35 | else:
36 | meta_example = tf.train.SequenceExample()
37 | append_fn = append_sequence_example
38 | for i, train_example in enumerate(condition_examples):
39 | append_fn(meta_example, train_example, 'condition_ep{:d}'.format(i))
40 |
41 | for i, val_example in enumerate(inference_examples):
42 | append_fn(meta_example, val_example, 'inference_ep{:d}'.format(i))
43 | return meta_example
44 |
45 |
46 | def append_example(example, ep_example, prefix):
47 | """Add episode Example to Meta TFExample with a prefix."""
48 | context_feature_map = example.features.feature
49 | for key, feature in six.iteritems(ep_example.features.feature):
50 | context_feature_map[six.ensure_str(prefix) + '/' +
51 | six.ensure_str(key)].CopyFrom(feature)
52 |
53 |
54 | def append_sequence_example(meta_example, ep_example, prefix):
55 | """Add episode SequenceExample to the Meta SequenceExample with a prefix."""
56 | context_feature_map = meta_example.context.feature
57 | # Append context features.
58 | for key, feature in six.iteritems(ep_example.context.feature):
59 | context_feature_map[six.ensure_str(prefix) + '/' +
60 | six.ensure_str(key)].CopyFrom(feature)
61 | # Append Sequential features.
62 | sequential_feature_map = meta_example.feature_lists.feature_list
63 | for key, feature_list in six.iteritems(ep_example.feature_lists.feature_list):
64 | sequential_feature_map[six.ensure_str(prefix) + '/' +
65 | six.ensure_str(key)].CopyFrom(feature_list)
66 |
--------------------------------------------------------------------------------
/meta_learning/meta_tf_models_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Tensor2Robot Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for learning.estimator_models.meta_learning.meta_tf_models."""
17 |
18 | from tensor2robot.meta_learning import meta_tf_models
19 | from tensor2robot.preprocessors import abstract_preprocessor
20 | from tensor2robot.utils import tensorspec_utils
21 | import tensorflow.compat.v1 as tf
22 |
23 |
24 | class MockBasePreprocessor(abstract_preprocessor.AbstractPreprocessor):
25 |
26 | def _get_feature_specification(self):
27 | spec = tensorspec_utils.TensorSpecStruct()
28 | spec.action = tensorspec_utils.ExtendedTensorSpec(
29 | name='action', shape=(1,), dtype=tf.float32)
30 | spec.velocity = tensorspec_utils.ExtendedTensorSpec(
31 | name='velocity', shape=(1,), dtype=tf.float32, is_optional=True)
32 | return spec
33 |
34 | def get_in_feature_specification(self):
35 | return self._get_feature_specification()
36 |
37 | def get_out_feature_specification(self):
38 | return self._get_feature_specification()
39 |
40 | def _get_label_specification(self):
41 | spec = tensorspec_utils.TensorSpecStruct()
42 | spec.target = tensorspec_utils.ExtendedTensorSpec(
43 | name='target', shape=(1,), dtype=tf.float32)
44 | spec.proxy = tensorspec_utils.ExtendedTensorSpec(
45 | name='proxy', shape=(1,), dtype=tf.float32, is_optional=True)
46 | return spec
47 |
48 | def get_in_label_specification(self):
49 | return self._get_label_specification()
50 |
51 | def get_out_label_specification(self):
52 | return self._get_label_specification()
53 |
54 | def _preprocess_fn(self, features, labels, unused_mode):
55 | return features, labels
56 |
57 |
58 | class MetaTfModelsTest(tf.test.TestCase):
59 |
60 | def test_meta_preprocessor_required_specs(self):
61 | meta_preprocessor = meta_tf_models.MetaPreprocessor(
62 | base_preprocessor=MockBasePreprocessor(),
63 | num_train_samples_per_task=1,
64 | num_val_samples_per_task=1)
65 | ref_feature_spec = meta_preprocessor.get_in_feature_specification()
66 | filtered_feature_spec = tensorspec_utils.filter_required_flat_tensor_spec(
67 | meta_preprocessor.get_in_feature_specification())
68 | self.assertDictEqual(ref_feature_spec, filtered_feature_spec)
69 |
70 | ref_label_spec = meta_preprocessor.get_in_label_specification()
71 | filtered_label_spec = tensorspec_utils.filter_required_flat_tensor_spec(
72 | meta_preprocessor.get_in_label_specification())
73 | self.assertDictEqual(ref_label_spec, filtered_label_spec)
74 |
75 |
76 | if __name__ == '__main__':
77 | tf.test.main()
78 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Tensor2Robot Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
--------------------------------------------------------------------------------
/policies/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Tensor2Robot Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
--------------------------------------------------------------------------------
/predictors/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Tensor2Robot Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
--------------------------------------------------------------------------------
/predictors/abstract_predictor.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Tensor2Robot Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """An abstract predictor to load tf models and expose a predict function."""
17 |
18 | import abc
19 | from typing import Dict, Optional, Text
20 |
21 | import numpy as np
22 | import six
23 | from tensor2robot.utils import tensorspec_utils
24 |
25 |
26 | class AbstractPredictor(six.with_metaclass(abc.ABCMeta, object)):
27 | """A predictor responsible to load a T2RModel and expose a predict function.
28 |
29 | The purpose of the predictor is to abstract model loading and running, e.g.
30 | using a raw session interface, a tensorflow predictor created from saved
31 | models or tensorflow 2.0 models.
32 | """
33 |
34 | @abc.abstractmethod
35 | def predict(self, features):
36 | """Predicts based on feature input using the loaded model.
37 |
38 | Args:
39 | features: A dict containing the features used for predictions.
40 | Returns:
41 | The result of the queried model predictions.
42 | """
43 |
44 | @abc.abstractmethod
45 | def get_feature_specification(self):
46 | """Exposes the required input features for evaluation of the model."""
47 |
48 | def get_label_specification(self
49 | ):
50 | """Exposes the optional labels for evaluation of the model."""
51 | return None
52 |
53 | @abc.abstractmethod
54 | def restore(self):
55 | """Restores the model parameters from the latest available data."""
56 |
57 | def init_randomly(self):
58 | """Initializes model parameters from with random values."""
59 |
60 | @abc.abstractmethod
61 | def close(self):
62 | """Closes all open handles used throughout model evaluation."""
63 |
64 | @abc.abstractmethod
65 | def assert_is_loaded(self):
66 | """Raises a ValueError if the predictor has not been restored yet."""
67 |
68 | @property
69 | def model_version(self):
70 | """The version of the model currently in use."""
71 | return 0
72 |
73 | @property
74 | def global_step(self):
75 | """The global step of the model currently in use."""
76 | return 0
77 |
78 | @property
79 | def model_path(self):
80 | """The path of the model currently in use."""
81 | return ''
82 |
--------------------------------------------------------------------------------
/predictors/checkpoint_predictor_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Tensor2Robot Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for py.tensor2robot.predictors.checkpoint_predictor_test."""
17 |
18 | from absl import flags
19 | import gin
20 | import numpy as np
21 | from tensor2robot.input_generators import default_input_generator
22 | from tensor2robot.predictors import checkpoint_predictor
23 | from tensor2robot.utils import mocks
24 | from tensor2robot.utils import tensorspec_utils
25 | from tensor2robot.utils import train_eval
26 | import tensorflow.compat.v1 as tf
27 | from tensorflow.compat.v1 import estimator as tf_estimator
28 |
29 | FLAGS = flags.FLAGS
30 |
31 | _BATCH_SIZE = 2
32 | _MAX_TRAIN_STEPS = 3
33 |
34 |
35 | class CheckpointPredictorTest(tf.test.TestCase):
36 |
37 | def setUp(self):
38 | super(CheckpointPredictorTest, self).setUp()
39 | gin.clear_config()
40 | gin.parse_config('tf.estimator.RunConfig.save_checkpoints_steps=1')
41 |
42 | def test_predictor(self):
43 | input_generator = default_input_generator.DefaultRandomInputGenerator(
44 | batch_size=_BATCH_SIZE)
45 | model_dir = self.create_tempdir().full_path
46 | mock_model = mocks.MockT2RModel()
47 | train_eval.train_eval_model(
48 | t2r_model=mock_model,
49 | input_generator_train=input_generator,
50 | max_train_steps=_MAX_TRAIN_STEPS,
51 | model_dir=model_dir)
52 |
53 | predictor = checkpoint_predictor.CheckpointPredictor(
54 | t2r_model=mock_model, checkpoint_dir=model_dir, use_gpu=False)
55 | with self.assertRaises(ValueError):
56 | predictor.predict({'does_not_matter': np.zeros(1)})
57 | self.assertEqual(predictor.model_version, -1)
58 | self.assertEqual(predictor.global_step, -1)
59 | self.assertTrue(predictor.restore())
60 | self.assertGreater(predictor.model_version, 0)
61 | self.assertEqual(predictor.global_step, 3)
62 | ref_feature_spec = mock_model.preprocessor.get_in_feature_specification(
63 | tf_estimator.ModeKeys.PREDICT)
64 | tensorspec_utils.assert_equal(predictor.get_feature_specification(),
65 | ref_feature_spec)
66 | features = tensorspec_utils.make_random_numpy(
67 | ref_feature_spec, batch_size=_BATCH_SIZE)
68 | predictions = predictor.predict(features)
69 | self.assertLen(predictions, 1)
70 | self.assertCountEqual(sorted(predictions.keys()), ['logit'])
71 | self.assertEqual(predictions['logit'].shape, (2, 1))
72 |
73 | def test_predictor_timeout(self):
74 | mock_model = mocks.MockT2RModel()
75 | predictor = checkpoint_predictor.CheckpointPredictor(
76 | t2r_model=mock_model,
77 | checkpoint_dir='/random/path/which/does/not/exist',
78 | timeout=1)
79 | self.assertFalse(predictor.restore())
80 |
81 | def test_predictor_raises(self):
82 | mock_model = mocks.MockT2RModel()
83 | # Raises because no checkpoint_dir and has been set and restore is called.
84 | predictor = checkpoint_predictor.CheckpointPredictor(t2r_model=mock_model)
85 | with self.assertRaises(ValueError):
86 | predictor.restore()
87 |
88 |
89 | if __name__ == '__main__':
90 | tf.test.main()
91 |
--------------------------------------------------------------------------------
/predictors/ensemble_exported_savedmodel_predictor_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Tensor2Robot Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for tensor2robot.predictors.ensemble_exported_savedmodel_predictor."""
17 |
18 | import os
19 |
20 | from absl import flags
21 | from absl.testing import parameterized
22 | import gin
23 | import numpy as np
24 | from tensor2robot.input_generators import default_input_generator
25 | from tensor2robot.predictors import ensemble_exported_savedmodel_predictor
26 | from tensor2robot.utils import mocks
27 | from tensor2robot.utils import tensorspec_utils
28 | from tensor2robot.utils import train_eval
29 | import tensorflow.compat.v1 as tf
30 | from tensorflow.compat.v1 import estimator as tf_estimator
31 |
32 | FLAGS = flags.FLAGS
33 |
34 | _EXPORT_DIR = 'asyn_export'
35 |
36 | _BATCH_SIZE = 2
37 | _MAX_TRAIN_STEPS = 3
38 | _MAX_EVAL_STEPS = 2
39 |
40 |
41 | class ExportedSavedmodelPredictorTest(tf.test.TestCase, parameterized.TestCase):
42 |
43 | def setUp(self):
44 | super(ExportedSavedmodelPredictorTest, self).setUp()
45 | gin.clear_config()
46 | gin.parse_config('tf.estimator.RunConfig.save_checkpoints_steps=1')
47 |
48 | def test_predictor_with_default_exporter(self):
49 | input_generator = default_input_generator.DefaultRandomInputGenerator(
50 | batch_size=_BATCH_SIZE)
51 | model_dir = self.create_tempdir().full_path
52 | mock_model = mocks.MockT2RModel()
53 | train_eval.train_eval_model(
54 | t2r_model=mock_model,
55 | input_generator_train=input_generator,
56 | input_generator_eval=input_generator,
57 | max_train_steps=_MAX_TRAIN_STEPS,
58 | eval_steps=_MAX_EVAL_STEPS,
59 | model_dir=model_dir,
60 | create_exporters_fn=train_eval.create_default_exporters)
61 | # Create ensemble by duplicating the same directory multiple times.
62 | export_dirs = ','.join(
63 | [os.path.join(model_dir, 'export', 'latest_exporter_numpy')] * 2)
64 | predictor = ensemble_exported_savedmodel_predictor.EnsembleExportedSavedModelPredictor(
65 | export_dirs=export_dirs, local_export_root=None, ensemble_size=2)
66 | predictor.resample_ensemble()
67 | with self.assertRaises(ValueError):
68 | predictor.get_feature_specification()
69 | with self.assertRaises(ValueError):
70 | predictor.predict({'does_not_matter': np.zeros(1)})
71 | with self.assertRaises(ValueError):
72 | _ = predictor.model_version
73 | self.assertEqual(predictor.global_step, -1)
74 | self.assertTrue(predictor.restore(is_async=False))
75 | self.assertGreater(predictor.model_version, 0)
76 | self.assertEqual(predictor.global_step, -1)
77 | ref_feature_spec = mock_model.preprocessor.get_in_feature_specification(
78 | tf_estimator.ModeKeys.PREDICT)
79 | tensorspec_utils.assert_equal(predictor.get_feature_specification(),
80 | ref_feature_spec)
81 | features = tensorspec_utils.make_random_numpy(
82 | ref_feature_spec, batch_size=_BATCH_SIZE)
83 | predictions = predictor.predict(features)
84 | self.assertLen(predictions, 1)
85 | self.assertCountEqual(predictions.keys(), ['logit'])
86 | self.assertEqual(predictions['logit'].shape, (2, 1))
87 |
88 |
89 | if __name__ == '__main__':
90 | tf.test.main()
91 |
--------------------------------------------------------------------------------
/predictors/saved_model_v2_predictor_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Tensor2Robot Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for tensor2robot.predictors.saved_model_v2_predictor."""
17 |
18 | import os
19 | import numpy as np
20 |
21 | from tensor2robot.predictors import saved_model_v2_predictor
22 | from tensor2robot.proto import t2r_pb2
23 | from tensor2robot.utils import mocks
24 | from tensor2robot.utils import tensorspec_utils
25 |
26 | from tensorflow import estimator as tf_estimator
27 | from tensorflow.compat.v1 import estimator as tf_compat_v1_estimator
28 | import tensorflow.compat.v2 as tf
29 |
30 | _BATCH_SIZE = 2
31 |
32 |
33 | def setUpModule():
34 | tf.enable_v2_behavior()
35 |
36 |
37 | def _generate_assets(model, export_dir):
38 | in_feature_spec = model.get_feature_specification_for_packing(
39 | mode=tf_estimator.ModeKeys.PREDICT)
40 | in_label_spec = model.get_label_specification_for_packing(
41 | mode=tf_compat_v1_estimator.ModeKeys.PREDICT)
42 |
43 | in_feature_spec = tensorspec_utils.filter_required_flat_tensor_spec(
44 | in_feature_spec)
45 | in_label_spec = tensorspec_utils.filter_required_flat_tensor_spec(
46 | in_label_spec)
47 |
48 | t2r_assets = t2r_pb2.T2RAssets()
49 | t2r_assets.feature_spec.CopyFrom(in_feature_spec.to_proto())
50 | t2r_assets.label_spec.CopyFrom(in_label_spec.to_proto())
51 | t2r_assets_dir = os.path.join(export_dir,
52 | tensorspec_utils.EXTRA_ASSETS_DIRECTORY)
53 |
54 | tf.io.gfile.makedirs(t2r_assets_dir)
55 | t2r_assets_filename = os.path.join(t2r_assets_dir,
56 | tensorspec_utils.T2R_ASSETS_FILENAME)
57 | tensorspec_utils.write_t2r_assets_to_file(t2r_assets, t2r_assets_filename)
58 |
59 |
60 | class SavedModelV2PredictorTest(tf.test.TestCase):
61 |
62 | def __init__(self, *args, **kwargs):
63 | super(SavedModelV2PredictorTest, self).__init__(*args, **kwargs)
64 | self._saved_model_path = None
65 |
66 | def _save_model(self, model, sample_features):
67 | if self._saved_model_path:
68 | return self._saved_model_path
69 |
70 | # Save inference_network_fn as the predict method for the saved_model.
71 | @tf.function(autograph=False)
72 | def predict(features):
73 | return model.inference_network_fn(features, None,
74 | tf_compat_v1_estimator.ModeKeys.PREDICT)
75 |
76 | # Call the model for the tf.function tracing side effects.
77 | predict(sample_features)
78 | model.predict = predict
79 |
80 | self._saved_model_path = self.create_tempdir().full_path
81 | tf.saved_model.save(model, self._saved_model_path)
82 | _generate_assets(model, self._saved_model_path)
83 | return self._saved_model_path
84 |
85 | def _test_predictor(self, predictor_cls, multi_dataset):
86 | mock_model = mocks.MockTF2T2RModel(multi_dataset=multi_dataset)
87 |
88 | # Generate a sample to evaluate
89 | feature_spec = mock_model.preprocessor.get_in_feature_specification(
90 | tf_compat_v1_estimator.ModeKeys.PREDICT)
91 | sample_features = tensorspec_utils.make_random_numpy(
92 | feature_spec, batch_size=_BATCH_SIZE)
93 |
94 | # Generate a saved model and load it.
95 | path = self._save_model(mock_model, sample_features)
96 | saved_model_predictor = predictor_cls(path)
97 |
98 | # Not restored yet.
99 | with self.assertRaises(ValueError):
100 | saved_model_predictor.predict(sample_features)
101 |
102 | saved_model_predictor.restore()
103 |
104 | # Validate evaluations are the same afterwards.
105 | original_model_out = mock_model.inference_network_fn(
106 | sample_features, None, tf_compat_v1_estimator.ModeKeys.PREDICT)
107 |
108 | predictor_out = saved_model_predictor.predict(sample_features)
109 |
110 | np.testing.assert_almost_equal(original_model_out['logits'],
111 | predictor_out['logits'])
112 |
113 | def testTF1PredictorSingleDataset(self):
114 | self._test_predictor(saved_model_v2_predictor.SavedModelTF1Predictor, False)
115 |
116 | def testTF1PredictorMultiDataset(self):
117 | self._test_predictor(saved_model_v2_predictor.SavedModelTF1Predictor, True)
118 |
119 | def testTF2PredictorSingleDataset(self):
120 | self._test_predictor(saved_model_v2_predictor.SavedModelTF2Predictor, False)
121 |
122 | def testTF2PredictorMultiDataset(self):
123 | self._test_predictor(saved_model_v2_predictor.SavedModelTF2Predictor, True)
124 |
125 |
126 | if __name__ == '__main__':
127 | tf.test.main()
128 |
--------------------------------------------------------------------------------
/preprocessors/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Tensor2Robot Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Preprocessor implementations for estimator models."""
17 | from tensor2robot.preprocessors import abstract_preprocessor
18 |
--------------------------------------------------------------------------------
/preprocessors/abstract_preprocessor_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Tensor2Robot Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for tensor2robot.preprocessors.abstract_preprocessor."""
17 |
18 | from tensor2robot.preprocessors import abstract_preprocessor
19 | import tensorflow.compat.v1 as tf
20 |
21 |
22 | class AbstractPreprocessorTest(tf.test.TestCase):
23 |
24 | def test_init_abstract(self):
25 | with self.assertRaises(TypeError):
26 | abstract_preprocessor.AbstractPreprocessor()
27 |
28 |
29 | if __name__ == '__main__':
30 | tf.test.main()
31 |
--------------------------------------------------------------------------------
/preprocessors/noop_preprocessor.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Tensor2Robot Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """A simple no operation preprocessor."""
17 |
18 | from typing import Optional, Tuple
19 |
20 | import gin
21 | from tensor2robot.preprocessors import abstract_preprocessor
22 | from tensor2robot.utils import tensorspec_utils
23 | from tensorflow.compat.v1 import estimator as tf_estimator
24 |
25 |
26 | @gin.configurable
27 | class NoOpPreprocessor(abstract_preprocessor.AbstractPreprocessor):
28 | """A convenience preprocessor which does not perform any preprocessing.
29 |
30 | This prerpocessor provides convenience functionality in case we simply want
31 | to ensure that already single examples contain the right information for our
32 | model. This preprocessor does not perform any preprocessing, but allows
33 | existing models to initialize a preprocessor without any additional runtime
34 | overhead.
35 | """
36 |
37 | def get_in_feature_specification(
38 | self, mode):
39 | """The specification for the input features for the preprocess_fn.
40 |
41 | Arguments:
42 | mode: mode key for this feature specification
43 | Returns:
44 | A TensorSpecStruct describing the required and optional tensors.
45 | """
46 | return tensorspec_utils.flatten_spec_structure(
47 | self._model_feature_specification_fn(mode))
48 |
49 | def get_in_label_specification(
50 | self, mode):
51 | """The specification for the input labels for the preprocess_fn.
52 |
53 | Arguments:
54 | mode: mode key for this feature specification
55 | Returns:
56 | A TensorSpecStruct describing the required and optional tensors.
57 | """
58 | return tensorspec_utils.flatten_spec_structure(
59 | self._model_label_specification_fn(mode))
60 |
61 | def get_out_feature_specification(
62 | self, mode):
63 | """The specification for the output features after executing preprocess_fn.
64 |
65 | Arguments:
66 | mode: mode key for this feature specification
67 | Returns:
68 | A TensorSpecStruct describing the required and optional tensors.
69 | """
70 | return tensorspec_utils.flatten_spec_structure(
71 | self._model_feature_specification_fn(mode))
72 |
73 | def get_out_label_specification(
74 | self, mode):
75 | """The specification for the output labels after executing preprocess_fn.
76 |
77 | Arguments:
78 | mode: mode key for this feature specification
79 | Returns:
80 | A TensorSpecStruct describing the required and optional tensors.
81 | """
82 | return tensorspec_utils.flatten_spec_structure(
83 | self._model_label_specification_fn(mode))
84 |
85 | def _preprocess_fn(
86 | self, features,
87 | labels,
88 | mode
89 | ):
90 | """The preprocessing function which will be executed prior to the fn.
91 |
92 | As the name NoOpPreprocessor suggests, we do not perform any prerprocessing.
93 |
94 | Args:
95 | features: The input features extracted from a single example in our
96 | in_features_specification format.
97 | labels: (Optional None) The input labels extracted from a single example
98 | in our in_features_specification format.
99 | mode: (ModeKeys) Specifies if this is training, evaluation or prediction.
100 |
101 | Returns:
102 | features: The preprocessed features, potentially adding
103 | additional tensors derived from the input features.
104 | labels: (Optional) The preprocessed labels, potentially
105 | adding additional tensors derived from the input features and labels.
106 | """
107 | return features, labels
108 |
--------------------------------------------------------------------------------
/proto/t2r.proto:
--------------------------------------------------------------------------------
1 | // Copyright 2024 The Tensor2Robot Authors.
2 | //
3 | // Licensed under the Apache License, Version 2.0 (the "License");
4 | // you may not use this file except in compliance with the License.
5 | // You may obtain a copy of the License at
6 | //
7 | // http://www.apache.org/licenses/LICENSE-2.0
8 | //
9 | // Unless required by applicable law or agreed to in writing, software
10 | // distributed under the License is distributed on an "AS IS" BASIS,
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | // See the License for the specific language governing permissions and
13 | // limitations under the License.
14 |
15 | syntax = "proto2";
16 |
17 | package third_party.py.tensor2robot;
18 |
19 | message ExtendedTensorSpec {
20 | // This message allows to (de)serialize tensorspec_utils.ExtendedTensorSpec.
21 | // Each field has a one to one mapping to the class constructor.
22 | repeated int32 shape = 1;
23 | optional int32 dtype = 2;
24 | optional string name = 3;
25 | optional bool is_optional = 4;
26 | optional bool is_extracted = 5;
27 | optional string data_format = 6;
28 | optional string dataset_key = 7;
29 | optional float varlen_default_value = 8;
30 | }
31 |
32 | message TensorSpecStruct {
33 | // This message allows to (de)serialize tensorspec_utils.TensorSpecStruct.
34 | // This structure is essentially an OrderedDict which is therefore
35 | // serializable through a key: value map.
36 | map key_value = 1;
37 | }
38 |
39 | message T2RAssets {
40 | optional TensorSpecStruct feature_spec = 1;
41 | optional TensorSpecStruct label_spec = 2;
42 | optional int32 global_step = 3;
43 | }
44 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py>=0.5.0
2 | numpy>=1.13.3
3 | tensorflow>=1.13.0
4 | tensorflow-serving-api>=1.13.0
5 | gin-config>=0.1.4
6 | pybullet==2.5.0
7 | Pillow==5.3.0
8 | gym>=0.10.9
9 | tensorflow-probability>=0.6.0
10 | tf-slim>=1.0
11 |
--------------------------------------------------------------------------------
/research/bcz/README.md:
--------------------------------------------------------------------------------
1 | # BC-Z
2 |
3 | Source codes for reproducing "BC-Z: Zero-Shot Task Generalization with Robotic
4 | Imitation Learning".
5 |
6 | Links
7 |
8 | - [Project Website](https://sites.google.com/view/bc-z/home/)
9 | - [Paper](https://arxiv.org/abs/2202.02005)
10 | - [Google AI Blog Post](https://ai.googleblog.com/2022/02/can-robots-follow-instructions-for-new.html)
11 |
12 | ## Training the Model
13 |
14 | Download the data in the TFRecords format is open-sourced here:
15 | https://www.kaggle.com/google/bc-z-robot
16 |
17 | ```
18 | TRAIN_DATA="/path/to/bcz-21task_v9.0.1.tfrecord/train*,/path/to/bcz-79task_v16.0.0.tfrecord/train*"
19 | EVAL_DATA="/path/to/bcz-21task_v9.0.1.tfrecord/val*,/path/to/bcz-79task_v16.0.0.tfrecord/val*"
20 | python3 -m tensor2robot.bin.run_t2r_trainer --logtostderr \
21 | --gin_configs="tensor2robot/research/bcz/configs/run_train_bc_langcond_trajectory.gin" \
22 | --gin_bindings="train_eval_model.model_dir='/tmp/bcz/'" \
23 | --gin_bindings="TRAIN_DATA='${TRAIN_DATA}' \
24 | --gin_bindings="EVAL_DATA='${EVAL_DATA}'"
25 | ```
26 |
--------------------------------------------------------------------------------
/research/bcz/configs/common_imagedistortions.gin:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Tensor2Robot Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | ApplyPhotometricImageDistortions.random_brightness = True
17 | ApplyPhotometricImageDistortions.random_saturation = True
18 | ApplyPhotometricImageDistortions.random_hue = True
19 | ApplyPhotometricImageDistortions.random_contrast = True
20 |
--------------------------------------------------------------------------------
/research/bcz/configs/common_imports.gin:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Tensor2Robot Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import tensor2robot.input_generators.default_input_generator
17 | import tensor2robot.utils.train_eval
18 | import tensor2robot.models.abstract_model
19 | import tensor2robot.layers.bcz_networks
20 | import tensor2robot.research.bcz.model
21 | import tensor2robot.hooks.golden_values_hook_builder
22 | import tensor2robot.hooks.gin_config_hook_builder
23 | import tensor2robot.predictors.exported_savedmodel_predictor
24 |
--------------------------------------------------------------------------------
/research/bcz/configs/run_train_bc_gtcond_trajectory.gin:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Tensor2Robot Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | include 'tensor2robot/research/bcz/configs/common_imports.gin'
17 | include 'tensor2robot/research/bcz/configs/common_imagedistortions.gin'
18 |
19 | #######################################
20 | # INPUT GENERATION
21 | #######################################
22 |
23 | TRAIN_DATA=""
24 | EVAL_DATA=""
25 |
26 | STATE_COMPONENTS = []
27 | ACTION_COMPONENTS = [
28 | ('xyz', 3, True, 100.),
29 | ('axis_angle', 3, True, 10.),
30 | ('target_close', 1, False, 0.5), # best m104 model used this param.
31 | ]
32 |
33 | NUM_WAYPOINTS = 10
34 |
35 | TRAIN_BATCH_SIZE = 32
36 | EVAL_BATCH_SIZE = 64
37 | TRAIN_FRACTION = 1.0
38 |
39 | TRAIN_INPUT_GENERATOR = @train_input_generator/FractionalRecordInputGenerator()
40 | train_input_generator/FractionalRecordInputGenerator.file_patterns = %TRAIN_DATA
41 | train_input_generator/FractionalRecordInputGenerator.batch_size = %TRAIN_BATCH_SIZE
42 | train_input_generator/FractionalRecordInputGenerator.file_fraction = %TRAIN_FRACTION
43 |
44 | EVAL_INPUT_GENERATOR = @eval_input_generator/DefaultRecordInputGenerator()
45 | eval_input_generator/DefaultRecordInputGenerator.file_patterns = %EVAL_DATA
46 | eval_input_generator/DefaultRecordInputGenerator.batch_size = %EVAL_BATCH_SIZE
47 |
48 | #######################################
49 | # PREPROCESSOR & DATA AUGMENTATION
50 | #######################################
51 | CROP_SIZE = 450
52 | IMAGE_SIZE = 150
53 | BCZPreprocessor.binarize_gripper = True
54 | BCZPreprocessor.crop_size = (%CROP_SIZE, %CROP_SIZE)
55 | BCZPreprocessor.image_size = (%IMAGE_SIZE, %IMAGE_SIZE)
56 |
57 | #######################################
58 | # MODEL
59 | #######################################
60 |
61 | BCZModel.image_size = (%IMAGE_SIZE, %IMAGE_SIZE)
62 | BCZModel.network_fn = @resnet_film_network
63 | BCZModel.predict_stop = False
64 | MultiHeadMLP.stop_gradient_future_waypoints = False
65 |
66 | resnet_film_network.film_generator_fn = @linear_film_generator
67 | resnet_model.resnet_size = 18
68 | BCZModel.ignore_task_embedding = False
69 | BCZModel.task_embedding_noise_std = 0.1
70 | linear_film_generator.enabled_block_layers = [True, True, True, True]
71 | MultiHeadMLP.stop_gradient_future_waypoints = False
72 |
73 | BCZPreprocessor.cutout_size = 0 # Was 20 in paper, but not implemented in OSS.
74 | resnet_film_network.fc_layers = (256, 256)
75 |
76 | compute_stop_state_loss.class_weights = [[1.0309278350515465, 0, 33.333333333333336]]
77 | BCZModel.num_past = 0
78 | BCZModel.num_waypoints = %NUM_WAYPOINTS
79 | BCZModel.summarize_gradients = False
80 | BCZModel.state_components = %STATE_COMPONENTS
81 | BCZModel.action_components = %ACTION_COMPONENTS
82 | resnet_model.resnet_size = 18
83 | train_eval_model.t2r_model = @BCZModel()
84 |
85 | #####################################
86 | # TRAINING
87 | ######################################
88 | default_create_optimizer_fn.learning_rate = %LEARNING_RATE
89 | LEARNING_RATE = 2.5e-4
90 |
91 | train_eval_model.max_train_steps = 50000
92 | train_eval_model.eval_steps = 1000
93 | train_eval_model.eval_throttle_secs = 300 # Export model every 5 min.
94 | train_eval_model.input_generator_train = %TRAIN_INPUT_GENERATOR
95 | train_eval_model.input_generator_eval = %EVAL_INPUT_GENERATOR
96 | train_eval_model.create_exporters_fn = @create_default_exporters
97 | train_eval_model.chief_train_hook_builders = [@OperativeGinConfigLoggerHookBuilder()]
98 | create_default_exporters.exports_to_keep = None # Keep all ckpts to support evaluation.
99 |
100 | # Export best numpy models based on arm joint loss.
101 | create_valid_result_smaller.result_key = 'mean_first_xyz_error'
102 |
103 | # Save checkpoints frequently for evaluation.
104 | tf.estimator.RunConfig.save_summary_steps = 1000 # save summary every n global steps
105 |
106 |
--------------------------------------------------------------------------------
/research/bcz/configs/run_train_bc_langcond_trajectory.gin:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Tensor2Robot Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | include 'tensor2robot/research/bcz/configs/run_train_bc_gtcond_trajectory.gin'
17 |
18 | TRAIN_INPUT_GENERATOR = @train_input_generator/WeightedRecordInputGenerator()
19 | train_input_generator/WeightedRecordInputGenerator.file_patterns = %TRAIN_DATA
20 | train_input_generator/WeightedRecordInputGenerator.batch_size = %TRAIN_BATCH_SIZE
21 | train_input_generator/WeightedRecordInputGenerator.weights = [0.1, 0.9] # 10% 21-task data, 90% 83-task data.
22 | train_input_generator/WeightedRecordInputGenerator.seed = 0
23 |
24 | BCZModel.cond_modality = %ConditionMode.LANGUAGE_EMBEDDING
25 | BCZModel.task_embedding_noise_std = 0.1
26 | BCZPreprocessor.binarize_gripper = False
27 | BCZPreprocessor.rescale_gripper = True
28 |
29 | IMAGE_SIZE = 200
30 |
--------------------------------------------------------------------------------
/research/bcz/model_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Tensor2Robot Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests BCZ models with placeholder data."""
17 |
18 | import itertools
19 | import os
20 | from absl import flags
21 | from absl.testing import parameterized
22 | import gin
23 | from tensor2robot.research.bcz import model
24 | from tensor2robot.utils.t2r_test_fixture import T2RModelFixture
25 | import tensorflow.compat.v1 as tf
26 | from tensorflow.compat.v1 import estimator as tf_estimator
27 |
28 | FLAGS = flags.FLAGS
29 | TRAIN = tf_estimator.ModeKeys.TRAIN
30 | _POSE_COMPONENTS_LIST = list(itertools.product(*[
31 | [True, False], [True, False], ['axis_angle', 'quaternion'], [False]]))
32 |
33 |
34 | class BCZModelTest(tf.test.TestCase, parameterized.TestCase):
35 |
36 | def setUp(self):
37 | gin.clear_config()
38 | super(BCZModelTest, self).setUp()
39 | self._fixture = T2RModelFixture(test_case=self, use_tpu=False)
40 |
41 | @parameterized.parameters(
42 | (model.spatial_softmax_network),
43 | (model.resnet_film_network),
44 | )
45 | def test_network_fn(self, network_fn):
46 | model_name = 'BCZModel'
47 | gin.bind_parameter(
48 | 'BCZModel.network_fn', network_fn)
49 | gin.parse_config('BCZPreprocessor.mock_subtask = True')
50 | gin.parse_config(
51 | 'resnet_film_network.film_generator_fn = @linear_film_generator')
52 | self._fixture.random_train(model, model_name)
53 |
54 | def test_all_components(self):
55 | """Train with all pose components."""
56 | model_name = 'BCZModel'
57 | pose_components = [
58 | ('xyz', 3, True, 100.),
59 | ('quaternion', 4, False, 10.),
60 | ('axis_angle', 3, True, 10.),
61 | ('arm_joints', 7, True, 1.),
62 | ('target_close', 1, False, 1.),
63 | ]
64 | gin.bind_parameter(
65 | 'BCZModel.action_components', pose_components)
66 | gin.parse_config('BCZPreprocessor.mock_subtask = True')
67 | gin.parse_config(
68 | 'resnet_film_network.film_generator_fn = @linear_film_generator')
69 | self._fixture.random_train(model, model_name)
70 |
71 | @parameterized.parameters(*_POSE_COMPONENTS_LIST)
72 | def test_pose_components(self,
73 | residual_xyz,
74 | residual_angle,
75 | angle_format,
76 | residual_gripper):
77 | """Tests with different action configurations."""
78 | model_name = 'BCZModel'
79 | if angle_format == 'axis_angle':
80 | angle_size = 3
81 | elif angle_format == 'quaternion':
82 | angle_size = 4
83 | action_components = [
84 | ('xyz', 3, residual_xyz, 100.),
85 | (angle_format, angle_size, residual_angle, 10.),
86 | ('target_close', 1, residual_gripper, 1.),
87 | ]
88 | gin.bind_parameter(
89 | 'BCZModel.action_components', action_components)
90 | gin.bind_parameter(
91 | 'BCZModel.state_components', [])
92 | gin.parse_config('BCZPreprocessor.mock_subtask = True')
93 | gin.parse_config(
94 | 'resnet_film_network.film_generator_fn = @linear_film_generator')
95 | self._fixture.random_train(model, model_name)
96 |
97 | def test_random_train(self):
98 | base_dir = 'tensor2robot'
99 |
100 | gin_config = os.path.join(
101 | FLAGS.test_srcdir, base_dir, 'research/bcz/configs',
102 | 'run_train_bc_langcond_trajectory.gin')
103 | model_name = 'BCZModel'
104 | gin_bindings = [
105 | 'train_eval_model.eval_steps = 1',
106 | 'EVAL_INPUT_GENERATOR=None',
107 | ]
108 | gin.parse_config_files_and_bindings(
109 | [gin_config], gin_bindings, finalize_config=False)
110 | self._fixture.random_train(model, model_name)
111 |
112 | def tearDown(self):
113 | gin.clear_config()
114 | super(BCZModelTest, self).tearDown()
115 |
116 |
117 | if __name__ == '__main__':
118 | tf.test.main()
119 |
--------------------------------------------------------------------------------
/research/bcz/pose_components_lib.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Tensor2Robot Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Action space definitions for MetaTidy models.
17 | """
18 |
19 | from typing import Text, Tuple
20 |
21 | # Name, size, whether it is residual or not, and loss weight.
22 | # This is used to parameterize action labels.
23 | ActionComponent = Tuple[Text, int, bool, float]
24 |
25 | # Name, size, whether residual or not.
26 | # This is used to parameterize proprioceptive state inputs.
27 | StateComponent = Tuple[Text, int, bool]
28 |
29 | DEFAULT_STATE_COMPONENTS = []
30 | DEFAULT_ACTION_COMPONENTS = [
31 | ('xyz', 3, True, 100.),
32 | ('quaternion', 4, False, 10.),
33 | ('target_close', 1, False, 1.),
34 | ]
35 |
--------------------------------------------------------------------------------
/research/dql_grasping_lib/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Tensor2Robot Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
--------------------------------------------------------------------------------
/research/dql_grasping_lib/tf_modules.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Tensor2Robot Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Reused modules for building actors/critics for grasping task.
17 | """
18 |
19 | import gin
20 | import tensorflow.compat.v1 as tf
21 | from tensorflow.contrib import slim
22 |
23 |
24 | @gin.configurable
25 | def argscope(is_training=None, normalizer_fn=slim.layer_norm):
26 | """Default TF argscope used for convnet-based grasping models.
27 |
28 | Args:
29 | is_training: Whether this argscope is for training or inference.
30 | normalizer_fn: Which conv/fc normalizer to use.
31 | Returns:
32 | Dictionary of argument overrides.
33 | """
34 | with slim.arg_scope([slim.batch_norm, slim.dropout], is_training=is_training):
35 | with slim.arg_scope(
36 | [slim.conv2d, slim.fully_connected],
37 | weights_initializer=tf.truncated_normal_initializer(stddev=0.01),
38 | activation_fn=tf.nn.relu,
39 | normalizer_fn=normalizer_fn):
40 | with slim.arg_scope(
41 | [slim.conv2d, slim.max_pool2d], stride=2, padding='VALID') as scope:
42 | return scope
43 |
44 |
45 | def tile_to_match_context(net, context):
46 | """Tiles net along a new axis=1 to match context.
47 |
48 | Repeats minibatch elements of `net` tensor to match multiple corresponding
49 | minibatch elements from `context`.
50 | Args:
51 | net: Tensor of shape [num_batch_net, ....].
52 | context: Tensor of shape [num_batch_net, num_examples, context_size].
53 | Returns:
54 | Tensor of shape [num_batch_net, num_examples, ...], where each minibatch
55 | element of net has been tiled M times where M = num_batch_context /
56 | num_batch_net.
57 | """
58 | with tf.name_scope('tile_to_context'):
59 | num_samples = tf.shape(context)[1]
60 | net_examples = tf.expand_dims(net, 1) # [batch_size, 1, ...]
61 |
62 | net_ndim = len(net_examples.get_shape().as_list())
63 | # Tile net by num_samples in axis=1.
64 | multiples = [1]*net_ndim
65 | multiples[1] = num_samples
66 | net_examples = tf.tile(net_examples, multiples)
67 | return net_examples
68 |
69 |
70 | def add_context(net, context):
71 | """Merges visual perception with context using elementwise addition.
72 |
73 | Actions are reshaped to match net dimension depth-wise, and are added to
74 | the conv layers by broadcasting element-wise across H, W extent.
75 |
76 | Args:
77 | net: Tensor of shape [batch_size, H, W, C].
78 | context: Tensor of shape [batch_size * num_examples, C].
79 | Returns:
80 | Tensor with shape [batch_size * num_examples, H, W, C]
81 | """
82 | num_batch_net = tf.shape(net)[0]
83 | _, h, w, d1 = net.get_shape().as_list()
84 | _, d2 = context.get_shape().as_list()
85 | assert d1 == d2
86 | context = tf.reshape(context, [num_batch_net, -1, d2])
87 | net_examples = tile_to_match_context(net, context)
88 | # Flatten first two dimensions.
89 | net = tf.reshape(net_examples, [-1, h, w, d1])
90 | context = tf.reshape(context, [-1, 1, 1, d2])
91 | context = tf.tile(context, [1, h, w, 1])
92 | net = tf.add_n([net, context])
93 | return net
94 |
--------------------------------------------------------------------------------
/research/grasp2vec/README.md:
--------------------------------------------------------------------------------
1 | # Grasp2Vec
2 |
3 | Source codes for reproducing "Grasp2Vec: Learning Object Representations from
4 | Self-Supervised Grasping".
5 |
6 | Links
7 |
8 | - [Project Website](https://sites.google.com/site/grasp2vec/)
9 | - [Paper](https://arxiv.org/abs/1811.06964)
10 | - [Google AI Blog Post](https://ai.googleblog.com/2018/12/grasp2vec-learning-object.html)
11 |
12 | ## Authors
13 |
14 | Eric Jang*1, Coline Devin*2, Vincent
15 | Vanhoucke1, Sergey Levine12
16 |
17 | *Equal Contribution, 1 Google Brain, 2UC
18 | Berkeley
19 |
20 | ## Training the Model
21 |
22 | Data is not included in this repository, so you will have to provide your own
23 | training/eval datasets of TFRecords. The Grasp2Vec T2R model attempts to parse
24 | the following Feature spec from the data, before cropping and resizing the
25 | parsed images:
26 |
27 | ```
28 | tspec.pregrasp_image = TensorSpec(shape=(512, 640, 3),
29 | dtype=tf.uint8, name='image', data_format='jpeg')
30 | tspec.postgrasp_image = TensorSpec(
31 | shape=(512, 640, 3), dtype=tf.uint8, name='postgrasp_image',
32 | data_format='jpeg')
33 | tspec.goal_image = TensorSpec(
34 | shape=(512, 640, 3), dtype=tf.uint8, name='present_image',
35 | data_format='jpeg')
36 | ```
37 |
38 | Note that `image`, `postgrasp_image`, `present_image` are the names of features
39 | stored in the TFExample feature map.
40 |
41 | ```
42 | python3 -m tensor2robot.bin.run_t2r_trainer --logtostderr \
43 | --gin_configs="tensor2robot/research/grasp2vec/configs/train_grasp2vec.gin" \
44 | --gin_bindings="train_eval_model.model_dir='/tmp/grasp2vec/'" \
45 | --gin_bindings="TRAIN_DATA='/path/to/your/data/train*' \
46 | --gin_bindings="EVAL_DATA='/path/to/your/data/val*'"
47 | ```
48 |
49 | Tensorboard will show heatmap localization visualization summaries as shown in
50 | the paper.
51 |
--------------------------------------------------------------------------------
/research/grasp2vec/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Tensor2Robot Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
--------------------------------------------------------------------------------
/research/grasp2vec/configs/common_imports.gin:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Tensor2Robot Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import tensor2robot.input_generators.default_input_generator
17 | import tensor2robot.utils.train_eval
18 | import tensor2robot.models.abstract_model
19 | import tensor2robot.research.grasp2vec.grasp2vec_model
20 |
--------------------------------------------------------------------------------
/research/grasp2vec/configs/train_grasp2vec.gin:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Tensor2Robot Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | include 'tensor2robot/research/grasp2vec/configs/common_imports.gin'
17 |
18 | ######## INPUT GENERATION
19 |
20 | TRAIN_DATA="/path/to/your/data/train*"
21 | EVAL_DATA="/path/to/your/data/val*"
22 |
23 | TRAIN_BATCH_SIZE = 8
24 | EVAL_BATCH_SIZE = 8
25 |
26 | TRAIN_INPUT_GENERATOR = @train_input_generator/DefaultRecordInputGenerator()
27 | train_input_generator/DefaultRecordInputGenerator.file_patterns = %TRAIN_DATA
28 | train_input_generator/DefaultRecordInputGenerator.batch_size = %TRAIN_BATCH_SIZE
29 |
30 | EVAL_INPUT_GENERATOR = @eval_input_generator/DefaultRecordInputGenerator()
31 | eval_input_generator/DefaultRecordInputGenerator.file_patterns = %EVAL_DATA
32 | eval_input_generator/DefaultRecordInputGenerator.batch_size = %EVAL_BATCH_SIZE
33 |
34 | #######################################
35 | # MODEL
36 | #######################################
37 |
38 | train_eval_model.t2r_model = @Grasp2VecModel()
39 |
40 | default_create_optimizer_fn.learning_rate = 0.0001
41 |
42 | #####################################
43 | # TRAINING
44 | ######################################
45 |
46 | train_eval_model.max_train_steps = 50000
47 | train_eval_model.eval_steps = 200
48 | train_eval_model.input_generator_train = %TRAIN_INPUT_GENERATOR
49 | train_eval_model.input_generator_eval = %EVAL_INPUT_GENERATOR
50 | train_eval_model.create_exporters_fn = @create_default_exporters
51 |
--------------------------------------------------------------------------------
/research/grasp2vec/networks.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Tensor2Robot Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Implements forward pass for embeddings.
17 | """
18 |
19 | from tensor2robot.research.grasp2vec import resnet
20 | import tensorflow.compat.v1 as tf
21 | from tensorflow.compat.v1 import estimator as tf_estimator
22 |
23 |
24 | def Embedding(image, mode, params, reuse=tf.AUTO_REUSE, scope='scene'):
25 | """Implements scene or goal embedding.
26 |
27 | Args:
28 | image: Batch of images corresponding to scene or goal.
29 | mode: Mode is tf.estimator.ModeKeys.EVAL, TRAIN, or PREDICT (unused).
30 | params: Hyperparameters for the network.
31 | reuse: Reuse parameter for variable scope.
32 | scope: The variable_scope to use for the variables.
33 | Returns:
34 | A tuple (batch of summed embeddings, batch of embedding maps).
35 | """
36 | del params
37 | is_training = mode == tf_estimator.ModeKeys.TRAIN
38 | with tf.variable_scope(scope, reuse=reuse):
39 | scene = resnet.get_resnet50_spatial(image, is_training)
40 | scene = tf.nn.relu(scene)
41 | summed_scene = tf.reduce_mean(scene, axis=[1, 2])
42 | return summed_scene, scene
43 |
--------------------------------------------------------------------------------
/research/pose_env/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Tensor2Robot Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
--------------------------------------------------------------------------------
/research/pose_env/configs/common_imports.gin:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Tensor2Robot Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import tensor2robot.input_generators.default_input_generator
17 | import tensor2robot.meta_learning.meta_policies
18 | import tensor2robot.meta_learning.run_meta_env
19 | import tensor2robot.utils.train_eval
20 | import tensor2robot.models.abstract_model
21 | import tensor2robot.utils.continuous_collect_eval
22 | import tensor2robot.research.pose_env.episode_to_transitions
23 | import tensor2robot.research.pose_env.pose_env
24 | import tensor2robot.research.pose_env.pose_env_maml_models
25 | import tensor2robot.research.pose_env.pose_env_models
26 | import tensor2robot.utils.writer
27 | import tensor2robot.research.dql_grasping_lib.run_env
28 |
--------------------------------------------------------------------------------
/research/pose_env/configs/run_random_collect.gin:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Tensor2Robot Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | include 'tensor2robot/research/pose_env/configs/common_imports.gin'
17 |
18 | collect_eval_loop.collect_env = @train/PoseToyEnv()
19 | collect_eval_loop.eval_env = None
20 | collect_eval_loop.run_agent_fn = @run_meta_env
21 | collect_eval_loop.policy_class = @PoseEnvRandomPolicy
22 | collect_eval_loop.num_collect = None # this is ignored
23 |
24 | train/PoseToyEnv.render_mode = 'DIRECT'
25 | train/PoseToyEnv.hidden_drift = True
26 |
27 | run_meta_env.num_adaptations_per_task = 1
28 | run_meta_env.num_tasks = 5000
29 | run_meta_env.num_episodes_per_adaptation = 4
30 |
31 | # Save out for visualization.
32 | run_meta_env.episode_to_transitions_fn = @episode_to_transitions_pose_toy
33 | run_meta_env.replay_writer = @TFRecordReplayWriter()
34 |
--------------------------------------------------------------------------------
/research/pose_env/configs/run_train_reg.gin:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Tensor2Robot Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | include 'tensor2robot/research/pose_env/configs/common_imports.gin'
17 |
18 | TRAIN_INPUT_GENERATOR = @train_input_generator/DefaultConstantInputGenerator()
19 | train_input_generator/DefaultConstantInputGenerator.constant_value = 1.0
20 | train_input_generator/DefaultRecordInputGenerator.batch_size = 64
21 |
22 | EVAL_INPUT_GENERATOR = @eval_input_generator/DefaultConstantInputGenerator()
23 | eval_input_generator/DefaultConstantInputGenerator.constant_value = 1.0
24 | eval_input_generator/DefaultRecordInputGenerator.batch_size = 64
25 |
26 | # Training - input generator and preprocessor numbers need to match up.
27 | train_eval_model.t2r_model = @PoseEnvRegressionModel()
28 | train_eval_model.max_train_steps = 5000
29 | train_eval_model.eval_steps = 1000
30 | train_eval_model.input_generator_train = %TRAIN_INPUT_GENERATOR
31 | train_eval_model.input_generator_eval = %EVAL_INPUT_GENERATOR
32 |
33 | # Collection & Evaluation.
34 | collect_eval_loop.collect_env = None
35 | collect_eval_loop.eval_env = @train/PoseToyEnv()
36 | collect_eval_loop.run_agent_fn = @run_meta_env
37 | collect_eval_loop.num_eval = 100
38 |
39 | train/PoseToyEnv.hidden_drift = True
40 |
41 | train/PoseToyEnv.render_mode = 'DIRECT'
42 |
43 | run_meta_env.num_episodes = 100
44 | run_meta_env.num_episodes_per_adaptation = 2
45 |
46 | collect_eval_loop.policy_class = @RegressionPolicy
47 | RegressionPolicy.t2r_model = @PoseEnvRegressionModel()
48 |
49 | # Save data out for visualization.
50 | run_meta_env.episode_to_transitions_fn = @episode_to_transitions_pose_toy
51 | run_meta_env.replay_writer = @TFRecordReplayWriter()
52 |
--------------------------------------------------------------------------------
/research/pose_env/configs/run_train_reg_maml.gin:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Tensor2Robot Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | include 'tensor2robot/research/pose_env/configs/common_imports.gin'
17 |
18 | # Input Pipeline
19 |
20 | TRAIN_INPUT_GENERATOR = @train_input_generator/DefaultConstantInputGenerator()
21 | train_input_generator/DefaultConstantInputGenerator.constant_value = 1.0
22 | train_input_generator/DefaultRecordInputGenerator.batch_size = 64
23 |
24 | EVAL_INPUT_GENERATOR = @eval_input_generator/DefaultConstantInputGenerator()
25 | eval_input_generator/DefaultConstantInputGenerator.constant_value = 1.0
26 | eval_input_generator/DefaultRecordInputGenerator.batch_size = 64
27 |
28 | # Model
29 |
30 | train_eval_model.t2r_model = @train/PoseEnvRegressionModelMAML()
31 | train_eval_model.max_train_steps = 5000
32 | train_eval_model.eval_steps = 100
33 | train_eval_model.input_generator_train = %TRAIN_INPUT_GENERATOR
34 | train_eval_model.input_generator_eval = %EVAL_INPUT_GENERATOR
35 |
36 | PoseEnvRegressionModelMAML.base_model = @PoseEnvRegressionModel()
37 | PoseEnvRegressionModelMAML.preprocessor_cls = @FixedLenMetaExamplePreprocessor
38 | train_eval_model.t2r_model = @train/PoseEnvRegressionModelMAML()
39 |
40 | FixedLenMetaExamplePreprocessor.num_condition_samples_per_task = 1
41 | FixedLenMetaExamplePreprocessor.num_inference_samples_per_task = 1
42 |
43 | # MAMLInnerLoopGradientDescent.learning_rate = 0.001
44 | # MAMLModel.num_inner_loop_steps = 4
45 |
46 | # Collection & Evaluation.
47 | collect_eval_loop.collect_env = None
48 | collect_eval_loop.eval_env = @train/PoseToyEnv()
49 | collect_eval_loop.run_agent_fn = @run_meta_env
50 | train/PoseToyEnv.hidden_drift = True
51 | train/PoseToyEnv.render_mode = 'DIRECT'
52 | collect_eval_loop.num_eval = 100
53 | run_meta_env.num_adaptations_per_task = 4
54 | run_meta_env.num_episodes_per_adaptation = 1
55 |
56 | collect_eval_loop.policy_class = @MAMLRegressionPolicy
57 | MAMLRegressionPolicy.t2r_model = @eval/PoseEnvRegressionModelMAML()
58 |
59 | # Save data out for visualization.
60 | run_meta_env.episode_to_transitions_fn = @episode_to_transitions_pose_toy
61 | run_meta_env.replay_writer = @TFRecordReplayWriter()
62 |
--------------------------------------------------------------------------------
/research/pose_env/episode_to_transitions.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Tensor2Robot Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Functions for converting env episode data to tfrecords of transitions."""
17 |
18 | import gin
19 | from PIL import Image
20 | from tensor2robot.utils import image
21 | import tensorflow.compat.v1 as tf
22 |
23 | _bytes_feature = (
24 | lambda v: tf.train.Feature(bytes_list=tf.train.BytesList(value=v)))
25 | _int64_feature = (
26 | lambda v: tf.train.Feature(int64_list=tf.train.Int64List(value=v)))
27 | _float_feature = (
28 | lambda v: tf.train.Feature(float_list=tf.train.FloatList(value=v)))
29 |
30 |
31 | @gin.configurable
32 | def episode_to_transitions_pose_toy(episode_data):
33 | """Converts pose toy env episode data to transition Examples."""
34 | # This is just saving data for a supervised regression problem, so obs_tp1
35 | # can be discarded.
36 | transitions = []
37 | for transition in episode_data:
38 | (obs_t, action, reward, obs_tp1, done, debug) = transition
39 | del obs_tp1
40 | del done
41 | features = {}
42 | obs_t = Image.fromarray(obs_t)
43 | features['state/image'] = _bytes_feature([image.jpeg_string(obs_t)])
44 | features['pose'] = _float_feature(action.flatten().tolist())
45 | features['reward'] = _float_feature([reward])
46 | features['target_pose'] = _float_feature(debug['target_pose'].tolist())
47 | transitions.append(
48 | tf.train.Example(features=tf.train.Features(feature=features)))
49 | return transitions
50 |
51 |
52 |
--------------------------------------------------------------------------------
/research/pose_env/pose_env_maml_models.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Tensor2Robot Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """MAML-based meta-learning models for the duck task."""
17 |
18 | import gin
19 | import numpy as np
20 | from tensor2robot.meta_learning import maml_model
21 | from tensor2robot.utils import tensorspec_utils
22 | from tensorflow.compat.v1 import estimator as tf_estimator
23 | from tensorflow.contrib import framework as contrib_framework
24 | nest = contrib_framework.nest
25 |
26 |
27 | @gin.configurable
28 | class PoseEnvRegressionModelMAML(maml_model.MAMLModel):
29 | """MAML Regression environment for duck task."""
30 |
31 | def _make_dummy_labels(self):
32 | """Helper function to make dummy labels for pack_labels."""
33 | label_spec = self._base_model.get_label_specification(
34 | tf_estimator.ModeKeys.TRAIN)
35 | reward_shape = tuple(label_spec.reward.shape)
36 | pose_shape = tuple(label_spec.target_pose.shape)
37 | dummy_reward = np.zeros(reward_shape).astype(np.float32)
38 | dummy_pose = np.zeros(pose_shape).astype(np.float32)
39 | return tensorspec_utils.TensorSpecStruct(
40 | reward=dummy_reward, target_pose=dummy_pose)
41 |
42 | def _select_inference_output(self, predictions):
43 | """Inference output selection for regression models."""
44 | # We select our output for inference.
45 | predictions.condition_output = (
46 | predictions.full_condition_output.inference_output)
47 | predictions.inference_output = (
48 | predictions.full_inference_output.inference_output)
49 | return predictions
50 |
51 | def pack_features(self, state, prev_episode_data, timestep):
52 | """Combines current state and conditioning data into MetaExample spec.
53 |
54 | See create_metaexample_spec for an example of the spec layout.
55 |
56 | If prev_episode_data does not contain enough episodes to fill
57 | num_condition_samples_per_task, we stuff dummy episodes with reward=0.5
58 | so that no inner gradients are applied.
59 |
60 | Args:
61 | state: VRGripperObservation containing image and pose.
62 | prev_episode_data: A list of episode data, each of which is a list of
63 | tuples containing transition data. Each transition tuple takes the form
64 | (obs, action, rew, new_obs, done, debug).
65 | timestep: Current episode timestep.
66 | Returns:
67 | TensorSpecStruct containing conditioning (features, labels)
68 | and inference (features) keys.
69 | Raises:
70 | ValueError: If no demonstration is provided.
71 | """
72 | meta_features = tensorspec_utils.TensorSpecStruct()
73 | meta_features['inference/features/state/0'] = state
74 | def pack_condition_features(episode_data, idx, dummy_values=False):
75 | """Pack previous episode data into condition_ep* features/labels.
76 |
77 | Args:
78 | episode_data: List of (obs, action, rew, new_obs, done, debug) tuples.
79 | idx: Index of the conditioning episode. 0 for demo, 1 for first trial,
80 | etc.
81 | dummy_values: If an episode is not available yet, set the loss_mask
82 | to 0.
83 |
84 | """
85 | transition = episode_data[0]
86 | meta_features['condition/features/state/%d' % idx] = transition[0]
87 | reward = np.array([transition[2]])
88 | reward = 2 * reward - 1
89 | if dummy_values:
90 | # success_weight of 0. = no gradients in inner loop for this batch.
91 | reward = np.array([0.])
92 | meta_features['condition/labels/target_pose/%d' % idx] = transition[1]
93 | meta_features['condition/labels/reward/%d' % idx] = reward.astype(
94 | np.float32)
95 |
96 | if prev_episode_data:
97 | pack_condition_features(prev_episode_data[0], 0)
98 | else:
99 | dummy_labels = self._make_dummy_labels()
100 | dummy_episode = [(state, dummy_labels.target_pose, dummy_labels.reward)]
101 | pack_condition_features(dummy_episode, 0, dummy_values=True)
102 | return nest.map_structure(lambda x: np.expand_dims(x, 0), meta_features)
103 |
--------------------------------------------------------------------------------
/research/pose_env/pose_env_models_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Tensor2Robot Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Integration tests for training pose_env models."""
17 |
18 | import os
19 |
20 | from absl.testing import absltest
21 | from absl.testing import parameterized
22 | import gin
23 | from tensor2robot.input_generators import default_input_generator
24 | from tensor2robot.meta_learning import meta_policies
25 | from tensor2robot.meta_learning import preprocessors
26 | from tensor2robot.predictors import checkpoint_predictor
27 | from tensor2robot.research.pose_env import pose_env
28 | from tensor2robot.research.pose_env import pose_env_maml_models
29 | from tensor2robot.research.pose_env import pose_env_models
30 | from tensor2robot.utils import train_eval
31 | from tensor2robot.utils import train_eval_test_utils
32 | import tensorflow.compat.v1 as tf # tf
33 |
34 |
35 | BATCH_SIZE = 1
36 | MAX_TRAIN_STEPS = 1
37 | EVAL_STEPS = 1
38 |
39 | NUM_TRAIN_SAMPLES_PER_TASK = 1
40 | NUM_VAL_SAMPLES_PER_TASK = 1
41 |
42 | FLAGS = tf.app.flags.FLAGS
43 |
44 |
45 | class PoseEnvModelsTest(parameterized.TestCase):
46 |
47 | def setUp(self):
48 | super(PoseEnvModelsTest, self).setUp()
49 | base_dir = 'tensor2robot'
50 | test_data = os.path.join(FLAGS.test_srcdir,
51 | base_dir,
52 | 'test_data/pose_env_test_data.tfrecord')
53 | self._train_log_dir = FLAGS.test_tmpdir
54 | if tf.io.gfile.exists(self._train_log_dir):
55 | tf.io.gfile.rmtree(self._train_log_dir)
56 | gin.bind_parameter('train_eval_model.max_train_steps', 3)
57 | gin.bind_parameter('train_eval_model.eval_steps', 2)
58 |
59 | self._record_input_generator = (
60 | default_input_generator.DefaultRecordInputGenerator(
61 | batch_size=BATCH_SIZE, file_patterns=test_data))
62 |
63 | self._meta_record_input_generator_train = (
64 | default_input_generator.DefaultRandomInputGenerator(
65 | batch_size=BATCH_SIZE))
66 | self._meta_record_input_generator_eval = (
67 | default_input_generator.DefaultRandomInputGenerator(
68 | batch_size=BATCH_SIZE))
69 |
70 | def test_mc(self):
71 | train_eval.train_eval_model(
72 | t2r_model=pose_env_models.PoseEnvContinuousMCModel(),
73 | input_generator_train=self._record_input_generator,
74 | input_generator_eval=self._record_input_generator,
75 | create_exporters_fn=None)
76 |
77 | def test_regression(self):
78 | train_eval.train_eval_model(
79 | t2r_model=pose_env_models.PoseEnvRegressionModel(),
80 | input_generator_train=self._record_input_generator,
81 | input_generator_eval=self._record_input_generator,
82 | create_exporters_fn=None)
83 |
84 | def test_regression_maml(self):
85 | maml_model = pose_env_maml_models.PoseEnvRegressionModelMAML(
86 | base_model=pose_env_models.PoseEnvRegressionModel())
87 | train_eval.train_eval_model(
88 | t2r_model=maml_model,
89 | input_generator_train=self._meta_record_input_generator_train,
90 | input_generator_eval=self._meta_record_input_generator_eval,
91 | create_exporters_fn=None)
92 |
93 | def _test_policy_interface(self, policy, restore=True):
94 | urdf_root = pose_env.get_pybullet_urdf_root()
95 | self.assertTrue(os.path.exists(urdf_root))
96 | env = pose_env.PoseToyEnv(
97 | urdf_root=urdf_root, render_mode='DIRECT')
98 | env.reset_task()
99 | obs = env.reset()
100 | if restore:
101 | policy.restore()
102 | policy.reset_task()
103 | action = policy.SelectAction(obs, None, 0)
104 |
105 | new_obs, rew, done, env_debug = env.step(action)
106 | episode_data = [[(obs, action, rew, new_obs, done, env_debug)]]
107 | policy.adapt(episode_data)
108 |
109 | policy.SelectAction(new_obs, None, 1)
110 |
111 | def test_regression_maml_policy_interface(self):
112 | t2r_model = pose_env_maml_models.PoseEnvRegressionModelMAML(
113 | base_model=pose_env_models.PoseEnvRegressionModel(),
114 | preprocessor_cls=preprocessors.FixedLenMetaExamplePreprocessor)
115 | predictor = checkpoint_predictor.CheckpointPredictor(t2r_model=t2r_model)
116 | predictor.init_randomly()
117 | policy = meta_policies.MAMLRegressionPolicy(t2r_model, predictor=predictor)
118 | self._test_policy_interface(policy, restore=False)
119 |
120 | @parameterized.parameters(
121 | ('run_train_reg_maml.gin',),
122 | ('run_train_reg.gin',))
123 | def test_train_eval_gin(self, gin_file):
124 | base_dir = 'tensor2robot'
125 | full_gin_path = os.path.join(
126 | FLAGS.test_srcdir, base_dir, 'research/pose_env/configs', gin_file)
127 | model_dir = os.path.join(FLAGS.test_tmpdir, 'test_train_eval_gin', gin_file)
128 | train_eval_test_utils.test_train_eval_gin(
129 | test_case=self,
130 | model_dir=model_dir,
131 | full_gin_path=full_gin_path,
132 | max_train_steps=MAX_TRAIN_STEPS,
133 | eval_steps=EVAL_STEPS)
134 |
135 |
136 | if __name__ == '__main__':
137 | absltest.main()
138 |
--------------------------------------------------------------------------------
/research/pose_env/pose_env_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Tensor2Robot Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for tensor2robot.research.pose_env.pose_env."""
17 |
18 | import os
19 | from absl.testing import absltest
20 | from six.moves import range
21 | from tensor2robot.research.pose_env import pose_env
22 |
23 |
24 | class PoseEnvTest(absltest.TestCase):
25 |
26 | def test_PoseEnv(self):
27 | urdf_root = pose_env.get_pybullet_urdf_root()
28 | self.assertTrue(os.path.exists(urdf_root))
29 | env = pose_env.PoseToyEnv(urdf_root=urdf_root)
30 | obs = env.reset()
31 | policy = pose_env.PoseEnvRandomPolicy()
32 | action, _ = policy.sample_action(obs, 0)
33 | for _ in range(3):
34 | obs, _, done, _ = env.step(action)
35 | if done:
36 | obs = env.reset()
37 |
38 | if __name__ == '__main__':
39 | absltest.main()
40 |
--------------------------------------------------------------------------------
/research/qtopt/README.md:
--------------------------------------------------------------------------------
1 | # QT-Opt
2 |
3 | This directory contains network architecture definitions for the Grasping critic
4 | architecture described in [QT-Opt: Scalable Deep Reinforcement Learning for
5 | Vision-Based Robotic Manipulation](https://arxiv.org/abs/1806.10293).
6 |
7 | ## Running the code
8 |
9 | The following command trains the QT-Opt critic architecture for a few gradient
10 | steps with mock data (real data is not included in this repo). The learning
11 | obective resembles supervised learning, since Bellman targets in QT-Opt are
12 | computed in a separate process (not open-sourced).
13 |
14 | ```
15 | git clone https://github.com/google/tensor2robot
16 | # Optional: Create a virtualenv
17 | python3 -m venv ~/venv
18 | source ~/venv/bin/activate
19 | pip install -r tensor2robot/requirements.txt
20 | python -m tensor2robot.research.qtopt.t2r_models_test
21 | ```
22 |
23 | ## PCGrad
24 |
25 | This directory also contains a multi-task optimization method
26 | [PCGrad](https://arxiv.org/pdf/2001.06782.pdf) that is implemented in the form
27 | of a optimization wrapper. This is based on the open-source implementation
28 | [here](https://github.com/tianheyu927/PCGrad/blob/master/PCGrad_tf.py).
29 |
--------------------------------------------------------------------------------
/research/qtopt/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Tensor2Robot Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
--------------------------------------------------------------------------------
/research/qtopt/optimizer_builder.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Tensor2Robot Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Build optimizer with the given hyperparamaters.
17 | """
18 |
19 | from absl import logging
20 | import tensorflow.compat.v1 as tf
21 | from tensorflow.contrib import opt as contrib_opt
22 | from tensorflow.contrib.tpu.python.tpu import tpu_function
23 |
24 |
25 | def BuildOpt(hparams):
26 | """Constructs the optimizer.
27 |
28 | Args:
29 | hparams: An instance of tf.HParams, with these parameters:
30 | - batch_size
31 | - examples_per_epoch
32 | - learning_rate
33 | - learning_rate_decay_factor
34 | - model_weights_averaging
35 | - momentum
36 | - num_epochs_per_decay
37 | - optimizer
38 | - rmsprop_decay
39 | - use_avg_model_params
40 |
41 | Returns:
42 | opt: The optimizer.
43 | """
44 | logging.info('Hyperparameters: %s', hparams)
45 | batch_size = hparams.batch_size
46 | examples_per_epoch = hparams.examples_per_epoch
47 | learning_rate_decay_factor = hparams.learning_rate_decay_factor
48 | learning_rate = hparams.learning_rate
49 | model_weights_averaging = hparams.model_weights_averaging
50 | momentum = hparams.momentum
51 | num_epochs_per_decay = hparams.num_epochs_per_decay
52 | optimizer = hparams.optimizer
53 | rmsprop_decay = hparams.rmsprop_decay
54 | rmsprop_epsilon = hparams.rmsprop_epsilon
55 | adam_beta2 = hparams.get('adam_beta2', 0.999)
56 | adam_epsilon = hparams.get('adam_epsilon', 1e-8)
57 | use_avg_model_params = hparams.use_avg_model_params
58 |
59 | global_step = tf.train.get_or_create_global_step()
60 |
61 | # Configure the learning rate using an exponetial decay.
62 | decay_steps = int(examples_per_epoch / batch_size *
63 | num_epochs_per_decay)
64 |
65 | learning_rate = tf.train.exponential_decay(
66 | learning_rate,
67 | global_step,
68 | decay_steps,
69 | learning_rate_decay_factor,
70 | staircase=True)
71 | if not tpu_function.get_tpu_context():
72 | tf.summary.scalar('Learning Rate', learning_rate)
73 |
74 | if optimizer == 'momentum':
75 | opt = tf.train.MomentumOptimizer(learning_rate, momentum)
76 | elif optimizer == 'rmsprop':
77 | opt = tf.train.RMSPropOptimizer(
78 | learning_rate,
79 | decay=rmsprop_decay,
80 | momentum=momentum,
81 | epsilon=rmsprop_epsilon)
82 | else:
83 | opt = tf.train.AdamOptimizer(
84 | learning_rate,
85 | beta1=momentum,
86 | beta2=adam_beta2,
87 | epsilon=adam_epsilon)
88 |
89 | if use_avg_model_params:
90 | # Callers of BuildOpt() with use_avg_model_params=True expect the
91 | # MovingAverageOptimizer to be the last optimizer returned by this function
92 | # so that the swapping_saver can be constructed from it.
93 | return contrib_opt.MovingAverageOptimizer(
94 | opt, average_decay=model_weights_averaging)
95 |
96 | return opt
97 |
--------------------------------------------------------------------------------
/research/qtopt/pcgrad_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Tensor2Robot Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for tensor2robot.research.qtopt.pcgrad."""
17 |
18 | from absl.testing import parameterized
19 | import numpy as np
20 | from tensor2robot.research.qtopt import pcgrad
21 | import tensorflow.compat.v1 as tf
22 |
23 |
24 | class PcgradTest(tf.test.TestCase, parameterized.TestCase):
25 |
26 | @parameterized.parameters(
27 | (None, None, [0, 1]),
28 | (None, ['*var*'], [0, 1]),
29 | (['second*'], None, [0]),
30 | (None, ['first*'], [0]),
31 | (None, ['*0'], [0]),
32 | (['first*'], None, [1]),
33 | (['*var*'], None, []),
34 | )
35 | def testPCgradBasic(self,
36 | denylist,
37 | allowlist,
38 | pcgrad_var_idx):
39 | tf.disable_eager_execution()
40 | for dtype in [tf.dtypes.float32, tf.dtypes.float64]:
41 | with self.session(graph=tf.Graph()):
42 | var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
43 | var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
44 | const0_np = np.array([1., 0.], dtype=dtype.as_numpy_dtype)
45 | const1_np = np.array([-1., -1.], dtype=dtype.as_numpy_dtype)
46 | const2_np = np.array([-1., 1.], dtype=dtype.as_numpy_dtype)
47 |
48 | var0 = tf.Variable(var0_np, dtype=dtype, name='first_var/var0')
49 | var1 = tf.Variable(var1_np, dtype=dtype, name='second_var/var1')
50 | const0 = tf.constant(const0_np)
51 | const1 = tf.constant(const1_np)
52 | const2 = tf.constant(const2_np)
53 | loss0 = tf.tensordot(var0, const0, 1) + tf.tensordot(var1, const2, 1)
54 | loss1 = tf.tensordot(var0, const1, 1) + tf.tensordot(var1, const0, 1)
55 |
56 | learning_rate = lambda: 0.001
57 | opt = tf.train.GradientDescentOptimizer(learning_rate)
58 | losses = loss0 + loss1
59 | opt_grads = opt.compute_gradients(losses, var_list=[var0, var1])
60 |
61 | pcgrad_opt = pcgrad.PCGrad(
62 | tf.train.GradientDescentOptimizer(learning_rate),
63 | denylist=denylist,
64 | allowlist=allowlist)
65 | pcgrad_col_opt = pcgrad.PCGrad(
66 | tf.train.GradientDescentOptimizer(learning_rate),
67 | use_collection_losses=True,
68 | denylist=denylist,
69 | allowlist=allowlist)
70 | losses = [loss0, loss1]
71 | pcgrad_grads = pcgrad_opt.compute_gradients(
72 | losses, var_list=[var0, var1])
73 | tf.add_to_collection(pcgrad.PCGRAD_LOSSES_COLLECTION, loss0)
74 | tf.add_to_collection(pcgrad.PCGRAD_LOSSES_COLLECTION, loss1)
75 | pcgrad_grads_collection = pcgrad_col_opt.compute_gradients(
76 | None, var_list=[var0, var1])
77 |
78 | with tf.Graph().as_default():
79 | # Shouldn't return non-slot variables from other graphs.
80 | self.assertEmpty(opt.variables())
81 |
82 | self.evaluate(tf.global_variables_initializer())
83 | grad_vec, pcgrad_vec, pcgrad_col_vec = self.evaluate(
84 | [opt_grads, pcgrad_grads, pcgrad_grads_collection])
85 | # Make sure that both methods take grads of the same vars.
86 | self.assertAllCloseAccordingToType(pcgrad_vec, pcgrad_col_vec)
87 |
88 | results = [{
89 | 'var': var0,
90 | 'pcgrad_vec': [0.5, -1.5],
91 | 'result': [0.9995, 2.0015]
92 | }, {
93 | 'var': var1,
94 | 'pcgrad_vec': [0.5, 1.5],
95 | 'result': [2.9995, 3.9985]
96 | }]
97 | grad_var_idx = {0, 1}.difference(pcgrad_var_idx)
98 |
99 | self.assertAllCloseAccordingToType(
100 | grad_vec[0][0], [0.0, -1.0], atol=1e-5)
101 | self.assertAllCloseAccordingToType(
102 | grad_vec[1][0], [0.0, 1.0], atol=1e-5)
103 | pcgrad_vec_idx = 0
104 | for var_idx in pcgrad_var_idx:
105 | self.assertAllCloseAccordingToType(
106 | pcgrad_vec[pcgrad_vec_idx][0],
107 | results[var_idx]['pcgrad_vec'],
108 | atol=1e-5)
109 | pcgrad_vec_idx += 1
110 |
111 | for var_idx in grad_var_idx:
112 | self.assertAllCloseAccordingToType(
113 | pcgrad_vec[pcgrad_vec_idx][0], grad_vec[var_idx][0], atol=1e-5)
114 | pcgrad_vec_idx += 1
115 |
116 | self.evaluate(opt.apply_gradients(pcgrad_grads))
117 | self.assertAllCloseAccordingToType(
118 | self.evaluate([results[idx]['var'] for idx in pcgrad_var_idx]),
119 | [results[idx]['result'] for idx in pcgrad_var_idx])
120 |
121 |
122 | if __name__ == '__main__':
123 | tf.test.main()
124 |
--------------------------------------------------------------------------------
/research/qtopt/pcgrad_tpu_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Tensor2Robot Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for tensor2robot.research.qtopt.pcgrad."""
17 |
18 | from tensor2robot.research.qtopt import pcgrad
19 | import tensorflow.compat.v1 as tf
20 |
21 |
22 | class PcgradTest(tf.test.TestCase):
23 |
24 | def testPCgradNetworkTPU(self):
25 | tf.reset_default_graph()
26 | tf.disable_eager_execution()
27 | learning_rate = lambda: 0.001
28 | def pcgrad_computation():
29 | x = tf.constant(1., shape=[64, 472, 472, 3])
30 | layers = [
31 | tf.keras.layers.Conv2D(filters=64, kernel_size=3),
32 | tf.keras.layers.Conv2D(filters=32, kernel_size=3, strides=(2, 2)),
33 | tf.keras.layers.Conv2D(filters=32, kernel_size=3, strides=(2, 2)),
34 | tf.keras.layers.Conv2D(filters=32, kernel_size=3, strides=(2, 2)),
35 | tf.keras.layers.Conv2D(filters=32, kernel_size=3, strides=(2, 2)),
36 | ]
37 | y = x
38 | for layer in layers:
39 | y = layer(y)
40 | n_tasks = 10
41 | task_loss_0 = tf.reduce_sum(y)
42 | task_losses = [task_loss_0 * (1. + (n / 10.)) for n in range(n_tasks)]
43 |
44 | pcgrad_opt = pcgrad.PCGrad(
45 | tf.train.GradientDescentOptimizer(learning_rate))
46 | pcgrad_grads_and_vars = pcgrad_opt.compute_gradients(
47 | task_losses, var_list=tf.trainable_variables())
48 | return pcgrad_opt.apply_gradients(pcgrad_grads_and_vars)
49 |
50 | tpu_computation = tf.compat.v1.tpu.batch_parallel(pcgrad_computation,
51 | num_shards=2)
52 | self.evaluate(tf.compat.v1.tpu.initialize_system())
53 | self.evaluate(tf.compat.v1.global_variables_initializer())
54 | self.evaluate(tpu_computation)
55 | self.evaluate(tf.compat.v1.tpu.shutdown_system())
56 |
57 | if __name__ == "__main__":
58 | tf.test.main()
59 |
--------------------------------------------------------------------------------
/research/qtopt/t2r_models_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Tensor2Robot Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests that we can run a few training steps (mock data) on T2R model."""
17 |
18 | from absl import flags
19 | from absl.testing import parameterized
20 | from tensor2robot.research.qtopt import t2r_models
21 | from tensor2robot.utils.t2r_test_fixture import T2RModelFixture
22 | import tensorflow.compat.v1 as tf
23 | from tensorflow.compat.v1 import estimator as tf_estimator
24 |
25 | FLAGS = flags.FLAGS
26 | TRAIN = tf_estimator.ModeKeys.TRAIN
27 | MODEL_NAME = 'Grasping44E2EOpenCloseTerminateGripperStatusHeightToBottom'
28 |
29 |
30 | class GraspPredictT2RTest(parameterized.TestCase):
31 |
32 | def setUp(self):
33 | super(GraspPredictT2RTest, self).setUp()
34 | self._fixture = T2RModelFixture(
35 | test_case=self,
36 | use_tpu=False,
37 | )
38 |
39 | @parameterized.parameters(
40 | (MODEL_NAME,))
41 | def test_random_train(self, model_name):
42 | self._fixture.random_train(
43 | module_name=t2r_models, model_name=model_name)
44 |
45 | @parameterized.parameters(
46 | (MODEL_NAME,))
47 | def test_inference(self, model_name):
48 | result = self._fixture.random_predict(
49 | t2r_models, model_name, action_batch_size=64)
50 | self.assertIsNotNone(result)
51 | self.assertDictContainsSubset({'global_step': 0}, result)
52 |
53 |
54 | if __name__ == '__main__':
55 | tf.test.main()
56 |
--------------------------------------------------------------------------------
/research/vrgripper/README.md:
--------------------------------------------------------------------------------
1 | # VRGripper Environment Models
2 |
3 | Contains code for training models in the VRGripper environment from
4 | "Watch, Try, Learn: Meta-Learning from Demonstrations and Rewards."
5 |
6 | Includes the models used in the Watch, Try, Learn (WTL) gripping experiments.
7 |
8 | Links
9 |
10 | - [Project Website](https://sites.google.com/corp/view/watch-try-learn-project)
11 | - [Paper Preprint](https://arxiv.org/abs/1906.03352)
12 |
13 | ## Authors
14 |
15 | Allan Zhou1, Eric Jang1, Daniel Kappler2,
16 | Alex Herzog2, Mohi Khansari2,
17 | Paul Wohlhart2, Yunfei Bai2,
18 | Mrinal Kalakrishnan2, Sergey Levine1,3,
19 | Chelsea Finn1
20 |
21 | 1 Google Brain, 2X, 3UC Berkeley
22 |
23 | ## Training the WTL gripping experiment models.
24 |
25 | WTL experiment models are located in `vrgripper_env_wtl_models.py`.
26 | Data is not included in this repository, so you will have to provide your own
27 | training/eval datasets. Training is configured by the following gin configs:
28 |
29 | * `configs/run_train_wtl_statespace_trial.gin`: Train a trial policy on
30 | state-space observations.
31 | * `configs/run_train_wtl_statespace_retrial.gin`: Train a retrial policy
32 | on state-space observations.
33 | * `configs/run_train_wtl_vision_trial.gin`: Train a trial policy on image
34 | observations.
35 | * `configs/run_train_wtl_vision_retrial.gin`: Train a retrial policy on
36 | image observations.
37 |
--------------------------------------------------------------------------------
/research/vrgripper/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Tensor2Robot Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
--------------------------------------------------------------------------------
/research/vrgripper/configs/common_imports.gin:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Tensor2Robot Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import tensor2robot.utils.continuous_collect_eval
17 | import tensor2robot.input_generators.default_input_generator
18 | import tensor2robot.meta_learning.meta_policies
19 | import tensor2robot.meta_learning.run_meta_env
20 | import tensor2robot.models.abstract_model
21 | import tensor2robot.predictors.checkpoint_predictor
22 | import tensor2robot.research.vrgripper.episode_to_transitions
23 | import tensor2robot.research.vrgripper.vrgripper_env_models
24 | import tensor2robot.research.vrgripper.maf
25 | import tensor2robot.research.vrgripper.mse_decoder
26 | import tensor2robot.research.vrgripper.discrete
27 | import tensor2robot.research.vrgripper.vrgripper_env_meta_models
28 | import tensor2robot.research.vrgripper.vrgripper_env_wtl_models
29 | import tensor2robot.utils.train_eval
30 |
--------------------------------------------------------------------------------
/research/vrgripper/configs/run_train_wtl_statespace_retrial.gin:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Tensor2Robot Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | # Train a state-space-based WTL retrial policy.
17 |
18 | import tensor2robot.input_generators.default_input_generator
19 |
20 | include 'tensor2robot/research/vrgripper/configs/common_imports.gin'
21 |
22 | # Input Generation.
23 | TRAIN_DATA = ''
24 | EVAL_DATA = ''
25 |
26 | TRAIN_BATCH_SIZE = 8
27 | EVAL_BATCH_SIZE = 8
28 |
29 | TRAIN_INPUT_GENERATOR = @train_input_generator/DefaultRecordInputGenerator()
30 | train_input_generator/DefaultRecordInputGenerator.file_patterns = %TRAIN_DATA
31 | train_input_generator/DefaultRecordInputGenerator.batch_size = %TRAIN_BATCH_SIZE
32 |
33 | EVAL_INPUT_GENERATOR = @eval_input_generator/DefaultRecordInputGenerator()
34 | eval_input_generator/DefaultRecordInputGenerator.file_patterns = %EVAL_DATA
35 | eval_input_generator/DefaultRecordInputGenerator.batch_size = %EVAL_BATCH_SIZE
36 |
37 | #######################################
38 | # MODEL
39 | #######################################
40 |
41 | train_eval_model.t2r_model = @retrial/VRGripperEnvSimpleTrialModel()
42 | retrial/VRGripperEnvSimpleTrialModel.use_sync_replicas_optimizer = True
43 | retrial/VRGripperEnvSimpleTrialModel.retrial = True
44 | retrial/VRGripperEnvSimpleTrialModel.num_condition_samples_per_task = 2
45 | BuildImageFeaturesToPoseModel.bias_transform_size = 0
46 |
47 | reduce_temporal_embeddings.conv1d_layers = (64, 32)
48 | reduce_temporal_embeddings.fc_hidden_layers = (100,)
49 | default_create_optimizer_fn.learning_rate = 1e-3
50 |
51 | train_eval_model.max_train_steps = 100
52 | train_eval_model.eval_steps = 1000
53 | train_eval_model.input_generator_train = %TRAIN_INPUT_GENERATOR
54 | train_eval_model.input_generator_eval = %EVAL_INPUT_GENERATOR
55 | train_eval_model.create_exporters_fn = @create_default_exporters
56 |
57 |
--------------------------------------------------------------------------------
/research/vrgripper/configs/run_train_wtl_statespace_trial.gin:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Tensor2Robot Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | # Train a state-space-based WTL trial policy.
17 |
18 | import tensor2robot.input_generators.default_input_generator
19 |
20 | include 'tensor2robot/research/vrgripper/configs/common_imports.gin'
21 |
22 | # Input Generation.
23 | TRAIN_DATA = ''
24 | EVAL_DATA = ''
25 |
26 | TRAIN_BATCH_SIZE = 8
27 | EVAL_BATCH_SIZE = 8
28 |
29 | TRAIN_INPUT_GENERATOR = @train_input_generator/DefaultRecordInputGenerator()
30 | train_input_generator/DefaultRecordInputGenerator.file_patterns = %TRAIN_DATA
31 | train_input_generator/DefaultRecordInputGenerator.batch_size = %TRAIN_BATCH_SIZE
32 |
33 | EVAL_INPUT_GENERATOR = @eval_input_generator/DefaultRecordInputGenerator()
34 | eval_input_generator/DefaultRecordInputGenerator.file_patterns = %EVAL_DATA
35 | eval_input_generator/DefaultRecordInputGenerator.batch_size = %EVAL_BATCH_SIZE
36 |
37 | #######################################
38 | # MODEL
39 | #######################################
40 |
41 | train_eval_model.t2r_model = @VRGripperEnvSimpleTrialModel()
42 | VRGripperEnvSimpleTrialModel.num_mixture_components = 10
43 | VRGripperEnvSimpleTrialModel.use_sync_replicas_optimizer = True
44 | VRGripperEnvSimpleTrialModel.embed_type = 'temporal'
45 | BuildImageFeaturesToPoseModel.bias_transform_size = 0
46 |
47 | reduce_temporal_embeddings.conv1d_layers = (64, 32)
48 | reduce_temporal_embeddings.fc_hidden_layers = (100,)
49 | default_create_optimizer_fn.learning_rate = 1e-3
50 |
51 | train_eval_model.max_train_steps = 100
52 | train_eval_model.eval_steps = 1000
53 | train_eval_model.input_generator_train = %TRAIN_INPUT_GENERATOR
54 | train_eval_model.input_generator_eval = %EVAL_INPUT_GENERATOR
55 | train_eval_model.create_exporters_fn = @create_default_exporters
56 |
57 |
--------------------------------------------------------------------------------
/research/vrgripper/configs/run_train_wtl_vision_retrial.gin:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Tensor2Robot Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | # Train a vision-based WTL retrial policy.
17 |
18 | import tensor2robot.input_generators.default_input_generator
19 |
20 | include 'tensor2robot/research/vrgripper/configs/common_imports.gin'
21 |
22 | # Input Generation.
23 | TRAIN_DATA = ''
24 | EVAL_DATA = ''
25 |
26 | TRAIN_BATCH_SIZE = 8
27 | EVAL_BATCH_SIZE = 8
28 |
29 | TRAIN_INPUT_GENERATOR = @train_input_generator/DefaultRecordInputGenerator()
30 | train_input_generator/DefaultRecordInputGenerator.file_patterns = %TRAIN_DATA
31 | train_input_generator/DefaultRecordInputGenerator.batch_size = %TRAIN_BATCH_SIZE
32 |
33 | EVAL_INPUT_GENERATOR = @eval_input_generator/DefaultRecordInputGenerator()
34 | eval_input_generator/DefaultRecordInputGenerator.file_patterns = %EVAL_DATA
35 | eval_input_generator/DefaultRecordInputGenerator.batch_size = %EVAL_BATCH_SIZE
36 |
37 | #######################################
38 | # MODEL
39 | #######################################
40 |
41 | train_eval_model.t2r_model = @retrial/VRGripperEnvVisionTrialModel()
42 | retrial/VRGripperEnvVisionTrialModel.use_sync_replicas_optimizer = True
43 | retrial/VRGripperEnvVisionTrialModel.num_condition_samples_per_task = 2
44 | BuildImageFeaturesToPoseModel.bias_transform_size = 0
45 |
46 | reduce_temporal_embeddings.conv1d_layers = (64, 32)
47 | reduce_temporal_embeddings.fc_hidden_layers = (100,)
48 | default_create_optimizer_fn.learning_rate = 1e-3
49 |
50 | train_eval_model.max_train_steps = 100
51 | train_eval_model.eval_steps = 1000
52 | train_eval_model.input_generator_train = %TRAIN_INPUT_GENERATOR
53 | train_eval_model.input_generator_eval = %EVAL_INPUT_GENERATOR
54 | train_eval_model.create_exporters_fn = @create_default_exporters
55 |
56 |
--------------------------------------------------------------------------------
/research/vrgripper/configs/run_train_wtl_vision_trial.gin:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Tensor2Robot Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | # Train a vision-based WTL trial policy.
17 |
18 | import tensor2robot.input_generators.default_input_generator
19 |
20 | include 'tensor2robot/research/vrgripper/configs/common_imports.gin'
21 |
22 | # Input Generation.
23 | TRAIN_DATA = ''
24 | EVAL_DATA = ''
25 |
26 | TRAIN_BATCH_SIZE = 8
27 | EVAL_BATCH_SIZE = 8
28 |
29 | TRAIN_INPUT_GENERATOR = @train_input_generator/DefaultRecordInputGenerator()
30 | train_input_generator/DefaultRecordInputGenerator.file_patterns = %TRAIN_DATA
31 | train_input_generator/DefaultRecordInputGenerator.batch_size = %TRAIN_BATCH_SIZE
32 |
33 | EVAL_INPUT_GENERATOR = @eval_input_generator/DefaultRecordInputGenerator()
34 | eval_input_generator/DefaultRecordInputGenerator.file_patterns = %EVAL_DATA
35 | eval_input_generator/DefaultRecordInputGenerator.batch_size = %EVAL_BATCH_SIZE
36 |
37 | #######################################
38 | # MODEL
39 | #######################################
40 |
41 | train_eval_model.t2r_model = @VRGripperEnvVisionTrialModel()
42 | VRGripperEnvVisionTrialModel.num_mixture_components = 20
43 | VRGripperEnvVisionTrialModel.use_sync_replicas_optimizer = True
44 | BuildImageFeaturesToPoseModel.bias_transform_size = 0
45 |
46 | embed_condition_images.fc_layers = (100, 64)
47 | reduce_temporal_embeddings.conv1d_layers = (32,)
48 | reduce_temporal_embeddings.fc_hidden_layers = (100,)
49 | default_create_optimizer_fn.learning_rate = 5e-4
50 |
51 | train_eval_model.max_train_steps = 5000
52 | train_eval_model.eval_steps = 1000
53 | train_eval_model.input_generator_train = %TRAIN_INPUT_GENERATOR
54 | train_eval_model.input_generator_eval = %EVAL_INPUT_GENERATOR
55 | train_eval_model.create_exporters_fn = @create_default_exporters
56 |
57 |
--------------------------------------------------------------------------------
/research/vrgripper/episode_to_transitions.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Tensor2Robot Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Functions for converting env episode data to tfrecords of transitions."""
17 |
18 | import collections
19 |
20 | import gin
21 | import numpy as np
22 | from PIL import Image
23 | import six
24 | from six.moves import range
25 | import tensorflow.compat.v1 as tf
26 |
27 |
28 |
29 | _bytes_feature = (
30 | lambda v: tf.train.Feature(bytes_list=tf.train.BytesList(value=v)))
31 | _int64_feature = (
32 | lambda v: tf.train.Feature(int64_list=tf.train.Int64List(value=v)))
33 | _float_feature = (
34 | lambda v: tf.train.Feature(float_list=tf.train.FloatList(value=v)))
35 |
36 | _IMAGE_KEY_PREFIX = 'image'
37 |
38 |
39 | @gin.configurable
40 | def make_fixed_length(
41 | input_list,
42 | fixed_length,
43 | always_include_endpoints=True,
44 | randomized=True):
45 | """Create a fixed length list by sampling entries from input_list.
46 |
47 | Args:
48 | input_list: The original list we sample entries from.
49 | fixed_length: An integer: the desired length of the output list.
50 | always_include_endpoints: If True, always include the first and last entries
51 | of input_list in the output.
52 | randomized: If True, select entries from input_list by random sampling with
53 | replacement. If False, select entries from input_list deterministically.
54 | Returns:
55 | A list of length fixed_length containing sampled entries of input_list.
56 | """
57 | original_length = len(input_list)
58 | if original_length <= 2:
59 | return None
60 | if not randomized:
61 | indices = np.sort(np.mod(np.arange(fixed_length), original_length))
62 | return [input_list[i] for i in indices]
63 | if always_include_endpoints:
64 | # Always include entries 0 and N-1.
65 | endpoint_indices = np.array([0, original_length - 1])
66 | # The remaining (fixed_length-2) frames are sampled with replacement
67 | # from entries [1, N-1) of input_list.
68 | other_indices = 1 + np.random.choice(
69 | original_length - 2, fixed_length-2, replace=True)
70 | indices = np.concatenate(
71 | (endpoint_indices, other_indices),
72 | axis=0)
73 | else:
74 | indices = np.random.choice(
75 | original_length, fixed_length, replace=True)
76 | indices = np.sort(indices)
77 | return [input_list[i] for i in indices]
78 |
79 |
80 |
81 |
82 | @gin.configurable
83 | def episode_to_transitions_reacher(episode_data, is_demo=False):
84 | """Converts reacher env data to transition examples."""
85 | transitions = []
86 | for i, transition in enumerate(episode_data):
87 | del i
88 | feature_dict = {}
89 | (obs_t, action, reward, obs_tp1, done, debug) = transition
90 | del debug
91 | feature_dict['pose_t'] = _float_feature(obs_t)
92 | feature_dict['pose_tp1'] = _float_feature(obs_tp1)
93 | feature_dict['action'] = _float_feature(action)
94 | feature_dict['reward'] = _float_feature([reward])
95 | feature_dict['done'] = _int64_feature([int(done)])
96 | feature_dict['is_demo'] = _int64_feature([int(is_demo)])
97 | example = tf.train.Example(features=tf.train.Features(feature=feature_dict))
98 | transitions.append(example)
99 | return transitions
100 |
101 |
102 | @gin.configurable
103 | def episode_to_transitions_metareacher(episode_data):
104 | """Converts metareacher env data to transition examples."""
105 | context_features = {}
106 | feature_lists = collections.defaultdict(list)
107 |
108 | context_features['is_demo'] = _int64_feature(
109 | [int(episode_data[0][-1]['is_demo'])])
110 | context_features['target_idx'] = _int64_feature(
111 | [episode_data[0][-1]['target_idx']])
112 |
113 | for i, transition in enumerate(episode_data):
114 | del i
115 | (obs_t, action, reward, obs_tp1, done, debug) = transition
116 | del debug
117 | feature_lists['pose_t'].append(_float_feature(obs_t))
118 | feature_lists['pose_tp1'].append(_float_feature(obs_tp1))
119 | feature_lists['action'].append(_float_feature(action))
120 | feature_lists['reward'].append(_float_feature([reward]))
121 | feature_lists['done'].append(_int64_feature([int(done)]))
122 |
123 | tf_feature_lists = {}
124 | for key in feature_lists:
125 | tf_feature_lists[key] = tf.train.FeatureList(feature=feature_lists[key])
126 |
127 | return [tf.train.SequenceExample(
128 | context=tf.train.Features(feature=context_features),
129 | feature_lists=tf.train.FeatureLists(feature_list=tf_feature_lists))]
130 |
131 |
132 |
--------------------------------------------------------------------------------
/research/vrgripper/episode_to_transitions_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Tensor2Robot Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for episodes_to_transitions."""
17 |
18 | from six.moves import range
19 | from tensor2robot.research.vrgripper import episode_to_transitions
20 | import tensorflow.compat.v1 as tf
21 |
22 |
23 | class EpisodeToTransitionsTest(tf.test.TestCase):
24 |
25 | def test_make_fixed_length(self):
26 | fixed_length = 10
27 | dummy_feature_dict_lists = [
28 | [{'dummy_feature': i} for i in range(5)],
29 | [{'dummy_feature': i} for i in range(20)],
30 | ]
31 |
32 | for feature_dict_list in dummy_feature_dict_lists:
33 | filtered_feature_dict_list = episode_to_transitions.make_fixed_length(
34 | feature_dict_list,
35 | fixed_length=fixed_length,
36 | always_include_endpoints=True)
37 | self.assertLen(filtered_feature_dict_list, fixed_length)
38 |
39 | # The first and last entries of the original list should be present in
40 | # the filtered list.
41 | self.assertEqual(feature_dict_list[0], filtered_feature_dict_list[0])
42 | self.assertEqual(feature_dict_list[-1], filtered_feature_dict_list[-1])
43 |
--------------------------------------------------------------------------------
/research/vrgripper/maf.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Tensor2Robot Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Conditional density estimation with masked autoregressive flow.
17 | """
18 |
19 | import gin
20 | import numpy as np
21 | from six.moves import range
22 | import tensorflow.compat.v1 as tf
23 | import tensorflow_probability as tfp
24 | from tensorflow.contrib import slim
25 | tfd = tfp.distributions
26 | tfb = tfp.bijectors
27 |
28 |
29 | def init_once(x, name):
30 | """Return a variable initialized with a constant value.
31 |
32 | This is used to initialize Permutation bijectors. See [1] for more information
33 | on why returning Permute(np.random.permutation(event_size)) is unsafe.
34 |
35 | Args:
36 | x: A TF Variables initializer or constant-valued tensor.
37 | name: String name for the returned variable.
38 |
39 | Returns:
40 | Variable copy of the tensor.
41 |
42 | References:
43 |
44 | [1] https://www.tensorflow.org/probability/api_docs/python/
45 | tfp/bijectors/Permute
46 | """
47 | return tf.get_variable(name, initializer=x, trainable=False)
48 |
49 |
50 | def maf_bijector(event_size, num_flows, hidden_layers):
51 | """Construct a chain of MAF flows into a single bijector."""
52 | bijectors = []
53 | for i in range(num_flows):
54 | bijectors.append(tfb.MaskedAutoregressiveFlow(
55 | shift_and_log_scale_fn=tfb.masked_autoregressive_default_template(
56 | hidden_layers=hidden_layers)))
57 | bijectors.append(
58 | tfb.Permute(
59 | permutation=init_once(
60 | np.random.permutation(event_size).astype('int32'),
61 | name='permute_%d' % i)))
62 | # Chain the bijectors, leaving out the last permutation bijector.
63 | return tfb.Chain(list(reversed(bijectors[:-1])))
64 |
65 |
66 | @gin.configurable
67 | class MAFDecoder(object):
68 | """Decoder using a Masked Autoregressive Flow.
69 |
70 | Conditioning is specified by warping the centers of the base isotropic normal
71 | distributions, e.g. MAF(N(mu, 1)), where mu is the incoming conditioning
72 | parameters. This allows us to avoid having to incorporate conditioning into
73 | the actual bijector.
74 | """
75 |
76 | def __init__(self, num_flows=1, hidden_layers=None):
77 | self._num_flows = num_flows
78 | self._hidden_layers = hidden_layers or [512, 512]
79 |
80 | def __call__(self, params, output_size):
81 | mus = slim.fully_connected(
82 | params, output_size, activation_fn=None, scope='maf_mus')
83 | base_dist = tfd.MultivariateNormalDiag(
84 | loc=mus, scale_diag=tf.ones_like(mus))
85 | event_shape = base_dist.event_shape.as_list()
86 | if np.any([event_shape[0] > l for l in self._hidden_layers]):
87 | raise ValueError(
88 | 'MAF hidden layers have to be at least as wide as event size.')
89 | self._maf = tfd.TransformedDistribution(
90 | distribution=base_dist,
91 | bijector=maf_bijector(
92 | event_shape[0], self._num_flows, self._hidden_layers))
93 | return self._maf.sample()
94 |
95 | def loss(self, labels):
96 | nll_local = -self._maf.log_prob(labels.action)
97 | # Average across batch, sequence.
98 | return tf.reduce_mean(nll_local)
99 |
--------------------------------------------------------------------------------
/research/vrgripper/mse_decoder.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Tensor2Robot Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Abstract decoder and MSE decoder.
17 | """
18 |
19 | import gin
20 |
21 | import tensorflow.compat.v1 as tf
22 | from tensorflow.contrib import slim
23 |
24 |
25 | @gin.configurable
26 | class MSEDecoder(object):
27 | """Default MSE decoder."""
28 |
29 | def __call__(self, params, output_size):
30 | self._predictions = slim.fully_connected(
31 | params, output_size, activation_fn=None, scope='pose')
32 | return self._predictions
33 |
34 | def loss(self, labels):
35 | return tf.losses.mean_squared_error(labels=labels.action,
36 | predictions=self._predictions)
37 |
--------------------------------------------------------------------------------
/test_data/mock_exported_savedmodel/assets.extra/t2r_assets.pbtxt:
--------------------------------------------------------------------------------
1 | feature_spec {
2 | key_value {
3 | key: "x"
4 | value {
5 | shape: 3
6 | dtype: 1
7 | name: "measured_position"
8 | is_optional: false
9 | is_extracted: false
10 | }
11 | }
12 | }
13 | label_spec {
14 | key_value {
15 | key: "y"
16 | value {
17 | shape: 1
18 | dtype: 1
19 | name: "valid_position"
20 | is_optional: false
21 | is_extracted: false
22 | }
23 | }
24 | }
25 |
--------------------------------------------------------------------------------
/test_data/mock_exported_savedmodel/saved_model.pb:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google-research/tensor2robot/034f49c435168151930c6a07abea48b0df58b74b/test_data/mock_exported_savedmodel/saved_model.pb
--------------------------------------------------------------------------------
/test_data/mock_exported_savedmodel/variables/variables.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google-research/tensor2robot/034f49c435168151930c6a07abea48b0df58b74b/test_data/mock_exported_savedmodel/variables/variables.data-00000-of-00001
--------------------------------------------------------------------------------
/test_data/mock_exported_savedmodel/variables/variables.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google-research/tensor2robot/034f49c435168151930c6a07abea48b0df58b74b/test_data/mock_exported_savedmodel/variables/variables.index
--------------------------------------------------------------------------------
/test_data/pose_env_test_data.tfrecord:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google-research/tensor2robot/034f49c435168151930c6a07abea48b0df58b74b/test_data/pose_env_test_data.tfrecord
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Tensor2Robot Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
--------------------------------------------------------------------------------
/utils/continuous_collect_eval.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Tensor2Robot Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Collect/Eval a policy on the live environment."""
17 |
18 | import os
19 | import time
20 | from typing import Text
21 |
22 | import gin
23 | import gym
24 | import tensorflow.compat.v1 as tf
25 |
26 |
27 | @gin.configurable
28 | def collect_eval_loop(
29 | collect_env,
30 | eval_env,
31 | policy_class,
32 | num_collect = 2000,
33 | num_eval = 100,
34 | run_agent_fn=None,
35 | root_dir = '',
36 | continuous = False,
37 | min_collect_eval_step = 0,
38 | max_steps = 1,
39 | pre_collect_eval_fn=None,
40 | record_eval_env_video = False,
41 | init_with_random_variables = False):
42 | """Like dql_grasping.collect_eval, but can run continuously.
43 |
44 | Args:
45 | collect_env: (gym.Env) Gym environment to collect data from (and train the
46 | policy on).
47 | eval_env: (gym.Env) Gym environment to evaluate the policy on. Can be
48 | another instance of collect_env, or a different environment if one
49 | wishes to evaluate generalization capability. The only constraint is
50 | that the action and observation spaces have to be equivalent. If None,
51 | eval_env is not evaluated.
52 | policy_class: Policy class that we want to train.
53 | num_collect: (int) Number of episodes to collect from collect_env.
54 | num_eval: (int) Number of episodes to evaluate from eval_env.
55 | run_agent_fn: (Optional) Python function that executes the interaction of
56 | the policy with the environment. Defaults to run_env.run_env.
57 | root_dir: Base directory where collect data and eval data are
58 | stored.
59 | continuous: If True, loop and wait for new ckpt to load a policy from
60 | (up until the ckpt number exceeds max_steps).
61 | min_collect_eval_step: An integer which specifies the lowest ckpt step
62 | number that we will collect/evaluate.
63 | max_steps: (Ignored unless continuous=True). An integer controlling when
64 | to stop looping: once we see a policy with global_step > max_steps, we
65 | stop.
66 | pre_collect_eval_fn: This callable will be run prior to the start of this
67 | collect/eval loop. Example use: pushing a record dataset into a replay
68 | buffer at the start of training.
69 | record_eval_env_video: Whether to enable video recording in our eval env.
70 | init_with_random_variables: If True, initializes policy model with random
71 | variables instead (useful for unit testing).
72 | """
73 | if pre_collect_eval_fn:
74 | pre_collect_eval_fn()
75 |
76 | collect_dir = os.path.join(root_dir, 'policy_collect')
77 | eval_dir = os.path.join(root_dir, 'eval')
78 |
79 | policy = policy_class()
80 | prev_global_step = -1
81 | while True:
82 | global_step = None
83 | if hasattr(policy, 'restore'):
84 | if init_with_random_variables:
85 | policy.init_randomly()
86 | else:
87 | policy.restore()
88 | global_step = policy.global_step
89 |
90 | if global_step is None or global_step < min_collect_eval_step \
91 | or global_step <= prev_global_step:
92 | time.sleep(10)
93 | continue
94 |
95 | if collect_env:
96 | run_agent_fn(collect_env, policy=policy, num_episodes=num_collect,
97 | root_dir=collect_dir, global_step=global_step, tag='collect')
98 | if eval_env:
99 | if record_eval_env_video and hasattr(eval_env, 'set_video_output_dir'):
100 | eval_env.set_video_output_dir(
101 | os.path.join(root_dir, 'videos', str(global_step)))
102 | run_agent_fn(eval_env, policy=policy, num_episodes=num_eval,
103 | root_dir=eval_dir, global_step=global_step, tag='eval')
104 | if not continuous or global_step >= max_steps:
105 | tf.logging.info('Completed collect/eval on final ckpt.')
106 | break
107 |
108 | prev_global_step = global_step
109 |
--------------------------------------------------------------------------------
/utils/continuous_collect_eval_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Tensor2Robot Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for run_collect_eval."""
17 |
18 | import os
19 | from absl import flags
20 | from absl.testing import absltest
21 | from absl.testing import parameterized
22 | import gin
23 | from tensor2robot.research.pose_env import pose_env
24 | from tensor2robot.utils import continuous_collect_eval
25 | import tensorflow.compat.v1 as tf
26 | FLAGS = flags.FLAGS
27 |
28 |
29 | class PoseEnvModelsTest(parameterized.TestCase):
30 |
31 | @parameterized.parameters(
32 | (pose_env.PoseEnvRandomPolicy,),
33 | )
34 | def test_run_pose_env_collect(self, demo_policy_cls):
35 | urdf_root = pose_env.get_pybullet_urdf_root()
36 |
37 | config_dir = 'research/pose_env/configs'
38 | gin_config = os.path.join(
39 | FLAGS.test_srcdir, config_dir, 'run_random_collect.gin')
40 | gin.parse_config_file(gin_config)
41 | tmp_dir = absltest.get_default_test_tmpdir()
42 | root_dir = os.path.join(tmp_dir, str(demo_policy_cls))
43 | gin.bind_parameter('PoseToyEnv.urdf_root', urdf_root)
44 | gin.bind_parameter(
45 | 'collect_eval_loop.root_dir', root_dir)
46 | gin.bind_parameter('run_meta_env.num_tasks', 2)
47 | gin.bind_parameter('run_meta_env.num_episodes_per_adaptation', 1)
48 | gin.bind_parameter(
49 | 'collect_eval_loop.policy_class', demo_policy_cls)
50 | continuous_collect_eval.collect_eval_loop()
51 | output_files = tf.io.gfile.glob(os.path.join(
52 | root_dir, 'policy_collect', '*.tfrecord'))
53 | self.assertLen(output_files, 2)
54 |
55 |
56 | if __name__ == '__main__':
57 | absltest.main()
58 |
--------------------------------------------------------------------------------
/utils/convert_pkl_assets_to_proto_assets.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Tensor2Robot Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Convert existing pickle based assets to t2r_pb2 based assets."""
17 |
18 | import os
19 | from typing import Text
20 |
21 | from absl import app
22 | from absl import flags
23 |
24 | from tensor2robot.proto import t2r_pb2
25 | from tensor2robot.utils import tensorspec_utils
26 | import tensorflow.compat.v1 as tf
27 |
28 |
29 | FLAGS = flags.FLAGS
30 |
31 | flags.DEFINE_string('assets_filepath', None,
32 | 'The path to the exported savedmodel assets directory.')
33 |
34 |
35 | def convert(assets_filepath):
36 | """Converts existing asset pickle based files to t2r proto based assets."""
37 |
38 | t2r_assets = t2r_pb2.T2RAssets()
39 | input_spec_filepath = os.path.join(assets_filepath, 'input_specs.pkl')
40 | if not tf.io.gfile.exists(input_spec_filepath):
41 | raise ValueError('No file exists for {}.'.format(input_spec_filepath))
42 | feature_spec, label_spec = tensorspec_utils.load_input_spec_from_file(
43 | input_spec_filepath)
44 |
45 | t2r_assets.feature_spec.CopyFrom(feature_spec.to_proto())
46 | t2r_assets.label_spec.CopyFrom(label_spec.to_proto())
47 |
48 | global_step_filepath = os.path.join(assets_filepath, 'global_step.pkl')
49 | if tf.io.gfile.exists(global_step_filepath):
50 | global_step = tensorspec_utils.load_input_spec_from_file(
51 | global_step_filepath)
52 | t2r_assets.global_step = global_step
53 |
54 | t2r_assets_filepath = os.path.join(assets_filepath,
55 | tensorspec_utils.T2R_ASSETS_FILENAME)
56 | tensorspec_utils.write_t2r_assets_to_file(t2r_assets, t2r_assets_filepath)
57 |
58 |
59 | def main(unused_argv):
60 | flags.mark_flag_as_required('assets_filepath')
61 | convert(FLAGS.assets_filepath)
62 |
63 |
64 | if __name__ == '__main__':
65 | app.run(main)
66 |
--------------------------------------------------------------------------------
/utils/global_step_functions.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Tensor2Robot Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Gin configurable functions returning tf.Tensors based on the global_step.
17 | """
18 |
19 | from typing import Optional, Sequence, Text, Union
20 |
21 | import gin
22 | import numpy as np
23 | import tensorflow.compat.v1 as tf
24 |
25 |
26 | @gin.configurable
27 | def piecewise_linear(boundaries,
28 | values,
29 | name = None):
30 | """Piecewise linear function assuming given values at given boundaries.
31 |
32 | Args:
33 | boundaries: A list of `Tensor`s or `int`s or `float`s with strictly
34 | increasing entries. The first entry must be 0.
35 | values: A list of `Tensor`s or float`s or `int`s that specifies the values
36 | at the `boundaries`. It must have the same number of elements as
37 | `boundaries`, and all elements should have the same type.
38 | name: A string. Optional name of the operation. Defaults to
39 | 'PiecewiseConstant'.
40 |
41 | Returns:
42 | A 0-D Tensor. Its value is `values[0]` if `x < boundaries[0]` and
43 | `values[-1]` if `x >= boundaries[-1]. If `boundaries[i] <= x <
44 | boundaries[i+1]` it is the linear interpolation between `values[i]` and
45 | `values[i+1]`: `values[i] + (values[i+1]-values[i]) * (x-boundaries[i]) /
46 | (boundaries[i+1]-boundaries[i])`.
47 |
48 | Raises:
49 | AssertionError: if values or boundaries is empty, or not the same size.
50 | """
51 | global_step = tf.train.get_or_create_global_step()
52 | with tf.name_scope(name, 'PiecewiseLinear', [global_step, boundaries, values,
53 | name]) as name:
54 | values = tf.convert_to_tensor(values)
55 | x = tf.cast(tf.convert_to_tensor(global_step), values.dtype)
56 | boundaries = tf.cast(tf.convert_to_tensor(boundaries), values.dtype)
57 |
58 | num_boundaries = np.prod(boundaries.shape.as_list())
59 | num_values = np.prod(values.shape.as_list())
60 | assert num_boundaries > 0, 'Need more than 0 boundaries'
61 | assert num_values > 0, 'Need more than 0 values'
62 | assert num_values == num_boundaries, ('boundaries and values must be of '
63 | 'same size')
64 |
65 | # Make sure there is an unmet last boundary with the same value as the
66 | # last one that was passed in, and at least one boundary was met.
67 | values = tf.concat([values, tf.reshape(values[-1], [1])], 0)
68 | boundaries = tf.concat(
69 | [boundaries,
70 | tf.reshape(tf.maximum(x + 1, boundaries[-1]), [1])], 0)
71 |
72 | # Make sure there is at least one boundary that was already met, with the
73 | # same value as the first one that was passed in.
74 | values = tf.concat([tf.reshape(values[0], [1]), values], 0)
75 | boundaries = tf.concat(
76 | [tf.reshape(tf.minimum(x - 1, boundaries[0]), [1]), boundaries], 0)
77 |
78 | # Identify index of the last boundary that was passed.
79 | unreached_boundaries = tf.reshape(
80 | tf.where(tf.greater(boundaries, x)), [-1])
81 | unreached_boundaries = tf.concat(
82 | [unreached_boundaries, [tf.cast(tf.size(boundaries), tf.int64)]], 0)
83 | index = tf.reshape(tf.reduce_min(unreached_boundaries), [1])
84 |
85 | # Get values at last and next boundaries.
86 | value_left = tf.reshape(tf.slice(values, index - 1, [1]), [])
87 | left_boundary = tf.reshape(tf.slice(boundaries, index - 1, [1]), [])
88 | value_right = tf.reshape(tf.slice(values, index, [1]), [])
89 | right_boundary = tf.reshape(tf.slice(boundaries, index, [1]), [])
90 |
91 | # Calculate linear interpolation.
92 | a = (value_right - value_left) / (right_boundary - left_boundary)
93 | b = value_left - a * left_boundary
94 | return a * x + b
95 |
96 |
97 | @gin.configurable
98 | def exponential_decay(initial_value = 0.0001,
99 | decay_steps = 10000,
100 | decay_rate = 0.9,
101 | staircase = True):
102 | """Create a value that decays exponentially with global_step.
103 |
104 | Args:
105 | initial_value: A scalar float32 or float64 Tensor or a Python
106 | number. The initial value returned for global_step == 0.
107 | decay_steps: A scalar int32 or int64 Tensor or a Python number. Must be
108 | positive. See the decay computation in `tf.exponential_decay`.
109 | decay_rate: A scalar float32 or float64 Tensor or a Python number. The decay
110 | rate.
111 | staircase: Boolean. If True, decay the value at discrete intervals.
112 |
113 | Returns:
114 | value: Scalar tf.Tensor with the value decaying based on the global_step.
115 | """
116 | global_step = tf.train.get_or_create_global_step()
117 | value = tf.compat.v1.train.exponential_decay(
118 | learning_rate=initial_value,
119 | global_step=global_step,
120 | decay_steps=decay_steps,
121 | decay_rate=decay_rate,
122 | staircase=staircase)
123 | return value
124 |
--------------------------------------------------------------------------------
/utils/global_step_functions_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Tensor2Robot Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for tensor2robot.utils.global_step_functions."""
17 |
18 | from absl.testing import parameterized
19 | from tensor2robot.utils import global_step_functions
20 | import tensorflow.compat.v1 as tf
21 |
22 |
23 | class GlobalStepFunctionsTest(parameterized.TestCase, tf.test.TestCase):
24 |
25 | @parameterized.named_parameters({
26 | 'testcase_name': 'constant',
27 | 'boundaries': [1],
28 | 'values': [5.0],
29 | 'test_inputs': [0, 1, 10],
30 | 'expected_outputs': [5.0, 5.0, 5.0]
31 | }, {
32 | 'testcase_name': 'ramp_up',
33 | 'boundaries': [10, 20],
34 | 'values': [1.0, 11.0],
35 | 'test_inputs': [0, 10, 13, 15, 18, 20, 25],
36 | 'expected_outputs': [1.0, 1.0, 4.0, 6.0, 9.0, 11.0, 11.0]
37 | })
38 | def test_piecewise_linear(self, boundaries, values, test_inputs,
39 | expected_outputs):
40 | global_step = tf.train.get_or_create_global_step()
41 | global_step_value = tf.placeholder(tf.int64, [])
42 | set_global_step = tf.assign(global_step, global_step_value)
43 |
44 | test_function = global_step_functions.piecewise_linear(boundaries, values)
45 | with tf.Session() as sess:
46 | for x, y_expected in zip(test_inputs, expected_outputs):
47 | sess.run(set_global_step, {global_step_value: x})
48 | y = sess.run(test_function)
49 | self.assertEqual(y, y_expected)
50 |
51 | # Test the same with tensors as inputs
52 | test_function = global_step_functions.piecewise_linear(
53 | tf.convert_to_tensor(boundaries), tf.convert_to_tensor(values))
54 | with tf.Session() as sess:
55 | for x, y_expected in zip(test_inputs, expected_outputs):
56 | sess.run(set_global_step, {global_step_value: x})
57 | y = sess.run(test_function)
58 | self.assertEqual(y, y_expected)
59 |
60 |
61 | if __name__ == '__main__':
62 | tf.test.main()
63 |
--------------------------------------------------------------------------------
/utils/image.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Tensor2Robot Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Utilities for encoding images."""
17 |
18 | import io
19 | import numpy as np
20 | from PIL import Image
21 | from PIL import ImageFile
22 |
23 |
24 | def jpeg_string(image, jpeg_quality = 90):
25 | """Returns given PIL.Image instance as jpeg string.
26 |
27 | Args:
28 | image: A PIL image.
29 | jpeg_quality: The image quality, on a scale from 1 (worst) to 95 (best).
30 |
31 | Returns:
32 | a jpeg_string.
33 | """
34 | # This fix to PIL makes sure that we don't get an error when saving large
35 | # jpeg files. This is a workaround for a bug in PIL. The value should be
36 | # substantially larger than the size of the image being saved.
37 | ImageFile.MAXBLOCK = 640 * 512 * 64
38 |
39 | output_jpeg = io.BytesIO()
40 | image.save(output_jpeg, 'jpeg', quality=jpeg_quality, optimize=True)
41 | return output_jpeg.getvalue()
42 |
43 |
44 | def numpy_to_image_string(image_array, image_format='jpeg',
45 | data_type=np.uint8):
46 | image = Image.fromarray(image_array.astype(data_type))
47 | output_jpeg = io.BytesIO()
48 | image.save(output_jpeg, image_format)
49 | return output_jpeg.getvalue()
50 |
--------------------------------------------------------------------------------
/utils/writer.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Tensor2Robot Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Write episode transitions to Recordio-backed replay buffer.
17 |
18 | TODO(T2R_CONTRIBUTORS) - re-base class using tf.python_io.TFRecordWriter.
19 | """
20 |
21 | import os
22 | import gin
23 | import tensorflow.compat.v1 as tf
24 |
25 |
26 | @gin.configurable
27 | class TFRecordReplayWriter(object):
28 | """Saves transitions to a TFRecord-backed replay buffer."""
29 |
30 | def __init__(self):
31 | self.writer = None
32 |
33 | def open(self, path):
34 | if self.writer is not None:
35 | raise ValueError('Writer is already open!')
36 |
37 | path_dirname = os.path.dirname(path)
38 | if not tf.gfile.IsDirectory(path_dirname):
39 | tf.gfile.MakeDirs(path_dirname)
40 |
41 | self.writer = tf.python_io.TFRecordWriter(path + '.tfrecord')
42 |
43 | def close(self):
44 | if self.writer is None:
45 | raise ValueError('Writer is not open!')
46 | self.writer.close()
47 | self.writer = None
48 |
49 | def write(self, transitions):
50 | """Writes entire episode to a TFRecord file.
51 |
52 | Args:
53 | transitions: List of tf.Examples.
54 |
55 | Raises:
56 | ValueError: If writer has not been opened.
57 | """
58 | if self.writer is None:
59 | raise ValueError('Writer is not open!')
60 | for transition in transitions:
61 | self.writer.write(transition.SerializeToString())
62 |
--------------------------------------------------------------------------------