├── CONTRIBUTING.md ├── LICENSE ├── MANIFEST.in ├── README.md ├── docs ├── real.jpeg └── sim.jpeg ├── language_table ├── __init__.py ├── common │ ├── __init__.py │ ├── clip_tokenizer.py │ └── clip_tokenizer_test.py ├── environments │ ├── __init__.py │ ├── assets │ │ ├── blocks │ │ │ ├── blue_cube.urdf │ │ │ ├── blue_moon.urdf │ │ │ ├── blue_pentagon.urdf │ │ │ ├── blue_star.urdf │ │ │ ├── cube.obj │ │ │ ├── green_cube.urdf │ │ │ ├── green_moon.urdf │ │ │ ├── green_pentagon.urdf │ │ │ ├── green_star.urdf │ │ │ ├── moon.obj │ │ │ ├── pentagon.obj │ │ │ ├── purple_pole.urdf │ │ │ ├── red_cube.urdf │ │ │ ├── red_moon.urdf │ │ │ ├── red_pentagon.urdf │ │ │ ├── red_star.urdf │ │ │ ├── star.obj │ │ │ ├── yellow_cube.urdf │ │ │ ├── yellow_moon.urdf │ │ │ ├── yellow_pentagon.urdf │ │ │ └── yellow_star.urdf │ │ ├── plane.obj │ │ ├── suction │ │ │ ├── base.obj │ │ │ ├── cylinder.urdf │ │ │ ├── cylinder_real.urdf │ │ │ ├── head.obj │ │ │ ├── mid.obj │ │ │ ├── suction-base.urdf │ │ │ ├── suction-head-long.urdf │ │ │ ├── suction-head.urdf │ │ │ └── tip.obj │ │ └── workspace_real.urdf │ ├── blocks.py │ ├── constants.py │ ├── language_table.py │ ├── language_table_test.py │ ├── oracles │ │ ├── __init__.py │ │ ├── oriented_push_oracle.py │ │ ├── plot.py │ │ ├── push_oracle_rrt_slowdown.py │ │ └── rrt_star.py │ ├── rewards │ │ ├── __init__.py │ │ ├── block1_to_corner.py │ │ ├── block2absolutelocation.py │ │ ├── block2block.py │ │ ├── block2block_relative_location.py │ │ ├── block2relativelocation.py │ │ ├── constants.py │ │ ├── instructions.py │ │ ├── instructions_test.py │ │ ├── play.py │ │ ├── point2block.py │ │ ├── reward.py │ │ ├── separate_blocks.py │ │ ├── synonyms.py │ │ └── task_info.py │ └── utils │ │ ├── __init__.py │ │ ├── pose3d.py │ │ ├── utils_pybullet.py │ │ ├── xarm_sim_robot.py │ │ └── xarm_sim_robot_test.py ├── eval │ ├── __init__.py │ ├── main.py │ └── wrappers.py ├── examples │ ├── dataset_example.py │ ├── environment_example.py │ └── language_table_tutorial.ipynb └── train │ ├── __init__.py │ ├── bc.py │ ├── configs │ ├── language_table_resnet_sim_local.py │ └── language_table_sim_local.py │ ├── download_clip_flax_ckpt.py │ ├── input_pipeline_rlds.py │ ├── main.py │ ├── networks │ ├── dense_resnet.py │ ├── lava.py │ ├── pixel.py │ ├── resnet_v1.py │ └── resnet_v1_test.py │ ├── normalization.py │ ├── policy.py │ └── train.py ├── requirements.txt ├── requirements_static.txt └── setup.py /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | We'd love to accept your patches and contributions to this project. There are 4 | just a few small guidelines you need to follow. 5 | 6 | ## Contributor License Agreement 7 | 8 | Contributions to this project must be accompanied by a Contributor License 9 | Agreement (CLA). You (or your employer) retain the copyright to your 10 | contribution; this simply gives us permission to use and redistribute your 11 | contributions as part of the project. Head over to 12 | to see your current agreements on file or 13 | to sign a new one. 14 | 15 | You generally only need to submit a CLA once, so if you've already submitted one 16 | (even if it was for a different project), you probably don't need to do it 17 | again. 18 | 19 | ## Code Reviews 20 | 21 | All submissions, including submissions by project members, require review. We 22 | use GitHub pull requests for this purpose. Consult 23 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 24 | information on using pull requests. 25 | 26 | ## Community Guidelines 27 | 28 | This project follows 29 | [Google's Open Source Community Guidelines](https://opensource.google/conduct/). 30 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | recursive-include language_table/environments/assets * 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Language Table 2 | 3 | Language-Table is a suite of human-collected datasets and a multi-task continuous control benchmark for open vocabulary visuolinguomotor learning. 4 | 5 | ![](./docs/real.jpeg) | ![](./docs/sim.jpeg) 6 | :-------------------------:|:-------------------------:| 7 | 8 | ## Installation 9 | 10 | Installation with `pip`. `requirements.txt` contains dependencies for running 11 | the environment and simple dataset examples. 12 | 13 | ``` 14 | python3 -m venv ./ltvenv 15 | source ./ltvenv/bin/activate 16 | pip install -r ./requirements.txt 17 | export PYTHONPATH=${PWD}:$PYTHONPATH 18 | ``` 19 | 20 | For running the full train script, install using `requirements_static.txt`, as 21 | this contains pinned versions for running the full train script. 22 | 23 | ``` 24 | python3 -m venv ./ltvenvtrain 25 | source ./ltvenvtrain/bin/activate 26 | pip install --no-deps -r ./requirements_static.txt 27 | export PYTHONPATH=${PWD}:$PYTHONPATH 28 | ``` 29 | ## Quickstart 30 | 31 | ### Examples 32 | #### Scripts 33 | Run and edit the following examples: 34 | 35 | Load the environment and run 5 random steps: 36 | 37 | ``` 38 | python3 language_table/examples/environment_example.py 39 | ``` 40 | 41 | Load dataset and print first 5 elements: 42 | 43 | ``` 44 | python3 language_table/examples/dataset_example.py 45 | ``` 46 | 47 | #### Train 48 | 49 | ``` 50 | source ./ltvenvtrain/bin/activate 51 | mkdir -p /tmp/language_table_train/ 52 | PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python python language_table/train/main.py --config=./language_table/train/configs/language_table_sim_local.py --workdir=/tmp/language_table_train/ 53 | ``` 54 | 55 | #### Colab 56 | See the [colab](https://colab.research.google.com/github/google-research/language-table/blob/main/language_table/examples/language_table_tutorial.ipynb) for a more complete tutorial. 57 | 58 | ### Data 59 | ``` 60 | import tensorflow_datasets as tfds 61 | data_directory = 'gs://gresearch/robotics/language_table/0.0.1/' 62 | dataset = tfds.builder_from_directory(data_directory).as_dataset() 63 | ``` 64 | 65 | ### Environment 66 | ``` 67 | from language_table.environments import blocks 68 | from language_table.environments import language_table 69 | from language_table.environments.rewards import block2block 70 | 71 | env = language_table.LanguageTable( 72 | block_mode=blocks.LanguageTableBlockVariants.BLOCK_8, 73 | reward_factory=block2block.BlockToBlockReward, 74 | control_frequency=10.0, 75 | ) 76 | obs = env.reset() 77 | ``` 78 | 79 | ## Datasets 80 | 81 | ### Descriptions 82 | 83 | * **Real Robot** 84 | * **language_table**: 442,226 episodes of real robot relabeled data. 85 | * **Simulation (human)** 86 | * **language_table_sim**: 181,020 episodes of simulation relabeled data. 87 | * **language_table_blocktoblock_sim**: 8,000 episodes of single task "block to block" data. 88 | * **language_table_blocktoblock_4block_sim**: 8,298 episodes of single task "block to block" data in the 4 block configuration. 89 | * **Simulation (oracle)** 90 | * **language_table_blocktoblock_oracle_sim**: 200,000 episodes of single task "block to block" data from an oracle scripted agent. 91 | * **language_table_blocktoblockrelative_oracle_sim**: 200,000 episodes of single task "block-to-block-relative" data from an oracle scripted agent. 92 | * **language_table_blocktoabsolute_oracle_sim**: 200,000 episodes of single task "block to absolute location" data from an oracle scripted agent. 93 | * **language_table_blocktorelative_oracle_sim**: 200,000 episodes of single task "block to relative location" data from an oracle scripted agent. 94 | * **language_table_separate_oracle_sim**: 200,000 episodes of single task "separate blocks" data from an oracle scripted agent. 95 | 96 | ### Summary Table 97 | 98 | Dataset | Real/sim | Controlled by | Language-labeled by | # episodes 99 | --------| --------- | ------------- | ----------------- | --------: 100 | language_table | real | human | human | 442,226 101 | language_table_sim | sim | human | human | 181,020 102 | language_table_blocktoblock_sim | sim | human | scripted | 8,000 103 | language_table_blocktoblock_4block_sim | sim | human | scripted | 8,298 104 | language_table_blocktoblock_oracle_sim | sim | oracle | scripted | 200,000 105 | language_table_blocktoblockrelative_oracle_sim | sim | oracle | scripted | 200,000 106 | language_table_blocktoabsolute_oracle_sim | sim | oracle | scripted | 200,000 107 | language_table_blocktorelative_oracle_sim | sim | oracle | scripted | 200,000 108 | language_table_separate_oracle_sim | sim | oracle | scripted | 200,000 109 | 110 | ### Paths 111 | 112 | Dataset | Data Location 113 | --------| -------------- 114 | language_table | [gs://gresearch/robotics/language_table](https://console.cloud.google.com/storage/browser/gresearch/robotics/language_table/0.0.1/) 115 | language_table_sim | [gs://gresearch/robotics/language_table_sim](https://console.cloud.google.com/storage/browser/gresearch/robotics/language_table_sim/0.0.1/) 116 | language_table_blocktoblock_sim | [gs://gresearch/robotics/language_table_blocktoblock_sim](https://console.cloud.google.com/storage/browser/gresearch/robotics/language_table_blocktoblock_sim/0.0.1/) 117 | language_table_blocktoblock_4block_sim | [gs://gresearch/robotics/language_table_blocktoblock_4block_sim](https://console.cloud.google.com/storage/browser/gresearch/robotics/language_table_blocktoblock_4block_sim/0.0.1/) 118 | language_table_blocktoblock_oracle_sim | [gs://gresearch/robotics/language_table_blocktoblock_oracle_sim](https://console.cloud.google.com/storage/browser/gresearch/robotics/language_table_blocktoblock_oracle_sim/0.0.1/) 119 | language_table_blocktoblockrelative_oracle_sim | [gs://gresearch/robotics/language_table_blocktoblockrelative_oracle_sim](https://console.cloud.google.com/storage/browser/gresearch/robotics/language_table_blocktoblockrelative_oracle_sim/0.0.1/) 120 | language_table_blocktoabsolute_oracle_sim | [gs://gresearch/robotics/language_table_blocktoabsolute_oracle_sim](https://console.cloud.google.com/storage/browser/gresearch/robotics/language_table_blocktoabsolute_oracle_sim/0.0.1/) 121 | language_table_blocktorelative_oracle_sim | [gs://gresearch/robotics/language_table_blocktorelative_oracle_sim](https://console.cloud.google.com/storage/browser/gresearch/robotics/language_table_blocktorelative_oracle_sim/0.0.1/) 122 | language_table_separate_oracle_sim | [gs://gresearch/robotics/language_table_separate_oracle_sim](https://console.cloud.google.com/storage/browser/gresearch/robotics/language_table_separate_oracle_sim/0.0.1/) 123 | 124 | ## Checkpoints 125 | 126 | Name | Config | Checkpoint Location 127 | -----| -------| ------------------- 128 | BC+ResNet Sim| language_table/train/configs/language_table_resnet_sim_local.py | [gs://gresearch/robotics/language_table_checkpoints/bc_resnet_sim_checkpoint_955000](https://storage.googleapis.com/gresearch/robotics/language_table_checkpoints/bc_resnet_sim_checkpoint_955000) 129 | 130 | ## Interactive Language: Talking to Robots in Real Time 131 | [Project Website](https://interactive-language.github.io/)  •  [PDF](https://arxiv.org/pdf/2210.06407.pdf) 132 | 133 | *Corey Lynch, Ayzaan Wahid, Jonathan Tompson, Tianli Ding, James Betker, Robert Baruch, Travis Armstrong, Pete Florence* 134 | 135 | **Abstract.** We present a framework for building interactive, real-time, natural language-instructable robots in the real world, and we open source related assets (dataset, environment, benchmark, and policies). Trained with behavioral cloning on a dataset of hundreds of thousands of language-annotated trajectories, a produced policy can proficiently execute an order of magnitude more commands than previous works: specifically we estimate a 93.5% success rate on a set of 87,000 unique natural language strings specifying raw end-to-end visuolinguo-motor skills in the real world. We find that the same policy is capable of being guided by a human via real-time language to address a wide range of precise long-horizon rearrangement goals, e.g. "make a smiley face out of blocks". The dataset we release comprises nearly 600,000 language-labeled trajectories, an order of magnitude larger than prior available datasets. We hope the demonstrated results and associated assets enable further advancement of helpful, capable, natural-language-interactable robots. 136 | 137 | ## Note 138 | 139 | This is not an officially supported Google product. 140 | -------------------------------------------------------------------------------- /docs/real.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-research/language-table/ee3104061f35e1261a2e28cc2a29ea84e7156661/docs/real.jpeg -------------------------------------------------------------------------------- /docs/sim.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-research/language-table/ee3104061f35e1261a2e28cc2a29ea84e7156661/docs/sim.jpeg -------------------------------------------------------------------------------- /language_table/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Language Tale Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /language_table/common/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Language Tale Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /language_table/common/clip_tokenizer.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Language Tale Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """An implementation of an in-graph TF Tokenizer. 17 | 18 | This is based on the SimpleTokenizer implementation from CLIP. 19 | 20 | Note: while this returns similar results for many strings used in 21 | language_table, there is no guarantee that this returns equivalent tokens for 22 | every text input in general. Particularly, escaped HTML input is not 23 | correctly handled. See test cases for more details. 24 | 25 | General usage is: 26 | ``` 27 | vocab_lookup = create_vocab(bpe_path=default_bpe()) 28 | tokenizer = ClipTokenizer(vocab_lookup) 29 | tokens = tokenize_text("example input text", tokenizer) 30 | ``` 31 | """ 32 | 33 | import gzip 34 | 35 | from clip.simple_tokenizer import bytes_to_unicode 36 | from clip.simple_tokenizer import default_bpe 37 | 38 | import tensorflow as tf 39 | import tensorflow_text as tf_text 40 | 41 | 42 | class ClipTokenizer(tf_text.TokenizerWithOffsets): 43 | """An in-graph TF implementation similar to SimpleTokenizer.""" 44 | 45 | def __init__( 46 | self, 47 | vocab_lookup_table, 48 | max_bytes_per_word=100, 49 | max_chars_per_token=None, 50 | token_out_type=tf.int64, 51 | unknown_token="[UNK]", 52 | split_unknown_characters=False, 53 | lower_case=False, 54 | keep_whitespace=False, 55 | normalization_form=None, 56 | preserve_unused_token=False, 57 | ): 58 | super(ClipTokenizer, self).__init__() 59 | 60 | self._re = r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""" # pylint: disable=line-too-long 61 | self._wordpiece_tokenizer = tf_text.WordpieceTokenizer( 62 | vocab_lookup_table, 63 | suffix_indicator="", 64 | max_bytes_per_word=max_bytes_per_word, 65 | max_chars_per_token=max_chars_per_token, 66 | token_out_type=token_out_type, 67 | unknown_token=unknown_token, 68 | split_unknown_characters=split_unknown_characters, 69 | ) 70 | 71 | def tokenize_with_offsets(self, text_input): 72 | # Do basic cleaning of text inputs. 73 | # the CLIP simple tokenizer does the following: 74 | # 1. text = html.unescape(html.unescape(text)) 75 | # (we skip this for tf, which means *escaped HTML inputs will not return 76 | # the same tokens as SimpleTokenizer.) 77 | # 2. text = text.strip() 78 | text_input = tf.strings.strip(text_input) 79 | # 3. text = re.sub(r'\s+', ' ', text) 80 | text_input = tf.strings.regex_replace(text_input, r"\s+", " ") 81 | # 4. text = text.strip() 82 | text_input = tf.strings.strip(text_input) 83 | # 5. text = text.lower() 84 | text_input = tf.strings.lower(text_input) 85 | 86 | tokens, begin, _ = tf_text.regex_split_with_offsets( 87 | text_input, 88 | delim_regex_pattern=self._re + r"|\s", 89 | keep_delim_regex_pattern=self._re, 90 | ) 91 | 92 | # Add the end character to each token after the regex split. 93 | num_tokens = tf.shape(tokens.values)[0] 94 | end_char = tf.tile([""], [num_tokens]) 95 | tokens_with_end = tf.strings.join([tokens, end_char]) 96 | 97 | begin = tf.cast(begin, tf.int64) 98 | 99 | # The wordpiece tokenizer has the same logic as the bpe() routine. 100 | ( 101 | wordpieces, 102 | wp_begin, 103 | wp_end, 104 | ) = self._wordpiece_tokenizer.tokenize_with_offsets(tokens_with_end) 105 | 106 | begin_expanded = tf.expand_dims(begin, axis=2) 107 | final_begin = begin_expanded + wp_begin 108 | final_end = begin_expanded + wp_end 109 | 110 | return wordpieces, final_begin, final_end 111 | 112 | def tokenize(self, text_input): 113 | tokens, _, _ = self.tokenize_with_offsets(text_input) 114 | return tokens.flat_values 115 | 116 | 117 | def create_vocab(*, bpe_path=default_bpe()): 118 | """Creates the input vocabulary table for the tokenizer.""" 119 | with tf.io.gfile.GFile(bpe_path, "rb") as f: 120 | merges = gzip.open(f).read().decode("utf-8").split("\n") 121 | merges = merges[1 : 49152 - 256 - 2 + 1] 122 | merges = [tuple(merge.split()) for merge in merges] 123 | vocab = list(bytes_to_unicode().values()) 124 | # add '##' prefix to indicate sub-word 125 | vocab = vocab + [v + "" for v in vocab] 126 | for merge in merges: 127 | vocab.append("".join(merge)) 128 | vocab.extend(["<|startoftext|>", "<|endoftext|>"]) 129 | bpe_lookup = tf.lookup.StaticVocabularyTable( 130 | num_oov_buckets=1, 131 | initializer=tf.lookup.KeyValueTensorInitializer( 132 | keys=vocab, values=tf.range(len(vocab), dtype=tf.int64) 133 | ), 134 | ) 135 | return bpe_lookup 136 | 137 | 138 | def tokenize_text(text, tokenizer, vocab_size=49408): 139 | """Tokenizes the input text given a tokenizer.""" 140 | tokens, start_idx, end_idx = tokenizer.tokenize_with_offsets(text) 141 | # flatten sub-word tokenization 142 | tokens = tokens.merge_dims(-2, -1) 143 | start_idx = start_idx.merge_dims(-2, -1) 144 | end_idx = end_idx.merge_dims(-2, -1) 145 | count = tokens.bounding_shape()[0] 146 | # pad sparse tokens tensor with start and end tokens 147 | starts = tf.cast(tf.fill([count, 1], vocab_size - 2), dtype=tf.int64) 148 | ends = tf.cast(tf.fill([count, 1], vocab_size - 1), dtype=tf.int64) 149 | tokens = tf.concat([starts, tokens, ends], axis=1) 150 | # convert sparse tensor to zero padded dense tensor 151 | tokens = tokens.to_tensor(shape=(count, 77)) 152 | return tokens 153 | -------------------------------------------------------------------------------- /language_table/common/clip_tokenizer_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Language Tale Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for clip_tokenizer.""" 17 | 18 | from clip.simple_tokenizer import default_bpe 19 | from clip.simple_tokenizer import SimpleTokenizer 20 | 21 | from language_table.common import clip_tokenizer 22 | 23 | import numpy as np 24 | import tensorflow as tf 25 | 26 | 27 | class ClipTokenizerTest(tf.test.TestCase): 28 | 29 | def test_matches_simple_tokenizer(self): 30 | vocab_lookup = clip_tokenizer.create_vocab(bpe_path=default_bpe()) 31 | tokenizer = clip_tokenizer.ClipTokenizer(vocab_lookup) 32 | 33 | simple_tokenizer = SimpleTokenizer(default_bpe()) 34 | 35 | all_instructions = [ 36 | "pull the red moon apart from the star, red pentagon, and blue moon", 37 | "hexagon cube star crescent moon block", 38 | "I must've gone, to where you're going and place where he'll go,", 39 | "push the blue hexagon to the left!!!", 40 | " a thing with lots of spaces", 41 | "a picture of an elephantcat ", 42 | " a picture of an \n elephantcat ", 43 | "A PICTURE of AN elephantCAT", 44 | 45 | # Current known failure cases. 46 | # "<html>I want to escape HTML </html>" 47 | ] 48 | 49 | for instruction in all_instructions: 50 | result = clip_tokenizer.tokenize_text(instruction, tokenizer) 51 | simple_result = _simple_tokenize(simple_tokenizer, [instruction]) 52 | np.testing.assert_equal(result, simple_result, instruction) 53 | 54 | 55 | def _simple_tokenize(tokenizer, texts, context_length = 77): 56 | sot_token = tokenizer.encoder["<|startoftext|>"] 57 | eot_token = tokenizer.encoder["<|endoftext|>"] 58 | all_tokens = [ 59 | [sot_token] + tokenizer.encode(text) + [eot_token] for text in texts 60 | ] 61 | result = np.zeros((len(all_tokens), context_length), dtype=int) 62 | 63 | for i, tokens in enumerate(all_tokens): 64 | if len(tokens) > context_length: 65 | raise RuntimeError( 66 | f"Input {texts[i]} is too long for context length {context_length}") 67 | result[i, :len(tokens)] = np.asarray(tokens) 68 | 69 | return result 70 | 71 | 72 | if __name__ == "__main__": 73 | tf.test.main() 74 | -------------------------------------------------------------------------------- /language_table/environments/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Language Tale Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /language_table/environments/assets/blocks/blue_cube.urdf: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | -------------------------------------------------------------------------------- /language_table/environments/assets/blocks/blue_moon.urdf: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | -------------------------------------------------------------------------------- /language_table/environments/assets/blocks/blue_pentagon.urdf: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | -------------------------------------------------------------------------------- /language_table/environments/assets/blocks/blue_star.urdf: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | -------------------------------------------------------------------------------- /language_table/environments/assets/blocks/green_cube.urdf: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | -------------------------------------------------------------------------------- /language_table/environments/assets/blocks/green_moon.urdf: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | -------------------------------------------------------------------------------- /language_table/environments/assets/blocks/green_pentagon.urdf: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | -------------------------------------------------------------------------------- /language_table/environments/assets/blocks/green_star.urdf: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | -------------------------------------------------------------------------------- /language_table/environments/assets/blocks/purple_pole.urdf: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | -------------------------------------------------------------------------------- /language_table/environments/assets/blocks/red_cube.urdf: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | -------------------------------------------------------------------------------- /language_table/environments/assets/blocks/red_moon.urdf: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | -------------------------------------------------------------------------------- /language_table/environments/assets/blocks/red_pentagon.urdf: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | -------------------------------------------------------------------------------- /language_table/environments/assets/blocks/red_star.urdf: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | -------------------------------------------------------------------------------- /language_table/environments/assets/blocks/yellow_cube.urdf: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | -------------------------------------------------------------------------------- /language_table/environments/assets/blocks/yellow_moon.urdf: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | -------------------------------------------------------------------------------- /language_table/environments/assets/blocks/yellow_pentagon.urdf: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | -------------------------------------------------------------------------------- /language_table/environments/assets/blocks/yellow_star.urdf: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | -------------------------------------------------------------------------------- /language_table/environments/assets/plane.obj: -------------------------------------------------------------------------------- 1 | # Blender v2.66 (sub 1) OBJ File: '' 2 | # www.blender.org 3 | mtllib plane.mtl 4 | o Plane 5 | v 15.000000 -15.000000 0.000000 6 | v 15.000000 15.000000 0.000000 7 | v -15.000000 15.000000 0.000000 8 | v -15.000000 -15.000000 0.000000 9 | 10 | vt 15.000000 0.000000 11 | vt 15.000000 15.000000 12 | vt 0.000000 15.000000 13 | vt 0.000000 0.000000 14 | 15 | usemtl Material 16 | s off 17 | f 1/1 2/2 3/3 18 | f 1/1 3/3 4/4 19 | -------------------------------------------------------------------------------- /language_table/environments/assets/suction/cylinder.urdf: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 95 | 96 | 97 | 98 | 99 | -------------------------------------------------------------------------------- /language_table/environments/assets/suction/cylinder_real.urdf: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 95 | 96 | 97 | 98 | 99 | -------------------------------------------------------------------------------- /language_table/environments/assets/suction/suction-base.urdf: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | -------------------------------------------------------------------------------- /language_table/environments/assets/suction/suction-head-long.urdf: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 98 | 99 | 100 | 101 | 102 | -------------------------------------------------------------------------------- /language_table/environments/assets/suction/suction-head.urdf: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | -------------------------------------------------------------------------------- /language_table/environments/assets/workspace_real.urdf: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 28 | 29 | 30 | -------------------------------------------------------------------------------- /language_table/environments/blocks.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Language Tale Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Defines the n choose k blocks on Language Table.""" 17 | import collections 18 | import enum 19 | import itertools 20 | 21 | import numpy as np 22 | 23 | 24 | class LanguageTableBlockVariants(enum.Enum): 25 | BLOCK_1 = 'BLOCK_1' # 1 green star. Just for debugging. 26 | BLOCK_4 = 'BLOCK_4' # The original 4 blocks. 27 | BLOCK_8 = 'BLOCK_8' # 2 of each color, 2 of each shape, 8 total. 28 | BLOCK_4_WPOLE = 'BLOCK_4_WPOLE' # original 4 blocks with purple pole as goal 29 | BLOCK_8_WPOLE = 'BLOCK_8_WPOLE' # 8 blocks with purple pole as goal 30 | N_CHOOSE_K = 'N_CHOOSE_K' # Combinatorial. 31 | 32 | 33 | BLOCK_VARIANTS = [i.value for i in LanguageTableBlockVariants] 34 | 35 | 36 | def get_all_block_subsets(mode, training): 37 | """Returns all subsets for the chosen mode.""" 38 | if mode == LanguageTableBlockVariants.BLOCK_1: 39 | return [FIXED_1_COMBINATION] 40 | if mode == LanguageTableBlockVariants.BLOCK_4: 41 | return [FIXED_4_COMBINATION] 42 | elif mode == LanguageTableBlockVariants.BLOCK_8: 43 | return [FIXED_8_COMBINATION] 44 | elif mode == LanguageTableBlockVariants.N_CHOOSE_K: 45 | if training: 46 | return TRAIN_COMBINATIONS 47 | else: 48 | return TEST_COMBINATIONS 49 | elif mode == LanguageTableBlockVariants.BLOCK_4_WPOLE: 50 | return [FIXED_4_COMBINATION_WPOLE] 51 | elif mode == LanguageTableBlockVariants.BLOCK_8_WPOLE: 52 | return [FIXED_8_COMBINATION_WPOLE] 53 | else: 54 | raise ValueError('Unsupported block mode') 55 | 56 | 57 | def get_block_set(mode): 58 | """Defines unique set of blocks by mode.""" 59 | if mode == LanguageTableBlockVariants.BLOCK_1: 60 | return FIXED_1_COMBINATION 61 | if mode == LanguageTableBlockVariants.BLOCK_4: 62 | return FIXED_4_COMBINATION 63 | elif mode == LanguageTableBlockVariants.BLOCK_8: 64 | return FIXED_8_COMBINATION 65 | elif mode == LanguageTableBlockVariants.N_CHOOSE_K: 66 | return ALL_BLOCKS 67 | else: 68 | raise ValueError('Unsupported block mode') 69 | 70 | 71 | def get_all_block_pairs(mode): 72 | """Defines all pairs of blocks. Useful for generating all instructions.""" 73 | all_blocks = get_block_set(mode) 74 | all_pairs = itertools.permutations(all_blocks, 2) 75 | return all_pairs 76 | 77 | 78 | def get_blocks_text_descriptions(mode): 79 | """Get text strings for all blocks on table by mode.""" 80 | blocks = get_block_set(mode) 81 | blocks_text = [' '.join(i.split('_')) for i in blocks] 82 | return blocks_text 83 | 84 | 85 | BLOCK_URDF_PATHS = collections.OrderedDict( 86 | # Red blocks. 87 | red_moon='third_party/py/language_table/environments/assets/blocks/red_moon.urdf', 88 | red_cube='third_party/py/language_table/environments/assets/blocks/red_cube.urdf', 89 | red_star='third_party/py/language_table/environments/assets/blocks/red_star.urdf', 90 | red_pentagon='third_party/py/language_table/environments/assets/blocks/red_pentagon.urdf', 91 | # Blue blocks. 92 | blue_moon='third_party/py/language_table/environments/assets/blocks/blue_moon.urdf', 93 | blue_cube='third_party/py/language_table/environments/assets/blocks/blue_cube.urdf', 94 | blue_star='third_party/py/language_table/environments/assets/blocks/blue_star.urdf', 95 | blue_pentagon='third_party/py/language_table/environments/assets/blocks/blue_pentagon.urdf', 96 | # Yellow blocks. 97 | yellow_moon='third_party/py/language_table/environments/assets/blocks/yellow_moon.urdf', 98 | yellow_cube='third_party/py/language_table/environments/assets/blocks/yellow_cube.urdf', 99 | yellow_star='third_party/py/language_table/environments/assets/blocks/yellow_star.urdf', 100 | yellow_pentagon='third_party/py/language_table/environments/assets/blocks/yellow_pentagon.urdf', 101 | # Green blocks. 102 | green_moon='third_party/py/language_table/environments/assets/blocks/green_moon.urdf', 103 | green_cube='third_party/py/language_table/environments/assets/blocks/green_cube.urdf', 104 | green_star='third_party/py/language_table/environments/assets/blocks/green_star.urdf', 105 | green_pentagon='third_party/py/language_table/environments/assets/blocks/green_pentagon.urdf', 106 | ) 107 | 108 | POLE_URDF_PATHS = collections.OrderedDict( 109 | # Purple Pole. 110 | purple_pole='third_party/py/language_table/environments/assets/blocks/purple_pole.urdf', 111 | ) 112 | 113 | # Use this just to define the observation space. 114 | DUMMY_START_BLOCK = list(BLOCK_URDF_PATHS.keys())[0] 115 | COLORS = ['red', 'blue', 'green', 'yellow'] 116 | SHAPES = ['moon', 'cube', 'star', 'pentagon'] 117 | ALL_BLOCKS = ['_'.join(i) for i in itertools.product(COLORS, SHAPES)] 118 | MIN_K = 4 119 | MAX_K = 10 120 | ALL_COMBINATIONS = [] 121 | for k in range(MIN_K, MAX_K+1): 122 | k_combos = list(itertools.combinations(ALL_BLOCKS, k)) 123 | ALL_COMBINATIONS.extend(k_combos) 124 | # Seeded shuffle. 125 | combo_rng = np.random.RandomState(seed=0) 126 | combo_rng.shuffle(ALL_COMBINATIONS) 127 | # Divide combinations by train / test. 128 | TRAIN_COMBINATIONS = ALL_COMBINATIONS[:int(len(ALL_COMBINATIONS)*0.9)] 129 | TEST_COMBINATIONS = ALL_COMBINATIONS[int(len(ALL_COMBINATIONS)*0.9):] 130 | 131 | # 8 total, 2 of each color, 2 of each shape. 132 | FIXED_8_COMBINATION = ( 133 | 'red_moon', 134 | 'red_pentagon', 135 | 'blue_moon', 136 | 'blue_cube', 137 | 'green_cube', 138 | 'green_star', 139 | 'yellow_star', 140 | 'yellow_pentagon') 141 | 142 | # The original "4-block" environment. 143 | FIXED_4_COMBINATION = ( 144 | 'red_moon', 145 | 'blue_cube', 146 | 'green_star', 147 | 'yellow_pentagon' 148 | ) 149 | 150 | 151 | # 8 total blocks + 1 goal purple pole, 2 of each color, 2 of each shape. 152 | FIXED_8_COMBINATION_WPOLE = ('red_moon', 'red_pentagon', 'blue_moon', 153 | 'blue_cube', 'green_cube', 'green_star', 154 | 'yellow_star', 'yellow_pentagon', 'purple_pole') 155 | 156 | # The original "4-block" environment + 1 goal purple pole. 157 | FIXED_4_COMBINATION_WPOLE = ('red_moon', 'blue_cube', 'green_star', 158 | 'yellow_pentagon', 'purple_pole') 159 | # 1-block debugging environment. 160 | FIXED_1_COMBINATION = ['green_star'] 161 | -------------------------------------------------------------------------------- /language_table/environments/constants.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Language Tale Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Shared language table constants.""" 17 | 18 | import math 19 | import numpy as np 20 | from scipy.spatial import transform 21 | 22 | PLANE_URDF_PATH = ('third_party/bullet/examples/pybullet/gym/pybullet_data/' 23 | 'plane.urdf') 24 | 25 | EFFECTOR_HEIGHT = 0.145 26 | EFFECTOR_DOWN_ROTATION = transform.Rotation.from_rotvec([0, math.pi, 0]) 27 | 28 | X_MIN = 0.15 29 | X_MAX = 0.6 30 | Y_MIN = -0.3048 31 | Y_MAX = 0.3048 32 | CENTER_X = (X_MAX - X_MIN) / 2. + X_MIN 33 | CENTER_Y = (Y_MAX - Y_MIN) / 2. + Y_MIN 34 | 35 | WORKSPACE_BOUNDS_BUFFER = 0.08 36 | 37 | BLOCK_DISTANCE_THRESHOLD = 0.0175 38 | ARM_DISTANCE_THRESHOLD = 0.06 39 | INSTRUCTION_LENGTH = 512 # max number of chars in instruction 40 | 41 | WORKSPACE_BOUNDS = np.array(((X_MIN, Y_MIN), (X_MAX, Y_MAX))) 42 | WORKSPACE_URDF_PATH = 'third_party/py/language_table/environments/assets/workspace_real.urdf' 43 | CAMERA_POSE = (0.75, 0, 0.5) 44 | CAMERA_ORIENTATION = (np.pi / 5, np.pi, -np.pi / 2) 45 | 46 | IMAGE_WIDTH = 320 47 | IMAGE_HEIGHT = 180 48 | CAMERA_INTRINSICS = ( 49 | 0.803 * IMAGE_WIDTH, # fx 50 | 0, 51 | IMAGE_WIDTH / 2., # cx 52 | 0, 53 | 0.803 * IMAGE_WIDTH, # fy 54 | IMAGE_HEIGHT / 2., # cy 55 | 0, 56 | 0, 57 | 1) 58 | 59 | # Corresponds to: 60 | # rotation = transform.Rotation.from_rotvec([0, math.pi, 0]) 61 | # translation = np.array([0.3, -0.2, 0.145]) 62 | INITIAL_JOINT_POSITIONS = np.array([ 63 | -0.5875016909413221, 0.15985553866983415, -0.4992862770497537, 64 | 0.0017427885915130214, 0.33927183830553914, -3.7249551487437524 65 | ]) 66 | -------------------------------------------------------------------------------- /language_table/environments/language_table_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Language Tale Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for language_table.""" 17 | 18 | from language_table.environments import blocks 19 | from language_table.environments import language_table 20 | from language_table.environments.rewards import block2block 21 | import numpy as np 22 | import tensorflow as tf 23 | 24 | 25 | class LanguageTableTest(tf.test.TestCase): 26 | 27 | def test_save_restore_pybullet_state(self): 28 | block_modes = [ 29 | blocks.LanguageTableBlockVariants.BLOCK_4, 30 | blocks.LanguageTableBlockVariants.BLOCK_8 31 | ] 32 | for block_mode in block_modes: 33 | env = language_table.LanguageTable( 34 | block_mode=block_mode, 35 | reward_factory=block2block.BlockToBlockReward, 36 | control_frequency=10.0, 37 | seed=0) 38 | obs0 = env.reset() 39 | actions = [ 40 | np.random.uniform(size=(2,), low=-0.03, high=0.03) for _ in range(40) 41 | ] 42 | observations = [obs0] 43 | pbstates = [env.get_pybullet_state()] 44 | 45 | # Take actions in environment, storing obs and pbstates. 46 | for act in actions: 47 | obs, _, _, _ = env.step(act) 48 | observations.append(obs) 49 | pbstates.append(env.get_pybullet_state()) 50 | 51 | # Reset env. 52 | env.reset() 53 | 54 | # Replay states into observations. 55 | reconstructed_obs = [] 56 | for pb in pbstates: 57 | env.set_pybullet_state(pb) 58 | obs = env._compute_state(request_task_update=False) 59 | reconstructed_obs.append(obs) 60 | self.assertEqual(len(observations), len(reconstructed_obs)) 61 | for i, (orig_obs, 62 | recon_obs) in enumerate(zip(observations, reconstructed_obs)): 63 | for k in orig_obs: 64 | if k == 'rgb': 65 | # pybullet rgb is sometimes flaky as the rendering is 66 | # non-deterministic. Since the individual pixel errors 67 | # are high magnitude, just make sure the average error 68 | # is low. 69 | self.assertLess( 70 | np.abs(orig_obs[k] - recon_obs[k]).mean(), 1e-6, 71 | f'Observation {k} at step {i} was not equal.') 72 | elif k == 'effector_translation': 73 | self.assertLess( 74 | np.abs(orig_obs[k] - recon_obs[k]).max(), 1e-3, 75 | f'Observation {k} at step {i} was not equal.') 76 | else: 77 | self.assertLess( 78 | np.abs(orig_obs[k] - recon_obs[k]).max(), 1e-6, 79 | f'Observation {k} at step {i} was not equal.') 80 | self.assertEqual(orig_obs[k].dtype, recon_obs[k].dtype) 81 | 82 | def test_save_restore_pybullet_state_instruction_len(self): 83 | block_modes = [ 84 | blocks.LanguageTableBlockVariants.BLOCK_4, 85 | blocks.LanguageTableBlockVariants.BLOCK_8 86 | ] 87 | for block_mode in block_modes: 88 | env = language_table.LanguageTable( 89 | block_mode=block_mode, 90 | reward_factory=block2block.BlockToBlockReward, 91 | control_frequency=10.0, 92 | seed=0) 93 | obs0 = env.reset() 94 | actions = [ 95 | np.random.uniform(size=(2,), low=-0.03, high=0.03) for _ in range(40) 96 | ] 97 | observations = [obs0] 98 | pbstates = [env.get_pybullet_state()] 99 | 100 | # Take actions in environment, storing obs and pbstates. 101 | for act in actions: 102 | obs, _, _, _ = env.step(act) 103 | observations.append(obs) 104 | 105 | pybullet_state = env.get_pybullet_state() 106 | 107 | # Modify the pybullet state to make the instruction length different. 108 | pybullet_state['instruction'] = pybullet_state['instruction'][:128] 109 | 110 | pbstates.append(pybullet_state) 111 | 112 | # Reset env. 113 | env.reset() 114 | 115 | # Replay states into observations. 116 | reconstructed_obs = [] 117 | for pb in pbstates: 118 | env.set_pybullet_state(pb) 119 | state = env._compute_state(request_task_update=False) 120 | obs = env._compute_observation(state=state) 121 | 122 | self.assertTrue(env.observation_space.contains(obs)) 123 | 124 | # 256 is the constants.INSTRUCTION_LENGTH 125 | self.assertEqual(obs['instruction'].size, 512) 126 | 127 | reconstructed_obs.append(obs) 128 | for i, (orig_obs, 129 | recon_obs) in enumerate(zip(observations, reconstructed_obs)): 130 | for k in orig_obs: 131 | if k == 'rgb': 132 | # pybullet rgb is sometimes flaky as the rendering is 133 | # non-deterministic. Since the individual pixel errors 134 | # are high magnitude, just make sure the average error 135 | # is low. 136 | self.assertLess( 137 | np.abs(orig_obs[k] - recon_obs[k]).mean(), 1e-6, 138 | f'Observation {k} at step {i} was not equal.') 139 | elif k == 'effector_translation': 140 | self.assertLess( 141 | np.abs(orig_obs[k] - recon_obs[k]).max(), 1e-3, 142 | f'Observation {k} at step {i} was not equal.') 143 | else: 144 | self.assertLess( 145 | np.abs(orig_obs[k] - recon_obs[k]).max(), 1e-6, 146 | f'Observation {k} at step {i} was not equal.') 147 | self.assertEqual(orig_obs[k].dtype, recon_obs[k].dtype) 148 | 149 | def test_environment_initializes_and_resets(self): 150 | env = language_table.LanguageTable( 151 | block_mode=blocks.LanguageTableBlockVariants.BLOCK_4, 152 | reward_factory=block2block.BlockToBlockReward, 153 | control_frequency=10.0) 154 | env.reset() 155 | 156 | def test_environment_obs_space_contains_obs(self): 157 | env = language_table.LanguageTable( 158 | block_mode=blocks.LanguageTableBlockVariants.BLOCK_4, 159 | reward_factory=block2block.BlockToBlockReward, 160 | control_frequency=10.0) 161 | obs = env.reset() 162 | self.assertTrue(env.observation_space.contains(obs)) 163 | for k in obs: 164 | self.assertEqual(obs[k].dtype, env.observation_space[k].dtype) 165 | 166 | def test_environment_steps_block4(self): 167 | env = language_table.LanguageTable( 168 | block_mode=blocks.LanguageTableBlockVariants.BLOCK_4, 169 | reward_factory=block2block.BlockToBlockReward, 170 | control_frequency=10.0, 171 | seed=0) 172 | obs = env.reset() 173 | for _ in range(5): 174 | obs, _, _, _ = env.step(env.action_space.sample()) 175 | self.assertTrue(env.observation_space.contains(obs)) 176 | 177 | def test_environment_steps_block8(self): 178 | env = language_table.LanguageTable( 179 | block_mode=blocks.LanguageTableBlockVariants.BLOCK_8, 180 | reward_factory=block2block.BlockToBlockReward, 181 | control_frequency=10.0, 182 | seed=0) 183 | obs = env.reset() 184 | for _ in range(5): 185 | obs, _, _, _ = env.step(env.action_space.sample()) 186 | self.assertTrue(env.observation_space.contains(obs)) 187 | 188 | def test_environment_steps_nchoosek(self): 189 | env = language_table.LanguageTable( 190 | block_mode=blocks.LanguageTableBlockVariants.N_CHOOSE_K, 191 | reward_factory=block2block.BlockToBlockReward, 192 | control_frequency=10.0, 193 | seed=0) 194 | obs = env.reset() 195 | for _ in range(5): 196 | obs, _, _, _ = env.step(env.action_space.sample()) 197 | self.assertTrue(env.observation_space.contains(obs)) 198 | 199 | 200 | if __name__ == '__main__': 201 | tf.test.main() 202 | -------------------------------------------------------------------------------- /language_table/environments/oracles/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Language Tale Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /language_table/environments/oracles/oriented_push_oracle.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Language Tale Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Oracle for pushing task which orients the block then pushes it.""" 17 | import dataclasses 18 | 19 | from typing import Any 20 | 21 | import numpy as np 22 | from tf_agents.policies import py_policy 23 | from tf_agents.trajectories import policy_step 24 | from tf_agents.trajectories import time_step as ts 25 | 26 | 27 | @dataclasses.dataclass 28 | class PushingInfo: 29 | """Holds onto info necessary for pushing state machine.""" 30 | xy_block: Any = None 31 | xy_ee: Any = None 32 | xy_pre_block: Any = None 33 | xy_dir_block_to_target: Any = None 34 | xy_delta_to_nexttoblock: Any = None 35 | xy_delta_to_touchingblock: Any = None 36 | xy_dir_block_to_ee: Any = None 37 | theta_threshold_to_orient: Any = None 38 | theta_threshold_flat_enough: Any = None 39 | theta_error: Any = None 40 | obstacle_poses: Any = None 41 | distance_to_target: Any = None 42 | 43 | 44 | class OrientedPushOracle(py_policy.PyPolicy): 45 | """Oracle for pushing task which orients the block then pushes it.""" 46 | 47 | def __init__(self, env, action_noise_std=0.0): 48 | super(OrientedPushOracle, self).__init__(env.time_step_spec(), 49 | env.action_spec()) 50 | self._env = env 51 | self._np_random_state = np.random.RandomState(0) 52 | self.phase = "move_to_pre_block" 53 | self._action_noise_std = action_noise_std 54 | 55 | def reset(self): 56 | self.phase = "move_to_pre_block" 57 | 58 | def get_theta_from_vector(self, vector): 59 | return np.arctan2(vector[1], vector[0]) 60 | 61 | def theta_to_rotation2d(self, theta): 62 | r = np.array([[np.cos(theta), -np.sin(theta)], 63 | [np.sin(theta), np.cos(theta)]]) 64 | return r 65 | 66 | def rotate(self, theta, xy_dir_block_to_ee): 67 | rot_2d = self.theta_to_rotation2d(theta) 68 | return rot_2d @ xy_dir_block_to_ee 69 | 70 | def _get_action_info(self, raw_state, block, target): 71 | xy_block = raw_state["%s_translation" % block][:2] 72 | theta_block = raw_state["%s_orientation" % block] 73 | xy_target = raw_state["%s_translation" % target][:2] 74 | xy_ee = raw_state["effector_target_translation"][:2] 75 | 76 | xy_block_to_target = xy_target - xy_block 77 | xy_dir_block_to_target = ( 78 | xy_block_to_target) / np.linalg.norm(xy_block_to_target) 79 | theta_to_target = self.get_theta_from_vector(xy_dir_block_to_target) 80 | 81 | theta_error = theta_to_target - theta_block 82 | # Block has 4-way symmetry. 83 | while theta_error > np.pi/4: 84 | theta_error -= np.pi/2. 85 | while theta_error < -np.pi/4: 86 | theta_error += np.pi/2. 87 | 88 | xy_pre_block = xy_block + -xy_dir_block_to_target * 0.05 89 | xy_nexttoblock = xy_block + -xy_dir_block_to_target * 0.03 90 | xy_touchingblock = xy_block + -xy_dir_block_to_target * 0.01 91 | xy_delta_to_nexttoblock = xy_nexttoblock - xy_ee 92 | xy_delta_to_touchingblock = xy_touchingblock - xy_ee 93 | 94 | xy_block_to_ee = xy_ee - xy_block 95 | xy_dir_block_to_ee = xy_block_to_ee / np.linalg.norm(xy_block_to_ee) 96 | 97 | theta_threshold_to_orient = 0.2 98 | theta_threshold_flat_enough = 0.03 99 | return PushingInfo( 100 | xy_block=xy_block, 101 | xy_ee=xy_ee, 102 | xy_pre_block=xy_pre_block, 103 | xy_delta_to_nexttoblock=xy_delta_to_nexttoblock, 104 | xy_delta_to_touchingblock=xy_delta_to_touchingblock, 105 | xy_dir_block_to_ee=xy_dir_block_to_ee, 106 | theta_threshold_to_orient=theta_threshold_to_orient, 107 | theta_threshold_flat_enough=theta_threshold_flat_enough, 108 | theta_error=theta_error) 109 | 110 | def _get_move_to_preblock(self, xy_pre_block, xy_ee): 111 | max_step_velocity = 0.3 112 | # Go 5 cm away from the block, on the line between the block and target. 113 | xy_delta_to_preblock = xy_pre_block - xy_ee 114 | diff = np.linalg.norm(xy_delta_to_preblock) 115 | if diff < 0.001: 116 | self.phase = "move_to_block" 117 | xy_delta = xy_delta_to_preblock 118 | return xy_delta, max_step_velocity 119 | 120 | def _get_move_to_block( 121 | self, xy_delta_to_nexttoblock, theta_threshold_to_orient, theta_error): 122 | diff = np.linalg.norm(xy_delta_to_nexttoblock) 123 | if diff < 0.001: 124 | self.phase = "push_block" 125 | # If need to re-oorient, then re-orient. 126 | if theta_error > theta_threshold_to_orient: 127 | self.phase = "orient_block_left" 128 | if theta_error < -theta_threshold_to_orient: 129 | self.phase = "orient_block_right" 130 | # Otherwise, push into the block. 131 | xy_delta = xy_delta_to_nexttoblock 132 | return xy_delta 133 | 134 | def _get_push_block( 135 | self, theta_error, theta_threshold_to_orient, xy_delta_to_touchingblock): 136 | # If need to reorient, go back to move_to_pre_block, move_to_block first. 137 | if theta_error > theta_threshold_to_orient: 138 | self.phase = "move_to_pre_block" 139 | if theta_error < -theta_threshold_to_orient: 140 | self.phase = "move_to_pre_block" 141 | xy_delta = xy_delta_to_touchingblock 142 | return xy_delta 143 | 144 | def _get_orient_block_left(self, 145 | xy_dir_block_to_ee, 146 | orient_circle_diameter, 147 | xy_block, 148 | xy_ee, 149 | theta_error, 150 | theta_threshold_flat_enough): 151 | xy_dir_block_to_ee = self.rotate(0.2, xy_dir_block_to_ee) 152 | xy_block_to_ee = xy_dir_block_to_ee * orient_circle_diameter 153 | xy_push_left_spot = xy_block + xy_block_to_ee 154 | xy_delta = xy_push_left_spot - xy_ee 155 | if theta_error < theta_threshold_flat_enough: 156 | self.phase = "move_to_pre_block" 157 | return xy_delta 158 | 159 | def _get_orient_block_right(self, 160 | xy_dir_block_to_ee, 161 | orient_circle_diameter, 162 | xy_block, 163 | xy_ee, 164 | theta_error, 165 | theta_threshold_flat_enough): 166 | xy_dir_block_to_ee = self.rotate(-0.2, xy_dir_block_to_ee) 167 | xy_block_to_ee = xy_dir_block_to_ee * orient_circle_diameter 168 | xy_push_left_spot = xy_block + xy_block_to_ee 169 | xy_delta = xy_push_left_spot - xy_ee 170 | if theta_error > -theta_threshold_flat_enough: 171 | self.phase = "move_to_pre_block" 172 | return xy_delta 173 | 174 | def _get_action_for_block_target(self, 175 | raw_state, 176 | block="block", 177 | target="target"): 178 | # Specifying this as velocity makes it independent of control frequency. 179 | max_step_velocity = 0.35 180 | info = self._get_action_info(raw_state, block, target) 181 | 182 | if self.phase == "move_to_pre_block": 183 | xy_delta, max_step_velocity = self._get_move_to_preblock( 184 | info.xy_pre_block, info.xy_ee) 185 | 186 | if self.phase == "move_to_block": 187 | xy_delta = self._get_move_to_block( 188 | info.xy_delta_to_nexttoblock, info.theta_threshold_to_orient, 189 | info.theta_error) 190 | 191 | if self.phase == "push_block": 192 | xy_delta = self._get_push_block( 193 | info.theta_error, info.theta_threshold_to_orient, 194 | info.xy_delta_to_touchingblock) 195 | 196 | orient_circle_diameter = 0.025 197 | 198 | if self.phase == "orient_block_left" or self.phase == "orient_block_right": 199 | max_step_velocity = 0.15 200 | 201 | if self.phase == "orient_block_left": 202 | xy_delta = self._get_orient_block_left( 203 | info.xy_dir_block_to_ee, 204 | orient_circle_diameter, 205 | info.xy_block, 206 | info.xy_ee, 207 | info.theta_error, 208 | info.theta_threshold_flat_enough) 209 | 210 | if self.phase == "orient_block_right": 211 | xy_delta = self._get_orient_block_right( 212 | info.xy_dir_block_to_ee, 213 | orient_circle_diameter, 214 | info.xy_block, 215 | info.xy_ee, 216 | info.theta_error, 217 | info.theta_threshold_flat_enough) 218 | 219 | if self._action_noise_std != 0.0: 220 | xy_delta += (self._np_random_state.randn(2) * 221 | self._action_noise_std) 222 | 223 | max_step_distance = max_step_velocity * (1 / 224 | self._env.get_control_frequency()) 225 | length = np.linalg.norm(xy_delta) 226 | if length > max_step_distance: 227 | xy_direction = xy_delta / length 228 | xy_delta = xy_direction * max_step_distance 229 | return xy_delta 230 | 231 | def _action(self, 232 | time_step, 233 | policy_state, 234 | seed = None): 235 | if time_step.is_first(): 236 | self.reset() 237 | raw_state = self._env.compute_state() 238 | xy_delta = self._get_action_for_block_target( 239 | raw_state, block="block", target="target") 240 | return policy_step.PolicyStep(action=np.asarray(xy_delta, dtype=np.float32)) 241 | -------------------------------------------------------------------------------- /language_table/environments/oracles/plot.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Language Tale Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Utils for plotting RRT and writing debug images.""" 17 | 18 | import imageio 19 | from matplotlib import patches 20 | from matplotlib import pyplot as plt 21 | import numpy as np 22 | import tensorflow as tf 23 | 24 | 25 | DEBUG_IMAGE_PATH = '/tmp/rrt_debug_img.png' 26 | 27 | 28 | class PlotRRT: 29 | """Plotting helper for visualizing RRT.""" 30 | 31 | def __init__(self, 32 | x_start, 33 | x_goal, 34 | obs_boundary, 35 | obs_circle, 36 | image_path=DEBUG_IMAGE_PATH): 37 | self._x_start, self._x_goal = x_start, x_goal 38 | self._obs_bound = obs_boundary 39 | self._obs_circle = obs_circle 40 | self._image_path = image_path 41 | 42 | def animation(self, nodelist, path, name, animation=False): 43 | self.plot_grid(name) 44 | self.plot_visited(nodelist, animation) 45 | self.plot_path(path) 46 | 47 | def plot_grid(self, name): 48 | """Plot whole grid.""" 49 | self.fig, ax = plt.subplots() 50 | for (ox, oy, w, h) in self._obs_bound: 51 | ax.add_patch( 52 | patches.Rectangle( 53 | (ox, oy), w, h, 54 | edgecolor='black', 55 | facecolor='black', 56 | fill=True) 57 | ) 58 | for (ox, oy, r) in self._obs_circle: 59 | ax.add_patch( 60 | patches.Circle( 61 | (ox, oy), r, 62 | edgecolor='black', 63 | facecolor='gray', 64 | fill=True) 65 | ) 66 | 67 | radius = self._obs_circle[0][2] 68 | ax.add_patch( 69 | patches.Circle( 70 | (self._x_start[0], self._x_start[1]), radius, 71 | edgecolor='black', 72 | facecolor='blue', 73 | fill=True) 74 | ) 75 | ax.add_patch( 76 | patches.Circle( 77 | (self._x_goal[0], self._x_goal[1]), radius, 78 | edgecolor='black', 79 | facecolor='green', 80 | fill=True) 81 | ) 82 | plt.title(name) 83 | plt.axis('equal') 84 | 85 | def plot_visited(self, nodelist, animation): 86 | """Plot visited.""" 87 | if animation: 88 | count = 0 89 | for node in nodelist: 90 | count += 1 91 | if node.parent: 92 | plt.plot([node.parent.x, node.x], [node.parent.y, node.y], '-g') 93 | else: 94 | for node in nodelist: 95 | if node.parent: 96 | plt.plot([node.parent.x, node.x], [node.parent.y, node.y], '-g') 97 | 98 | def plot_path(self, path): 99 | if path: 100 | plt.plot([x[0] for x in path], [x[1] for x in path], '-r', linewidth=2) 101 | plt.show() 102 | 103 | def fig_to_array(self, figure): 104 | """Converts a matplotlib figure to a numpy array.""" 105 | figure.canvas.draw() 106 | np_fig = np.fromstring(figure.canvas.tostring_rgb(), dtype=np.uint8, sep='') 107 | np_fig = np_fig.reshape(figure.canvas.get_width_height()[::-1] + (3,)) 108 | return np_fig 109 | 110 | def save_debug_image(self): 111 | """Writes an image to a debug path on cns.""" 112 | array = self.fig_to_array(self.fig) 113 | with tf.io.gfile.GFile(self._image_path, 'wb') as f: 114 | imageio.imwrite(f, array * 255.0, format='png') 115 | -------------------------------------------------------------------------------- /language_table/environments/rewards/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Language Tale Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /language_table/environments/rewards/block1_to_corner.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Language Tale Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Defines block2absolutelocation reset and reward.""" 17 | import enum 18 | from typing import Any, List, Mapping 19 | 20 | from absl import logging 21 | from language_table.environments import blocks as blocks_module 22 | from language_table.environments.rewards import reward as base_reward 23 | from language_table.environments.rewards import synonyms 24 | from language_table.environments.rewards import task_info 25 | import numpy as np 26 | 27 | BUFFER = 0.08 28 | 29 | X_MIN = 0.15 30 | X_MAX = 0.6 31 | Y_MIN = -0.3048 32 | Y_MAX = 0.3048 33 | CENTER_X = (X_MAX - X_MIN) / 2. + X_MIN 34 | CENTER_Y = (Y_MAX - Y_MIN) / 2. + Y_MIN 35 | 36 | BLOCK2ABSOLUTELOCATION_TARGET_DISTANCE = 0.08 37 | 38 | 39 | class Locations(enum.Enum): 40 | BOTTOM_LEFT = 'bottom_left' 41 | 42 | 43 | ABSOLUTE_LOCATIONS = { 44 | 'bottom_left': [X_MAX - BUFFER, Y_MIN + BUFFER], 45 | } 46 | 47 | LOCATION_SYNONYMS = { 48 | 'bottom_left': [ 49 | 'bottom left of the board', 'bottom left', 'bottom left corner' 50 | ], 51 | } 52 | 53 | BLOCK2ABSOLUTELOCATION_VERBS = [ 54 | 'move the', 55 | 'push the', 56 | 'slide the', 57 | ] 58 | 59 | 60 | def generate_all_instructions(block_mode): 61 | """Generate all instructions for block2relativeposition.""" 62 | all_instructions = [] 63 | all_block_text_descriptions = blocks_module.get_blocks_text_descriptions( 64 | block_mode) 65 | for block_text in all_block_text_descriptions: 66 | for location in ABSOLUTE_LOCATIONS: 67 | for location_syn in LOCATION_SYNONYMS[location]: 68 | for verb in BLOCK2ABSOLUTELOCATION_VERBS: 69 | # Add instruction. 70 | inst = (f'{verb} {block_text} to the {location_syn}') 71 | all_instructions.append(inst) 72 | return all_instructions 73 | 74 | 75 | class Block1ToCornerLocationReward(base_reward.LanguageTableReward): 76 | """Calculates reward/instructions for 'push 1 block to corner'.""" 77 | 78 | def __init__(self, goal_reward, rng, delay_reward_steps, 79 | block_mode): 80 | super(Block1ToCornerLocationReward, 81 | self).__init__(goal_reward, rng, delay_reward_steps, block_mode) 82 | self._block = None 83 | self._instruction = None 84 | self._location = None 85 | self._target_translation = None 86 | 87 | def _sample_instruction(self, block, blocks_on_table, 88 | location): 89 | """Randomly sample a task involving two objects.""" 90 | verb = self._rng.choice(synonyms.PUSH_VERBS) 91 | # Get some synonym for block. 92 | block_text = self._rng.choice( 93 | synonyms.get_block_synonyms(block, blocks_on_table)) 94 | # Get some synonym for location. 95 | location_syn = self._rng.choice(LOCATION_SYNONYMS[location]) 96 | return f'{verb} {block_text} to the {location_syn}' 97 | 98 | def reset(self, state, blocks_on_table): 99 | """Chooses new target block and location.""" 100 | # Choose a random block. 101 | block = self._sample_object(blocks_on_table) 102 | 103 | # Choose a location randomly. 104 | location = self._rng.choice(list(sorted(ABSOLUTE_LOCATIONS.keys()))) 105 | 106 | info = self.reset_to(state, block, location, blocks_on_table) 107 | # If the state of the board already triggers the reward, try to reset 108 | # again with a new configuration. 109 | if self.reward(state)[0]: 110 | # Try again with a new board configuration. 111 | return task_info.FAILURE 112 | return info 113 | 114 | def reset_to(self, state, block, location, 115 | blocks_on_table): 116 | """Reset to a particular task definition.""" 117 | self._block = block 118 | # Sample an instruction. 119 | self._instruction = self._sample_instruction(self._block, blocks_on_table, 120 | location) 121 | # Get the corresponding target_translation. 122 | target_translation = ABSOLUTE_LOCATIONS[location] 123 | # Cache the target location corresponding to the instruction. 124 | self._target_translation = np.copy(target_translation) 125 | self._location = location 126 | info = self.get_current_task_info(state) 127 | self._in_reward_zone_steps = 0 128 | return info 129 | 130 | @property 131 | def target_translation(self): 132 | return self._target_translation 133 | 134 | def reward(self, state): 135 | """Calculates reward given state.""" 136 | reward, done = self.reward_for(state, self._block, self._target_translation) 137 | return reward, done 138 | 139 | def reward_for(self, state, pushing_block, 140 | target_translation): 141 | """Returns 1. if pushing_block is in location.""" 142 | # Get current location of the target block. 143 | current_translation, _ = self._get_pose_for_block(pushing_block, state) 144 | # Compute distance between current translation and target. 145 | dist = np.linalg.norm( 146 | np.array(current_translation) - np.array(target_translation)) 147 | reward = 0.0 148 | done = False 149 | if dist < BLOCK2ABSOLUTELOCATION_TARGET_DISTANCE: 150 | if self._in_reward_zone_steps >= self._delay_reward_steps: 151 | reward = self._goal_reward 152 | done = True 153 | else: 154 | logging.info('In reward zone for %d steps', self._in_reward_zone_steps) 155 | self._in_reward_zone_steps += 1 156 | return reward, done 157 | 158 | def reward_for_info(self, state, info): 159 | return self.reward_for(state, info.block, info.target_translation) 160 | 161 | def debug_info(self, state): 162 | """Returns 1. if pushing_block is in location.""" 163 | # Get current location of the target block. 164 | current_translation, _ = self._get_pose_for_block(self._block, state) 165 | # Compute distance between current translation and target. 166 | dist = np.linalg.norm( 167 | np.array(current_translation) - np.array(self._target_translation)) 168 | return dist 169 | 170 | def get_current_task_info(self, state): 171 | return task_info.Block2LocationTaskInfo( 172 | instruction=self._instruction, 173 | block=self._block, 174 | location=self._location, 175 | target_translation=self._target_translation) 176 | -------------------------------------------------------------------------------- /language_table/environments/rewards/block2absolutelocation.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Language Tale Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Defines block2absolutelocation reset and reward.""" 17 | import enum 18 | 19 | from typing import Any, List 20 | from absl import logging 21 | from language_table.environments import blocks as blocks_module 22 | from language_table.environments.rewards import reward as base_reward 23 | from language_table.environments.rewards import synonyms 24 | from language_table.environments.rewards import task_info 25 | import numpy as np 26 | 27 | 28 | # There's a small offset in the Y direction to subtract. 29 | # The red dots represent the bounds of the arm, which are not exactly in the 30 | # center of the boards. 31 | # This should only matter for this reward, which deals with absolute locations. 32 | X_BUFFER = 0.025 33 | 34 | X_MIN_REAL = 0.15 35 | X_MAX_REAL = 0.6 36 | Y_MIN_REAL = -0.3048 37 | Y_MAX_REAL = 0.3048 38 | X_MIN = X_MIN_REAL - X_BUFFER 39 | X_MAX = X_MAX_REAL - X_BUFFER 40 | Y_MIN = Y_MIN_REAL 41 | Y_MAX = Y_MAX_REAL 42 | CENTER_X = (X_MAX - X_MIN) / 2. + X_MIN 43 | CENTER_Y = (Y_MAX - Y_MIN)/2. + Y_MIN 44 | 45 | BLOCK2ABSOLUTELOCATION_TARGET_DISTANCE = 0.115 46 | BLOCK2ABSOLUTELOCATION_CENTER_TARGET_DISTANCE = 0.1 47 | 48 | 49 | class Locations(enum.Enum): 50 | TOP = 'top' 51 | TOP_LEFT = 'top_left' 52 | TOP_RIGHT = 'top_right' 53 | CENTER = 'center' 54 | CENTER_LEFT = 'center_left' 55 | CENTER_RIGHT = 'center_right' 56 | BOTTOM = 'bottom' 57 | BOTTOM_LEFT = 'bottom_left' 58 | BOTTOM_RIGHT = 'bottom_right' 59 | 60 | 61 | ABSOLUTE_LOCATIONS = { 62 | 'top': [X_MIN, CENTER_Y], 63 | 'top_left': [X_MIN, Y_MIN], 64 | 'top_right': [X_MIN, Y_MAX], 65 | 'center': [CENTER_X, CENTER_Y], 66 | 'center_left': [CENTER_X, Y_MIN], 67 | 'center_right': [CENTER_X, Y_MAX], 68 | 'bottom': [X_MAX, CENTER_Y], 69 | 'bottom_left': [X_MAX, Y_MIN], 70 | 'bottom_right': [X_MAX, Y_MAX], 71 | } 72 | 73 | LOCATION_SYNONYMS = { 74 | 'top': ['top side', 'top', 'towards your base'], 75 | 'top_left': ['top left of the board', 'top left', 76 | 'upper left corner', 'top left corner'], 77 | 'top_right': ['top right of the board', 'top right', 78 | 'upper right corner', 'top right corner'], 79 | 'center': ['middle of the board', 'center of the board', 80 | 'center', 'middle'], 81 | 'center_left': ['left side of the board', 'center left', 'left side'], 82 | 'center_right': ['right side of the board', 'center right', 'right side'], 83 | 'bottom': ['bottom side', 'bottom'], 84 | 'bottom_left': ['bottom left of the board', 'bottom left', 85 | 'lower left corner', 'bottom left corner'], 86 | 'bottom_right': ['bottom right of the board', 'bottom right', 87 | 'lower right corner', 'bottom right corner'], 88 | } 89 | 90 | BLOCK2ABSOLUTELOCATION_VERBS = [ 91 | 'move the', 92 | 'push the', 93 | 'slide the', 94 | ] 95 | 96 | 97 | def generate_all_instructions(block_mode): 98 | """Generate all instructions for block2relativeposition.""" 99 | all_instructions = [] 100 | all_block_text_descriptions = blocks_module.get_blocks_text_descriptions( 101 | block_mode) 102 | for block_text in all_block_text_descriptions: 103 | for location in ABSOLUTE_LOCATIONS: 104 | for location_syn in LOCATION_SYNONYMS[location]: 105 | for verb in BLOCK2ABSOLUTELOCATION_VERBS: 106 | # Add instruction. 107 | inst = (f'{verb} {block_text} to the {location_syn}') 108 | all_instructions.append(inst) 109 | return all_instructions 110 | 111 | 112 | class BlockToAbsoluteLocationReward(base_reward.LanguageTableReward): 113 | """Calculates reward/instructions for 'push block to absolute location'.""" 114 | 115 | def __init__(self, goal_reward, rng, delay_reward_steps, 116 | block_mode): 117 | super(BlockToAbsoluteLocationReward, self).__init__( 118 | goal_reward=goal_reward, 119 | rng=rng, 120 | delay_reward_steps=delay_reward_steps, 121 | block_mode=block_mode) 122 | self._block = None 123 | self._instruction = None 124 | self._location = None 125 | self._target_translation = None 126 | 127 | def _sample_instruction( 128 | self, block, blocks_on_table, location): 129 | """Randomly sample a task involving two objects.""" 130 | verb = self._rng.choice(synonyms.PUSH_VERBS) 131 | # Get some synonym for block. 132 | block_text = self._rng.choice( 133 | synonyms.get_block_synonyms(block, blocks_on_table)) 134 | # Get some synonym for location. 135 | location_syn = self._rng.choice(LOCATION_SYNONYMS[location]) 136 | return f'{verb} {block_text} to the {location_syn}' 137 | 138 | def reset(self, state, blocks_on_table): 139 | """Chooses new target block and location.""" 140 | # Choose a random block. 141 | block = self._sample_object(blocks_on_table) 142 | 143 | # Choose a location randomly. 144 | location = self._rng.choice(list(sorted(ABSOLUTE_LOCATIONS.keys()))) 145 | 146 | info = self.reset_to(state, block, location, blocks_on_table) 147 | # If the state of the board already triggers the reward, try to reset 148 | # again with a new configuration. 149 | if self._in_goal_region(state, self._block, self._target_translation): 150 | # Try again with a new board configuration. 151 | return task_info.FAILURE 152 | return info 153 | 154 | def reset_to( 155 | self, state, block, location, blocks_on_table): 156 | """Reset to a particular task definition.""" 157 | self._block = block 158 | # Sample an instruction. 159 | self._instruction = self._sample_instruction( 160 | self._block, blocks_on_table, location) 161 | # Get the corresponding target_translation. 162 | target_translation = ABSOLUTE_LOCATIONS[location] 163 | # Cache the target location corresponding to the instruction. 164 | self._target_translation = np.copy(target_translation) 165 | self._location = location 166 | info = self.get_current_task_info(state) 167 | self._in_reward_zone_steps = 0 168 | return info 169 | 170 | @property 171 | def target_translation(self): 172 | return self._target_translation 173 | 174 | def get_goal_region(self): 175 | if self._location == Locations.CENTER.value: 176 | return self._target_translation, BLOCK2ABSOLUTELOCATION_CENTER_TARGET_DISTANCE 177 | return self._target_translation, BLOCK2ABSOLUTELOCATION_TARGET_DISTANCE 178 | 179 | def reward(self, state): 180 | """Calculates reward given state.""" 181 | reward, done = self.reward_for(state, self._block, self._target_translation) 182 | return reward, done 183 | 184 | def reward_for( 185 | self, state, pushing_block, target_translation): 186 | """Returns 1. if pushing_block is in location.""" 187 | reward = 0.0 188 | done = False 189 | 190 | in_goal_region = self._in_goal_region(state, pushing_block, 191 | target_translation) 192 | 193 | if in_goal_region: 194 | if self._in_reward_zone_steps >= self._delay_reward_steps: 195 | reward = self._goal_reward 196 | done = True 197 | else: 198 | logging.info('In reward zone for %d steps', self._in_reward_zone_steps) 199 | self._in_reward_zone_steps += 1 200 | return reward, done 201 | 202 | def _in_goal_region(self, state, pushing_block, 203 | target_translation): 204 | # Get current location of the target block. 205 | current_translation, _ = self._get_pose_for_block(pushing_block, state) 206 | # Compute distance between current translation and target. 207 | dist = np.linalg.norm( 208 | np.array(current_translation) - np.array(target_translation)) 209 | 210 | if self._location == Locations.CENTER.value: 211 | target_dist = BLOCK2ABSOLUTELOCATION_CENTER_TARGET_DISTANCE 212 | else: 213 | target_dist = BLOCK2ABSOLUTELOCATION_TARGET_DISTANCE 214 | 215 | if dist < target_dist: 216 | return True 217 | return False 218 | 219 | def reward_for_info(self, state, info): 220 | return self.reward_for(state, info.block, info.target_translation) 221 | 222 | def debug_info(self, state): 223 | """Returns 1. if pushing_block is in location.""" 224 | # Get current location of the target block. 225 | current_translation, _ = self._get_pose_for_block( 226 | self._block, state) 227 | # Compute distance between current translation and target. 228 | dist = np.linalg.norm( 229 | np.array(current_translation) - np.array(self._target_translation)) 230 | return dist 231 | 232 | def get_current_task_info(self, state): 233 | return task_info.Block2LocationTaskInfo( 234 | instruction=self._instruction, 235 | block=self._block, 236 | location=self._location, 237 | target_translation=self._target_translation) 238 | -------------------------------------------------------------------------------- /language_table/environments/rewards/block2block.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Language Tale Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Defines block2block reset and reward.""" 17 | import itertools 18 | 19 | from absl import logging 20 | from language_table.environments import blocks as blocks_module 21 | from language_table.environments.rewards import constants 22 | from language_table.environments.rewards import reward as base_reward 23 | from language_table.environments.rewards import synonyms 24 | from language_table.environments.rewards import task_info 25 | import numpy as np 26 | 27 | 28 | def generate_all_instructions(block_mode): 29 | """Generates all block2block instructions.""" 30 | all_instructions = [] 31 | all_block_text_descriptions = blocks_module.get_blocks_text_descriptions( 32 | block_mode) 33 | for start_block_text, target_block_text in itertools.permutations( 34 | all_block_text_descriptions, 2): 35 | for verb in synonyms.PUSH_VERBS: 36 | for preposition in synonyms.PREPOSITIONS: 37 | inst = f'{verb} {start_block_text} {preposition} {target_block_text}' 38 | all_instructions.append(inst) 39 | return all_instructions 40 | 41 | 42 | # pytype: skip-file 43 | class BlockToBlockReward(base_reward.LanguageTableReward): 44 | """Block2block reward.""" 45 | 46 | def _sample_instruction( 47 | self, start_block, target_block, blocks_on_table): 48 | """Randomly sample a task involving two objects.""" 49 | verb = self._rng.choice(synonyms.PUSH_VERBS) 50 | # Sample synonyms for start and target blocks. 51 | start_syn = self._rng.choice( 52 | synonyms.get_block_synonyms(start_block, blocks_on_table)) 53 | target_syn = self._rng.choice( 54 | synonyms.get_block_synonyms(target_block, blocks_on_table)) 55 | preposition = self._rng.choice(synonyms.PREPOSITIONS) 56 | return f'{verb} {start_syn} {preposition} {target_syn}' 57 | 58 | def reset(self, state, blocks_on_table): 59 | """Resets the start/target objects and returns a text instruction.""" 60 | # pick two objects sufficiently far away and get their poses. 61 | # track start object and target object poses. 62 | max_attempts = 10 63 | num_attempts = 0 64 | while True: 65 | start_block, target_block = self._sample_objects(blocks_on_table) 66 | start_translation, _ = self._get_pose_for_block( 67 | start_block, state) 68 | target_translation, _ = self._get_pose_for_block( 69 | target_block, state) 70 | dist = np.linalg.norm( 71 | np.array(start_translation) - np.array(target_translation)) 72 | if dist < constants.TARGET_BLOCK_DISTANCE + 0.01: 73 | num_attempts += 1 74 | if num_attempts > max_attempts: 75 | logging.info( 76 | 'Exceeded max number of attempts to find start/target blocks. ' 77 | 'No valid reward found for the current object configuration.') 78 | return task_info.FAILURE 79 | continue 80 | else: 81 | self._start_block = start_block 82 | self._target_block = target_block 83 | break 84 | self._instruction = self._sample_instruction( 85 | start_block, target_block, blocks_on_table) 86 | self._in_reward_zone_steps = 0 87 | return task_info.Block2BlockTaskInfo( 88 | instruction=self._instruction, 89 | block1=self._start_block, 90 | block2=self._target_block) 91 | 92 | def get_goal_region(self): 93 | return self._target_translation, constants.TARGET_BLOCK_DISTANCE 94 | 95 | def reward(self, state): 96 | """Calculates reward given state.""" 97 | # For now only have sparse reward. 98 | start_translation, _ = self._get_pose_for_block(self._start_block, state) 99 | target_translation, _ = self._get_pose_for_block(self._target_block, state) 100 | 101 | self._target_translation = target_translation 102 | 103 | # This check ignore whether start block was moved (rather than target object 104 | # being moved towards start object. 105 | # TODO(ayzaan): Add smarter logic here. 106 | dist = np.linalg.norm( 107 | np.array(start_translation) - np.array(target_translation)) 108 | reward = 0.0 109 | done = False 110 | if dist < constants.TARGET_BLOCK_DISTANCE: 111 | if self._in_reward_zone_steps >= self._delay_reward_steps: 112 | reward = self._goal_reward 113 | done = True 114 | else: 115 | logging.info('In reward zone for %d steps', self._in_reward_zone_steps) 116 | self._in_reward_zone_steps += 1 117 | 118 | return reward, done 119 | -------------------------------------------------------------------------------- /language_table/environments/rewards/block2relativelocation.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Language Tale Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Defines block2relativelocation reset and reward.""" 17 | 18 | from absl import logging 19 | from language_table.environments import blocks as blocks_module 20 | from language_table.environments.rewards import reward as base_reward 21 | from language_table.environments.rewards import synonyms 22 | from language_table.environments.rewards import task_info 23 | import numpy as np 24 | 25 | 26 | MAGNITUDES = { 27 | 'near': 0.15, 28 | 'far': 0.25 29 | } 30 | 31 | # Cardinal directions you can push the block. The top left of board 32 | # is (height=0, width=0), so UP == -1, LEFT == -1, etc. 33 | UP = -1. 34 | DOWN = 1. 35 | LEFT = -1. 36 | RIGHT = 1. 37 | 38 | DIRECTIONS = { 39 | 'up': [UP, 0.], # good. 40 | 'down': [DOWN, 0.], # good. 41 | 'left': [0., LEFT], # good. 42 | 'right': [0., RIGHT], # good. 43 | 'diagonal_up_left': [UP, LEFT] / np.linalg.norm([UP, LEFT]), 44 | 'diagonal_up_right': [UP, RIGHT] / np.linalg.norm([UP, RIGHT]), 45 | 'diagonal_down_left': [DOWN, LEFT] / np.linalg.norm([DOWN, LEFT]), 46 | 'diagonal_down_right': [DOWN, RIGHT] / np.linalg.norm([DOWN, RIGHT]), 47 | } 48 | 49 | 50 | BLOCK2RELATIVELOCATION_VERBS = [ 51 | 'move the', 52 | 'push the', 53 | 'slide the', 54 | ] 55 | 56 | 57 | # These cover [up, down, left, right]. 58 | SLIGHTLY_PREPOSITION_SYNONYMS = [ 59 | 'slightly', 60 | 'slightly to the', 61 | 'a bit', 62 | 'a bit to the', 63 | 'a little', 64 | 'a little to the', 65 | 'a little bit to the', 66 | 'somewhat', 67 | 'somewhat to the', 68 | ] 69 | 70 | 71 | SLIGHTLY_SYNONYMS = [ 72 | 'slightly', 73 | 'a bit', 74 | 'a little', 75 | 'a little bit', 76 | 'somewhat', 77 | ] 78 | 79 | BLOCK2RELATIVELOCATION_MODE_TO_PREPOSITIONS = { 80 | 'near': SLIGHTLY_PREPOSITION_SYNONYMS, 81 | # empty string because 'push the red block left' is valid. 82 | 'far': ['to the', ''] 83 | } 84 | 85 | 86 | DIRECTION_SYNONYMS = { 87 | 'up': ['up', 'upwards'], 88 | 'down': ['down', 'downwards'], 89 | 'left': ['to the left', 'left'], 90 | 'right': ['to the right', 'right'] 91 | } 92 | 93 | DIAGONAL_PREPOSITIONS = [ 94 | '%s and %s', 95 | '%s and then %s', 96 | 'diagonally %s and %s', 97 | '%s and %s diagonally', 98 | ] 99 | 100 | 101 | BLOCK2RELATIVELOCATION_TARGET_DISTANCE = 0.1 102 | 103 | 104 | def create_slightly_instruction(rng, verb, block, direction): 105 | """Created a `slightly` modified instruction.""" 106 | mode = rng.choice(['slightly_first', 'prefix', 'suffix']) 107 | if mode == 'slightly_first': 108 | # e.g. 'slightly push the blue cube down and to the right' 109 | inst = f'slightly {verb} {block} {direction}' 110 | elif mode == 'prefix': 111 | # e.g. 'push the blue cube slightly down and to the right' 112 | slightly_syn = rng.choice(SLIGHTLY_SYNONYMS) 113 | inst = f'{verb} {block} {slightly_syn} {direction}' 114 | else: 115 | # e.g. 'push the blue cube down and right slightly' 116 | slightly_syn = rng.choice(SLIGHTLY_SYNONYMS) 117 | inst = f'{verb} {block} {direction} {slightly_syn}' 118 | return inst 119 | 120 | 121 | def enumerate_slightly_instruction(verb, block, direction): 122 | # Slightly first. 123 | instructions = [f'slightly {verb} {block} {direction}'] 124 | for slightly_syn in SLIGHTLY_SYNONYMS: 125 | # prefix. 126 | instructions.append(f'{verb} {block} {slightly_syn} {direction}') 127 | # suffix. 128 | instructions.append(f'{verb} {block} {direction} {slightly_syn}') 129 | for instruction in instructions: 130 | yield instruction 131 | 132 | 133 | def get_diagonal_direction(rng, direction): 134 | """Map canonical diagonal direction `diagonal_up_right' to natural lang.""" 135 | # Break out e.g. 'up', and 'right' 136 | _, first_dir, second_dir = direction.split('_') 137 | # Choose synonyms, e.g. 'right' -> 'to the right'. 138 | first_dir = rng.choice(DIRECTION_SYNONYMS[first_dir]) 139 | second_dir = rng.choice(DIRECTION_SYNONYMS[second_dir]) 140 | # Choose a diagonal preposition, e.g. ' and then '. 141 | diagonal_prep = rng.choice(DIAGONAL_PREPOSITIONS) 142 | # Insert the actual directions, e.g. 'up and then to the right'. 143 | diagonal_direction = diagonal_prep % (first_dir, second_dir) 144 | return diagonal_direction 145 | 146 | 147 | def enumerate_diagonal_direction(direction): 148 | """Map canonical diagonal direction `diagonal_up_right' to natural lang.""" 149 | # Break out e.g. 'up', and 'right' 150 | _, first_dir, second_dir = direction.split('_') 151 | # Choose synonyms, e.g. 'right' -> 'to the right'. 152 | for first_dir in DIRECTION_SYNONYMS[first_dir]: 153 | for second_dir in DIRECTION_SYNONYMS[second_dir]: 154 | # Choose a diagonal preposition, e.g. ' and then '. 155 | for diagonal_prep in DIAGONAL_PREPOSITIONS: 156 | # Insert the actual directions, e.g. 'up and then to the right'. 157 | diagonal_direction = diagonal_prep % (first_dir, second_dir) 158 | yield diagonal_direction 159 | 160 | 161 | def generate_all_instructions(block_mode): 162 | """Generate all instructions for block2relativeposition.""" 163 | all_instructions = [] 164 | all_block_text_descriptions = blocks_module.get_blocks_text_descriptions( 165 | block_mode) 166 | for block_text in all_block_text_descriptions: 167 | for verb in BLOCK2RELATIVELOCATION_VERBS: 168 | for direction in DIRECTIONS: 169 | if 'diagonal' in direction: 170 | for direction_syn in enumerate_diagonal_direction(direction): 171 | # Add 'near diagonal' instructions. 172 | for near_inst in enumerate_slightly_instruction( 173 | verb, block_text, direction_syn): 174 | all_instructions.append(near_inst) 175 | # Add 'far diagonal' instructions. 176 | inst = f'{verb} {block_text} {direction_syn}' 177 | all_instructions.append(inst) 178 | else: 179 | for direction_syn in DIRECTION_SYNONYMS[direction]: 180 | # Add 'near' instructions. 181 | for near_inst in enumerate_slightly_instruction( 182 | verb, block_text, direction_syn): 183 | all_instructions.append(near_inst) 184 | # Add 'far' instructions. 185 | inst = f'{verb} {block_text} {direction_syn}' 186 | all_instructions.append(inst) 187 | return all_instructions 188 | 189 | 190 | class BlockToRelativeLocationReward(base_reward.LanguageTableReward): 191 | """Calculates reward/instructions for 'push block to relative location'.""" 192 | 193 | def _sample_instruction( 194 | self, block, distance_mode, direction, blocks_on_table): 195 | """Randomly sample a task involving two objects.""" 196 | verb = self._rng.choice(BLOCK2RELATIVELOCATION_VERBS) 197 | # Sample synonym for block. 198 | block_syn = self._rng.choice( 199 | synonyms.get_block_synonyms(block, blocks_on_table)) 200 | if 'diagonal' in direction: 201 | # Map canonical diagonal dir to natural language. E.g. 202 | # diagonal_up_right -> 'up and to the right diagonally'. 203 | direction = get_diagonal_direction(self._rng, direction) 204 | else: 205 | direction = self._rng.choice(DIRECTION_SYNONYMS[direction]) 206 | if distance_mode == 'near': 207 | # Modify this with language like 'slightly'. 208 | inst = create_slightly_instruction(self._rng, verb, block_syn, direction) 209 | else: 210 | inst = f'{verb} {block_syn} {direction}' 211 | return inst 212 | 213 | def reset(self, state, blocks_on_table): 214 | """Chooses new target block, direction, and distance.""" 215 | cnt = 0 216 | max_tries = 100 217 | while True: 218 | # Choose a random block. 219 | self._block = self._sample_object(blocks_on_table) 220 | 221 | # Get the current block location. 222 | block_translation, _ = self._get_pose_for_block( 223 | self._block, state) 224 | 225 | # Choose a direction. 226 | direction = self._rng.choice(sorted(list(DIRECTIONS.keys()))) 227 | 228 | # Choose a magnitude. 229 | distance_mode = self._rng.choice(sorted(list(MAGNITUDES.keys()))) 230 | magnitude = MAGNITUDES[distance_mode] 231 | 232 | # Define target_vector as % change in H,W. 233 | target_vector = np.array(DIRECTIONS[direction]) * magnitude 234 | # Define target_translation = direction * magnitude. 235 | target_translation = block_translation + target_vector 236 | 237 | # Only keep if target_translation is inside workspace bounds. 238 | if base_reward.target_inside_bounds(target_translation): 239 | break 240 | cnt += 1 241 | 242 | if cnt > max_tries: 243 | # Try again with a new board configuration. 244 | return task_info.FAILURE 245 | # Choose an instruction. 246 | self._instruction = self._sample_instruction( 247 | self._block, distance_mode, direction, blocks_on_table) 248 | 249 | # Cache the target location corresponding to the instruction. 250 | self._target_translation = np.copy(target_translation) 251 | 252 | self._in_reward_zone_steps = 0 253 | return task_info.Block2RelativeLocationTaskInfo( 254 | instruction=self._instruction, 255 | block=self._block, 256 | location=direction, 257 | target_translation=self._target_translation) 258 | 259 | def get_goal_region(self): 260 | return self._target_translation, BLOCK2RELATIVELOCATION_TARGET_DISTANCE 261 | 262 | def reward(self, state): 263 | """Calculates reward given state.""" 264 | # Get current location of the target block. 265 | current_translation, _ = self._get_pose_for_block( 266 | self._block, state) 267 | 268 | # Compute distance between current translation and target. 269 | dist = np.linalg.norm( 270 | np.array(current_translation) - np.array(self._target_translation)) 271 | reward = 0.0 272 | done = False 273 | if dist < BLOCK2RELATIVELOCATION_TARGET_DISTANCE: 274 | if self._in_reward_zone_steps >= self._delay_reward_steps: 275 | reward = self._goal_reward 276 | done = True 277 | else: 278 | logging.info('In reward zone for %d steps', self._in_reward_zone_steps) 279 | self._in_reward_zone_steps += 1 280 | 281 | return reward, done 282 | -------------------------------------------------------------------------------- /language_table/environments/rewards/constants.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Language Tale Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Define constants shared across all 2d board environment rewards.""" 17 | TARGET_BLOCK_DISTANCE = 0.05 18 | -------------------------------------------------------------------------------- /language_table/environments/rewards/instructions.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Language Tale Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Utilities for getting all instructions from the different 2d board envs.""" 17 | from language_table.environments.rewards import block2absolutelocation 18 | from language_table.environments.rewards import block2block 19 | from language_table.environments.rewards import block2block_relative_location 20 | from language_table.environments.rewards import block2relativelocation 21 | from language_table.environments.rewards import point2block 22 | from language_table.environments.rewards import separate_blocks 23 | 24 | 25 | CLIP_VOCAB_SIZE = 49408 26 | 27 | 28 | def generate_all_instructions(block_mode): 29 | """Gets all instructions across all environments.""" 30 | return (block2block.generate_all_instructions(block_mode) + 31 | point2block.generate_all_instructions(block_mode) + 32 | block2relativelocation.generate_all_instructions(block_mode) + 33 | block2absolutelocation.generate_all_instructions(block_mode) + 34 | block2block_relative_location.generate_all_instructions(block_mode) + 35 | separate_blocks.generate_all_instructions(block_mode)) 36 | 37 | 38 | def vocab_size(block_mode): 39 | words = set() 40 | for instruction in generate_all_instructions(block_mode): 41 | for word in instruction.split(' '): 42 | words.add(word) 43 | return len(words) 44 | -------------------------------------------------------------------------------- /language_table/environments/rewards/instructions_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Language Tale Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for instructions.""" 17 | 18 | from language_table.environments import blocks 19 | from language_table.environments.rewards import instructions 20 | import tensorflow as tf 21 | 22 | 23 | class InstructionsTest(tf.test.TestCase): 24 | 25 | def test_expected_instructions_generated(self): 26 | # This ensures that the same fixed number of instructions are generated 27 | # for each block mode. 28 | inst_block4 = instructions.generate_all_instructions( 29 | blocks.LanguageTableBlockVariants.BLOCK_4) 30 | self.assertLen(inst_block4, 12652) 31 | inst_block8 = instructions.generate_all_instructions( 32 | blocks.LanguageTableBlockVariants.BLOCK_8) 33 | self.assertLen(inst_block8, 30264) 34 | inst_n_choose_k = instructions.generate_all_instructions( 35 | blocks.LanguageTableBlockVariants.N_CHOOSE_K) 36 | self.assertLen(inst_n_choose_k, 80368) 37 | 38 | if __name__ == '__main__': 39 | tf.test.main() 40 | -------------------------------------------------------------------------------- /language_table/environments/rewards/point2block.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Language Tale Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Defines point2block reset and reward.""" 17 | from absl import logging 18 | from language_table.environments import blocks as blocks_module 19 | from language_table.environments.rewards import constants 20 | from language_table.environments.rewards import reward as base_reward 21 | from language_table.environments.rewards import synonyms 22 | from language_table.environments.rewards import task_info 23 | import numpy as np 24 | 25 | 26 | def generate_all_instructions(block_mode): 27 | all_instructions = [] 28 | all_block_text_descriptions = blocks_module.get_blocks_text_descriptions( 29 | block_mode) 30 | for block_text in all_block_text_descriptions: 31 | for preposition in synonyms.POINT_PREPOSITIONS: 32 | inst = f'{preposition} {block_text}' 33 | all_instructions.append(inst) 34 | return all_instructions 35 | 36 | 37 | class PointToBlockReward(base_reward.LanguageTableReward): 38 | """Calculates reward/instructions for simple variants of 2d board tasks.""" 39 | 40 | def _sample_instruction(self, block, blocks_on_table): 41 | """Randomly sample a task involving two objects.""" 42 | # Get some synonym for block. 43 | block_text = self._rng.choice( 44 | synonyms.get_block_synonyms(block, blocks_on_table)) 45 | preposition = self._rng.choice(synonyms.POINT_PREPOSITIONS) 46 | return f'{preposition} {block_text}' 47 | 48 | def reset(self, state, blocks_on_table): 49 | """Resets to new pointing task.""" 50 | # pick two objects sufficiently far away and get their poses. 51 | # track start object and target object poses. 52 | max_attempts = 10 53 | num_attempts = 0 54 | while True: 55 | block = self._sample_object(blocks_on_table) 56 | start_translation, _ = self._get_pose_for_block( 57 | block, state) 58 | dist = np.linalg.norm( 59 | np.array(start_translation) - 60 | np.array(state['effector_target_translation'])) 61 | if dist < constants.TARGET_BLOCK_DISTANCE + 0.01: 62 | num_attempts += 1 63 | if num_attempts > max_attempts: 64 | logging.info( 65 | 'Exceeded max number of attempts to find start/target blocks. ' 66 | 'No valid reward found for the current object configuration.') 67 | return task_info.FAILURE 68 | continue 69 | else: 70 | self._block = block 71 | break 72 | self._instruction = self._sample_instruction(block, blocks_on_table) 73 | self._in_reward_zone_steps = 0 74 | return task_info.Point2BlockTaskInfo( 75 | instruction=self._instruction, 76 | block_target=self._block) 77 | 78 | def reward(self, state): 79 | """Calculates reward given state.""" 80 | # For now only have sparse reward. 81 | object_translation, _ = self._get_pose_for_block(self._block, state) 82 | 83 | # This check ignore whether start block was moved (rather than target object 84 | # being moved towards start object. 85 | # TODO(ayzaan): Add smarter logic here. 86 | dist = np.linalg.norm( 87 | np.array(object_translation) - 88 | np.array(state['effector_target_translation'])) 89 | reward = 0.0 90 | done = False 91 | if dist < constants.TARGET_BLOCK_DISTANCE: 92 | if self._in_reward_zone_steps >= self._delay_reward_steps: 93 | reward = self._goal_reward 94 | done = True 95 | else: 96 | logging.info('In reward zone for %d steps', self._in_reward_zone_steps) 97 | self._in_reward_zone_steps += 1 98 | return reward, done 99 | -------------------------------------------------------------------------------- /language_table/environments/rewards/reward.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Language Tale Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """LanguageTable base reward class.""" 17 | from typing import Any, List, Tuple 18 | from language_table.environments import blocks as blocks_module 19 | from language_table.environments import constants 20 | from language_table.environments.rewards import synonyms 21 | import numpy as np 22 | 23 | 24 | class LanguageTableReward(object): 25 | """Base class for all 2d board rewards.""" 26 | 27 | def __init__(self, goal_reward, rng, delay_reward_steps, 28 | block_mode): 29 | self._block_mode = block_mode 30 | self._goal_reward = goal_reward 31 | self._rng = rng 32 | # TODO(tding): Handle this in all rewards 33 | self._delay_reward_steps = delay_reward_steps 34 | self._in_reward_zone_steps = None 35 | 36 | self._target_translation = None 37 | 38 | def seed(self, rng): 39 | self._rng = rng 40 | 41 | def get_goal_region(self): 42 | """Returns the (target translation, radius) tuple.""" 43 | return None, None 44 | 45 | def _get_block_synonym(self, block, blocks_on_table): 46 | return self._rng.choice(synonyms.get_block_synonyms(block, blocks_on_table)) 47 | 48 | def _get_pose_for_block(self, block, state): 49 | state_translation_key = 'block_%s_translation' % block 50 | state_rotation_key = 'block_%s_orientation' % block 51 | return state[state_translation_key], state[state_rotation_key] 52 | 53 | def _get_translation_for_block(self, block, state): 54 | return np.array(self._get_pose_for_block(block, state)[0]) 55 | 56 | def _sample_object(self, blocks_on_table): 57 | """Choose one of the blocks randomly.""" 58 | block = self._rng.choice(blocks_on_table) 59 | return block 60 | 61 | def _sample_objects(self, blocks_on_table): 62 | """Randomly sample two objects.""" 63 | start_block, target_block = self._rng.choice( 64 | blocks_on_table, 2, replace=False) 65 | return start_block, target_block 66 | 67 | 68 | def target_inside_bounds( 69 | target, buffer=constants.WORKSPACE_BOUNDS_BUFFER): 70 | target_x, target_y = target 71 | return (target_x > constants.X_MIN + buffer and 72 | target_x < constants.X_MAX - buffer and 73 | target_y > constants.Y_MIN + buffer and 74 | target_y < constants.Y_MAX - buffer) 75 | -------------------------------------------------------------------------------- /language_table/environments/rewards/synonyms.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Language Tale Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Defines synonyms for the structured language 2d board environment.""" 17 | import collections 18 | 19 | 20 | def get_block_synonyms(block, blocks_on_table): 21 | """Get synonyms for blocks based on what is on table.""" 22 | color, shape = get_color_shape(block) 23 | color_counts, shape_counts = count_color_shape(blocks_on_table) 24 | synonyms = [] 25 | if color_counts[color] == 1: 26 | # There is only one 'red' block, so feel free to refer to it as 27 | # 'red block'. 28 | synonyms.append('%s block' % color) 29 | if shape_counts[shape] == 1: 30 | # There is only one 'star' block, so feel free to refer to it as 31 | # 'star'. 32 | synonyms.append(shape) 33 | # (color, shape) is always unique. 34 | synonyms.append('%s %s' % (color, shape)) 35 | return synonyms 36 | 37 | 38 | def count_color_shape(blocks_on_table): 39 | colors, shapes = zip(*[get_color_shape(i) for i in blocks_on_table]) 40 | color_counts = collections.Counter(colors) 41 | shape_counts = collections.Counter(shapes) 42 | return color_counts, shape_counts 43 | 44 | 45 | def get_color_shape(block): 46 | color, shape = block.split('_') 47 | return color, shape 48 | 49 | 50 | PUSH_VERBS = [ 51 | 'push the', 52 | 'move the', 53 | 'slide the', 54 | 'put the', 55 | ] 56 | 57 | PREPOSITIONS = [ 58 | 'to the', 59 | 'towards the', 60 | 'close to the', 61 | 'next to the', 62 | ] 63 | 64 | POINT_PREPOSITIONS = [ 65 | 'point next to the', 66 | 'point close to the', 67 | 'point to the', 68 | 'point at the', 69 | 'move the arm next to the', 70 | 'move the arm close to the', 71 | 'move the arm to the', 72 | 'move your arm next to the', 73 | 'move your arm close to the', 74 | 'move your arm to the', 75 | 'move next to the', 76 | 'move close to the', 77 | 'move to the', 78 | ] 79 | -------------------------------------------------------------------------------- /language_table/environments/rewards/task_info.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Language Tale Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Data classes holding info returned to environment for each reset.""" 17 | 18 | import dataclasses 19 | from typing import List 20 | 21 | import numpy as np 22 | 23 | 24 | @dataclasses.dataclass 25 | class Block2BlockTaskInfo: 26 | """Data class defining a chosen block2block task after reset.""" 27 | instruction: str 28 | block1: str 29 | block2: str 30 | 31 | 32 | @dataclasses.dataclass 33 | class Block2LocationTaskInfo: 34 | """Class defining a chosen block2block task after reset.""" 35 | instruction: str 36 | block: str 37 | target_translation: np.ndarray 38 | location: str 39 | 40 | 41 | @dataclasses.dataclass 42 | class Block2LineTaskInfo: 43 | """Class defining a chosen block2block task after reset.""" 44 | instruction: str 45 | block: str 46 | target_translation: np.ndarray 47 | 48 | 49 | @dataclasses.dataclass 50 | class Block2PoleTaskInfo: 51 | """Data class defining a chosen block2pole task after reset.""" 52 | instruction: str 53 | block1: str 54 | goal: str 55 | 56 | 57 | @dataclasses.dataclass 58 | class Block2RelativeLocationTaskInfo: 59 | """Class defining a chosen block2block task after reset.""" 60 | instruction: str 61 | block: str 62 | target_translation: np.ndarray 63 | location: str 64 | 65 | 66 | @dataclasses.dataclass 67 | class Block2BlockRelativeLocationTaskInfo: 68 | """Class defining a chosen block2block task after reset.""" 69 | instruction: str 70 | block: str 71 | target_block: str 72 | direction: str 73 | target_translation: np.ndarray 74 | 75 | 76 | @dataclasses.dataclass 77 | class SeparateBlocksTaskInfo: 78 | """Class defining a chosen "separate blocks" task after reset.""" 79 | instruction: str 80 | block: str 81 | avoid_blocks: List[str] 82 | target_translation: np.ndarray 83 | 84 | 85 | @dataclasses.dataclass 86 | class Point2BlockTaskInfo: 87 | """Data class defining a chosen point2block task after reset.""" 88 | instruction: str 89 | block_target: str 90 | 91 | 92 | ALL_TASKS = [ 93 | Block2BlockTaskInfo, 94 | Block2LocationTaskInfo, 95 | Block2RelativeLocationTaskInfo, 96 | Block2BlockRelativeLocationTaskInfo, 97 | SeparateBlocksTaskInfo, 98 | Point2BlockTaskInfo, 99 | Block2LineTaskInfo, 100 | Block2PoleTaskInfo, 101 | ] 102 | 103 | # Return this if cannot create a valid board state and need to reset. 104 | FAILURE = 'failure' 105 | -------------------------------------------------------------------------------- /language_table/environments/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Language Tale Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /language_table/environments/utils/pose3d.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Language Tale Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """A simple 6DOF pose container. 17 | """ 18 | 19 | import dataclasses 20 | import numpy as np 21 | from scipy.spatial import transform 22 | 23 | 24 | class NoCopyAsDict(object): 25 | """Base class for dataclasses. Avoids a copy in the asdict() call.""" 26 | 27 | def asdict(self): 28 | """Replacement for dataclasses.asdict. 29 | 30 | TF Dataset does not handle dataclasses.asdict, which uses copy.deepcopy when 31 | setting values in the output dict. This causes issues with tf.Dataset. 32 | Instead, shallow copy contents. 33 | 34 | Returns: 35 | dict containing contents of dataclass. 36 | """ 37 | return {k.name: getattr(self, k.name) for k in dataclasses.fields(self)} # pytype: disable=wrong-arg-types # re-none 38 | 39 | 40 | @dataclasses.dataclass 41 | class Pose3d(NoCopyAsDict): 42 | """Simple container for translation and rotation.""" 43 | 44 | rotation: transform.Rotation 45 | translation: np.ndarray 46 | 47 | @property 48 | def vec7(self): 49 | return np.concatenate([self.translation, self.rotation.as_quat()]) 50 | 51 | def serialize(self): 52 | return {'rotation': self.rotation.as_quat().tolist(), 53 | 'translation': self.translation.tolist()} 54 | 55 | @staticmethod 56 | def deserialize(data): 57 | return Pose3d(rotation=transform.Rotation.from_quat(data['rotation']), 58 | translation=np.array(data['translation'])) 59 | 60 | def __eq__(self, other): 61 | return (np.array_equal(self.rotation.as_quat(), 62 | other.rotation.as_quat()) and 63 | np.array_equal(self.translation, 64 | other.translation)) 65 | 66 | def __ne__(self, other): 67 | return not self.__eq__(other) 68 | -------------------------------------------------------------------------------- /language_table/environments/utils/xarm_sim_robot.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Language Tale Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """XArm Robot Kinematics.""" 17 | from language_table.environments.utils import utils_pybullet 18 | from language_table.environments.utils.pose3d import Pose3d 19 | import numpy as np 20 | from scipy.spatial import transform 21 | import pybullet 22 | 23 | 24 | XARM_URDF_PATH = ('third_party/bullet/examples/pybullet/gym/pybullet_data/' 25 | 'xarm/xarm6_robot.urdf') 26 | XARM_WHITE_URDF_PATH = ('third_party/bullet/examples/pybullet/gym/' 27 | 'pybullet_data/xarm/xarm6_robot_white.urdf') 28 | SUCTION_URDF_PATH = ( 29 | 'third_party/py/language_table/environments/assets/suction/' 30 | 'suction-head-long.urdf') 31 | CYLINDER_URDF_PATH = ( 32 | 'third_party/py/language_table/environments/assets/suction/' 33 | 'cylinder.urdf') 34 | CYLINDER_REAL_URDF_PATH = ( 35 | 'third_party/py/language_table/environments/assets/suction/' 36 | 'cylinder_real.urdf') 37 | HOME_JOINT_POSITIONS = np.deg2rad([0, -20, -80, 0, 100, -30]) 38 | 39 | 40 | class XArmSimRobot(): 41 | """A simulated PyBullet XArm robot, mostly for forward/inverse kinematics.""" 42 | 43 | def __init__(self, 44 | pybullet_client, 45 | initial_joint_positions = HOME_JOINT_POSITIONS, 46 | end_effector = 'none', 47 | color = 'default'): 48 | self._pybullet_client = pybullet_client 49 | self.initial_joint_positions = initial_joint_positions 50 | 51 | if color == 'default': 52 | self.xarm = utils_pybullet.load_urdf(pybullet_client, XARM_URDF_PATH, 53 | [0, 0, 0]) 54 | elif color == 'white': 55 | self.xarm = utils_pybullet.load_urdf(pybullet_client, 56 | XARM_WHITE_URDF_PATH, [0, 0, 0]) 57 | else: 58 | raise ValueError('Unrecognized xarm color %s' % color) 59 | 60 | # Get revolute joints of robot (skip fixed joints). 61 | joints = [] 62 | joint_indices = [] 63 | for i in range(self._pybullet_client.getNumJoints(self.xarm)): 64 | joint_info = self._pybullet_client.getJointInfo(self.xarm, i) 65 | if joint_info[2] == pybullet.JOINT_REVOLUTE: 66 | joints.append(joint_info[0]) 67 | joint_indices.append(i) 68 | # Note examples in pybullet do this, but it is not clear what the 69 | # benefits are. 70 | self._pybullet_client.changeDynamics( 71 | self.xarm, i, linearDamping=0, angularDamping=0) 72 | 73 | self._n_joints = len(joints) 74 | self._joints = tuple(joints) 75 | self._joint_indices = tuple(joint_indices) 76 | 77 | # Move robot to home joint configuration 78 | self.reset_joints(self.initial_joint_positions) 79 | self.set_target_joint_positions(self.initial_joint_positions) 80 | self.effector_link = 6 81 | 82 | if (end_effector == 'suction' 83 | or end_effector == 'cylinder' 84 | or end_effector == 'cylinder_real'): 85 | self.end_effector = self._setup_end_effector(end_effector) 86 | else: 87 | if end_effector != 'none': 88 | raise ValueError('end_effector "%s" is not supported.' % end_effector) 89 | self.end_effector = None 90 | 91 | def _setup_end_effector(self, end_effector): 92 | """Adds a suction or cylinder end effector.""" 93 | pose = self.forward_kinematics() 94 | if end_effector == 'suction': 95 | body = utils_pybullet.load_urdf(self._pybullet_client, SUCTION_URDF_PATH, 96 | pose.translation, pose.rotation.as_quat()) 97 | elif end_effector == 'cylinder': 98 | body = utils_pybullet.load_urdf(self._pybullet_client, CYLINDER_URDF_PATH, 99 | pose.translation, pose.rotation.as_quat()) 100 | elif end_effector == 'cylinder_real': 101 | body = utils_pybullet.load_urdf(self._pybullet_client, 102 | CYLINDER_REAL_URDF_PATH, pose.translation, 103 | pose.rotation.as_quat()) 104 | else: 105 | raise ValueError('end_effector "%s" is not supported.' % end_effector) 106 | 107 | constraint_id = self._pybullet_client.createConstraint( 108 | parentBodyUniqueId=self.xarm, 109 | parentLinkIndex=6, 110 | childBodyUniqueId=body, 111 | childLinkIndex=-1, 112 | jointType=pybullet.JOINT_FIXED, 113 | jointAxis=(0, 0, 0), 114 | parentFramePosition=(0, 0, 0), 115 | childFramePosition=(0, 0, 0)) 116 | self._pybullet_client.changeConstraint(constraint_id, maxForce=1000) 117 | 118 | return body 119 | 120 | def reset_joints(self, joint_values): 121 | """Sets the position of the Robot's joints. 122 | 123 | *Note*: This should only be used at the start while not running the 124 | simulation resetJointState overrides all physics simulation. 125 | 126 | Args: 127 | joint_values: Iterable with desired joint positions. 128 | """ 129 | for i in range(self._n_joints): 130 | self._pybullet_client.resetJointState(self.xarm, self._joints[i], 131 | joint_values[i]) 132 | 133 | def get_joints_measured(self): 134 | joint_states = self._pybullet_client.getJointStates(self.xarm, 135 | self._joint_indices) 136 | joint_positions = np.array([state[0] for state in joint_states]) 137 | joint_velocities = np.array([state[1] for state in joint_states]) 138 | joint_torques = np.array([state[3] for state in joint_states]) 139 | return joint_positions, joint_velocities, joint_torques 140 | 141 | def get_joint_positions(self): 142 | joint_states = self._pybullet_client.getJointStates(self.xarm, 143 | self._joint_indices) 144 | joint_positions = np.array([state[0] for state in joint_states]) 145 | return joint_positions 146 | 147 | def forward_kinematics(self): 148 | """Forward kinematics.""" 149 | effector_state = self._pybullet_client.getLinkState( 150 | self.xarm, self.effector_link, computeForwardKinematics=1) 151 | return Pose3d(translation=np.array(effector_state[0]), 152 | rotation=transform.Rotation.from_quat(effector_state[1])) 153 | 154 | def inverse_kinematics(self, 155 | world_effector_pose, 156 | max_iterations=100, 157 | residual_threshold=1e-10): 158 | """Inverse kinematics. 159 | 160 | Args: 161 | world_effector_pose: Target Pose3d for the robot's end effector. 162 | max_iterations: Refine the IK solution until the distance between target 163 | and actual end effector position is below this threshold, or the 164 | maxNumIterations is reached. Default is 20 iterations. 165 | residual_threshold: Refine the IK solution until the distance between 166 | target and actual end effector position is below this threshold, or the 167 | maxNumIterations is reached. 168 | 169 | Returns: 170 | Numpy array with required joint angles to reach the requested pose. 171 | """ 172 | return np.array( 173 | self._pybullet_client.calculateInverseKinematics( 174 | self.xarm, 175 | self.effector_link, 176 | world_effector_pose.translation, 177 | world_effector_pose.rotation.as_quat(), # as_quat returns xyzw. 178 | # TODO(peteflorence): use real limits 179 | lowerLimits=[-17] * 6, 180 | upperLimits=[17] * 6, 181 | jointRanges=[17] * 6, 182 | # TODO(oars): Understand why examples don't use actual positions for 183 | # the first two joints. Taken from 184 | # `pybullet/gym/pybullet_robots/xarm/xarm_sim.py` 185 | restPoses=[0, 0] + self.get_joint_positions()[2:].tolist(), 186 | maxNumIterations=max_iterations, 187 | residualThreshold=residual_threshold)) 188 | 189 | def set_target_effector_pose(self, world_effector_pose): 190 | target_joint_positions = self.inverse_kinematics(world_effector_pose) 191 | self.set_target_joint_positions(target_joint_positions) 192 | 193 | def set_target_joint_velocities(self, target_joint_velocities): 194 | self._pybullet_client.setJointMotorControlArray( 195 | self.xarm, 196 | self._joint_indices, 197 | pybullet.VELOCITY_CONTROL, 198 | targetVelocities=target_joint_velocities, 199 | forces=[5 * 240.] * 6) 200 | 201 | def set_target_joint_positions(self, target_joint_positions): 202 | # TODO(oars, peteflorence): Probably should adjust gains to get reasonable 203 | # speeds. It's moving kinda fast. 204 | self._pybullet_client.setJointMotorControlArray( 205 | self.xarm, 206 | self._joint_indices, 207 | pybullet.POSITION_CONTROL, 208 | targetPositions=target_joint_positions, 209 | forces=[5 * 240.] * 6) 210 | 211 | def set_alpha_transparency(self, alpha): 212 | visual_shape_data = self._pybullet_client.getVisualShapeData(self.xarm) 213 | 214 | for i in range(self._pybullet_client.getNumJoints(self.xarm)): 215 | object_id, link_index, _, _, _, _, _, rgba_color = visual_shape_data[i] 216 | assert object_id == self.xarm, 'xarm id mismatch.' 217 | assert link_index == i, 'Link visual data was returned out of order.' 218 | rgba_color = list(rgba_color[0:3]) + [alpha] 219 | self._pybullet_client.changeVisualShape( 220 | self.xarm, linkIndex=i, rgbaColor=rgba_color) 221 | -------------------------------------------------------------------------------- /language_table/environments/utils/xarm_sim_robot_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Language Tale Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for ibc.environments.xarm_sim_robot.""" 17 | import math 18 | 19 | from language_table.environments.utils import xarm_sim_robot 20 | from language_table.environments.utils.pose3d import Pose3d 21 | import numpy as np 22 | from scipy.spatial import transform 23 | import tensorflow.compat.v2 as tf 24 | import pybullet 25 | import pybullet_utils.bullet_client as bullet_client 26 | 27 | 28 | class XArmSimRobotTest(tf.test.TestCase): 29 | 30 | def setUp(self): 31 | super(XArmSimRobotTest, self).setUp() 32 | 33 | # To debug we can use the SHARED_MEMORY connection. 34 | # pybullet.connect(pybullet.SHARED_MEMORY) 35 | connection_mode = pybullet.SHARED_MEMORY 36 | connection_mode = pybullet.DIRECT 37 | self._pybullet_client = bullet_client.BulletClient(connection_mode) 38 | self._pybullet_client.resetSimulation() 39 | self._pybullet_client.configureDebugVisualizer(pybullet.COV_ENABLE_GUI, 0) 40 | self._pybullet_client.setPhysicsEngineParameter(enableFileCaching=0) 41 | 42 | def test_arm_loads(self): 43 | xarm_sim_robot.XArmSimRobot(self._pybullet_client) 44 | 45 | def test_arm_loads_suction(self): 46 | xarm_sim_robot.XArmSimRobot(self._pybullet_client, end_effector='suction') 47 | 48 | def test_forward_kinematics(self): 49 | robot = xarm_sim_robot.XArmSimRobot(self._pybullet_client) 50 | 51 | # Pointing down X Axis 52 | robot.reset_joints([0, math.pi / 2, math.pi, 0, 0, 0]) 53 | x, y, _ = robot.forward_kinematics().translation 54 | 55 | self.assertAlmostEqual(0.714479, x, places=3) 56 | self.assertAlmostEqual(-0.0006, y, places=3) 57 | 58 | # Pointing down Y Axis 59 | robot.reset_joints([math.pi / 2, math.pi / 2, math.pi, 0, 0, 0]) 60 | x, y, _ = robot.forward_kinematics().translation 61 | 62 | self.assertAlmostEqual(0.0006, x, places=3) 63 | self.assertAlmostEqual(0.714479, y, places=3) 64 | 65 | def test_inverse_kinematics(self): 66 | robot = xarm_sim_robot.XArmSimRobot(self._pybullet_client) 67 | initial_pose = robot.forward_kinematics() 68 | 69 | rotation = transform.Rotation.from_rotvec([0, math.pi / 2, 0]) 70 | translation = np.array([0.5, 0.0, 0.10]) 71 | target_pose = Pose3d(rotation=rotation, translation=translation) 72 | 73 | robot.reset_joints(robot.inverse_kinematics(target_pose)) 74 | pose = robot.forward_kinematics() 75 | 76 | self.assertFalse(np.all(initial_pose.vec7 == pose.vec7)) 77 | np.testing.assert_almost_equal(pose.vec7, target_pose.vec7, decimal=2) 78 | 79 | 80 | if __name__ == '__main__': 81 | tf.test.main() 82 | -------------------------------------------------------------------------------- /language_table/eval/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Language Tale Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /language_table/eval/main.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Language Tale Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Simple offline evaluation script for language table sim.""" 17 | 18 | import collections 19 | from collections.abc import Sequence 20 | import os 21 | 22 | from absl import app 23 | from absl import flags 24 | from absl import logging 25 | 26 | import jax 27 | 28 | from language_table.environments import blocks 29 | from language_table.environments import language_table 30 | from language_table.environments.oracles import push_oracle_rrt_slowdown 31 | from language_table.environments.rewards import block2absolutelocation 32 | from language_table.environments.rewards import block2block 33 | from language_table.environments.rewards import block2block_relative_location 34 | from language_table.environments.rewards import block2relativelocation 35 | from language_table.environments.rewards import separate_blocks 36 | from language_table.eval import wrappers as env_wrappers 37 | from language_table.train import policy as jax_policy 38 | from language_table.train.networks import lava 39 | 40 | import mediapy as mediapy_lib 41 | from ml_collections import config_flags 42 | 43 | import tensorflow as tf 44 | from tf_agents.environments import gym_wrapper 45 | from tf_agents.environments import wrappers as tfa_wrappers 46 | 47 | _CONFIG = config_flags.DEFINE_config_file( 48 | "config", None, "Training configuration.", lock_config=True) 49 | _WORKDIR = flags.DEFINE_string("workdir", None, "Evaluation result directory.") 50 | _CHECKPOINT_PATH = flags.DEFINE_string("checkpoint_path", None, 51 | "FLAX checkpoint path.") 52 | 53 | 54 | def evaluate_checkpoint(checkpoint_path, workdir, config): 55 | """Evaluates the given checkpoint and writes results to workdir.""" 56 | video_dir = os.path.join(workdir, "videos") 57 | if not tf.io.gfile.exists(video_dir): 58 | tf.io.gfile.makedirs(video_dir) 59 | rewards = { 60 | "blocktoblock": 61 | block2block.BlockToBlockReward, 62 | "blocktoabsolutelocation": 63 | block2absolutelocation.BlockToAbsoluteLocationReward, 64 | "blocktoblockrelativelocation": 65 | block2block_relative_location.BlockToBlockRelativeLocationReward, 66 | "blocktorelativelocation": 67 | block2relativelocation.BlockToRelativeLocationReward, 68 | "separate": 69 | separate_blocks.SeparateBlocksReward, 70 | } 71 | 72 | num_evals_per_reward = 50 73 | max_episode_steps = 200 74 | 75 | policy = None 76 | model = lava.SequenceLAVMSE(action_size=2, **config.model) 77 | 78 | results = collections.defaultdict(lambda: 0) 79 | for reward_name, reward_factory in rewards.items(): 80 | env = language_table.LanguageTable( 81 | block_mode=blocks.LanguageTableBlockVariants.BLOCK_8, 82 | reward_factory=reward_factory, 83 | seed=0) 84 | env = gym_wrapper.GymWrapper(env) 85 | env = env_wrappers.ClipTokenWrapper(env) 86 | env = env_wrappers.CentralCropImageWrapper( 87 | env, 88 | target_width=config.data_target_width, 89 | target_height=config.data_target_height, 90 | random_crop_factor=config.random_crop_factor) 91 | env = tfa_wrappers.HistoryWrapper( 92 | env, history_length=config.sequence_length, tile_first_step_obs=True) 93 | 94 | if policy is None: 95 | policy = jax_policy.BCJaxPyPolicy( 96 | env.time_step_spec(), 97 | env.action_spec(), 98 | model=model, 99 | checkpoint_path=checkpoint_path, 100 | rng=jax.random.PRNGKey(0)) 101 | 102 | for ep_num in range(num_evals_per_reward): 103 | # Reset env. Choose new init if oracle cannot find valid motion plan. 104 | # Get an oracle. We use this at the moment to decide whether an 105 | # environment initialization is valid. If oracle can motion plan, 106 | # init is valid. 107 | oracle_policy = push_oracle_rrt_slowdown.ObstacleOrientedPushOracleBoard2dRRT( 108 | env, use_ee_planner=True) 109 | plan_success = False 110 | while not plan_success: 111 | ts = env.reset() 112 | raw_state = env.compute_state() 113 | plan_success = oracle_policy.get_plan(raw_state) 114 | if not plan_success: 115 | logging.info( 116 | "Resetting environment because the " 117 | "initialization was invalid (could not find motion plan).") 118 | 119 | frames = [env.render()] 120 | 121 | episode_steps = 0 122 | while not ts.is_last(): 123 | policy_step = policy.action(ts, ()) 124 | ts = env.step(policy_step.action) 125 | frames.append(env.render()) 126 | episode_steps += 1 127 | 128 | if episode_steps > max_episode_steps: 129 | break 130 | 131 | success_str = "" 132 | if env.succeeded: 133 | results[reward_name] += 1 134 | logging.info("Episode %d: success.", ep_num) 135 | success_str = "success" 136 | else: 137 | logging.info("Episode %d: failure.", ep_num) 138 | success_str = "failure" 139 | 140 | # Write out video of rollout. 141 | video_path = os.path.join(workdir, "videos/", 142 | f"{reward_name}_{ep_num}_{success_str}.mp4") 143 | mediapy_lib.write_video(video_path, frames, fps=10) 144 | 145 | print(results) 146 | 147 | 148 | def main(argv): 149 | if len(argv) > 1: 150 | raise app.UsageError("Too many command-line arguments.") 151 | 152 | evaluate_checkpoint( 153 | checkpoint_path=_CHECKPOINT_PATH.value, 154 | workdir=_WORKDIR.value, 155 | config=_CONFIG.value, 156 | ) 157 | 158 | 159 | if __name__ == "__main__": 160 | app.run(main) 161 | -------------------------------------------------------------------------------- /language_table/eval/wrappers.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Language Tale Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Environment wrappers.""" 17 | 18 | from typing import Any, Optional 19 | 20 | from language_table.common import clip_tokenizer 21 | import numpy as np 22 | import tensorflow as tf 23 | from tf_agents.environments import wrappers 24 | 25 | 26 | class ClipTokenWrapper(wrappers.PyEnvironmentBaseWrapper): 27 | """Environment wrapper that adds CLIP tokens to the obs.""" 28 | 29 | def __init__(self, env, context_length = 77): 30 | """Centrally crops an image from a dict observation.""" 31 | super(ClipTokenWrapper, self).__init__(env) 32 | self._context_length = context_length 33 | vocab_lookup = clip_tokenizer.create_vocab() 34 | self._tokenizer = clip_tokenizer.ClipTokenizer(vocab_lookup) 35 | self._current_tokens = None 36 | 37 | def _reset(self): 38 | time_step = self._env.reset() 39 | self._current_tokens = self._tokenize(time_step.observation['instruction']) 40 | new_obs = time_step.observation 41 | new_obs['instruction_tokenized_clip'] = self._current_tokens 42 | return time_step._replace(observation=new_obs) 43 | 44 | def _step(self, action): 45 | time_step = self._env.step(action) 46 | new_obs = time_step.observation 47 | new_obs['instruction_tokenized_clip'] = self._current_tokens 48 | return time_step._replace(observation=new_obs) 49 | 50 | def _tokenize(self, instruction): 51 | bytes_list = instruction 52 | non_zero = bytes_list[np.where(bytes_list != 0)] 53 | if non_zero.shape[0] == 0: 54 | decoded = '' 55 | else: 56 | bytes_list = bytes(non_zero.tolist()) 57 | decoded = bytes_list.decode('utf-8') 58 | tokens = clip_tokenizer.tokenize_text(decoded, self._tokenizer)[0] 59 | return tokens 60 | 61 | 62 | class CentralCropImageWrapper(wrappers.PyEnvironmentBaseWrapper): 63 | """Environment wrapper that crops image observations.""" 64 | 65 | def __init__(self, 66 | env, 67 | target_height, 68 | target_width, 69 | random_crop_factor = None): 70 | """Centrally crops an image from a dict observation.""" 71 | super(CentralCropImageWrapper, self).__init__(env) 72 | self._target_height = target_height 73 | self._target_width = target_width 74 | self._random_crop_factor = random_crop_factor 75 | 76 | def _reset(self): 77 | time_step = self._env.reset() 78 | new_obs = self._crop_observation(time_step.observation) 79 | return time_step._replace(observation=new_obs) 80 | 81 | def _step(self, action): 82 | time_step = self._env.step(action) 83 | new_obs = self._crop_observation(time_step.observation) 84 | return time_step._replace(observation=new_obs) 85 | 86 | def _crop_observation(self, obs): 87 | new_obs = obs 88 | image = obs['rgb'] 89 | image = tf.image.convert_image_dtype(image, dtype=tf.float32) 90 | # Apply average crop augmentation. 91 | image = crop_test_image(image, self._random_crop_factor) 92 | image = resize_images(image, self._target_height, self._target_width) 93 | new_obs['rgb'] = image.numpy() 94 | return new_obs 95 | 96 | 97 | def crop_test_image(images, random_crop_factor): 98 | """Get the average crop applied during crop training augmentation.""" 99 | 100 | def take_center_crop_consistent_with_random(im): 101 | im_raw_size = tf.shape(im) 102 | raw_height = tf.cast(im_raw_size[0], tf.float32) 103 | raw_width = tf.cast(im_raw_size[1], tf.float32) 104 | scaled_height = raw_height * random_crop_factor 105 | scaled_width = raw_width * random_crop_factor 106 | offset_height = tf.cast((raw_height - scaled_height) // 2, tf.int32) 107 | offset_width = tf.cast((raw_width - scaled_width) // 2, tf.int32) 108 | target_height = tf.cast(scaled_height, tf.int32) 109 | target_width = tf.cast(scaled_width, tf.int32) 110 | im = tf.image.crop_to_bounding_box( 111 | im, 112 | offset_height=offset_height, 113 | offset_width=offset_width, 114 | target_height=target_height, 115 | target_width=target_width) 116 | return im 117 | 118 | if len(images.shape) == 3: 119 | return take_center_crop_consistent_with_random(images) 120 | images = tf.map_fn(take_center_crop_consistent_with_random, images) 121 | return images 122 | 123 | 124 | def resize_images(images, target_height=None, target_width=None): 125 | """Resizes images to target_height, target_width.""" 126 | assert target_height 127 | assert target_width 128 | 129 | # Resize to target height and width. 130 | def _resize(im): 131 | return tf.image.resize(im, [target_height, target_width]) 132 | 133 | images = _resize(images) 134 | 135 | return images 136 | -------------------------------------------------------------------------------- /language_table/examples/dataset_example.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Language Tale Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Example for loading the Language-Table dataset. 17 | 18 | Language-Table data is in the [RLDS](https://github.com/google-research/rlds) 19 | format. 20 | See the [RLDS Tutorial](https://colab.research.google.com/github/ 21 | google-research/rlds/blob/main/rlds/examples/rlds_tutorial.ipynb) 22 | for more details on how to use RLDS datasets. 23 | """ 24 | 25 | from collections.abc import Sequence 26 | 27 | from absl import app 28 | 29 | import tensorflow_datasets as tfds 30 | 31 | dataset_paths = { 32 | 'language_table': 'gs://gresearch/robotics/language_table/0.0.1/', 33 | 'language_table_sim': 'gs://gresearch/robotics/language_table_sim/0.0.1/', 34 | } 35 | 36 | 37 | def main(argv): 38 | if len(argv) > 1: 39 | raise app.UsageError('Too many command-line arguments.') 40 | 41 | # Iterate through 5 items in language_table. 42 | builder = tfds.builder_from_directory(dataset_paths['language_table']) 43 | ds = builder.as_dataset(split='train') 44 | ds = ds.flat_map(lambda x: x['steps']) # get the dataset as individual steps 45 | for item in iter(ds.take(5)): 46 | print(item) 47 | 48 | # Iterate through 5 items in language_table_sim. 49 | builder = tfds.builder_from_directory(dataset_paths['language_table_sim']) 50 | ds = builder.as_dataset(split='train') 51 | ds = ds.flat_map(lambda x: x['steps']) # get the dataset as individual steps 52 | for item in iter(ds.take(5)): 53 | print(item) 54 | 55 | 56 | if __name__ == '__main__': 57 | app.run(main) 58 | -------------------------------------------------------------------------------- /language_table/examples/environment_example.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Language Tale Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Example for running the Language-Table environment.""" 17 | 18 | from collections.abc import Sequence 19 | 20 | from absl import app 21 | 22 | from language_table.environments import blocks 23 | from language_table.environments import language_table 24 | from language_table.environments.rewards import block2block 25 | 26 | from matplotlib import pyplot as plt 27 | 28 | 29 | def main(argv): 30 | if len(argv) > 1: 31 | raise app.UsageError('Too many command-line arguments.') 32 | 33 | env = language_table.LanguageTable( 34 | block_mode=blocks.LanguageTableBlockVariants.BLOCK_8, 35 | reward_factory=block2block.BlockToBlockReward, 36 | control_frequency=10.0, 37 | ) 38 | _ = env.reset() 39 | 40 | # Take a few random actions. 41 | for _ in range(5): 42 | env.step(env.action_space.sample()) 43 | 44 | # Save a rendered image. 45 | plt.imsave('/tmp/language_table_render.png', env.render()) 46 | 47 | 48 | if __name__ == '__main__': 49 | app.run(main) 50 | -------------------------------------------------------------------------------- /language_table/train/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Language Tale Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /language_table/train/bc.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Language Tale Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Behavioral Cloning Agent.""" 17 | 18 | from typing import Any, List, Optional 19 | 20 | from absl import logging 21 | 22 | from clu import checkpoint 23 | from clu import metrics 24 | from clu import parameter_overview 25 | 26 | import flax 27 | from flax import linen as nn 28 | import jax 29 | import jax.numpy as jnp 30 | import optax 31 | 32 | 33 | @flax.struct.dataclass 34 | class TrainState: 35 | step: int 36 | params: Any 37 | opt_state: optax.OptState 38 | batch_stats: Any 39 | norm_info: Any 40 | 41 | 42 | @flax.struct.dataclass 43 | class TrainMetrics(metrics.Collection): 44 | """Train metrics for the IBC Agent.""" 45 | 46 | learning_rate: metrics.LastValue.from_output("learning_rate") 47 | loss: metrics.Average.from_output("loss") 48 | loss_std: metrics.Std.from_output("loss") 49 | 50 | 51 | class BCAgent(object): 52 | """Behavioral Cloning Agent from ...""" 53 | 54 | def __init__(self, 55 | model, 56 | sequence_length, 57 | learning_rate, 58 | observation_statistics, 59 | action_statistics, 60 | action_min, 61 | action_max, 62 | pretrained_checkpoints = None, 63 | freeze_keys = None): 64 | """Creates the agent.""" 65 | self.model = model 66 | 67 | self.sequence_length = sequence_length 68 | 69 | self.pretrained_checkpoints = pretrained_checkpoints or [] 70 | self.freeze_keys = freeze_keys or [] 71 | 72 | self.optimizer = None 73 | 74 | self.learning_rate = learning_rate 75 | 76 | self.observation_statistics = observation_statistics 77 | self.action_statistics = action_statistics 78 | 79 | self.action_min = action_min 80 | self.action_max = action_max 81 | 82 | logging.info("action_min=%s,action_max=%s", action_min, action_max) 83 | 84 | def create_train_state(self, batch, rng): 85 | """Creates the train state and initial metrics for agent.""" 86 | obs_input = batch["observation"] 87 | 88 | rng, encoder_rng = jax.random.split(rng) 89 | variables = self.model.init(encoder_rng, obs_input, train=False) 90 | 91 | # Try to restore. 92 | flat_variables = flax.traverse_util.flatten_dict(variables, sep="/") 93 | 94 | for pretrained_checkpoint in self.pretrained_checkpoints: 95 | checkpoint_path, replacements = pretrained_checkpoint 96 | variable_dict_ckpt = checkpoint.load_state_dict(checkpoint_path) 97 | flat_variable_ckpt = flax.traverse_util.flatten_dict( 98 | variable_dict_ckpt, sep="/") 99 | for ckpt_variable_name, to_variable_name in replacements: 100 | for key in flat_variable_ckpt: 101 | if not key.startswith(ckpt_variable_name): 102 | continue 103 | variable_key = key.replace(ckpt_variable_name, to_variable_name) 104 | if variable_key in flat_variables: 105 | new_value = flat_variable_ckpt[key] 106 | flat_variables[variable_key] = new_value 107 | logging.info("Loading %s into %s: shape %s", key, variable_key, 108 | new_value.shape) 109 | 110 | variables = flax.traverse_util.unflatten_dict(flat_variables, sep="/") 111 | 112 | params = variables["params"] 113 | if variables.get("batch_stats"): 114 | batch_stats = variables["batch_stats"] 115 | else: 116 | batch_stats = {} 117 | 118 | # Optionally freeze variables. 119 | if self.freeze_keys: 120 | 121 | def _should_freeze(path): 122 | for freeze_key in self.freeze_keys: 123 | if freeze_key in path: 124 | logging.info("Freezing param: %s", path) 125 | return True 126 | logging.info("Not freezing param: %s", path) 127 | return False 128 | 129 | label_fn = flattened_traversal( 130 | lambda path, _: "zero" if _should_freeze("/".join(path)) else "adam") 131 | 132 | optimizer = optax.multi_transform( 133 | { 134 | "adam": optax.adam(learning_rate=self.learning_rate, eps=1e-7), 135 | "zero": optax.set_to_zero() 136 | }, label_fn) 137 | 138 | self.optimizer = optimizer 139 | else: 140 | self.optimizer = optax.adam(learning_rate=self.learning_rate, eps=1e-7) 141 | 142 | parameter_overview.log_parameter_overview(params) 143 | train_state = TrainState( 144 | step=0, 145 | params=params, 146 | opt_state=self.optimizer.init(params), 147 | batch_stats=batch_stats, 148 | norm_info={ 149 | "observation_statistics": 150 | dict( 151 | jax.tree_util.tree_map( 152 | jnp.asarray, 153 | self.observation_statistics, 154 | is_leaf=lambda x: isinstance(x, list))), 155 | "action_statistics": 156 | dict( 157 | jax.tree_util.tree_map( 158 | jnp.asarray, 159 | self.action_statistics, 160 | is_leaf=lambda x: isinstance(x, list))), 161 | "action_min": 162 | self.action_min, 163 | "action_max": 164 | self.action_max 165 | }) 166 | initial_metrics = TrainMetrics.single_from_model_output( 167 | loss=jnp.zeros((1,)), 168 | logits=jnp.zeros((1,)), 169 | negative_logits=jnp.zeros((1,)), 170 | learning_rate=jnp.zeros((1,))) 171 | return (train_state, initial_metrics) 172 | 173 | def train(self, batch, state, rng): 174 | """Performs a single training step.""" 175 | logging.info("train_step(batch=%s)", batch) 176 | 177 | rng, loss_rng = jax.random.split(rng) 178 | def loss_fn(params): 179 | variables = {"params": params, "batch_stats": state.batch_stats} 180 | per_example_loss, new_variables = self.bc_loss( 181 | self.model, batch=batch, variables=variables, rng=loss_rng) 182 | loss = jnp.mean(per_example_loss) 183 | return loss, new_variables["batch_stats"] 184 | 185 | grad_fn = jax.value_and_grad(loss_fn, has_aux=True) 186 | (loss, new_batch_stats), grad = grad_fn(state.params) 187 | 188 | # Compute average gradient across multiple workers. 189 | grad = jax.lax.pmean(grad, axis_name="batch") 190 | # Also get the average loss. 191 | loss = jax.lax.pmean(loss, axis_name="batch") 192 | updates, new_opt_state = self.optimizer.update(grad, state.opt_state, 193 | state.params) 194 | 195 | new_params = optax.apply_updates(state.params, updates) 196 | new_state = state.replace( # pytype: disable=attribute-error 197 | step=state.step + 1, 198 | params=flax.core.unfreeze(new_params), # pytype: disable=wrong-arg-types # numpy-scalars 199 | opt_state=flax.core.unfreeze(new_opt_state), 200 | batch_stats=flax.core.unfreeze(new_batch_stats)) 201 | 202 | metrics_update = TrainMetrics.gather_from_model_output( 203 | loss=loss, learning_rate=self.learning_rate) 204 | return new_state, metrics_update 205 | 206 | def bc_loss( 207 | self, 208 | model, 209 | batch, 210 | variables, 211 | rng, 212 | ): 213 | """Implements the BC loss.""" 214 | # Generate action counter examples. 215 | # Expand actions on dimension 1. 216 | observation = batch["observation"] 217 | action = batch["action"] 218 | 219 | # First, we encode the observations using the model.encode method. 220 | # This will give us an observation encoding (for the entire sequence). 221 | rng, params_rng = jax.random.split(rng) 222 | rng, dropout_rng = jax.random.split(rng) 223 | predicted_actions, new_variables = model.apply( 224 | variables, 225 | observation, 226 | train=True, 227 | mutable=["batch_stats"], 228 | rngs={ 229 | "params": params_rng, 230 | "dropout": dropout_rng 231 | }) 232 | 233 | per_example_loss = jnp.mean(jnp.square(predicted_actions - action)) 234 | return per_example_loss, new_variables 235 | 236 | 237 | def flattened_traversal(fn): 238 | """Returns function that is called with `(path, param)` instead of pytree.""" 239 | 240 | def mask(tree): 241 | flat = flax.traverse_util.flatten_dict(tree) 242 | return flax.traverse_util.unflatten_dict( 243 | {k: fn(k, v) for k, v in flat.items()}) 244 | 245 | return mask 246 | -------------------------------------------------------------------------------- /language_table/train/configs/language_table_resnet_sim_local.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Language Tale Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """A config for training sim human 8 block with Pretrained Resnet + BC.""" 17 | 18 | import ml_collections 19 | 20 | 21 | dataset_paths = { 22 | "language_table": "gs://gresearch/robotics/language_table/0.0.1/", 23 | "language_table_sim": "gs://gresearch/robotics/language_table_sim/0.0.1/", 24 | } 25 | 26 | 27 | def get_config(): 28 | """Config for training sim human 8 block with BC locally.""" 29 | config = ml_collections.ConfigDict() 30 | config.binary = "language_table/train/main" 31 | 32 | config.sequence_length = 4 33 | 34 | config.model_name = "sequence_lav_mse" 35 | config.model = ml_collections.ConfigDict() 36 | config.model.dense_resnet_width = 1024 37 | config.model.dense_resnet_num_blocks = 2 38 | config.model.lava_sequence_length = config.sequence_length 39 | config.model.lava_num_layers = 4 40 | config.model.lava_temporal_transformer_num_layers = 2 41 | config.model.lava_d_model = 128 42 | config.model.lava_num_heads = 2 43 | config.model.lava_pyramid_fuse_layers = (2, 3, 4) 44 | config.model.lava_image_encoder = "resnet" 45 | config.model.lava_lang_encoder = "clip" 46 | 47 | config.agent_name = "bc" 48 | config.agent = ml_collections.ConfigDict() 49 | config.agent.learning_rate = 1e-3 50 | config.agent.pretrained_checkpoints = [ 51 | # CHANGEME: Change this to a local path by running 52 | # download_clip_flax_ckpt.py. 53 | ( 54 | "/tmp/scenic_clip_ckpt/", 55 | [("params/text", "params/encoder/TextEncoder_0")], 56 | ) 57 | ] 58 | config.agent.freeze_keys = ["TextEncoder_0"] 59 | 60 | config.dataset_path = dataset_paths["language_table_sim"] 61 | config.data_target_width = 320 62 | config.data_target_height = 180 63 | config.image_photometric_distortions = True 64 | config.image_augment_crop = True 65 | config.random_crop_factor = 0.95 66 | config.data_normalization_num_samples = 32 67 | config.data_skip_normalize_keys = ["rgb", "instruction"] 68 | config.synthetic_data = False 69 | 70 | config.num_train_steps = 1_000_000 71 | config.per_device_batch_size = 4 # 4096 is used for 64 TPUs. 72 | config.replay_capacity = 5_000 73 | config.num_steps_per_train_iter = 1 74 | 75 | config.log_loss_every_steps = 50 76 | config.checkpoint_every_steps = 50 77 | 78 | config.seed = 42 79 | 80 | config.trial = 0 # Dummy for repeated runs. 81 | return config 82 | 83 | 84 | def get_hyper(h): 85 | return h.product( 86 | [ 87 | h.sweep("config.trial", range(1)), 88 | ] 89 | ) 90 | -------------------------------------------------------------------------------- /language_table/train/configs/language_table_sim_local.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Language Tale Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """A config for training sim human 8 block with BC.""" 17 | 18 | import ml_collections 19 | 20 | 21 | dataset_paths = { 22 | "language_table": "gs://gresearch/robotics/language_table/0.0.1/", 23 | "language_table_sim": "gs://gresearch/robotics/language_table_sim/0.0.1/", 24 | } 25 | 26 | 27 | def get_config(): 28 | """Config for training sim human 8 block with BC locally.""" 29 | config = ml_collections.ConfigDict() 30 | config.binary = "language_table/train/main" 31 | 32 | config.sequence_length = 4 33 | 34 | config.model_name = "sequence_lav_mse" 35 | config.model = ml_collections.ConfigDict() 36 | config.model.dense_resnet_width = 1024 37 | config.model.dense_resnet_num_blocks = 2 38 | config.model.lava_sequence_length = config.sequence_length 39 | config.model.lava_num_layers = 4 40 | config.model.lava_temporal_transformer_num_layers = 2 41 | config.model.lava_d_model = 128 42 | config.model.lava_num_heads = 2 43 | config.model.lava_pyramid_fuse_layers = (2, 3, 4) 44 | config.model.lava_image_encoder = "conv_maxpool" 45 | config.model.lava_lang_encoder = "clip" 46 | 47 | config.agent_name = "bc" 48 | config.agent = ml_collections.ConfigDict() 49 | config.agent.learning_rate = 1e-3 50 | config.agent.pretrained_checkpoints = [ 51 | # CHANGEME: Change this to a local path by running 52 | # download_clip_flax_ckpt.py. 53 | ( 54 | "/tmp/scenic_clip_ckpt/", 55 | [("params/text", "params/encoder/TextEncoder_0")], 56 | ) 57 | ] 58 | config.agent.freeze_keys = ["TextEncoder_0"] 59 | 60 | config.dataset_path = dataset_paths["language_table_sim"] 61 | config.data_target_width = 320 62 | config.data_target_height = 180 63 | config.image_photometric_distortions = True 64 | config.image_augment_crop = True 65 | config.random_crop_factor = 0.95 66 | config.data_normalization_num_samples = 32 67 | config.data_skip_normalize_keys = ["rgb", "instruction"] 68 | config.synthetic_data = False 69 | 70 | config.num_train_steps = 1_000_000 71 | config.per_device_batch_size = 4 # 4096 is used for 64 TPUs. 72 | config.replay_capacity = 5_000 73 | config.num_steps_per_train_iter = 1 74 | 75 | config.log_loss_every_steps = 50 76 | config.checkpoint_every_steps = 50 77 | 78 | config.seed = 42 79 | 80 | config.trial = 0 # Dummy for repeated runs. 81 | return config 82 | 83 | 84 | def get_hyper(h): 85 | return h.product( 86 | [ 87 | h.sweep("config.trial", range(1)), 88 | ] 89 | ) 90 | -------------------------------------------------------------------------------- /language_table/train/download_clip_flax_ckpt.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Language Tale Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Saves a scenic CLIP checkpoint.""" 17 | 18 | from collections.abc import Sequence 19 | 20 | from absl import app 21 | from absl import flags 22 | 23 | from clu import checkpoint 24 | from scenic.projects.baselines.clip import model 25 | import tensorflow as tf 26 | 27 | _CHECKPOINT_DIRECTORY = flags.DEFINE_string( 28 | 'checkpoint_directory', None, 'The directory to save the checkpoint.') 29 | 30 | 31 | def main(argv): 32 | if len(argv) > 1: 33 | raise app.UsageError('Too many command-line arguments.') 34 | 35 | model_vars = model.load_model_vars(model_name='vit_b16') 36 | 37 | out_directory = _CHECKPOINT_DIRECTORY.value 38 | if not tf.io.gfile.exists(out_directory): 39 | tf.io.gfile.makedirs(out_directory) 40 | ckpt = checkpoint.Checkpoint(base_directory=out_directory) 41 | 42 | ckpt.save(model_vars) 43 | 44 | 45 | if __name__ == '__main__': 46 | app.run(main) 47 | -------------------------------------------------------------------------------- /language_table/train/main.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Language Tale Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Main file for running the trainer.""" 17 | 18 | from absl import app 19 | from absl import flags 20 | from absl import logging 21 | 22 | from clu import platform 23 | import jax 24 | from language_table.train import train 25 | from ml_collections import config_flags 26 | import tensorflow as tf 27 | 28 | 29 | FLAGS = flags.FLAGS 30 | 31 | config_flags.DEFINE_config_file( 32 | "config", None, "Training configuration.", lock_config=True 33 | ) 34 | flags.DEFINE_string("workdir", None, "Work unit directory.") 35 | flags.DEFINE_string("tf_data_service_address", None, "TF Data address.") 36 | flags.mark_flags_as_required(["config", "workdir"]) 37 | # Flags --jax_backend_target and --jax_xla_backend are available through JAX. 38 | 39 | 40 | def main(argv): 41 | del argv 42 | 43 | # Hide any GPUs form TensorFlow. Otherwise TF might reserve memory and make 44 | # it unavailable to JAX. 45 | tf.config.experimental.set_visible_devices([], "GPU") 46 | 47 | if FLAGS.jax_backend_target: 48 | logging.info("Using JAX backend target %s", FLAGS.jax_backend_target) 49 | jax_xla_backend = ( 50 | "None" if FLAGS.jax_xla_backend is None else FLAGS.jax_xla_backend 51 | ) 52 | logging.info("Using JAX XLA backend %s", jax_xla_backend) 53 | 54 | logging.info("JAX process: %d / %d", jax.process_index(), jax.process_count()) 55 | logging.info("JAX devices: %r", jax.devices()) 56 | 57 | platform.work_unit().set_task_status( 58 | f"process_index: {jax.process_index()}, " 59 | f"process_count: {jax.process_count()}" 60 | ) 61 | platform.work_unit().create_artifact( 62 | platform.ArtifactType.DIRECTORY, FLAGS.workdir, "workdir" 63 | ) 64 | 65 | train.train(FLAGS.config, FLAGS.workdir, FLAGS.tf_data_service_address) 66 | 67 | 68 | if __name__ == "__main__": 69 | jax.config.config_with_absl() 70 | app.run(main) 71 | -------------------------------------------------------------------------------- /language_table/train/networks/dense_resnet.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Language Tale Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Residual dense block.""" 17 | 18 | from flax import linen as nn 19 | import jax 20 | 21 | 22 | class ResnetDenseBlock(nn.Module): 23 | """Single dense resnet block.""" 24 | width: int 25 | 26 | @nn.compact 27 | def __call__(self, x, *, train): 28 | normal_initializer = jax.nn.initializers.normal(stddev=0.05) 29 | y = nn.relu(x) 30 | y = nn.Dense( 31 | self.width // 4, 32 | kernel_init=normal_initializer, 33 | bias_init=normal_initializer)( 34 | y) 35 | y = nn.relu(y) 36 | y = nn.Dense( 37 | self.width // 4, 38 | kernel_init=normal_initializer, 39 | bias_init=normal_initializer)( 40 | y) 41 | y = nn.relu(y) 42 | y = nn.Dense( 43 | self.width, 44 | kernel_init=normal_initializer, 45 | bias_init=normal_initializer)( 46 | y) 47 | 48 | return x + y 49 | 50 | 51 | class DenseResnet(nn.Module): 52 | """Dense Resnet module.""" 53 | 54 | width: int 55 | num_blocks: int 56 | value_net: bool 57 | 58 | @nn.compact 59 | def __call__(self, x, *, train): 60 | normal_initializer = jax.nn.initializers.normal(stddev=0.05) 61 | x = nn.Dense( 62 | self.width, 63 | kernel_init=normal_initializer, 64 | bias_init=normal_initializer)( 65 | x) 66 | for _ in range(self.num_blocks): 67 | x = ResnetDenseBlock(self.width)(x, train=train) 68 | 69 | if self.value_net: 70 | x = nn.Dense( 71 | 1, kernel_init=normal_initializer, bias_init=normal_initializer)( 72 | x) 73 | return x 74 | -------------------------------------------------------------------------------- /language_table/train/networks/pixel.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Language Tale Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Simple Pixel + Language network impelementations.""" 17 | 18 | import flax.linen as nn 19 | import jax 20 | import jax.numpy as jnp 21 | 22 | from language_table.train.networks import dense_resnet 23 | 24 | 25 | class LanguageFusion(nn.Module): 26 | """Fuses language information multiplicatively.""" 27 | 28 | @nn.compact 29 | def __call__(self, lang, image): 30 | norm_init = jax.nn.initializers.normal(stddev=0.05) 31 | lang = nn.Dense( 32 | jnp.shape(image)[-1], kernel_init=norm_init, bias_init=norm_init)( 33 | lang) 34 | 35 | img_shape = jnp.shape(image) 36 | h = img_shape[1] 37 | w = img_shape[2] 38 | lang = jnp.tile(lang[:, None, None, :], [1, h, w, 1]) 39 | 40 | # Fuse. 41 | fused = image * lang 42 | return fused 43 | 44 | 45 | class ConvMaxpoolLanguageEncoder(nn.Module): 46 | """Simple Conv + Maxpool encoder that multiplicatively fuses language.""" 47 | 48 | @nn.compact 49 | def __call__(self, rgb, lang_embedding, *, train): 50 | 51 | x = rgb 52 | 53 | fuse_from = 2 54 | conv_channels = [32, 64, 128, 256] 55 | for idx, ch in enumerate(conv_channels): 56 | x = nn.Conv(features=ch, kernel_size=(3, 3), padding="SAME")(x) 57 | 58 | if fuse_from <= idx + 1: 59 | x = LanguageFusion()(lang_embedding, x) 60 | 61 | x = nn.relu(x) 62 | x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2), padding="VALID") 63 | 64 | x = jnp.mean(x, axis=(1, 2), keepdims=False) 65 | 66 | norm_init = jax.nn.initializers.normal(stddev=0.05) 67 | if fuse_from <= len(conv_channels) + 1: 68 | lang_info = nn.Dense( 69 | conv_channels[-1], kernel_init=norm_init, bias_init=norm_init)( 70 | lang_embedding) 71 | x *= lang_info 72 | 73 | x = nn.relu(x) 74 | 75 | x = nn.LayerNorm()(x) 76 | 77 | return x 78 | 79 | 80 | class PixelLangMSE(nn.Module): 81 | """Simple Pixel Language network.""" 82 | 83 | action_size: int 84 | 85 | dense_resnet_width: int 86 | dense_resnet_num_blocks: int 87 | 88 | def setup(self): 89 | self.encoder = ConvMaxpoolLanguageEncoder() 90 | self.dense_resnet = dense_resnet.DenseResnet( 91 | width=self.dense_resnet_width, 92 | num_blocks=self.dense_resnet_num_blocks, 93 | value_net=False) 94 | norm_init = jax.nn.initializers.normal(stddev=0.05) 95 | self.action_projection = nn.Dense( 96 | self.action_size, kernel_init=norm_init, bias_init=norm_init) 97 | 98 | def __call__(self, obs, *, train): 99 | rgb = obs["rgb"] 100 | # Reshape to stack images channelwise. 101 | sh = jnp.shape(rgb) 102 | b, n, w, h, c = sh 103 | rgb = jnp.reshape(rgb, (b, w, h, c * n)) 104 | 105 | lang = obs["clip_embedding"] 106 | lang = lang[:, -1, Ellipsis] 107 | encoded_obs = self.encoder(rgb, lang, train=train) 108 | 109 | x = self.dense_resnet(encoded_obs, train=train) 110 | x = self.action_projection(x) 111 | return x 112 | -------------------------------------------------------------------------------- /language_table/train/networks/resnet_v1.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Language Tale Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Implementation of ResNet V1 in Flax. 17 | 18 | "Deep Residual Learning for Image Recognition" 19 | He et al., 2015, [https://arxiv.org/abs/1512.03385] 20 | """ 21 | 22 | # See issue #620. 23 | # pytype: disable=wrong-arg-count 24 | # pytype: disable=missing-parameter 25 | # pytype: disable=wrong-keyword-args 26 | 27 | import functools 28 | from typing import Any, Tuple, Type 29 | 30 | import flax.linen as nn 31 | import jax.numpy as jnp 32 | 33 | Conv1x1 = functools.partial(nn.Conv, kernel_size=(1, 1), use_bias=False) 34 | Conv3x3 = functools.partial(nn.Conv, kernel_size=(3, 3), use_bias=False) 35 | 36 | 37 | class ResNetBlock(nn.Module): 38 | """ResNet block without bottleneck used in ResNet-18 and ResNet-34.""" 39 | 40 | filters: int 41 | norm: Any 42 | strides: Tuple[int, int] = (1, 1) 43 | 44 | @nn.compact 45 | def __call__(self, x): 46 | residual = x 47 | 48 | x = Conv3x3(self.filters, strides=self.strides, name="conv1")(x) 49 | x = self.norm(name="bn1")(x) 50 | x = nn.relu(x) 51 | x = Conv3x3(self.filters, name="conv2")(x) 52 | # Initializing the scale to 0 has been common practice since "Fixup 53 | # Initialization: Residual Learning Without Normalization" Tengyu et al, 54 | # 2019, [https://openreview.net/forum?id=H1gsz30cKX]. 55 | x = self.norm(scale_init=nn.initializers.zeros, name="bn2")(x) 56 | 57 | if residual.shape != x.shape: 58 | residual = Conv1x1( 59 | self.filters, strides=self.strides, name="proj_conv")( 60 | residual) 61 | residual = self.norm(name="proj_bn")(residual) 62 | 63 | x = nn.relu(residual + x) 64 | return x 65 | 66 | 67 | class BottleneckResNetBlock(ResNetBlock): 68 | """Bottleneck ResNet block used in ResNet-50 and larger.""" 69 | 70 | @nn.compact 71 | def __call__(self, x): 72 | residual = x 73 | 74 | x = Conv1x1(self.filters, name="conv1")(x) 75 | x = self.norm(name="bn1")(x) 76 | x = nn.relu(x) 77 | x = Conv3x3(self.filters, strides=self.strides, name="conv2")(x) 78 | x = self.norm(name="bn2")(x) 79 | x = nn.relu(x) 80 | x = Conv1x1(4 * self.filters, name="conv3")(x) 81 | # Initializing the scale to 0 has been common practice since "Fixup 82 | # Initialization: Residual Learning Without Normalization" Tengyu et al, 83 | # 2019, [https://openreview.net/forum?id=H1gsz30cKX]. 84 | x = self.norm(name="bn3")(x) 85 | 86 | if residual.shape != x.shape: 87 | residual = Conv1x1( 88 | 4 * self.filters, strides=self.strides, name="proj_conv")( 89 | residual) 90 | residual = self.norm(name="proj_bn")(residual) 91 | 92 | x = nn.relu(residual + x) 93 | return x 94 | 95 | 96 | class ResNetStage(nn.Module): 97 | """ResNet stage consistent of multiple ResNet blocks.""" 98 | 99 | stage_size: int 100 | filters: int 101 | block_cls: Type[ResNetBlock] 102 | norm: Any 103 | first_block_strides: Tuple[int, int] 104 | 105 | @nn.compact 106 | def __call__(self, x): 107 | for i in range(self.stage_size): 108 | x = self.block_cls( 109 | filters=self.filters, 110 | norm=self.norm, 111 | strides=self.first_block_strides if i == 0 else (1, 1), 112 | name=f"block{i + 1}")( 113 | x) 114 | return x 115 | 116 | 117 | class ResNet(nn.Module): 118 | """Construct ResNet V1 with `num_classes` outputs. 119 | 120 | Attributes: 121 | num_classes: Number of nodes in the final layer. 122 | block_cls: Class for the blocks. ResNet-50 and larger use 123 | `BottleneckResNetBlock` (convolutions: 1x1, 3x3, 1x1), ResNet-18 and 124 | ResNet-34 use `ResNetBlock` without bottleneck (two 3x3 convolutions). 125 | stage_sizes: Tuple with the number of ResNet blocks in each stage. Number of 126 | stages can be varied. 127 | width_factor: Factor applied to the number of filters. The 64 * width_factor 128 | is the number of filters in the first stage, every consecutive stage 129 | doubles the number of filters. 130 | """ 131 | num_classes: int 132 | block_cls: Type[ResNetBlock] 133 | stage_sizes: Tuple[int] 134 | width_factor: int = 1 135 | 136 | @nn.compact 137 | def __call__(self, x, *, train): 138 | """Apply the ResNet to the inputs `x`. 139 | 140 | Args: 141 | x: Inputs. 142 | train: Whether to use BatchNorm in training or inference mode. 143 | 144 | Returns: 145 | The output head with `num_classes` entries. 146 | """ 147 | width = 64 * self.width_factor 148 | norm = functools.partial( 149 | nn.BatchNorm, use_running_average=not train, momentum=0.9) 150 | 151 | # Root block 152 | x = nn.Conv( 153 | features=width, 154 | kernel_size=(7, 7), 155 | strides=(2, 2), 156 | use_bias=False, 157 | name="init_conv")( 158 | x) 159 | x = norm(name="init_bn")(x) 160 | x = nn.relu(x) 161 | x = nn.max_pool(x, (3, 3), strides=(2, 2), padding="SAME") 162 | 163 | # Stages 164 | for i, stage_size in enumerate(self.stage_sizes): 165 | x = ResNetStage( 166 | stage_size, 167 | filters=width * 2**i, 168 | block_cls=self.block_cls, 169 | norm=norm, 170 | first_block_strides=(1, 1) if i == 0 else (2, 2), 171 | name=f"stage{i + 1}")( 172 | x) 173 | 174 | # Head 175 | x = jnp.mean(x, axis=(1, 2)) 176 | x = nn.Dense( 177 | self.num_classes, kernel_init=nn.initializers.zeros, name="head")( 178 | x) 179 | return x 180 | 181 | 182 | ResNet18 = functools.partial( 183 | ResNet, stage_sizes=(2, 2, 2, 2), block_cls=ResNetBlock) 184 | ResNet34 = functools.partial( 185 | ResNet, stage_sizes=(3, 4, 6, 3), block_cls=ResNetBlock) 186 | ResNet50 = functools.partial( 187 | ResNet, stage_sizes=(3, 4, 6, 3), block_cls=BottleneckResNetBlock) 188 | ResNet101 = functools.partial( 189 | ResNet, stage_sizes=(3, 4, 23, 3), block_cls=BottleneckResNetBlock) 190 | ResNet152 = functools.partial( 191 | ResNet, stage_sizes=(3, 8, 36, 3), block_cls=BottleneckResNetBlock) 192 | ResNet200 = functools.partial( 193 | ResNet, stage_sizes=(3, 24, 36, 3), block_cls=BottleneckResNetBlock) 194 | 195 | 196 | class MultiscaleResNet(nn.Module): 197 | """Construct a multiscale ResNet. 198 | 199 | Attributes: 200 | block_cls: Class for the blocks. ResNet-50 and larger use 201 | `BottleneckResNetBlock` (convolutions: 1x1, 3x3, 1x1), ResNet-18 and 202 | ResNet-34 use `ResNetBlock` without bottleneck (two 3x3 convolutions). 203 | stage_sizes: Tuple with the number of ResNet blocks in each stage. Number of 204 | stages can be varied. 205 | width_factor: Factor applied to the number of filters. The 64 * width_factor 206 | is the number of filters in the first stage, every consecutive stage 207 | doubles the number of filters. 208 | """ 209 | block_cls: Type[ResNetBlock] 210 | stage_sizes: Tuple[int, Ellipsis] 211 | width_factor: int = 1 212 | 213 | @nn.compact 214 | def __call__(self, x, *, train): 215 | """Apply the ResNet to the inputs `x`. 216 | 217 | Args: 218 | x: Inputs. 219 | train: Whether to use BatchNorm in training or inference mode. 220 | 221 | Returns: 222 | The output head with `num_classes` entries. 223 | """ 224 | width = 64 * self.width_factor 225 | norm = functools.partial( 226 | nn.BatchNorm, use_running_average=not train, momentum=0.9) 227 | 228 | all_outputs = [] 229 | 230 | # Root block 231 | x = nn.Conv( 232 | features=width, 233 | kernel_size=(7, 7), 234 | strides=(2, 2), 235 | use_bias=False, 236 | name="init_conv")( 237 | x) 238 | 239 | all_outputs.append(x) 240 | 241 | x = norm(name="init_bn")(x) 242 | x = nn.relu(x) 243 | x = nn.max_pool(x, (3, 3), strides=(2, 2), padding="SAME") 244 | 245 | all_outputs.append(x) 246 | 247 | # Stages 248 | for i, stage_size in enumerate(self.stage_sizes): 249 | x = ResNetStage( 250 | stage_size, 251 | filters=width * 2**i, 252 | block_cls=self.block_cls, 253 | norm=norm, 254 | first_block_strides=(1, 1) if i == 0 else (2, 2), 255 | name=f"stage{i + 1}")( 256 | x) 257 | all_outputs.append(x) 258 | 259 | return all_outputs 260 | -------------------------------------------------------------------------------- /language_table/train/networks/resnet_v1_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Language Tale Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from absl.testing import parameterized 17 | from clu import parameter_overview 18 | import jax 19 | import jax.numpy as jnp 20 | from language_table.train.networks import resnet_v1 21 | import tensorflow as tf 22 | 23 | 24 | class ResNetV1Test(tf.test.TestCase, parameterized.TestCase): 25 | """Test cases for ResNet V1.""" 26 | 27 | @parameterized.named_parameters( 28 | ("ResNet18", resnet_v1.ResNet18, 11_689_512), 29 | ("ResNet34", resnet_v1.ResNet34, 21_797_672), 30 | ("ResNet50", resnet_v1.ResNet50, 25_557_032), 31 | ("ResNet101", resnet_v1.ResNet101, 44_549_160), 32 | ("ResNet152", resnet_v1.ResNet152, 60_192_808), 33 | ("ResNet200", resnet_v1.ResNet200, 64_673_832), 34 | ) 35 | def test_architecture(self, cls, param_count): 36 | rng = jax.random.PRNGKey(0) 37 | model = cls(num_classes=1000) 38 | variables = model.init(rng, jnp.ones([2, 224, 224, 3]), train=False) 39 | params = variables["params"] 40 | self.assertEqual(param_count, parameter_overview.count_parameters(params)) 41 | 42 | 43 | if __name__ == "__main__": 44 | tf.test.main() 45 | -------------------------------------------------------------------------------- /language_table/train/policy.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Language Tale Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """PyPolicy for BC Jax.""" 17 | 18 | from flax.training import checkpoints 19 | import jax 20 | import jax.numpy as jnp 21 | import numpy as np 22 | from tf_agents.policies import py_policy 23 | from tf_agents.trajectories import policy_step 24 | 25 | EPS = jnp.finfo(jnp.float32).eps 26 | 27 | 28 | class BCJaxPyPolicy(py_policy.PyPolicy): 29 | """Runs inference with a BC policy.""" 30 | 31 | def __init__(self, time_step_spec, action_spec, model, checkpoint_path, 32 | rng, params=None, action_statistics=None): 33 | super(BCJaxPyPolicy, self).__init__(time_step_spec, action_spec) 34 | self.model = model 35 | self.rng = rng 36 | 37 | if params is not None and action_statistics is not None: 38 | variables = { 39 | "params": params, 40 | "batch_stats": {} 41 | } 42 | else: 43 | state_dict = checkpoints.restore_checkpoint(checkpoint_path, None) 44 | variables = { 45 | "params": state_dict["params"], 46 | "batch_stats": state_dict["batch_stats"] 47 | } 48 | 49 | if action_statistics is not None: 50 | self.action_mean = np.array(action_statistics["mean"]) 51 | self.action_std = np.array(action_statistics["std"]) 52 | else: 53 | # We can load the observation and action statistics from the state dict. 54 | self.action_mean = np.array( 55 | state_dict["norm_info"]["action_statistics"]["mean"]) 56 | self.action_std = np.array( 57 | state_dict["norm_info"]["action_statistics"]["std"]) 58 | 59 | self._rgb_mean = jnp.array( 60 | state_dict["norm_info"]["observation_statistics"]["rgb"]["mean"]) 61 | self._rgb_std = jnp.array( 62 | state_dict["norm_info"]["observation_statistics"]["rgb"]["std"]) 63 | 64 | self.variables = variables 65 | 66 | self._run_action_inference_jit = jax.jit(self._run_action_inference) 67 | 68 | def _run_action_inference(self, observation): 69 | # Add a batch dim. 70 | observation = jax.tree.map(lambda x: jnp.expand_dims(x, 0), observation) 71 | 72 | normalized_action = self.model.apply( 73 | self.variables, observation, train=False) 74 | action = ( 75 | normalized_action * jnp.maximum(self.action_std, EPS) + 76 | self.action_mean) 77 | 78 | # Clip the action to spec. 79 | action = jnp.clip(action, self.action_spec.minimum, 80 | self.action_spec.maximum) 81 | 82 | return action 83 | 84 | def _action(self, time_step, policy_state=(), seed=0): 85 | observation = time_step.observation 86 | action = self._run_action_inference_jit(observation)[0] 87 | return policy_step.PolicyStep(action=action) 88 | -------------------------------------------------------------------------------- /language_table/train/train.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Language Tale Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """BC trainer for Language-Table.""" 17 | import functools 18 | import os 19 | from typing import Optional 20 | from absl import logging 21 | from clu import checkpoint 22 | from clu import metric_writers 23 | from clu import parameter_overview 24 | from clu import periodic_actions 25 | import flax.jax_utils as flax_utils 26 | from flax.training import checkpoints as flax_checkpoints 27 | import jax 28 | from language_table.train import bc 29 | from language_table.train import input_pipeline_rlds 30 | from language_table.train.networks import lava 31 | from language_table.train.networks import pixel 32 | import ml_collections 33 | import tensorflow as tf 34 | 35 | 36 | def multi_train_step(state, batches, agent, initial_metrics, rng): 37 | """Runs multiple iterations of a train step at once.""" 38 | num_batches = batches["action"].shape[0] 39 | 40 | def _body_fun(step, state_and_metrics): 41 | state, _ = state_and_metrics 42 | # Important to call fold_in since `rng` is the same on every call. 43 | train_rng = jax.random.fold_in(rng, state.step) 44 | train_rng = jax.random.fold_in(train_rng, jax.lax.axis_index("batch")) 45 | new_state, metrics_update = agent.train( 46 | state=state, 47 | batch=jax.tree.map(lambda x: x[step], batches), 48 | rng=train_rng, 49 | ) 50 | return new_state, metrics_update 51 | 52 | return jax.lax.fori_loop( 53 | lower=0, 54 | upper=num_batches, 55 | body_fun=_body_fun, 56 | init_val=(state, initial_metrics), 57 | ) 58 | 59 | 60 | def train( 61 | config, 62 | workdir, 63 | tf_data_service_address, 64 | ): 65 | """Runs a training loop. 66 | 67 | Args: 68 | config: Configuration to use. 69 | workdir: Working directory for checkpoints and TF summaries. If this 70 | contains checkpoint training will be resumed from the latest checkpoint. 71 | tf_data_service_address: Address of the TF Data Service. If None, all 72 | dataset computation will run locally. 73 | """ 74 | tf.io.gfile.makedirs(workdir) 75 | rng = jax.random.PRNGKey(config.seed) 76 | logging.info( 77 | "Global batch size = %d", 78 | jax.device_count() * config.per_device_batch_size, 79 | ) 80 | rng, data_rng = jax.random.split(rng) 81 | data_rng = jax.random.fold_in(data_rng, jax.process_index()) 82 | train_ds, obs_statistics, act_statistics, min_actions, max_actions = ( 83 | input_pipeline_rlds.create_datasets( 84 | data_rng, 85 | dataset_path=config.dataset_path, 86 | normalization_path=os.path.join(workdir, "norm_info"), 87 | sequence_length=config.sequence_length, 88 | per_device_batch_size=config.per_device_batch_size, 89 | num_steps_per_train_iter=config.num_steps_per_train_iter, 90 | target_width=config.data_target_width, 91 | target_height=config.data_target_height, 92 | random_crop_factor=config.random_crop_factor, 93 | normalization_num_samples=config.data_normalization_num_samples, 94 | skip_normalize_keys=config.data_skip_normalize_keys, 95 | cache=True, 96 | shuffle=True, 97 | shuffle_buffer_size=config.replay_capacity, 98 | prefetch_size=tf.data.AUTOTUNE, 99 | tf_data_service_address=tf_data_service_address, 100 | ) 101 | ) 102 | train_iter = train_ds.as_numpy_iterator() 103 | rng, agent_rng = jax.random.split(rng) 104 | sample_batch = jax.tree.map(lambda x: x[0][0], next(train_iter)) 105 | agent = create_agent( 106 | config.agent_name, 107 | config.model_name, 108 | config.agent, 109 | config.model, 110 | config.sequence_length, 111 | sample_batch, 112 | obs_statistics, 113 | act_statistics, 114 | min_actions, 115 | max_actions, 116 | ) 117 | state, initial_metrics = agent.create_train_state(sample_batch, agent_rng) 118 | # Save a file with the agent parameters. 119 | if jax.process_index() == 0: 120 | with tf.io.gfile.GFile(os.path.join(workdir, "parameters.txt"), "w") as f: 121 | f.write(parameter_overview.get_parameter_overview(state.params)) 122 | # Set up checkpointing of the model and the input pipeline. 123 | checkpoint_dir = os.path.join(workdir, "checkpoints") 124 | ckpt = checkpoint.MultihostCheckpoint(checkpoint_dir, max_to_keep=1_000_000) 125 | state = ckpt.restore_or_initialize(state) 126 | initial_step = int(state.step) 127 | initial_multi_step = initial_step // config.num_steps_per_train_iter 128 | if jax.process_index() == 0: 129 | flax_checkpoint_dir = os.path.join(workdir, "flax_checkpoints") 130 | tf.io.gfile.makedirs(flax_checkpoint_dir) 131 | flax_checkpoints.save_checkpoint( 132 | flax_checkpoint_dir, 133 | state, 134 | step=int(state.step), 135 | keep=10_000, 136 | keep_every_n_steps=config.checkpoint_every_steps, 137 | overwrite=True, 138 | ) 139 | # Distribute training. 140 | state = flax_utils.replicate(state) 141 | rng, train_rng = jax.random.split(rng) 142 | # Create the pmapped multi_train_step. 143 | p_train_step = jax.pmap( 144 | functools.partial( 145 | multi_train_step, 146 | agent=agent, 147 | initial_metrics=initial_metrics, 148 | rng=train_rng, 149 | ), 150 | axis_name="batch", 151 | ) 152 | if config.num_train_steps > 0: 153 | num_train_steps = config.num_train_steps 154 | logging.info("num_train_steps=%d", num_train_steps) 155 | writer = metric_writers.create_default_writer( 156 | workdir, just_logging=jax.process_index() > 0 157 | ) 158 | if initial_step == 0: 159 | writer.write_hparams(dict(config)) 160 | # We use a single thread threadpool for saving checkpoints, so that we're not 161 | # blocked. 162 | hooks = [] 163 | report_progress = periodic_actions.ReportProgress( 164 | num_train_steps=num_train_steps, writer=writer 165 | ) 166 | if jax.process_index() == 0: 167 | hooks += [ 168 | report_progress, 169 | ] 170 | train_metrics = None 171 | logging.info("Starting training loop at step %d.", initial_step) 172 | with metric_writers.ensure_flushes(writer): 173 | for multi_step in range( 174 | initial_multi_step, num_train_steps // config.num_steps_per_train_iter 175 | ): 176 | # Compute the actual step by multiplying the amount of steps we run 177 | # per train iteration. 178 | step = multi_step * config.num_steps_per_train_iter 179 | # `step` is a Python integer. `state.step` is JAX integer on the GPU/TPU 180 | # devices. 181 | is_last_step = step == num_train_steps 182 | with jax.profiler.StepTraceAnnotation("train", step_num=step): 183 | batch = next(train_iter) 184 | state, metrics_update = p_train_step(state=state, batches=batch) 185 | metric_update = flax_utils.unreplicate(metrics_update) 186 | train_metrics = ( 187 | metric_update 188 | if train_metrics is None 189 | else train_metrics.merge(metric_update) 190 | ) 191 | logging.log_first_n( 192 | logging.INFO, "Finished multi training step %d.", 5, multi_step 193 | ) 194 | logging.log_first_n(logging.INFO, "Finished training step %d.", 5, step) 195 | for h in hooks: 196 | h(step) 197 | if step % config.log_loss_every_steps == 0 or is_last_step: 198 | writer.write_scalars(step, train_metrics.compute()) 199 | train_metrics = None 200 | 201 | def _save_checkpoints(state, step): 202 | with report_progress.timed("checkpoint"): 203 | unreplicated_state = flax_utils.unreplicate(state) 204 | ckpt.save(unreplicated_state) 205 | if jax.process_index() == 0: 206 | flax_checkpoints.save_checkpoint( 207 | flax_checkpoint_dir, 208 | unreplicated_state, 209 | step=step, 210 | keep=10_000, 211 | keep_every_n_steps=config.checkpoint_every_steps, 212 | overwrite=True, 213 | ) 214 | 215 | if step % config.checkpoint_every_steps == 0 or is_last_step: 216 | state = merge_batch_stats(state) 217 | _save_checkpoints(state, step) 218 | logging.info("Finishing training at step %d", num_train_steps) 219 | 220 | 221 | def create_agent( 222 | agent_name, 223 | model_name, 224 | agent_config, 225 | model_config, 226 | sequence_length, 227 | sample_batch, 228 | obs_statistics, 229 | act_statistics, 230 | min_actions, 231 | max_actions, 232 | ): 233 | """Creates an agent using an agent and model config.""" 234 | # Automatically infer the action size. 235 | action = sample_batch["action"] 236 | action_size = action.shape[-1] 237 | if model_name == "sequence_lav_mse": 238 | model = lava.SequenceLAVMSE(action_size=action_size, **model_config) 239 | elif model_name == "pixel_lang_mse": 240 | model = pixel.PixelLangMSE(action_size=action_size, **model_config) 241 | else: 242 | raise NotImplementedError(f"{model_name} not implemented.") 243 | if agent_name == "bc": 244 | agent = bc.BCAgent( 245 | model=model, 246 | sequence_length=sequence_length, 247 | observation_statistics=obs_statistics, 248 | action_statistics=act_statistics, 249 | action_min=min_actions, 250 | action_max=max_actions, 251 | **agent_config, 252 | ) 253 | else: 254 | raise NotImplementedError(f"{agent_name} not implemented.") 255 | return agent 256 | 257 | 258 | def merge_batch_stats(replicated_state): 259 | """Merge model batch stats.""" 260 | if jax.tree.leaves(replicated_state.batch_stats): 261 | cross_replica_mean = jax.pmap(lambda x: jax.lax.pmean(x, "x"), "x") 262 | return replicated_state.replace( 263 | batch_stats=cross_replica_mean(replicated_state.batch_stats) 264 | ) 265 | else: 266 | return replicated_state 267 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | clu 2 | dm-reverb-nightly>=0.9.0.dev20221205 3 | gym<=0.23.0 4 | matplotlib 5 | numpy 6 | opencv-python 7 | protobuf 8 | pybullet 9 | rlds>=0.1.7 10 | scipy 11 | six 12 | tf-nightly>=2.12.0.dev20230201 13 | tensorflow_datasets 14 | tensorflow_text 15 | tf_agents>=0.14.0 16 | git+https://github.com/openai/CLIP.git 17 | git+https://github.com/google-research/scenic.git 18 | -------------------------------------------------------------------------------- /requirements_static.txt: -------------------------------------------------------------------------------- 1 | absl-py==1.2.0 2 | anyio==3.6.2 3 | argon2-cffi==21.3.0 4 | argon2-cffi-bindings==21.2.0 5 | asttokens==2.1.0 6 | astunparse==1.6.3 7 | backcall==0.2.0 8 | bleach==5.0.1 9 | cached-property==1.5.2 10 | cachetools==5.2.0 11 | certifi==2022.12.7 12 | cffi==1.15.1 13 | charset-normalizer==3.0.1 14 | chex==0.1.5 15 | clip @ git+https://github.com/openai/CLIP.git@3702849800aa56e2223035bccd1c6ef91c704ca8 16 | cloudpickle==2.2.0 17 | clu==0.0.8 18 | colorama==0.4.5 19 | commonmark==0.9.1 20 | contextlib2==21.6.0 21 | contourpy==1.0.6 22 | cycler==0.11.0 23 | debugpy==1.6.4 24 | decorator==5.1.1 25 | defusedxml==0.7.1 26 | dill==0.3.6 27 | dm-reverb-nightly==0.11.0.dev20230205 28 | dm-tree==0.1.7 29 | entrypoints==0.4 30 | etils==0.6.0 31 | executing==1.2.0 32 | fastjsonschema==2.16.2 33 | flatbuffers==22.10.26 34 | flax==0.5.3 35 | fonttools==4.38.0 36 | ftfy==6.1.1 37 | gast==0.4.0 38 | gin-config==0.5.0 39 | google-api-core==2.10.1 40 | google-auth==2.10.0 41 | google-auth-oauthlib==0.4.6 42 | google-cloud-speech==2.16.0 43 | google-pasta==0.2.0 44 | googleapis-common-protos==1.56.4 45 | grain==0.1.4 46 | grpcio==1.49.1 47 | grpcio-status==1.49.1 48 | gym==0.23.0 49 | gym-notices==0.0.8 50 | h5py==3.7.0 51 | idna==3.4 52 | imageio==2.28.0 53 | immutabledict==2.2.3 54 | importlib-resources==5.9.0 55 | install==1.3.5 56 | ipykernel==6.17.1 57 | ipython==8.7.0 58 | ipython-genutils==0.2.0 59 | ipywidgets==8.0.2 60 | jax==0.3.15 61 | jaxlib==0.3.15 62 | jedi==0.18.2 63 | Jinja2==3.1.2 64 | jsonschema==4.17.1 65 | jupyter==1.0.0 66 | jupyter-console==6.4.4 67 | jupyter-server==1.23.3 68 | jupyter_client==7.4.7 69 | jupyter_core==5.1.0 70 | jupyterlab-pygments==0.2.2 71 | jupyterlab-widgets==3.0.3 72 | keras-nightly==2.13.0.dev2023020508 73 | Keras-Preprocessing==1.1.2 74 | kiwisolver==1.4.4 75 | libclang==14.0.6 76 | Markdown==3.4.1 77 | MarkupSafe==2.1.1 78 | matplotlib==3.6.2 79 | matplotlib-inline==0.1.6 80 | mediapy==1.1.6 81 | mistune==2.0.4 82 | ml-collections==0.1.1 83 | msgpack==1.0.4 84 | mypy-protobuf==3.1.0 85 | nbclassic==0.4.8 86 | nbclient==0.7.0 87 | nbconvert==7.2.5 88 | nbformat==5.7.0 89 | nest-asyncio==1.5.6 90 | notebook==6.5.2 91 | notebook_shim==0.2.2 92 | numpy==1.23.5 93 | nvidia-cublas-cu11==11.10.3.66 94 | nvidia-cuda-nvrtc-cu11==11.7.99 95 | nvidia-cuda-runtime-cu11==11.7.99 96 | nvidia-cudnn-cu11==8.5.0.96 97 | oauthlib==3.2.2 98 | opencv-python==4.6.0.66 99 | opt-einsum==3.3.0 100 | optax @ git+https://github.com/deepmind/optax.git@d9157132b8661e1453d156f591784ff27d6d1141 101 | packaging==23.0 102 | pandocfilters==1.5.0 103 | parso==0.8.3 104 | pexpect==4.8.0 105 | pickleshare==0.7.5 106 | Pillow==9.4.0 107 | platformdirs==2.5.4 108 | portpicker==1.5.2 109 | prometheus-client==0.15.0 110 | promise==2.3 111 | prompt-toolkit==3.0.33 112 | proto-plus==1.22.1 113 | protobuf==4.21.12 114 | psutil==5.9.4 115 | ptyprocess==0.7.0 116 | pure-eval==0.2.2 117 | pyasn1==0.4.8 118 | pyasn1-modules==0.2.8 119 | PyAudio==0.2.12 120 | pybullet==3.2.5 121 | pycparser==2.21 122 | pygame==2.1.0 123 | Pygments==2.14.0 124 | pyparsing==3.0.9 125 | pyrsistent==0.19.2 126 | python-dateutil==2.8.2 127 | PyYAML==6.0 128 | pyzmq==24.0.1 129 | qtconsole==5.4.0 130 | QtPy==2.3.0 131 | regex==2022.10.31 132 | requests==2.28.2 133 | requests-oauthlib==1.3.1 134 | rich==11.2.0 135 | rlds==0.1.7 136 | rsa==4.9 137 | scenic @ git+https://github.com/google-research/scenic.git@ae21d9e884015aa7bc7cf1d489af53d16c249726 138 | scipy==1.9.3 139 | Send2Trash==1.8.0 140 | six==1.16.0 141 | sniffio==1.3.0 142 | stack-data==0.6.2 143 | tb-nightly==2.12.0a20230205 144 | tensorboard==2.11.0 145 | tensorboard-data-server==0.7.0 146 | tensorboard-plugin-wit==1.8.1 147 | tensorflow-addons==0.19.0 148 | tensorflow-datasets==4.7.0 149 | tensorflow-io-gcs-filesystem==0.26.0 150 | tensorflow-metadata==1.10.0 151 | tensorflow-probability==0.17.0 152 | tensorflow-text-nightly==2.12.0.dev20230203 153 | tensorstore==0.1.22 154 | termcolor==1.1.0 155 | terminado==0.17.0 156 | tf-agents==0.14.0 157 | tf-estimator-nightly==2.13.0.dev2023020509 158 | tf-hub-nightly==0.13.0.dev202302070818 159 | tf-nightly==2.12.0.dev20230203 160 | toml==0.10.2 161 | toolz==0.12.0 162 | torch==1.13.1 163 | torchvision==0.14.1 164 | tornado==6.2 165 | tqdm==4.64.1 166 | traitlets==5.5.0 167 | typeguard==2.13.3 168 | types-protobuf==3.20.4 169 | typing_extensions==4.3.0 170 | urllib3==1.26.14 171 | wcwidth==0.2.5 172 | websocket-client==1.4.2 173 | Werkzeug==2.2.2 174 | widgetsnbextension==4.0.3 175 | wrapt==1.14.1 176 | zipp==3.8.1 177 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Language Tale Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Setup.""" 17 | 18 | from distutils import core 19 | import os 20 | 21 | from setuptools import find_namespace_packages 22 | 23 | 24 | here = os.path.abspath(os.path.dirname(__file__)) 25 | try: 26 | README = open(os.path.join(here, 'README.md'), encoding='utf-8').read() 27 | except IOError: 28 | README = '' 29 | 30 | 31 | install_requires = [ 32 | 'clu', 33 | 'dm-reverb-nightly>=0.9.0.dev20221205', 34 | 'gym<=0.23.0', 35 | 'matplotlib', 36 | 'numpy', 37 | 'opencv-python', 38 | 'protobuf', 39 | 'pybullet', 40 | 'rlds>=0.1.7', 41 | 'scipy', 42 | 'six', 43 | 'tf-nightly>=2.12.0.dev20230201', 44 | 'tensorflow_datasets>=4.7.0', 45 | 'tf_agents>=0.14.0', 46 | ] 47 | 48 | 49 | core.setup( 50 | name='language_table', 51 | version='0.1', 52 | description=( 53 | 'Language-Table is a suite of human-collected datasets and a multi-task' 54 | ' continuous control benchmark for open vocabulary visuolinguomotor' 55 | ' learning.' 56 | ), 57 | long_description='\n\n'.join([README]), 58 | long_description_content_type='text/markdown', 59 | author='Language Table Team', 60 | author_email='language-table-team@google.com', 61 | url='https://github.com/google-research/language-table', 62 | packages=find_namespace_packages(), 63 | install_requires=install_requires, 64 | include_package_data=True, 65 | ) 66 | --------------------------------------------------------------------------------