├── CONTRIBUTING.md
├── LICENSE
├── MANIFEST.in
├── README.md
├── docs
├── real.jpeg
└── sim.jpeg
├── language_table
├── __init__.py
├── common
│ ├── __init__.py
│ ├── clip_tokenizer.py
│ └── clip_tokenizer_test.py
├── environments
│ ├── __init__.py
│ ├── assets
│ │ ├── blocks
│ │ │ ├── blue_cube.urdf
│ │ │ ├── blue_moon.urdf
│ │ │ ├── blue_pentagon.urdf
│ │ │ ├── blue_star.urdf
│ │ │ ├── cube.obj
│ │ │ ├── green_cube.urdf
│ │ │ ├── green_moon.urdf
│ │ │ ├── green_pentagon.urdf
│ │ │ ├── green_star.urdf
│ │ │ ├── moon.obj
│ │ │ ├── pentagon.obj
│ │ │ ├── purple_pole.urdf
│ │ │ ├── red_cube.urdf
│ │ │ ├── red_moon.urdf
│ │ │ ├── red_pentagon.urdf
│ │ │ ├── red_star.urdf
│ │ │ ├── star.obj
│ │ │ ├── yellow_cube.urdf
│ │ │ ├── yellow_moon.urdf
│ │ │ ├── yellow_pentagon.urdf
│ │ │ └── yellow_star.urdf
│ │ ├── plane.obj
│ │ ├── suction
│ │ │ ├── base.obj
│ │ │ ├── cylinder.urdf
│ │ │ ├── cylinder_real.urdf
│ │ │ ├── head.obj
│ │ │ ├── mid.obj
│ │ │ ├── suction-base.urdf
│ │ │ ├── suction-head-long.urdf
│ │ │ ├── suction-head.urdf
│ │ │ └── tip.obj
│ │ └── workspace_real.urdf
│ ├── blocks.py
│ ├── constants.py
│ ├── language_table.py
│ ├── language_table_test.py
│ ├── oracles
│ │ ├── __init__.py
│ │ ├── oriented_push_oracle.py
│ │ ├── plot.py
│ │ ├── push_oracle_rrt_slowdown.py
│ │ └── rrt_star.py
│ ├── rewards
│ │ ├── __init__.py
│ │ ├── block1_to_corner.py
│ │ ├── block2absolutelocation.py
│ │ ├── block2block.py
│ │ ├── block2block_relative_location.py
│ │ ├── block2relativelocation.py
│ │ ├── constants.py
│ │ ├── instructions.py
│ │ ├── instructions_test.py
│ │ ├── play.py
│ │ ├── point2block.py
│ │ ├── reward.py
│ │ ├── separate_blocks.py
│ │ ├── synonyms.py
│ │ └── task_info.py
│ └── utils
│ │ ├── __init__.py
│ │ ├── pose3d.py
│ │ ├── utils_pybullet.py
│ │ ├── xarm_sim_robot.py
│ │ └── xarm_sim_robot_test.py
├── eval
│ ├── __init__.py
│ ├── main.py
│ └── wrappers.py
├── examples
│ ├── dataset_example.py
│ ├── environment_example.py
│ └── language_table_tutorial.ipynb
└── train
│ ├── __init__.py
│ ├── bc.py
│ ├── configs
│ ├── language_table_resnet_sim_local.py
│ └── language_table_sim_local.py
│ ├── download_clip_flax_ckpt.py
│ ├── input_pipeline_rlds.py
│ ├── main.py
│ ├── networks
│ ├── dense_resnet.py
│ ├── lava.py
│ ├── pixel.py
│ ├── resnet_v1.py
│ └── resnet_v1_test.py
│ ├── normalization.py
│ ├── policy.py
│ └── train.py
├── requirements.txt
├── requirements_static.txt
└── setup.py
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # How to Contribute
2 |
3 | We'd love to accept your patches and contributions to this project. There are
4 | just a few small guidelines you need to follow.
5 |
6 | ## Contributor License Agreement
7 |
8 | Contributions to this project must be accompanied by a Contributor License
9 | Agreement (CLA). You (or your employer) retain the copyright to your
10 | contribution; this simply gives us permission to use and redistribute your
11 | contributions as part of the project. Head over to
12 | to see your current agreements on file or
13 | to sign a new one.
14 |
15 | You generally only need to submit a CLA once, so if you've already submitted one
16 | (even if it was for a different project), you probably don't need to do it
17 | again.
18 |
19 | ## Code Reviews
20 |
21 | All submissions, including submissions by project members, require review. We
22 | use GitHub pull requests for this purpose. Consult
23 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
24 | information on using pull requests.
25 |
26 | ## Community Guidelines
27 |
28 | This project follows
29 | [Google's Open Source Community Guidelines](https://opensource.google/conduct/).
30 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | recursive-include language_table/environments/assets *
2 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Language Table
2 |
3 | Language-Table is a suite of human-collected datasets and a multi-task continuous control benchmark for open vocabulary visuolinguomotor learning.
4 |
5 |  | 
6 | :-------------------------:|:-------------------------:|
7 |
8 | ## Installation
9 |
10 | Installation with `pip`. `requirements.txt` contains dependencies for running
11 | the environment and simple dataset examples.
12 |
13 | ```
14 | python3 -m venv ./ltvenv
15 | source ./ltvenv/bin/activate
16 | pip install -r ./requirements.txt
17 | export PYTHONPATH=${PWD}:$PYTHONPATH
18 | ```
19 |
20 | For running the full train script, install using `requirements_static.txt`, as
21 | this contains pinned versions for running the full train script.
22 |
23 | ```
24 | python3 -m venv ./ltvenvtrain
25 | source ./ltvenvtrain/bin/activate
26 | pip install --no-deps -r ./requirements_static.txt
27 | export PYTHONPATH=${PWD}:$PYTHONPATH
28 | ```
29 | ## Quickstart
30 |
31 | ### Examples
32 | #### Scripts
33 | Run and edit the following examples:
34 |
35 | Load the environment and run 5 random steps:
36 |
37 | ```
38 | python3 language_table/examples/environment_example.py
39 | ```
40 |
41 | Load dataset and print first 5 elements:
42 |
43 | ```
44 | python3 language_table/examples/dataset_example.py
45 | ```
46 |
47 | #### Train
48 |
49 | ```
50 | source ./ltvenvtrain/bin/activate
51 | mkdir -p /tmp/language_table_train/
52 | PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python python language_table/train/main.py --config=./language_table/train/configs/language_table_sim_local.py --workdir=/tmp/language_table_train/
53 | ```
54 |
55 | #### Colab
56 | See the [colab](https://colab.research.google.com/github/google-research/language-table/blob/main/language_table/examples/language_table_tutorial.ipynb) for a more complete tutorial.
57 |
58 | ### Data
59 | ```
60 | import tensorflow_datasets as tfds
61 | data_directory = 'gs://gresearch/robotics/language_table/0.0.1/'
62 | dataset = tfds.builder_from_directory(data_directory).as_dataset()
63 | ```
64 |
65 | ### Environment
66 | ```
67 | from language_table.environments import blocks
68 | from language_table.environments import language_table
69 | from language_table.environments.rewards import block2block
70 |
71 | env = language_table.LanguageTable(
72 | block_mode=blocks.LanguageTableBlockVariants.BLOCK_8,
73 | reward_factory=block2block.BlockToBlockReward,
74 | control_frequency=10.0,
75 | )
76 | obs = env.reset()
77 | ```
78 |
79 | ## Datasets
80 |
81 | ### Descriptions
82 |
83 | * **Real Robot**
84 | * **language_table**: 442,226 episodes of real robot relabeled data.
85 | * **Simulation (human)**
86 | * **language_table_sim**: 181,020 episodes of simulation relabeled data.
87 | * **language_table_blocktoblock_sim**: 8,000 episodes of single task "block to block" data.
88 | * **language_table_blocktoblock_4block_sim**: 8,298 episodes of single task "block to block" data in the 4 block configuration.
89 | * **Simulation (oracle)**
90 | * **language_table_blocktoblock_oracle_sim**: 200,000 episodes of single task "block to block" data from an oracle scripted agent.
91 | * **language_table_blocktoblockrelative_oracle_sim**: 200,000 episodes of single task "block-to-block-relative" data from an oracle scripted agent.
92 | * **language_table_blocktoabsolute_oracle_sim**: 200,000 episodes of single task "block to absolute location" data from an oracle scripted agent.
93 | * **language_table_blocktorelative_oracle_sim**: 200,000 episodes of single task "block to relative location" data from an oracle scripted agent.
94 | * **language_table_separate_oracle_sim**: 200,000 episodes of single task "separate blocks" data from an oracle scripted agent.
95 |
96 | ### Summary Table
97 |
98 | Dataset | Real/sim | Controlled by | Language-labeled by | # episodes
99 | --------| --------- | ------------- | ----------------- | --------:
100 | language_table | real | human | human | 442,226
101 | language_table_sim | sim | human | human | 181,020
102 | language_table_blocktoblock_sim | sim | human | scripted | 8,000
103 | language_table_blocktoblock_4block_sim | sim | human | scripted | 8,298
104 | language_table_blocktoblock_oracle_sim | sim | oracle | scripted | 200,000
105 | language_table_blocktoblockrelative_oracle_sim | sim | oracle | scripted | 200,000
106 | language_table_blocktoabsolute_oracle_sim | sim | oracle | scripted | 200,000
107 | language_table_blocktorelative_oracle_sim | sim | oracle | scripted | 200,000
108 | language_table_separate_oracle_sim | sim | oracle | scripted | 200,000
109 |
110 | ### Paths
111 |
112 | Dataset | Data Location
113 | --------| --------------
114 | language_table | [gs://gresearch/robotics/language_table](https://console.cloud.google.com/storage/browser/gresearch/robotics/language_table/0.0.1/)
115 | language_table_sim | [gs://gresearch/robotics/language_table_sim](https://console.cloud.google.com/storage/browser/gresearch/robotics/language_table_sim/0.0.1/)
116 | language_table_blocktoblock_sim | [gs://gresearch/robotics/language_table_blocktoblock_sim](https://console.cloud.google.com/storage/browser/gresearch/robotics/language_table_blocktoblock_sim/0.0.1/)
117 | language_table_blocktoblock_4block_sim | [gs://gresearch/robotics/language_table_blocktoblock_4block_sim](https://console.cloud.google.com/storage/browser/gresearch/robotics/language_table_blocktoblock_4block_sim/0.0.1/)
118 | language_table_blocktoblock_oracle_sim | [gs://gresearch/robotics/language_table_blocktoblock_oracle_sim](https://console.cloud.google.com/storage/browser/gresearch/robotics/language_table_blocktoblock_oracle_sim/0.0.1/)
119 | language_table_blocktoblockrelative_oracle_sim | [gs://gresearch/robotics/language_table_blocktoblockrelative_oracle_sim](https://console.cloud.google.com/storage/browser/gresearch/robotics/language_table_blocktoblockrelative_oracle_sim/0.0.1/)
120 | language_table_blocktoabsolute_oracle_sim | [gs://gresearch/robotics/language_table_blocktoabsolute_oracle_sim](https://console.cloud.google.com/storage/browser/gresearch/robotics/language_table_blocktoabsolute_oracle_sim/0.0.1/)
121 | language_table_blocktorelative_oracle_sim | [gs://gresearch/robotics/language_table_blocktorelative_oracle_sim](https://console.cloud.google.com/storage/browser/gresearch/robotics/language_table_blocktorelative_oracle_sim/0.0.1/)
122 | language_table_separate_oracle_sim | [gs://gresearch/robotics/language_table_separate_oracle_sim](https://console.cloud.google.com/storage/browser/gresearch/robotics/language_table_separate_oracle_sim/0.0.1/)
123 |
124 | ## Checkpoints
125 |
126 | Name | Config | Checkpoint Location
127 | -----| -------| -------------------
128 | BC+ResNet Sim| language_table/train/configs/language_table_resnet_sim_local.py | [gs://gresearch/robotics/language_table_checkpoints/bc_resnet_sim_checkpoint_955000](https://storage.googleapis.com/gresearch/robotics/language_table_checkpoints/bc_resnet_sim_checkpoint_955000)
129 |
130 | ## Interactive Language: Talking to Robots in Real Time
131 | [Project Website](https://interactive-language.github.io/) • [PDF](https://arxiv.org/pdf/2210.06407.pdf)
132 |
133 | *Corey Lynch, Ayzaan Wahid, Jonathan Tompson, Tianli Ding, James Betker, Robert Baruch, Travis Armstrong, Pete Florence*
134 |
135 | **Abstract.** We present a framework for building interactive, real-time, natural language-instructable robots in the real world, and we open source related assets (dataset, environment, benchmark, and policies). Trained with behavioral cloning on a dataset of hundreds of thousands of language-annotated trajectories, a produced policy can proficiently execute an order of magnitude more commands than previous works: specifically we estimate a 93.5% success rate on a set of 87,000 unique natural language strings specifying raw end-to-end visuolinguo-motor skills in the real world. We find that the same policy is capable of being guided by a human via real-time language to address a wide range of precise long-horizon rearrangement goals, e.g. "make a smiley face out of blocks". The dataset we release comprises nearly 600,000 language-labeled trajectories, an order of magnitude larger than prior available datasets. We hope the demonstrated results and associated assets enable further advancement of helpful, capable, natural-language-interactable robots.
136 |
137 | ## Note
138 |
139 | This is not an officially supported Google product.
140 |
--------------------------------------------------------------------------------
/docs/real.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google-research/language-table/ee3104061f35e1261a2e28cc2a29ea84e7156661/docs/real.jpeg
--------------------------------------------------------------------------------
/docs/sim.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google-research/language-table/ee3104061f35e1261a2e28cc2a29ea84e7156661/docs/sim.jpeg
--------------------------------------------------------------------------------
/language_table/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Language Tale Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
--------------------------------------------------------------------------------
/language_table/common/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Language Tale Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
--------------------------------------------------------------------------------
/language_table/common/clip_tokenizer.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Language Tale Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """An implementation of an in-graph TF Tokenizer.
17 |
18 | This is based on the SimpleTokenizer implementation from CLIP.
19 |
20 | Note: while this returns similar results for many strings used in
21 | language_table, there is no guarantee that this returns equivalent tokens for
22 | every text input in general. Particularly, escaped HTML input is not
23 | correctly handled. See test cases for more details.
24 |
25 | General usage is:
26 | ```
27 | vocab_lookup = create_vocab(bpe_path=default_bpe())
28 | tokenizer = ClipTokenizer(vocab_lookup)
29 | tokens = tokenize_text("example input text", tokenizer)
30 | ```
31 | """
32 |
33 | import gzip
34 |
35 | from clip.simple_tokenizer import bytes_to_unicode
36 | from clip.simple_tokenizer import default_bpe
37 |
38 | import tensorflow as tf
39 | import tensorflow_text as tf_text
40 |
41 |
42 | class ClipTokenizer(tf_text.TokenizerWithOffsets):
43 | """An in-graph TF implementation similar to SimpleTokenizer."""
44 |
45 | def __init__(
46 | self,
47 | vocab_lookup_table,
48 | max_bytes_per_word=100,
49 | max_chars_per_token=None,
50 | token_out_type=tf.int64,
51 | unknown_token="[UNK]",
52 | split_unknown_characters=False,
53 | lower_case=False,
54 | keep_whitespace=False,
55 | normalization_form=None,
56 | preserve_unused_token=False,
57 | ):
58 | super(ClipTokenizer, self).__init__()
59 |
60 | self._re = r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""" # pylint: disable=line-too-long
61 | self._wordpiece_tokenizer = tf_text.WordpieceTokenizer(
62 | vocab_lookup_table,
63 | suffix_indicator="",
64 | max_bytes_per_word=max_bytes_per_word,
65 | max_chars_per_token=max_chars_per_token,
66 | token_out_type=token_out_type,
67 | unknown_token=unknown_token,
68 | split_unknown_characters=split_unknown_characters,
69 | )
70 |
71 | def tokenize_with_offsets(self, text_input):
72 | # Do basic cleaning of text inputs.
73 | # the CLIP simple tokenizer does the following:
74 | # 1. text = html.unescape(html.unescape(text))
75 | # (we skip this for tf, which means *escaped HTML inputs will not return
76 | # the same tokens as SimpleTokenizer.)
77 | # 2. text = text.strip()
78 | text_input = tf.strings.strip(text_input)
79 | # 3. text = re.sub(r'\s+', ' ', text)
80 | text_input = tf.strings.regex_replace(text_input, r"\s+", " ")
81 | # 4. text = text.strip()
82 | text_input = tf.strings.strip(text_input)
83 | # 5. text = text.lower()
84 | text_input = tf.strings.lower(text_input)
85 |
86 | tokens, begin, _ = tf_text.regex_split_with_offsets(
87 | text_input,
88 | delim_regex_pattern=self._re + r"|\s",
89 | keep_delim_regex_pattern=self._re,
90 | )
91 |
92 | # Add the end character to each token after the regex split.
93 | num_tokens = tf.shape(tokens.values)[0]
94 | end_char = tf.tile([""], [num_tokens])
95 | tokens_with_end = tf.strings.join([tokens, end_char])
96 |
97 | begin = tf.cast(begin, tf.int64)
98 |
99 | # The wordpiece tokenizer has the same logic as the bpe() routine.
100 | (
101 | wordpieces,
102 | wp_begin,
103 | wp_end,
104 | ) = self._wordpiece_tokenizer.tokenize_with_offsets(tokens_with_end)
105 |
106 | begin_expanded = tf.expand_dims(begin, axis=2)
107 | final_begin = begin_expanded + wp_begin
108 | final_end = begin_expanded + wp_end
109 |
110 | return wordpieces, final_begin, final_end
111 |
112 | def tokenize(self, text_input):
113 | tokens, _, _ = self.tokenize_with_offsets(text_input)
114 | return tokens.flat_values
115 |
116 |
117 | def create_vocab(*, bpe_path=default_bpe()):
118 | """Creates the input vocabulary table for the tokenizer."""
119 | with tf.io.gfile.GFile(bpe_path, "rb") as f:
120 | merges = gzip.open(f).read().decode("utf-8").split("\n")
121 | merges = merges[1 : 49152 - 256 - 2 + 1]
122 | merges = [tuple(merge.split()) for merge in merges]
123 | vocab = list(bytes_to_unicode().values())
124 | # add '##' prefix to indicate sub-word
125 | vocab = vocab + [v + "" for v in vocab]
126 | for merge in merges:
127 | vocab.append("".join(merge))
128 | vocab.extend(["<|startoftext|>", "<|endoftext|>"])
129 | bpe_lookup = tf.lookup.StaticVocabularyTable(
130 | num_oov_buckets=1,
131 | initializer=tf.lookup.KeyValueTensorInitializer(
132 | keys=vocab, values=tf.range(len(vocab), dtype=tf.int64)
133 | ),
134 | )
135 | return bpe_lookup
136 |
137 |
138 | def tokenize_text(text, tokenizer, vocab_size=49408):
139 | """Tokenizes the input text given a tokenizer."""
140 | tokens, start_idx, end_idx = tokenizer.tokenize_with_offsets(text)
141 | # flatten sub-word tokenization
142 | tokens = tokens.merge_dims(-2, -1)
143 | start_idx = start_idx.merge_dims(-2, -1)
144 | end_idx = end_idx.merge_dims(-2, -1)
145 | count = tokens.bounding_shape()[0]
146 | # pad sparse tokens tensor with start and end tokens
147 | starts = tf.cast(tf.fill([count, 1], vocab_size - 2), dtype=tf.int64)
148 | ends = tf.cast(tf.fill([count, 1], vocab_size - 1), dtype=tf.int64)
149 | tokens = tf.concat([starts, tokens, ends], axis=1)
150 | # convert sparse tensor to zero padded dense tensor
151 | tokens = tokens.to_tensor(shape=(count, 77))
152 | return tokens
153 |
--------------------------------------------------------------------------------
/language_table/common/clip_tokenizer_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Language Tale Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for clip_tokenizer."""
17 |
18 | from clip.simple_tokenizer import default_bpe
19 | from clip.simple_tokenizer import SimpleTokenizer
20 |
21 | from language_table.common import clip_tokenizer
22 |
23 | import numpy as np
24 | import tensorflow as tf
25 |
26 |
27 | class ClipTokenizerTest(tf.test.TestCase):
28 |
29 | def test_matches_simple_tokenizer(self):
30 | vocab_lookup = clip_tokenizer.create_vocab(bpe_path=default_bpe())
31 | tokenizer = clip_tokenizer.ClipTokenizer(vocab_lookup)
32 |
33 | simple_tokenizer = SimpleTokenizer(default_bpe())
34 |
35 | all_instructions = [
36 | "pull the red moon apart from the star, red pentagon, and blue moon",
37 | "hexagon cube star crescent moon block",
38 | "I must've gone, to where you're going and place where he'll go,",
39 | "push the blue hexagon to the left!!!",
40 | " a thing with lots of spaces",
41 | "a picture of an elephantcat ",
42 | " a picture of an \n elephantcat ",
43 | "A PICTURE of AN elephantCAT",
44 |
45 | # Current known failure cases.
46 | # "<html>I want to escape HTML </html>"
47 | ]
48 |
49 | for instruction in all_instructions:
50 | result = clip_tokenizer.tokenize_text(instruction, tokenizer)
51 | simple_result = _simple_tokenize(simple_tokenizer, [instruction])
52 | np.testing.assert_equal(result, simple_result, instruction)
53 |
54 |
55 | def _simple_tokenize(tokenizer, texts, context_length = 77):
56 | sot_token = tokenizer.encoder["<|startoftext|>"]
57 | eot_token = tokenizer.encoder["<|endoftext|>"]
58 | all_tokens = [
59 | [sot_token] + tokenizer.encode(text) + [eot_token] for text in texts
60 | ]
61 | result = np.zeros((len(all_tokens), context_length), dtype=int)
62 |
63 | for i, tokens in enumerate(all_tokens):
64 | if len(tokens) > context_length:
65 | raise RuntimeError(
66 | f"Input {texts[i]} is too long for context length {context_length}")
67 | result[i, :len(tokens)] = np.asarray(tokens)
68 |
69 | return result
70 |
71 |
72 | if __name__ == "__main__":
73 | tf.test.main()
74 |
--------------------------------------------------------------------------------
/language_table/environments/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Language Tale Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
--------------------------------------------------------------------------------
/language_table/environments/assets/blocks/blue_cube.urdf:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
--------------------------------------------------------------------------------
/language_table/environments/assets/blocks/blue_moon.urdf:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
--------------------------------------------------------------------------------
/language_table/environments/assets/blocks/blue_pentagon.urdf:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
--------------------------------------------------------------------------------
/language_table/environments/assets/blocks/blue_star.urdf:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
--------------------------------------------------------------------------------
/language_table/environments/assets/blocks/green_cube.urdf:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
--------------------------------------------------------------------------------
/language_table/environments/assets/blocks/green_moon.urdf:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
--------------------------------------------------------------------------------
/language_table/environments/assets/blocks/green_pentagon.urdf:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
--------------------------------------------------------------------------------
/language_table/environments/assets/blocks/green_star.urdf:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
--------------------------------------------------------------------------------
/language_table/environments/assets/blocks/purple_pole.urdf:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
--------------------------------------------------------------------------------
/language_table/environments/assets/blocks/red_cube.urdf:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
--------------------------------------------------------------------------------
/language_table/environments/assets/blocks/red_moon.urdf:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
--------------------------------------------------------------------------------
/language_table/environments/assets/blocks/red_pentagon.urdf:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
--------------------------------------------------------------------------------
/language_table/environments/assets/blocks/red_star.urdf:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
--------------------------------------------------------------------------------
/language_table/environments/assets/blocks/yellow_cube.urdf:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
--------------------------------------------------------------------------------
/language_table/environments/assets/blocks/yellow_moon.urdf:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
--------------------------------------------------------------------------------
/language_table/environments/assets/blocks/yellow_pentagon.urdf:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
--------------------------------------------------------------------------------
/language_table/environments/assets/blocks/yellow_star.urdf:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
--------------------------------------------------------------------------------
/language_table/environments/assets/plane.obj:
--------------------------------------------------------------------------------
1 | # Blender v2.66 (sub 1) OBJ File: ''
2 | # www.blender.org
3 | mtllib plane.mtl
4 | o Plane
5 | v 15.000000 -15.000000 0.000000
6 | v 15.000000 15.000000 0.000000
7 | v -15.000000 15.000000 0.000000
8 | v -15.000000 -15.000000 0.000000
9 |
10 | vt 15.000000 0.000000
11 | vt 15.000000 15.000000
12 | vt 0.000000 15.000000
13 | vt 0.000000 0.000000
14 |
15 | usemtl Material
16 | s off
17 | f 1/1 2/2 3/3
18 | f 1/1 3/3 4/4
19 |
--------------------------------------------------------------------------------
/language_table/environments/assets/suction/cylinder.urdf:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
95 |
96 |
97 |
98 |
99 |
--------------------------------------------------------------------------------
/language_table/environments/assets/suction/cylinder_real.urdf:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
95 |
96 |
97 |
98 |
99 |
--------------------------------------------------------------------------------
/language_table/environments/assets/suction/suction-base.urdf:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
--------------------------------------------------------------------------------
/language_table/environments/assets/suction/suction-head-long.urdf:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
98 |
99 |
100 |
101 |
102 |
--------------------------------------------------------------------------------
/language_table/environments/assets/suction/suction-head.urdf:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
--------------------------------------------------------------------------------
/language_table/environments/assets/workspace_real.urdf:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
28 |
29 |
30 |
--------------------------------------------------------------------------------
/language_table/environments/blocks.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Language Tale Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Defines the n choose k blocks on Language Table."""
17 | import collections
18 | import enum
19 | import itertools
20 |
21 | import numpy as np
22 |
23 |
24 | class LanguageTableBlockVariants(enum.Enum):
25 | BLOCK_1 = 'BLOCK_1' # 1 green star. Just for debugging.
26 | BLOCK_4 = 'BLOCK_4' # The original 4 blocks.
27 | BLOCK_8 = 'BLOCK_8' # 2 of each color, 2 of each shape, 8 total.
28 | BLOCK_4_WPOLE = 'BLOCK_4_WPOLE' # original 4 blocks with purple pole as goal
29 | BLOCK_8_WPOLE = 'BLOCK_8_WPOLE' # 8 blocks with purple pole as goal
30 | N_CHOOSE_K = 'N_CHOOSE_K' # Combinatorial.
31 |
32 |
33 | BLOCK_VARIANTS = [i.value for i in LanguageTableBlockVariants]
34 |
35 |
36 | def get_all_block_subsets(mode, training):
37 | """Returns all subsets for the chosen mode."""
38 | if mode == LanguageTableBlockVariants.BLOCK_1:
39 | return [FIXED_1_COMBINATION]
40 | if mode == LanguageTableBlockVariants.BLOCK_4:
41 | return [FIXED_4_COMBINATION]
42 | elif mode == LanguageTableBlockVariants.BLOCK_8:
43 | return [FIXED_8_COMBINATION]
44 | elif mode == LanguageTableBlockVariants.N_CHOOSE_K:
45 | if training:
46 | return TRAIN_COMBINATIONS
47 | else:
48 | return TEST_COMBINATIONS
49 | elif mode == LanguageTableBlockVariants.BLOCK_4_WPOLE:
50 | return [FIXED_4_COMBINATION_WPOLE]
51 | elif mode == LanguageTableBlockVariants.BLOCK_8_WPOLE:
52 | return [FIXED_8_COMBINATION_WPOLE]
53 | else:
54 | raise ValueError('Unsupported block mode')
55 |
56 |
57 | def get_block_set(mode):
58 | """Defines unique set of blocks by mode."""
59 | if mode == LanguageTableBlockVariants.BLOCK_1:
60 | return FIXED_1_COMBINATION
61 | if mode == LanguageTableBlockVariants.BLOCK_4:
62 | return FIXED_4_COMBINATION
63 | elif mode == LanguageTableBlockVariants.BLOCK_8:
64 | return FIXED_8_COMBINATION
65 | elif mode == LanguageTableBlockVariants.N_CHOOSE_K:
66 | return ALL_BLOCKS
67 | else:
68 | raise ValueError('Unsupported block mode')
69 |
70 |
71 | def get_all_block_pairs(mode):
72 | """Defines all pairs of blocks. Useful for generating all instructions."""
73 | all_blocks = get_block_set(mode)
74 | all_pairs = itertools.permutations(all_blocks, 2)
75 | return all_pairs
76 |
77 |
78 | def get_blocks_text_descriptions(mode):
79 | """Get text strings for all blocks on table by mode."""
80 | blocks = get_block_set(mode)
81 | blocks_text = [' '.join(i.split('_')) for i in blocks]
82 | return blocks_text
83 |
84 |
85 | BLOCK_URDF_PATHS = collections.OrderedDict(
86 | # Red blocks.
87 | red_moon='third_party/py/language_table/environments/assets/blocks/red_moon.urdf',
88 | red_cube='third_party/py/language_table/environments/assets/blocks/red_cube.urdf',
89 | red_star='third_party/py/language_table/environments/assets/blocks/red_star.urdf',
90 | red_pentagon='third_party/py/language_table/environments/assets/blocks/red_pentagon.urdf',
91 | # Blue blocks.
92 | blue_moon='third_party/py/language_table/environments/assets/blocks/blue_moon.urdf',
93 | blue_cube='third_party/py/language_table/environments/assets/blocks/blue_cube.urdf',
94 | blue_star='third_party/py/language_table/environments/assets/blocks/blue_star.urdf',
95 | blue_pentagon='third_party/py/language_table/environments/assets/blocks/blue_pentagon.urdf',
96 | # Yellow blocks.
97 | yellow_moon='third_party/py/language_table/environments/assets/blocks/yellow_moon.urdf',
98 | yellow_cube='third_party/py/language_table/environments/assets/blocks/yellow_cube.urdf',
99 | yellow_star='third_party/py/language_table/environments/assets/blocks/yellow_star.urdf',
100 | yellow_pentagon='third_party/py/language_table/environments/assets/blocks/yellow_pentagon.urdf',
101 | # Green blocks.
102 | green_moon='third_party/py/language_table/environments/assets/blocks/green_moon.urdf',
103 | green_cube='third_party/py/language_table/environments/assets/blocks/green_cube.urdf',
104 | green_star='third_party/py/language_table/environments/assets/blocks/green_star.urdf',
105 | green_pentagon='third_party/py/language_table/environments/assets/blocks/green_pentagon.urdf',
106 | )
107 |
108 | POLE_URDF_PATHS = collections.OrderedDict(
109 | # Purple Pole.
110 | purple_pole='third_party/py/language_table/environments/assets/blocks/purple_pole.urdf',
111 | )
112 |
113 | # Use this just to define the observation space.
114 | DUMMY_START_BLOCK = list(BLOCK_URDF_PATHS.keys())[0]
115 | COLORS = ['red', 'blue', 'green', 'yellow']
116 | SHAPES = ['moon', 'cube', 'star', 'pentagon']
117 | ALL_BLOCKS = ['_'.join(i) for i in itertools.product(COLORS, SHAPES)]
118 | MIN_K = 4
119 | MAX_K = 10
120 | ALL_COMBINATIONS = []
121 | for k in range(MIN_K, MAX_K+1):
122 | k_combos = list(itertools.combinations(ALL_BLOCKS, k))
123 | ALL_COMBINATIONS.extend(k_combos)
124 | # Seeded shuffle.
125 | combo_rng = np.random.RandomState(seed=0)
126 | combo_rng.shuffle(ALL_COMBINATIONS)
127 | # Divide combinations by train / test.
128 | TRAIN_COMBINATIONS = ALL_COMBINATIONS[:int(len(ALL_COMBINATIONS)*0.9)]
129 | TEST_COMBINATIONS = ALL_COMBINATIONS[int(len(ALL_COMBINATIONS)*0.9):]
130 |
131 | # 8 total, 2 of each color, 2 of each shape.
132 | FIXED_8_COMBINATION = (
133 | 'red_moon',
134 | 'red_pentagon',
135 | 'blue_moon',
136 | 'blue_cube',
137 | 'green_cube',
138 | 'green_star',
139 | 'yellow_star',
140 | 'yellow_pentagon')
141 |
142 | # The original "4-block" environment.
143 | FIXED_4_COMBINATION = (
144 | 'red_moon',
145 | 'blue_cube',
146 | 'green_star',
147 | 'yellow_pentagon'
148 | )
149 |
150 |
151 | # 8 total blocks + 1 goal purple pole, 2 of each color, 2 of each shape.
152 | FIXED_8_COMBINATION_WPOLE = ('red_moon', 'red_pentagon', 'blue_moon',
153 | 'blue_cube', 'green_cube', 'green_star',
154 | 'yellow_star', 'yellow_pentagon', 'purple_pole')
155 |
156 | # The original "4-block" environment + 1 goal purple pole.
157 | FIXED_4_COMBINATION_WPOLE = ('red_moon', 'blue_cube', 'green_star',
158 | 'yellow_pentagon', 'purple_pole')
159 | # 1-block debugging environment.
160 | FIXED_1_COMBINATION = ['green_star']
161 |
--------------------------------------------------------------------------------
/language_table/environments/constants.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Language Tale Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Shared language table constants."""
17 |
18 | import math
19 | import numpy as np
20 | from scipy.spatial import transform
21 |
22 | PLANE_URDF_PATH = ('third_party/bullet/examples/pybullet/gym/pybullet_data/'
23 | 'plane.urdf')
24 |
25 | EFFECTOR_HEIGHT = 0.145
26 | EFFECTOR_DOWN_ROTATION = transform.Rotation.from_rotvec([0, math.pi, 0])
27 |
28 | X_MIN = 0.15
29 | X_MAX = 0.6
30 | Y_MIN = -0.3048
31 | Y_MAX = 0.3048
32 | CENTER_X = (X_MAX - X_MIN) / 2. + X_MIN
33 | CENTER_Y = (Y_MAX - Y_MIN) / 2. + Y_MIN
34 |
35 | WORKSPACE_BOUNDS_BUFFER = 0.08
36 |
37 | BLOCK_DISTANCE_THRESHOLD = 0.0175
38 | ARM_DISTANCE_THRESHOLD = 0.06
39 | INSTRUCTION_LENGTH = 512 # max number of chars in instruction
40 |
41 | WORKSPACE_BOUNDS = np.array(((X_MIN, Y_MIN), (X_MAX, Y_MAX)))
42 | WORKSPACE_URDF_PATH = 'third_party/py/language_table/environments/assets/workspace_real.urdf'
43 | CAMERA_POSE = (0.75, 0, 0.5)
44 | CAMERA_ORIENTATION = (np.pi / 5, np.pi, -np.pi / 2)
45 |
46 | IMAGE_WIDTH = 320
47 | IMAGE_HEIGHT = 180
48 | CAMERA_INTRINSICS = (
49 | 0.803 * IMAGE_WIDTH, # fx
50 | 0,
51 | IMAGE_WIDTH / 2., # cx
52 | 0,
53 | 0.803 * IMAGE_WIDTH, # fy
54 | IMAGE_HEIGHT / 2., # cy
55 | 0,
56 | 0,
57 | 1)
58 |
59 | # Corresponds to:
60 | # rotation = transform.Rotation.from_rotvec([0, math.pi, 0])
61 | # translation = np.array([0.3, -0.2, 0.145])
62 | INITIAL_JOINT_POSITIONS = np.array([
63 | -0.5875016909413221, 0.15985553866983415, -0.4992862770497537,
64 | 0.0017427885915130214, 0.33927183830553914, -3.7249551487437524
65 | ])
66 |
--------------------------------------------------------------------------------
/language_table/environments/language_table_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Language Tale Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for language_table."""
17 |
18 | from language_table.environments import blocks
19 | from language_table.environments import language_table
20 | from language_table.environments.rewards import block2block
21 | import numpy as np
22 | import tensorflow as tf
23 |
24 |
25 | class LanguageTableTest(tf.test.TestCase):
26 |
27 | def test_save_restore_pybullet_state(self):
28 | block_modes = [
29 | blocks.LanguageTableBlockVariants.BLOCK_4,
30 | blocks.LanguageTableBlockVariants.BLOCK_8
31 | ]
32 | for block_mode in block_modes:
33 | env = language_table.LanguageTable(
34 | block_mode=block_mode,
35 | reward_factory=block2block.BlockToBlockReward,
36 | control_frequency=10.0,
37 | seed=0)
38 | obs0 = env.reset()
39 | actions = [
40 | np.random.uniform(size=(2,), low=-0.03, high=0.03) for _ in range(40)
41 | ]
42 | observations = [obs0]
43 | pbstates = [env.get_pybullet_state()]
44 |
45 | # Take actions in environment, storing obs and pbstates.
46 | for act in actions:
47 | obs, _, _, _ = env.step(act)
48 | observations.append(obs)
49 | pbstates.append(env.get_pybullet_state())
50 |
51 | # Reset env.
52 | env.reset()
53 |
54 | # Replay states into observations.
55 | reconstructed_obs = []
56 | for pb in pbstates:
57 | env.set_pybullet_state(pb)
58 | obs = env._compute_state(request_task_update=False)
59 | reconstructed_obs.append(obs)
60 | self.assertEqual(len(observations), len(reconstructed_obs))
61 | for i, (orig_obs,
62 | recon_obs) in enumerate(zip(observations, reconstructed_obs)):
63 | for k in orig_obs:
64 | if k == 'rgb':
65 | # pybullet rgb is sometimes flaky as the rendering is
66 | # non-deterministic. Since the individual pixel errors
67 | # are high magnitude, just make sure the average error
68 | # is low.
69 | self.assertLess(
70 | np.abs(orig_obs[k] - recon_obs[k]).mean(), 1e-6,
71 | f'Observation {k} at step {i} was not equal.')
72 | elif k == 'effector_translation':
73 | self.assertLess(
74 | np.abs(orig_obs[k] - recon_obs[k]).max(), 1e-3,
75 | f'Observation {k} at step {i} was not equal.')
76 | else:
77 | self.assertLess(
78 | np.abs(orig_obs[k] - recon_obs[k]).max(), 1e-6,
79 | f'Observation {k} at step {i} was not equal.')
80 | self.assertEqual(orig_obs[k].dtype, recon_obs[k].dtype)
81 |
82 | def test_save_restore_pybullet_state_instruction_len(self):
83 | block_modes = [
84 | blocks.LanguageTableBlockVariants.BLOCK_4,
85 | blocks.LanguageTableBlockVariants.BLOCK_8
86 | ]
87 | for block_mode in block_modes:
88 | env = language_table.LanguageTable(
89 | block_mode=block_mode,
90 | reward_factory=block2block.BlockToBlockReward,
91 | control_frequency=10.0,
92 | seed=0)
93 | obs0 = env.reset()
94 | actions = [
95 | np.random.uniform(size=(2,), low=-0.03, high=0.03) for _ in range(40)
96 | ]
97 | observations = [obs0]
98 | pbstates = [env.get_pybullet_state()]
99 |
100 | # Take actions in environment, storing obs and pbstates.
101 | for act in actions:
102 | obs, _, _, _ = env.step(act)
103 | observations.append(obs)
104 |
105 | pybullet_state = env.get_pybullet_state()
106 |
107 | # Modify the pybullet state to make the instruction length different.
108 | pybullet_state['instruction'] = pybullet_state['instruction'][:128]
109 |
110 | pbstates.append(pybullet_state)
111 |
112 | # Reset env.
113 | env.reset()
114 |
115 | # Replay states into observations.
116 | reconstructed_obs = []
117 | for pb in pbstates:
118 | env.set_pybullet_state(pb)
119 | state = env._compute_state(request_task_update=False)
120 | obs = env._compute_observation(state=state)
121 |
122 | self.assertTrue(env.observation_space.contains(obs))
123 |
124 | # 256 is the constants.INSTRUCTION_LENGTH
125 | self.assertEqual(obs['instruction'].size, 512)
126 |
127 | reconstructed_obs.append(obs)
128 | for i, (orig_obs,
129 | recon_obs) in enumerate(zip(observations, reconstructed_obs)):
130 | for k in orig_obs:
131 | if k == 'rgb':
132 | # pybullet rgb is sometimes flaky as the rendering is
133 | # non-deterministic. Since the individual pixel errors
134 | # are high magnitude, just make sure the average error
135 | # is low.
136 | self.assertLess(
137 | np.abs(orig_obs[k] - recon_obs[k]).mean(), 1e-6,
138 | f'Observation {k} at step {i} was not equal.')
139 | elif k == 'effector_translation':
140 | self.assertLess(
141 | np.abs(orig_obs[k] - recon_obs[k]).max(), 1e-3,
142 | f'Observation {k} at step {i} was not equal.')
143 | else:
144 | self.assertLess(
145 | np.abs(orig_obs[k] - recon_obs[k]).max(), 1e-6,
146 | f'Observation {k} at step {i} was not equal.')
147 | self.assertEqual(orig_obs[k].dtype, recon_obs[k].dtype)
148 |
149 | def test_environment_initializes_and_resets(self):
150 | env = language_table.LanguageTable(
151 | block_mode=blocks.LanguageTableBlockVariants.BLOCK_4,
152 | reward_factory=block2block.BlockToBlockReward,
153 | control_frequency=10.0)
154 | env.reset()
155 |
156 | def test_environment_obs_space_contains_obs(self):
157 | env = language_table.LanguageTable(
158 | block_mode=blocks.LanguageTableBlockVariants.BLOCK_4,
159 | reward_factory=block2block.BlockToBlockReward,
160 | control_frequency=10.0)
161 | obs = env.reset()
162 | self.assertTrue(env.observation_space.contains(obs))
163 | for k in obs:
164 | self.assertEqual(obs[k].dtype, env.observation_space[k].dtype)
165 |
166 | def test_environment_steps_block4(self):
167 | env = language_table.LanguageTable(
168 | block_mode=blocks.LanguageTableBlockVariants.BLOCK_4,
169 | reward_factory=block2block.BlockToBlockReward,
170 | control_frequency=10.0,
171 | seed=0)
172 | obs = env.reset()
173 | for _ in range(5):
174 | obs, _, _, _ = env.step(env.action_space.sample())
175 | self.assertTrue(env.observation_space.contains(obs))
176 |
177 | def test_environment_steps_block8(self):
178 | env = language_table.LanguageTable(
179 | block_mode=blocks.LanguageTableBlockVariants.BLOCK_8,
180 | reward_factory=block2block.BlockToBlockReward,
181 | control_frequency=10.0,
182 | seed=0)
183 | obs = env.reset()
184 | for _ in range(5):
185 | obs, _, _, _ = env.step(env.action_space.sample())
186 | self.assertTrue(env.observation_space.contains(obs))
187 |
188 | def test_environment_steps_nchoosek(self):
189 | env = language_table.LanguageTable(
190 | block_mode=blocks.LanguageTableBlockVariants.N_CHOOSE_K,
191 | reward_factory=block2block.BlockToBlockReward,
192 | control_frequency=10.0,
193 | seed=0)
194 | obs = env.reset()
195 | for _ in range(5):
196 | obs, _, _, _ = env.step(env.action_space.sample())
197 | self.assertTrue(env.observation_space.contains(obs))
198 |
199 |
200 | if __name__ == '__main__':
201 | tf.test.main()
202 |
--------------------------------------------------------------------------------
/language_table/environments/oracles/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Language Tale Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
--------------------------------------------------------------------------------
/language_table/environments/oracles/oriented_push_oracle.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Language Tale Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Oracle for pushing task which orients the block then pushes it."""
17 | import dataclasses
18 |
19 | from typing import Any
20 |
21 | import numpy as np
22 | from tf_agents.policies import py_policy
23 | from tf_agents.trajectories import policy_step
24 | from tf_agents.trajectories import time_step as ts
25 |
26 |
27 | @dataclasses.dataclass
28 | class PushingInfo:
29 | """Holds onto info necessary for pushing state machine."""
30 | xy_block: Any = None
31 | xy_ee: Any = None
32 | xy_pre_block: Any = None
33 | xy_dir_block_to_target: Any = None
34 | xy_delta_to_nexttoblock: Any = None
35 | xy_delta_to_touchingblock: Any = None
36 | xy_dir_block_to_ee: Any = None
37 | theta_threshold_to_orient: Any = None
38 | theta_threshold_flat_enough: Any = None
39 | theta_error: Any = None
40 | obstacle_poses: Any = None
41 | distance_to_target: Any = None
42 |
43 |
44 | class OrientedPushOracle(py_policy.PyPolicy):
45 | """Oracle for pushing task which orients the block then pushes it."""
46 |
47 | def __init__(self, env, action_noise_std=0.0):
48 | super(OrientedPushOracle, self).__init__(env.time_step_spec(),
49 | env.action_spec())
50 | self._env = env
51 | self._np_random_state = np.random.RandomState(0)
52 | self.phase = "move_to_pre_block"
53 | self._action_noise_std = action_noise_std
54 |
55 | def reset(self):
56 | self.phase = "move_to_pre_block"
57 |
58 | def get_theta_from_vector(self, vector):
59 | return np.arctan2(vector[1], vector[0])
60 |
61 | def theta_to_rotation2d(self, theta):
62 | r = np.array([[np.cos(theta), -np.sin(theta)],
63 | [np.sin(theta), np.cos(theta)]])
64 | return r
65 |
66 | def rotate(self, theta, xy_dir_block_to_ee):
67 | rot_2d = self.theta_to_rotation2d(theta)
68 | return rot_2d @ xy_dir_block_to_ee
69 |
70 | def _get_action_info(self, raw_state, block, target):
71 | xy_block = raw_state["%s_translation" % block][:2]
72 | theta_block = raw_state["%s_orientation" % block]
73 | xy_target = raw_state["%s_translation" % target][:2]
74 | xy_ee = raw_state["effector_target_translation"][:2]
75 |
76 | xy_block_to_target = xy_target - xy_block
77 | xy_dir_block_to_target = (
78 | xy_block_to_target) / np.linalg.norm(xy_block_to_target)
79 | theta_to_target = self.get_theta_from_vector(xy_dir_block_to_target)
80 |
81 | theta_error = theta_to_target - theta_block
82 | # Block has 4-way symmetry.
83 | while theta_error > np.pi/4:
84 | theta_error -= np.pi/2.
85 | while theta_error < -np.pi/4:
86 | theta_error += np.pi/2.
87 |
88 | xy_pre_block = xy_block + -xy_dir_block_to_target * 0.05
89 | xy_nexttoblock = xy_block + -xy_dir_block_to_target * 0.03
90 | xy_touchingblock = xy_block + -xy_dir_block_to_target * 0.01
91 | xy_delta_to_nexttoblock = xy_nexttoblock - xy_ee
92 | xy_delta_to_touchingblock = xy_touchingblock - xy_ee
93 |
94 | xy_block_to_ee = xy_ee - xy_block
95 | xy_dir_block_to_ee = xy_block_to_ee / np.linalg.norm(xy_block_to_ee)
96 |
97 | theta_threshold_to_orient = 0.2
98 | theta_threshold_flat_enough = 0.03
99 | return PushingInfo(
100 | xy_block=xy_block,
101 | xy_ee=xy_ee,
102 | xy_pre_block=xy_pre_block,
103 | xy_delta_to_nexttoblock=xy_delta_to_nexttoblock,
104 | xy_delta_to_touchingblock=xy_delta_to_touchingblock,
105 | xy_dir_block_to_ee=xy_dir_block_to_ee,
106 | theta_threshold_to_orient=theta_threshold_to_orient,
107 | theta_threshold_flat_enough=theta_threshold_flat_enough,
108 | theta_error=theta_error)
109 |
110 | def _get_move_to_preblock(self, xy_pre_block, xy_ee):
111 | max_step_velocity = 0.3
112 | # Go 5 cm away from the block, on the line between the block and target.
113 | xy_delta_to_preblock = xy_pre_block - xy_ee
114 | diff = np.linalg.norm(xy_delta_to_preblock)
115 | if diff < 0.001:
116 | self.phase = "move_to_block"
117 | xy_delta = xy_delta_to_preblock
118 | return xy_delta, max_step_velocity
119 |
120 | def _get_move_to_block(
121 | self, xy_delta_to_nexttoblock, theta_threshold_to_orient, theta_error):
122 | diff = np.linalg.norm(xy_delta_to_nexttoblock)
123 | if diff < 0.001:
124 | self.phase = "push_block"
125 | # If need to re-oorient, then re-orient.
126 | if theta_error > theta_threshold_to_orient:
127 | self.phase = "orient_block_left"
128 | if theta_error < -theta_threshold_to_orient:
129 | self.phase = "orient_block_right"
130 | # Otherwise, push into the block.
131 | xy_delta = xy_delta_to_nexttoblock
132 | return xy_delta
133 |
134 | def _get_push_block(
135 | self, theta_error, theta_threshold_to_orient, xy_delta_to_touchingblock):
136 | # If need to reorient, go back to move_to_pre_block, move_to_block first.
137 | if theta_error > theta_threshold_to_orient:
138 | self.phase = "move_to_pre_block"
139 | if theta_error < -theta_threshold_to_orient:
140 | self.phase = "move_to_pre_block"
141 | xy_delta = xy_delta_to_touchingblock
142 | return xy_delta
143 |
144 | def _get_orient_block_left(self,
145 | xy_dir_block_to_ee,
146 | orient_circle_diameter,
147 | xy_block,
148 | xy_ee,
149 | theta_error,
150 | theta_threshold_flat_enough):
151 | xy_dir_block_to_ee = self.rotate(0.2, xy_dir_block_to_ee)
152 | xy_block_to_ee = xy_dir_block_to_ee * orient_circle_diameter
153 | xy_push_left_spot = xy_block + xy_block_to_ee
154 | xy_delta = xy_push_left_spot - xy_ee
155 | if theta_error < theta_threshold_flat_enough:
156 | self.phase = "move_to_pre_block"
157 | return xy_delta
158 |
159 | def _get_orient_block_right(self,
160 | xy_dir_block_to_ee,
161 | orient_circle_diameter,
162 | xy_block,
163 | xy_ee,
164 | theta_error,
165 | theta_threshold_flat_enough):
166 | xy_dir_block_to_ee = self.rotate(-0.2, xy_dir_block_to_ee)
167 | xy_block_to_ee = xy_dir_block_to_ee * orient_circle_diameter
168 | xy_push_left_spot = xy_block + xy_block_to_ee
169 | xy_delta = xy_push_left_spot - xy_ee
170 | if theta_error > -theta_threshold_flat_enough:
171 | self.phase = "move_to_pre_block"
172 | return xy_delta
173 |
174 | def _get_action_for_block_target(self,
175 | raw_state,
176 | block="block",
177 | target="target"):
178 | # Specifying this as velocity makes it independent of control frequency.
179 | max_step_velocity = 0.35
180 | info = self._get_action_info(raw_state, block, target)
181 |
182 | if self.phase == "move_to_pre_block":
183 | xy_delta, max_step_velocity = self._get_move_to_preblock(
184 | info.xy_pre_block, info.xy_ee)
185 |
186 | if self.phase == "move_to_block":
187 | xy_delta = self._get_move_to_block(
188 | info.xy_delta_to_nexttoblock, info.theta_threshold_to_orient,
189 | info.theta_error)
190 |
191 | if self.phase == "push_block":
192 | xy_delta = self._get_push_block(
193 | info.theta_error, info.theta_threshold_to_orient,
194 | info.xy_delta_to_touchingblock)
195 |
196 | orient_circle_diameter = 0.025
197 |
198 | if self.phase == "orient_block_left" or self.phase == "orient_block_right":
199 | max_step_velocity = 0.15
200 |
201 | if self.phase == "orient_block_left":
202 | xy_delta = self._get_orient_block_left(
203 | info.xy_dir_block_to_ee,
204 | orient_circle_diameter,
205 | info.xy_block,
206 | info.xy_ee,
207 | info.theta_error,
208 | info.theta_threshold_flat_enough)
209 |
210 | if self.phase == "orient_block_right":
211 | xy_delta = self._get_orient_block_right(
212 | info.xy_dir_block_to_ee,
213 | orient_circle_diameter,
214 | info.xy_block,
215 | info.xy_ee,
216 | info.theta_error,
217 | info.theta_threshold_flat_enough)
218 |
219 | if self._action_noise_std != 0.0:
220 | xy_delta += (self._np_random_state.randn(2) *
221 | self._action_noise_std)
222 |
223 | max_step_distance = max_step_velocity * (1 /
224 | self._env.get_control_frequency())
225 | length = np.linalg.norm(xy_delta)
226 | if length > max_step_distance:
227 | xy_direction = xy_delta / length
228 | xy_delta = xy_direction * max_step_distance
229 | return xy_delta
230 |
231 | def _action(self,
232 | time_step,
233 | policy_state,
234 | seed = None):
235 | if time_step.is_first():
236 | self.reset()
237 | raw_state = self._env.compute_state()
238 | xy_delta = self._get_action_for_block_target(
239 | raw_state, block="block", target="target")
240 | return policy_step.PolicyStep(action=np.asarray(xy_delta, dtype=np.float32))
241 |
--------------------------------------------------------------------------------
/language_table/environments/oracles/plot.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Language Tale Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Utils for plotting RRT and writing debug images."""
17 |
18 | import imageio
19 | from matplotlib import patches
20 | from matplotlib import pyplot as plt
21 | import numpy as np
22 | import tensorflow as tf
23 |
24 |
25 | DEBUG_IMAGE_PATH = '/tmp/rrt_debug_img.png'
26 |
27 |
28 | class PlotRRT:
29 | """Plotting helper for visualizing RRT."""
30 |
31 | def __init__(self,
32 | x_start,
33 | x_goal,
34 | obs_boundary,
35 | obs_circle,
36 | image_path=DEBUG_IMAGE_PATH):
37 | self._x_start, self._x_goal = x_start, x_goal
38 | self._obs_bound = obs_boundary
39 | self._obs_circle = obs_circle
40 | self._image_path = image_path
41 |
42 | def animation(self, nodelist, path, name, animation=False):
43 | self.plot_grid(name)
44 | self.plot_visited(nodelist, animation)
45 | self.plot_path(path)
46 |
47 | def plot_grid(self, name):
48 | """Plot whole grid."""
49 | self.fig, ax = plt.subplots()
50 | for (ox, oy, w, h) in self._obs_bound:
51 | ax.add_patch(
52 | patches.Rectangle(
53 | (ox, oy), w, h,
54 | edgecolor='black',
55 | facecolor='black',
56 | fill=True)
57 | )
58 | for (ox, oy, r) in self._obs_circle:
59 | ax.add_patch(
60 | patches.Circle(
61 | (ox, oy), r,
62 | edgecolor='black',
63 | facecolor='gray',
64 | fill=True)
65 | )
66 |
67 | radius = self._obs_circle[0][2]
68 | ax.add_patch(
69 | patches.Circle(
70 | (self._x_start[0], self._x_start[1]), radius,
71 | edgecolor='black',
72 | facecolor='blue',
73 | fill=True)
74 | )
75 | ax.add_patch(
76 | patches.Circle(
77 | (self._x_goal[0], self._x_goal[1]), radius,
78 | edgecolor='black',
79 | facecolor='green',
80 | fill=True)
81 | )
82 | plt.title(name)
83 | plt.axis('equal')
84 |
85 | def plot_visited(self, nodelist, animation):
86 | """Plot visited."""
87 | if animation:
88 | count = 0
89 | for node in nodelist:
90 | count += 1
91 | if node.parent:
92 | plt.plot([node.parent.x, node.x], [node.parent.y, node.y], '-g')
93 | else:
94 | for node in nodelist:
95 | if node.parent:
96 | plt.plot([node.parent.x, node.x], [node.parent.y, node.y], '-g')
97 |
98 | def plot_path(self, path):
99 | if path:
100 | plt.plot([x[0] for x in path], [x[1] for x in path], '-r', linewidth=2)
101 | plt.show()
102 |
103 | def fig_to_array(self, figure):
104 | """Converts a matplotlib figure to a numpy array."""
105 | figure.canvas.draw()
106 | np_fig = np.fromstring(figure.canvas.tostring_rgb(), dtype=np.uint8, sep='')
107 | np_fig = np_fig.reshape(figure.canvas.get_width_height()[::-1] + (3,))
108 | return np_fig
109 |
110 | def save_debug_image(self):
111 | """Writes an image to a debug path on cns."""
112 | array = self.fig_to_array(self.fig)
113 | with tf.io.gfile.GFile(self._image_path, 'wb') as f:
114 | imageio.imwrite(f, array * 255.0, format='png')
115 |
--------------------------------------------------------------------------------
/language_table/environments/rewards/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Language Tale Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
--------------------------------------------------------------------------------
/language_table/environments/rewards/block1_to_corner.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Language Tale Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Defines block2absolutelocation reset and reward."""
17 | import enum
18 | from typing import Any, List, Mapping
19 |
20 | from absl import logging
21 | from language_table.environments import blocks as blocks_module
22 | from language_table.environments.rewards import reward as base_reward
23 | from language_table.environments.rewards import synonyms
24 | from language_table.environments.rewards import task_info
25 | import numpy as np
26 |
27 | BUFFER = 0.08
28 |
29 | X_MIN = 0.15
30 | X_MAX = 0.6
31 | Y_MIN = -0.3048
32 | Y_MAX = 0.3048
33 | CENTER_X = (X_MAX - X_MIN) / 2. + X_MIN
34 | CENTER_Y = (Y_MAX - Y_MIN) / 2. + Y_MIN
35 |
36 | BLOCK2ABSOLUTELOCATION_TARGET_DISTANCE = 0.08
37 |
38 |
39 | class Locations(enum.Enum):
40 | BOTTOM_LEFT = 'bottom_left'
41 |
42 |
43 | ABSOLUTE_LOCATIONS = {
44 | 'bottom_left': [X_MAX - BUFFER, Y_MIN + BUFFER],
45 | }
46 |
47 | LOCATION_SYNONYMS = {
48 | 'bottom_left': [
49 | 'bottom left of the board', 'bottom left', 'bottom left corner'
50 | ],
51 | }
52 |
53 | BLOCK2ABSOLUTELOCATION_VERBS = [
54 | 'move the',
55 | 'push the',
56 | 'slide the',
57 | ]
58 |
59 |
60 | def generate_all_instructions(block_mode):
61 | """Generate all instructions for block2relativeposition."""
62 | all_instructions = []
63 | all_block_text_descriptions = blocks_module.get_blocks_text_descriptions(
64 | block_mode)
65 | for block_text in all_block_text_descriptions:
66 | for location in ABSOLUTE_LOCATIONS:
67 | for location_syn in LOCATION_SYNONYMS[location]:
68 | for verb in BLOCK2ABSOLUTELOCATION_VERBS:
69 | # Add instruction.
70 | inst = (f'{verb} {block_text} to the {location_syn}')
71 | all_instructions.append(inst)
72 | return all_instructions
73 |
74 |
75 | class Block1ToCornerLocationReward(base_reward.LanguageTableReward):
76 | """Calculates reward/instructions for 'push 1 block to corner'."""
77 |
78 | def __init__(self, goal_reward, rng, delay_reward_steps,
79 | block_mode):
80 | super(Block1ToCornerLocationReward,
81 | self).__init__(goal_reward, rng, delay_reward_steps, block_mode)
82 | self._block = None
83 | self._instruction = None
84 | self._location = None
85 | self._target_translation = None
86 |
87 | def _sample_instruction(self, block, blocks_on_table,
88 | location):
89 | """Randomly sample a task involving two objects."""
90 | verb = self._rng.choice(synonyms.PUSH_VERBS)
91 | # Get some synonym for block.
92 | block_text = self._rng.choice(
93 | synonyms.get_block_synonyms(block, blocks_on_table))
94 | # Get some synonym for location.
95 | location_syn = self._rng.choice(LOCATION_SYNONYMS[location])
96 | return f'{verb} {block_text} to the {location_syn}'
97 |
98 | def reset(self, state, blocks_on_table):
99 | """Chooses new target block and location."""
100 | # Choose a random block.
101 | block = self._sample_object(blocks_on_table)
102 |
103 | # Choose a location randomly.
104 | location = self._rng.choice(list(sorted(ABSOLUTE_LOCATIONS.keys())))
105 |
106 | info = self.reset_to(state, block, location, blocks_on_table)
107 | # If the state of the board already triggers the reward, try to reset
108 | # again with a new configuration.
109 | if self.reward(state)[0]:
110 | # Try again with a new board configuration.
111 | return task_info.FAILURE
112 | return info
113 |
114 | def reset_to(self, state, block, location,
115 | blocks_on_table):
116 | """Reset to a particular task definition."""
117 | self._block = block
118 | # Sample an instruction.
119 | self._instruction = self._sample_instruction(self._block, blocks_on_table,
120 | location)
121 | # Get the corresponding target_translation.
122 | target_translation = ABSOLUTE_LOCATIONS[location]
123 | # Cache the target location corresponding to the instruction.
124 | self._target_translation = np.copy(target_translation)
125 | self._location = location
126 | info = self.get_current_task_info(state)
127 | self._in_reward_zone_steps = 0
128 | return info
129 |
130 | @property
131 | def target_translation(self):
132 | return self._target_translation
133 |
134 | def reward(self, state):
135 | """Calculates reward given state."""
136 | reward, done = self.reward_for(state, self._block, self._target_translation)
137 | return reward, done
138 |
139 | def reward_for(self, state, pushing_block,
140 | target_translation):
141 | """Returns 1. if pushing_block is in location."""
142 | # Get current location of the target block.
143 | current_translation, _ = self._get_pose_for_block(pushing_block, state)
144 | # Compute distance between current translation and target.
145 | dist = np.linalg.norm(
146 | np.array(current_translation) - np.array(target_translation))
147 | reward = 0.0
148 | done = False
149 | if dist < BLOCK2ABSOLUTELOCATION_TARGET_DISTANCE:
150 | if self._in_reward_zone_steps >= self._delay_reward_steps:
151 | reward = self._goal_reward
152 | done = True
153 | else:
154 | logging.info('In reward zone for %d steps', self._in_reward_zone_steps)
155 | self._in_reward_zone_steps += 1
156 | return reward, done
157 |
158 | def reward_for_info(self, state, info):
159 | return self.reward_for(state, info.block, info.target_translation)
160 |
161 | def debug_info(self, state):
162 | """Returns 1. if pushing_block is in location."""
163 | # Get current location of the target block.
164 | current_translation, _ = self._get_pose_for_block(self._block, state)
165 | # Compute distance between current translation and target.
166 | dist = np.linalg.norm(
167 | np.array(current_translation) - np.array(self._target_translation))
168 | return dist
169 |
170 | def get_current_task_info(self, state):
171 | return task_info.Block2LocationTaskInfo(
172 | instruction=self._instruction,
173 | block=self._block,
174 | location=self._location,
175 | target_translation=self._target_translation)
176 |
--------------------------------------------------------------------------------
/language_table/environments/rewards/block2absolutelocation.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Language Tale Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Defines block2absolutelocation reset and reward."""
17 | import enum
18 |
19 | from typing import Any, List
20 | from absl import logging
21 | from language_table.environments import blocks as blocks_module
22 | from language_table.environments.rewards import reward as base_reward
23 | from language_table.environments.rewards import synonyms
24 | from language_table.environments.rewards import task_info
25 | import numpy as np
26 |
27 |
28 | # There's a small offset in the Y direction to subtract.
29 | # The red dots represent the bounds of the arm, which are not exactly in the
30 | # center of the boards.
31 | # This should only matter for this reward, which deals with absolute locations.
32 | X_BUFFER = 0.025
33 |
34 | X_MIN_REAL = 0.15
35 | X_MAX_REAL = 0.6
36 | Y_MIN_REAL = -0.3048
37 | Y_MAX_REAL = 0.3048
38 | X_MIN = X_MIN_REAL - X_BUFFER
39 | X_MAX = X_MAX_REAL - X_BUFFER
40 | Y_MIN = Y_MIN_REAL
41 | Y_MAX = Y_MAX_REAL
42 | CENTER_X = (X_MAX - X_MIN) / 2. + X_MIN
43 | CENTER_Y = (Y_MAX - Y_MIN)/2. + Y_MIN
44 |
45 | BLOCK2ABSOLUTELOCATION_TARGET_DISTANCE = 0.115
46 | BLOCK2ABSOLUTELOCATION_CENTER_TARGET_DISTANCE = 0.1
47 |
48 |
49 | class Locations(enum.Enum):
50 | TOP = 'top'
51 | TOP_LEFT = 'top_left'
52 | TOP_RIGHT = 'top_right'
53 | CENTER = 'center'
54 | CENTER_LEFT = 'center_left'
55 | CENTER_RIGHT = 'center_right'
56 | BOTTOM = 'bottom'
57 | BOTTOM_LEFT = 'bottom_left'
58 | BOTTOM_RIGHT = 'bottom_right'
59 |
60 |
61 | ABSOLUTE_LOCATIONS = {
62 | 'top': [X_MIN, CENTER_Y],
63 | 'top_left': [X_MIN, Y_MIN],
64 | 'top_right': [X_MIN, Y_MAX],
65 | 'center': [CENTER_X, CENTER_Y],
66 | 'center_left': [CENTER_X, Y_MIN],
67 | 'center_right': [CENTER_X, Y_MAX],
68 | 'bottom': [X_MAX, CENTER_Y],
69 | 'bottom_left': [X_MAX, Y_MIN],
70 | 'bottom_right': [X_MAX, Y_MAX],
71 | }
72 |
73 | LOCATION_SYNONYMS = {
74 | 'top': ['top side', 'top', 'towards your base'],
75 | 'top_left': ['top left of the board', 'top left',
76 | 'upper left corner', 'top left corner'],
77 | 'top_right': ['top right of the board', 'top right',
78 | 'upper right corner', 'top right corner'],
79 | 'center': ['middle of the board', 'center of the board',
80 | 'center', 'middle'],
81 | 'center_left': ['left side of the board', 'center left', 'left side'],
82 | 'center_right': ['right side of the board', 'center right', 'right side'],
83 | 'bottom': ['bottom side', 'bottom'],
84 | 'bottom_left': ['bottom left of the board', 'bottom left',
85 | 'lower left corner', 'bottom left corner'],
86 | 'bottom_right': ['bottom right of the board', 'bottom right',
87 | 'lower right corner', 'bottom right corner'],
88 | }
89 |
90 | BLOCK2ABSOLUTELOCATION_VERBS = [
91 | 'move the',
92 | 'push the',
93 | 'slide the',
94 | ]
95 |
96 |
97 | def generate_all_instructions(block_mode):
98 | """Generate all instructions for block2relativeposition."""
99 | all_instructions = []
100 | all_block_text_descriptions = blocks_module.get_blocks_text_descriptions(
101 | block_mode)
102 | for block_text in all_block_text_descriptions:
103 | for location in ABSOLUTE_LOCATIONS:
104 | for location_syn in LOCATION_SYNONYMS[location]:
105 | for verb in BLOCK2ABSOLUTELOCATION_VERBS:
106 | # Add instruction.
107 | inst = (f'{verb} {block_text} to the {location_syn}')
108 | all_instructions.append(inst)
109 | return all_instructions
110 |
111 |
112 | class BlockToAbsoluteLocationReward(base_reward.LanguageTableReward):
113 | """Calculates reward/instructions for 'push block to absolute location'."""
114 |
115 | def __init__(self, goal_reward, rng, delay_reward_steps,
116 | block_mode):
117 | super(BlockToAbsoluteLocationReward, self).__init__(
118 | goal_reward=goal_reward,
119 | rng=rng,
120 | delay_reward_steps=delay_reward_steps,
121 | block_mode=block_mode)
122 | self._block = None
123 | self._instruction = None
124 | self._location = None
125 | self._target_translation = None
126 |
127 | def _sample_instruction(
128 | self, block, blocks_on_table, location):
129 | """Randomly sample a task involving two objects."""
130 | verb = self._rng.choice(synonyms.PUSH_VERBS)
131 | # Get some synonym for block.
132 | block_text = self._rng.choice(
133 | synonyms.get_block_synonyms(block, blocks_on_table))
134 | # Get some synonym for location.
135 | location_syn = self._rng.choice(LOCATION_SYNONYMS[location])
136 | return f'{verb} {block_text} to the {location_syn}'
137 |
138 | def reset(self, state, blocks_on_table):
139 | """Chooses new target block and location."""
140 | # Choose a random block.
141 | block = self._sample_object(blocks_on_table)
142 |
143 | # Choose a location randomly.
144 | location = self._rng.choice(list(sorted(ABSOLUTE_LOCATIONS.keys())))
145 |
146 | info = self.reset_to(state, block, location, blocks_on_table)
147 | # If the state of the board already triggers the reward, try to reset
148 | # again with a new configuration.
149 | if self._in_goal_region(state, self._block, self._target_translation):
150 | # Try again with a new board configuration.
151 | return task_info.FAILURE
152 | return info
153 |
154 | def reset_to(
155 | self, state, block, location, blocks_on_table):
156 | """Reset to a particular task definition."""
157 | self._block = block
158 | # Sample an instruction.
159 | self._instruction = self._sample_instruction(
160 | self._block, blocks_on_table, location)
161 | # Get the corresponding target_translation.
162 | target_translation = ABSOLUTE_LOCATIONS[location]
163 | # Cache the target location corresponding to the instruction.
164 | self._target_translation = np.copy(target_translation)
165 | self._location = location
166 | info = self.get_current_task_info(state)
167 | self._in_reward_zone_steps = 0
168 | return info
169 |
170 | @property
171 | def target_translation(self):
172 | return self._target_translation
173 |
174 | def get_goal_region(self):
175 | if self._location == Locations.CENTER.value:
176 | return self._target_translation, BLOCK2ABSOLUTELOCATION_CENTER_TARGET_DISTANCE
177 | return self._target_translation, BLOCK2ABSOLUTELOCATION_TARGET_DISTANCE
178 |
179 | def reward(self, state):
180 | """Calculates reward given state."""
181 | reward, done = self.reward_for(state, self._block, self._target_translation)
182 | return reward, done
183 |
184 | def reward_for(
185 | self, state, pushing_block, target_translation):
186 | """Returns 1. if pushing_block is in location."""
187 | reward = 0.0
188 | done = False
189 |
190 | in_goal_region = self._in_goal_region(state, pushing_block,
191 | target_translation)
192 |
193 | if in_goal_region:
194 | if self._in_reward_zone_steps >= self._delay_reward_steps:
195 | reward = self._goal_reward
196 | done = True
197 | else:
198 | logging.info('In reward zone for %d steps', self._in_reward_zone_steps)
199 | self._in_reward_zone_steps += 1
200 | return reward, done
201 |
202 | def _in_goal_region(self, state, pushing_block,
203 | target_translation):
204 | # Get current location of the target block.
205 | current_translation, _ = self._get_pose_for_block(pushing_block, state)
206 | # Compute distance between current translation and target.
207 | dist = np.linalg.norm(
208 | np.array(current_translation) - np.array(target_translation))
209 |
210 | if self._location == Locations.CENTER.value:
211 | target_dist = BLOCK2ABSOLUTELOCATION_CENTER_TARGET_DISTANCE
212 | else:
213 | target_dist = BLOCK2ABSOLUTELOCATION_TARGET_DISTANCE
214 |
215 | if dist < target_dist:
216 | return True
217 | return False
218 |
219 | def reward_for_info(self, state, info):
220 | return self.reward_for(state, info.block, info.target_translation)
221 |
222 | def debug_info(self, state):
223 | """Returns 1. if pushing_block is in location."""
224 | # Get current location of the target block.
225 | current_translation, _ = self._get_pose_for_block(
226 | self._block, state)
227 | # Compute distance between current translation and target.
228 | dist = np.linalg.norm(
229 | np.array(current_translation) - np.array(self._target_translation))
230 | return dist
231 |
232 | def get_current_task_info(self, state):
233 | return task_info.Block2LocationTaskInfo(
234 | instruction=self._instruction,
235 | block=self._block,
236 | location=self._location,
237 | target_translation=self._target_translation)
238 |
--------------------------------------------------------------------------------
/language_table/environments/rewards/block2block.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Language Tale Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Defines block2block reset and reward."""
17 | import itertools
18 |
19 | from absl import logging
20 | from language_table.environments import blocks as blocks_module
21 | from language_table.environments.rewards import constants
22 | from language_table.environments.rewards import reward as base_reward
23 | from language_table.environments.rewards import synonyms
24 | from language_table.environments.rewards import task_info
25 | import numpy as np
26 |
27 |
28 | def generate_all_instructions(block_mode):
29 | """Generates all block2block instructions."""
30 | all_instructions = []
31 | all_block_text_descriptions = blocks_module.get_blocks_text_descriptions(
32 | block_mode)
33 | for start_block_text, target_block_text in itertools.permutations(
34 | all_block_text_descriptions, 2):
35 | for verb in synonyms.PUSH_VERBS:
36 | for preposition in synonyms.PREPOSITIONS:
37 | inst = f'{verb} {start_block_text} {preposition} {target_block_text}'
38 | all_instructions.append(inst)
39 | return all_instructions
40 |
41 |
42 | # pytype: skip-file
43 | class BlockToBlockReward(base_reward.LanguageTableReward):
44 | """Block2block reward."""
45 |
46 | def _sample_instruction(
47 | self, start_block, target_block, blocks_on_table):
48 | """Randomly sample a task involving two objects."""
49 | verb = self._rng.choice(synonyms.PUSH_VERBS)
50 | # Sample synonyms for start and target blocks.
51 | start_syn = self._rng.choice(
52 | synonyms.get_block_synonyms(start_block, blocks_on_table))
53 | target_syn = self._rng.choice(
54 | synonyms.get_block_synonyms(target_block, blocks_on_table))
55 | preposition = self._rng.choice(synonyms.PREPOSITIONS)
56 | return f'{verb} {start_syn} {preposition} {target_syn}'
57 |
58 | def reset(self, state, blocks_on_table):
59 | """Resets the start/target objects and returns a text instruction."""
60 | # pick two objects sufficiently far away and get their poses.
61 | # track start object and target object poses.
62 | max_attempts = 10
63 | num_attempts = 0
64 | while True:
65 | start_block, target_block = self._sample_objects(blocks_on_table)
66 | start_translation, _ = self._get_pose_for_block(
67 | start_block, state)
68 | target_translation, _ = self._get_pose_for_block(
69 | target_block, state)
70 | dist = np.linalg.norm(
71 | np.array(start_translation) - np.array(target_translation))
72 | if dist < constants.TARGET_BLOCK_DISTANCE + 0.01:
73 | num_attempts += 1
74 | if num_attempts > max_attempts:
75 | logging.info(
76 | 'Exceeded max number of attempts to find start/target blocks. '
77 | 'No valid reward found for the current object configuration.')
78 | return task_info.FAILURE
79 | continue
80 | else:
81 | self._start_block = start_block
82 | self._target_block = target_block
83 | break
84 | self._instruction = self._sample_instruction(
85 | start_block, target_block, blocks_on_table)
86 | self._in_reward_zone_steps = 0
87 | return task_info.Block2BlockTaskInfo(
88 | instruction=self._instruction,
89 | block1=self._start_block,
90 | block2=self._target_block)
91 |
92 | def get_goal_region(self):
93 | return self._target_translation, constants.TARGET_BLOCK_DISTANCE
94 |
95 | def reward(self, state):
96 | """Calculates reward given state."""
97 | # For now only have sparse reward.
98 | start_translation, _ = self._get_pose_for_block(self._start_block, state)
99 | target_translation, _ = self._get_pose_for_block(self._target_block, state)
100 |
101 | self._target_translation = target_translation
102 |
103 | # This check ignore whether start block was moved (rather than target object
104 | # being moved towards start object.
105 | # TODO(ayzaan): Add smarter logic here.
106 | dist = np.linalg.norm(
107 | np.array(start_translation) - np.array(target_translation))
108 | reward = 0.0
109 | done = False
110 | if dist < constants.TARGET_BLOCK_DISTANCE:
111 | if self._in_reward_zone_steps >= self._delay_reward_steps:
112 | reward = self._goal_reward
113 | done = True
114 | else:
115 | logging.info('In reward zone for %d steps', self._in_reward_zone_steps)
116 | self._in_reward_zone_steps += 1
117 |
118 | return reward, done
119 |
--------------------------------------------------------------------------------
/language_table/environments/rewards/block2relativelocation.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Language Tale Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Defines block2relativelocation reset and reward."""
17 |
18 | from absl import logging
19 | from language_table.environments import blocks as blocks_module
20 | from language_table.environments.rewards import reward as base_reward
21 | from language_table.environments.rewards import synonyms
22 | from language_table.environments.rewards import task_info
23 | import numpy as np
24 |
25 |
26 | MAGNITUDES = {
27 | 'near': 0.15,
28 | 'far': 0.25
29 | }
30 |
31 | # Cardinal directions you can push the block. The top left of board
32 | # is (height=0, width=0), so UP == -1, LEFT == -1, etc.
33 | UP = -1.
34 | DOWN = 1.
35 | LEFT = -1.
36 | RIGHT = 1.
37 |
38 | DIRECTIONS = {
39 | 'up': [UP, 0.], # good.
40 | 'down': [DOWN, 0.], # good.
41 | 'left': [0., LEFT], # good.
42 | 'right': [0., RIGHT], # good.
43 | 'diagonal_up_left': [UP, LEFT] / np.linalg.norm([UP, LEFT]),
44 | 'diagonal_up_right': [UP, RIGHT] / np.linalg.norm([UP, RIGHT]),
45 | 'diagonal_down_left': [DOWN, LEFT] / np.linalg.norm([DOWN, LEFT]),
46 | 'diagonal_down_right': [DOWN, RIGHT] / np.linalg.norm([DOWN, RIGHT]),
47 | }
48 |
49 |
50 | BLOCK2RELATIVELOCATION_VERBS = [
51 | 'move the',
52 | 'push the',
53 | 'slide the',
54 | ]
55 |
56 |
57 | # These cover [up, down, left, right].
58 | SLIGHTLY_PREPOSITION_SYNONYMS = [
59 | 'slightly',
60 | 'slightly to the',
61 | 'a bit',
62 | 'a bit to the',
63 | 'a little',
64 | 'a little to the',
65 | 'a little bit to the',
66 | 'somewhat',
67 | 'somewhat to the',
68 | ]
69 |
70 |
71 | SLIGHTLY_SYNONYMS = [
72 | 'slightly',
73 | 'a bit',
74 | 'a little',
75 | 'a little bit',
76 | 'somewhat',
77 | ]
78 |
79 | BLOCK2RELATIVELOCATION_MODE_TO_PREPOSITIONS = {
80 | 'near': SLIGHTLY_PREPOSITION_SYNONYMS,
81 | # empty string because 'push the red block left' is valid.
82 | 'far': ['to the', '']
83 | }
84 |
85 |
86 | DIRECTION_SYNONYMS = {
87 | 'up': ['up', 'upwards'],
88 | 'down': ['down', 'downwards'],
89 | 'left': ['to the left', 'left'],
90 | 'right': ['to the right', 'right']
91 | }
92 |
93 | DIAGONAL_PREPOSITIONS = [
94 | '%s and %s',
95 | '%s and then %s',
96 | 'diagonally %s and %s',
97 | '%s and %s diagonally',
98 | ]
99 |
100 |
101 | BLOCK2RELATIVELOCATION_TARGET_DISTANCE = 0.1
102 |
103 |
104 | def create_slightly_instruction(rng, verb, block, direction):
105 | """Created a `slightly` modified instruction."""
106 | mode = rng.choice(['slightly_first', 'prefix', 'suffix'])
107 | if mode == 'slightly_first':
108 | # e.g. 'slightly push the blue cube down and to the right'
109 | inst = f'slightly {verb} {block} {direction}'
110 | elif mode == 'prefix':
111 | # e.g. 'push the blue cube slightly down and to the right'
112 | slightly_syn = rng.choice(SLIGHTLY_SYNONYMS)
113 | inst = f'{verb} {block} {slightly_syn} {direction}'
114 | else:
115 | # e.g. 'push the blue cube down and right slightly'
116 | slightly_syn = rng.choice(SLIGHTLY_SYNONYMS)
117 | inst = f'{verb} {block} {direction} {slightly_syn}'
118 | return inst
119 |
120 |
121 | def enumerate_slightly_instruction(verb, block, direction):
122 | # Slightly first.
123 | instructions = [f'slightly {verb} {block} {direction}']
124 | for slightly_syn in SLIGHTLY_SYNONYMS:
125 | # prefix.
126 | instructions.append(f'{verb} {block} {slightly_syn} {direction}')
127 | # suffix.
128 | instructions.append(f'{verb} {block} {direction} {slightly_syn}')
129 | for instruction in instructions:
130 | yield instruction
131 |
132 |
133 | def get_diagonal_direction(rng, direction):
134 | """Map canonical diagonal direction `diagonal_up_right' to natural lang."""
135 | # Break out e.g. 'up', and 'right'
136 | _, first_dir, second_dir = direction.split('_')
137 | # Choose synonyms, e.g. 'right' -> 'to the right'.
138 | first_dir = rng.choice(DIRECTION_SYNONYMS[first_dir])
139 | second_dir = rng.choice(DIRECTION_SYNONYMS[second_dir])
140 | # Choose a diagonal preposition, e.g. '
and then '.
141 | diagonal_prep = rng.choice(DIAGONAL_PREPOSITIONS)
142 | # Insert the actual directions, e.g. 'up and then to the right'.
143 | diagonal_direction = diagonal_prep % (first_dir, second_dir)
144 | return diagonal_direction
145 |
146 |
147 | def enumerate_diagonal_direction(direction):
148 | """Map canonical diagonal direction `diagonal_up_right' to natural lang."""
149 | # Break out e.g. 'up', and 'right'
150 | _, first_dir, second_dir = direction.split('_')
151 | # Choose synonyms, e.g. 'right' -> 'to the right'.
152 | for first_dir in DIRECTION_SYNONYMS[first_dir]:
153 | for second_dir in DIRECTION_SYNONYMS[second_dir]:
154 | # Choose a diagonal preposition, e.g. ' and then '.
155 | for diagonal_prep in DIAGONAL_PREPOSITIONS:
156 | # Insert the actual directions, e.g. 'up and then to the right'.
157 | diagonal_direction = diagonal_prep % (first_dir, second_dir)
158 | yield diagonal_direction
159 |
160 |
161 | def generate_all_instructions(block_mode):
162 | """Generate all instructions for block2relativeposition."""
163 | all_instructions = []
164 | all_block_text_descriptions = blocks_module.get_blocks_text_descriptions(
165 | block_mode)
166 | for block_text in all_block_text_descriptions:
167 | for verb in BLOCK2RELATIVELOCATION_VERBS:
168 | for direction in DIRECTIONS:
169 | if 'diagonal' in direction:
170 | for direction_syn in enumerate_diagonal_direction(direction):
171 | # Add 'near diagonal' instructions.
172 | for near_inst in enumerate_slightly_instruction(
173 | verb, block_text, direction_syn):
174 | all_instructions.append(near_inst)
175 | # Add 'far diagonal' instructions.
176 | inst = f'{verb} {block_text} {direction_syn}'
177 | all_instructions.append(inst)
178 | else:
179 | for direction_syn in DIRECTION_SYNONYMS[direction]:
180 | # Add 'near' instructions.
181 | for near_inst in enumerate_slightly_instruction(
182 | verb, block_text, direction_syn):
183 | all_instructions.append(near_inst)
184 | # Add 'far' instructions.
185 | inst = f'{verb} {block_text} {direction_syn}'
186 | all_instructions.append(inst)
187 | return all_instructions
188 |
189 |
190 | class BlockToRelativeLocationReward(base_reward.LanguageTableReward):
191 | """Calculates reward/instructions for 'push block to relative location'."""
192 |
193 | def _sample_instruction(
194 | self, block, distance_mode, direction, blocks_on_table):
195 | """Randomly sample a task involving two objects."""
196 | verb = self._rng.choice(BLOCK2RELATIVELOCATION_VERBS)
197 | # Sample synonym for block.
198 | block_syn = self._rng.choice(
199 | synonyms.get_block_synonyms(block, blocks_on_table))
200 | if 'diagonal' in direction:
201 | # Map canonical diagonal dir to natural language. E.g.
202 | # diagonal_up_right -> 'up and to the right diagonally'.
203 | direction = get_diagonal_direction(self._rng, direction)
204 | else:
205 | direction = self._rng.choice(DIRECTION_SYNONYMS[direction])
206 | if distance_mode == 'near':
207 | # Modify this with language like 'slightly'.
208 | inst = create_slightly_instruction(self._rng, verb, block_syn, direction)
209 | else:
210 | inst = f'{verb} {block_syn} {direction}'
211 | return inst
212 |
213 | def reset(self, state, blocks_on_table):
214 | """Chooses new target block, direction, and distance."""
215 | cnt = 0
216 | max_tries = 100
217 | while True:
218 | # Choose a random block.
219 | self._block = self._sample_object(blocks_on_table)
220 |
221 | # Get the current block location.
222 | block_translation, _ = self._get_pose_for_block(
223 | self._block, state)
224 |
225 | # Choose a direction.
226 | direction = self._rng.choice(sorted(list(DIRECTIONS.keys())))
227 |
228 | # Choose a magnitude.
229 | distance_mode = self._rng.choice(sorted(list(MAGNITUDES.keys())))
230 | magnitude = MAGNITUDES[distance_mode]
231 |
232 | # Define target_vector as % change in H,W.
233 | target_vector = np.array(DIRECTIONS[direction]) * magnitude
234 | # Define target_translation = direction * magnitude.
235 | target_translation = block_translation + target_vector
236 |
237 | # Only keep if target_translation is inside workspace bounds.
238 | if base_reward.target_inside_bounds(target_translation):
239 | break
240 | cnt += 1
241 |
242 | if cnt > max_tries:
243 | # Try again with a new board configuration.
244 | return task_info.FAILURE
245 | # Choose an instruction.
246 | self._instruction = self._sample_instruction(
247 | self._block, distance_mode, direction, blocks_on_table)
248 |
249 | # Cache the target location corresponding to the instruction.
250 | self._target_translation = np.copy(target_translation)
251 |
252 | self._in_reward_zone_steps = 0
253 | return task_info.Block2RelativeLocationTaskInfo(
254 | instruction=self._instruction,
255 | block=self._block,
256 | location=direction,
257 | target_translation=self._target_translation)
258 |
259 | def get_goal_region(self):
260 | return self._target_translation, BLOCK2RELATIVELOCATION_TARGET_DISTANCE
261 |
262 | def reward(self, state):
263 | """Calculates reward given state."""
264 | # Get current location of the target block.
265 | current_translation, _ = self._get_pose_for_block(
266 | self._block, state)
267 |
268 | # Compute distance between current translation and target.
269 | dist = np.linalg.norm(
270 | np.array(current_translation) - np.array(self._target_translation))
271 | reward = 0.0
272 | done = False
273 | if dist < BLOCK2RELATIVELOCATION_TARGET_DISTANCE:
274 | if self._in_reward_zone_steps >= self._delay_reward_steps:
275 | reward = self._goal_reward
276 | done = True
277 | else:
278 | logging.info('In reward zone for %d steps', self._in_reward_zone_steps)
279 | self._in_reward_zone_steps += 1
280 |
281 | return reward, done
282 |
--------------------------------------------------------------------------------
/language_table/environments/rewards/constants.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Language Tale Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Define constants shared across all 2d board environment rewards."""
17 | TARGET_BLOCK_DISTANCE = 0.05
18 |
--------------------------------------------------------------------------------
/language_table/environments/rewards/instructions.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Language Tale Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Utilities for getting all instructions from the different 2d board envs."""
17 | from language_table.environments.rewards import block2absolutelocation
18 | from language_table.environments.rewards import block2block
19 | from language_table.environments.rewards import block2block_relative_location
20 | from language_table.environments.rewards import block2relativelocation
21 | from language_table.environments.rewards import point2block
22 | from language_table.environments.rewards import separate_blocks
23 |
24 |
25 | CLIP_VOCAB_SIZE = 49408
26 |
27 |
28 | def generate_all_instructions(block_mode):
29 | """Gets all instructions across all environments."""
30 | return (block2block.generate_all_instructions(block_mode) +
31 | point2block.generate_all_instructions(block_mode) +
32 | block2relativelocation.generate_all_instructions(block_mode) +
33 | block2absolutelocation.generate_all_instructions(block_mode) +
34 | block2block_relative_location.generate_all_instructions(block_mode) +
35 | separate_blocks.generate_all_instructions(block_mode))
36 |
37 |
38 | def vocab_size(block_mode):
39 | words = set()
40 | for instruction in generate_all_instructions(block_mode):
41 | for word in instruction.split(' '):
42 | words.add(word)
43 | return len(words)
44 |
--------------------------------------------------------------------------------
/language_table/environments/rewards/instructions_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Language Tale Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for instructions."""
17 |
18 | from language_table.environments import blocks
19 | from language_table.environments.rewards import instructions
20 | import tensorflow as tf
21 |
22 |
23 | class InstructionsTest(tf.test.TestCase):
24 |
25 | def test_expected_instructions_generated(self):
26 | # This ensures that the same fixed number of instructions are generated
27 | # for each block mode.
28 | inst_block4 = instructions.generate_all_instructions(
29 | blocks.LanguageTableBlockVariants.BLOCK_4)
30 | self.assertLen(inst_block4, 12652)
31 | inst_block8 = instructions.generate_all_instructions(
32 | blocks.LanguageTableBlockVariants.BLOCK_8)
33 | self.assertLen(inst_block8, 30264)
34 | inst_n_choose_k = instructions.generate_all_instructions(
35 | blocks.LanguageTableBlockVariants.N_CHOOSE_K)
36 | self.assertLen(inst_n_choose_k, 80368)
37 |
38 | if __name__ == '__main__':
39 | tf.test.main()
40 |
--------------------------------------------------------------------------------
/language_table/environments/rewards/point2block.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Language Tale Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Defines point2block reset and reward."""
17 | from absl import logging
18 | from language_table.environments import blocks as blocks_module
19 | from language_table.environments.rewards import constants
20 | from language_table.environments.rewards import reward as base_reward
21 | from language_table.environments.rewards import synonyms
22 | from language_table.environments.rewards import task_info
23 | import numpy as np
24 |
25 |
26 | def generate_all_instructions(block_mode):
27 | all_instructions = []
28 | all_block_text_descriptions = blocks_module.get_blocks_text_descriptions(
29 | block_mode)
30 | for block_text in all_block_text_descriptions:
31 | for preposition in synonyms.POINT_PREPOSITIONS:
32 | inst = f'{preposition} {block_text}'
33 | all_instructions.append(inst)
34 | return all_instructions
35 |
36 |
37 | class PointToBlockReward(base_reward.LanguageTableReward):
38 | """Calculates reward/instructions for simple variants of 2d board tasks."""
39 |
40 | def _sample_instruction(self, block, blocks_on_table):
41 | """Randomly sample a task involving two objects."""
42 | # Get some synonym for block.
43 | block_text = self._rng.choice(
44 | synonyms.get_block_synonyms(block, blocks_on_table))
45 | preposition = self._rng.choice(synonyms.POINT_PREPOSITIONS)
46 | return f'{preposition} {block_text}'
47 |
48 | def reset(self, state, blocks_on_table):
49 | """Resets to new pointing task."""
50 | # pick two objects sufficiently far away and get their poses.
51 | # track start object and target object poses.
52 | max_attempts = 10
53 | num_attempts = 0
54 | while True:
55 | block = self._sample_object(blocks_on_table)
56 | start_translation, _ = self._get_pose_for_block(
57 | block, state)
58 | dist = np.linalg.norm(
59 | np.array(start_translation) -
60 | np.array(state['effector_target_translation']))
61 | if dist < constants.TARGET_BLOCK_DISTANCE + 0.01:
62 | num_attempts += 1
63 | if num_attempts > max_attempts:
64 | logging.info(
65 | 'Exceeded max number of attempts to find start/target blocks. '
66 | 'No valid reward found for the current object configuration.')
67 | return task_info.FAILURE
68 | continue
69 | else:
70 | self._block = block
71 | break
72 | self._instruction = self._sample_instruction(block, blocks_on_table)
73 | self._in_reward_zone_steps = 0
74 | return task_info.Point2BlockTaskInfo(
75 | instruction=self._instruction,
76 | block_target=self._block)
77 |
78 | def reward(self, state):
79 | """Calculates reward given state."""
80 | # For now only have sparse reward.
81 | object_translation, _ = self._get_pose_for_block(self._block, state)
82 |
83 | # This check ignore whether start block was moved (rather than target object
84 | # being moved towards start object.
85 | # TODO(ayzaan): Add smarter logic here.
86 | dist = np.linalg.norm(
87 | np.array(object_translation) -
88 | np.array(state['effector_target_translation']))
89 | reward = 0.0
90 | done = False
91 | if dist < constants.TARGET_BLOCK_DISTANCE:
92 | if self._in_reward_zone_steps >= self._delay_reward_steps:
93 | reward = self._goal_reward
94 | done = True
95 | else:
96 | logging.info('In reward zone for %d steps', self._in_reward_zone_steps)
97 | self._in_reward_zone_steps += 1
98 | return reward, done
99 |
--------------------------------------------------------------------------------
/language_table/environments/rewards/reward.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Language Tale Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """LanguageTable base reward class."""
17 | from typing import Any, List, Tuple
18 | from language_table.environments import blocks as blocks_module
19 | from language_table.environments import constants
20 | from language_table.environments.rewards import synonyms
21 | import numpy as np
22 |
23 |
24 | class LanguageTableReward(object):
25 | """Base class for all 2d board rewards."""
26 |
27 | def __init__(self, goal_reward, rng, delay_reward_steps,
28 | block_mode):
29 | self._block_mode = block_mode
30 | self._goal_reward = goal_reward
31 | self._rng = rng
32 | # TODO(tding): Handle this in all rewards
33 | self._delay_reward_steps = delay_reward_steps
34 | self._in_reward_zone_steps = None
35 |
36 | self._target_translation = None
37 |
38 | def seed(self, rng):
39 | self._rng = rng
40 |
41 | def get_goal_region(self):
42 | """Returns the (target translation, radius) tuple."""
43 | return None, None
44 |
45 | def _get_block_synonym(self, block, blocks_on_table):
46 | return self._rng.choice(synonyms.get_block_synonyms(block, blocks_on_table))
47 |
48 | def _get_pose_for_block(self, block, state):
49 | state_translation_key = 'block_%s_translation' % block
50 | state_rotation_key = 'block_%s_orientation' % block
51 | return state[state_translation_key], state[state_rotation_key]
52 |
53 | def _get_translation_for_block(self, block, state):
54 | return np.array(self._get_pose_for_block(block, state)[0])
55 |
56 | def _sample_object(self, blocks_on_table):
57 | """Choose one of the blocks randomly."""
58 | block = self._rng.choice(blocks_on_table)
59 | return block
60 |
61 | def _sample_objects(self, blocks_on_table):
62 | """Randomly sample two objects."""
63 | start_block, target_block = self._rng.choice(
64 | blocks_on_table, 2, replace=False)
65 | return start_block, target_block
66 |
67 |
68 | def target_inside_bounds(
69 | target, buffer=constants.WORKSPACE_BOUNDS_BUFFER):
70 | target_x, target_y = target
71 | return (target_x > constants.X_MIN + buffer and
72 | target_x < constants.X_MAX - buffer and
73 | target_y > constants.Y_MIN + buffer and
74 | target_y < constants.Y_MAX - buffer)
75 |
--------------------------------------------------------------------------------
/language_table/environments/rewards/synonyms.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Language Tale Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Defines synonyms for the structured language 2d board environment."""
17 | import collections
18 |
19 |
20 | def get_block_synonyms(block, blocks_on_table):
21 | """Get synonyms for blocks based on what is on table."""
22 | color, shape = get_color_shape(block)
23 | color_counts, shape_counts = count_color_shape(blocks_on_table)
24 | synonyms = []
25 | if color_counts[color] == 1:
26 | # There is only one 'red' block, so feel free to refer to it as
27 | # 'red block'.
28 | synonyms.append('%s block' % color)
29 | if shape_counts[shape] == 1:
30 | # There is only one 'star' block, so feel free to refer to it as
31 | # 'star'.
32 | synonyms.append(shape)
33 | # (color, shape) is always unique.
34 | synonyms.append('%s %s' % (color, shape))
35 | return synonyms
36 |
37 |
38 | def count_color_shape(blocks_on_table):
39 | colors, shapes = zip(*[get_color_shape(i) for i in blocks_on_table])
40 | color_counts = collections.Counter(colors)
41 | shape_counts = collections.Counter(shapes)
42 | return color_counts, shape_counts
43 |
44 |
45 | def get_color_shape(block):
46 | color, shape = block.split('_')
47 | return color, shape
48 |
49 |
50 | PUSH_VERBS = [
51 | 'push the',
52 | 'move the',
53 | 'slide the',
54 | 'put the',
55 | ]
56 |
57 | PREPOSITIONS = [
58 | 'to the',
59 | 'towards the',
60 | 'close to the',
61 | 'next to the',
62 | ]
63 |
64 | POINT_PREPOSITIONS = [
65 | 'point next to the',
66 | 'point close to the',
67 | 'point to the',
68 | 'point at the',
69 | 'move the arm next to the',
70 | 'move the arm close to the',
71 | 'move the arm to the',
72 | 'move your arm next to the',
73 | 'move your arm close to the',
74 | 'move your arm to the',
75 | 'move next to the',
76 | 'move close to the',
77 | 'move to the',
78 | ]
79 |
--------------------------------------------------------------------------------
/language_table/environments/rewards/task_info.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Language Tale Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Data classes holding info returned to environment for each reset."""
17 |
18 | import dataclasses
19 | from typing import List
20 |
21 | import numpy as np
22 |
23 |
24 | @dataclasses.dataclass
25 | class Block2BlockTaskInfo:
26 | """Data class defining a chosen block2block task after reset."""
27 | instruction: str
28 | block1: str
29 | block2: str
30 |
31 |
32 | @dataclasses.dataclass
33 | class Block2LocationTaskInfo:
34 | """Class defining a chosen block2block task after reset."""
35 | instruction: str
36 | block: str
37 | target_translation: np.ndarray
38 | location: str
39 |
40 |
41 | @dataclasses.dataclass
42 | class Block2LineTaskInfo:
43 | """Class defining a chosen block2block task after reset."""
44 | instruction: str
45 | block: str
46 | target_translation: np.ndarray
47 |
48 |
49 | @dataclasses.dataclass
50 | class Block2PoleTaskInfo:
51 | """Data class defining a chosen block2pole task after reset."""
52 | instruction: str
53 | block1: str
54 | goal: str
55 |
56 |
57 | @dataclasses.dataclass
58 | class Block2RelativeLocationTaskInfo:
59 | """Class defining a chosen block2block task after reset."""
60 | instruction: str
61 | block: str
62 | target_translation: np.ndarray
63 | location: str
64 |
65 |
66 | @dataclasses.dataclass
67 | class Block2BlockRelativeLocationTaskInfo:
68 | """Class defining a chosen block2block task after reset."""
69 | instruction: str
70 | block: str
71 | target_block: str
72 | direction: str
73 | target_translation: np.ndarray
74 |
75 |
76 | @dataclasses.dataclass
77 | class SeparateBlocksTaskInfo:
78 | """Class defining a chosen "separate blocks" task after reset."""
79 | instruction: str
80 | block: str
81 | avoid_blocks: List[str]
82 | target_translation: np.ndarray
83 |
84 |
85 | @dataclasses.dataclass
86 | class Point2BlockTaskInfo:
87 | """Data class defining a chosen point2block task after reset."""
88 | instruction: str
89 | block_target: str
90 |
91 |
92 | ALL_TASKS = [
93 | Block2BlockTaskInfo,
94 | Block2LocationTaskInfo,
95 | Block2RelativeLocationTaskInfo,
96 | Block2BlockRelativeLocationTaskInfo,
97 | SeparateBlocksTaskInfo,
98 | Point2BlockTaskInfo,
99 | Block2LineTaskInfo,
100 | Block2PoleTaskInfo,
101 | ]
102 |
103 | # Return this if cannot create a valid board state and need to reset.
104 | FAILURE = 'failure'
105 |
--------------------------------------------------------------------------------
/language_table/environments/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Language Tale Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
--------------------------------------------------------------------------------
/language_table/environments/utils/pose3d.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Language Tale Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """A simple 6DOF pose container.
17 | """
18 |
19 | import dataclasses
20 | import numpy as np
21 | from scipy.spatial import transform
22 |
23 |
24 | class NoCopyAsDict(object):
25 | """Base class for dataclasses. Avoids a copy in the asdict() call."""
26 |
27 | def asdict(self):
28 | """Replacement for dataclasses.asdict.
29 |
30 | TF Dataset does not handle dataclasses.asdict, which uses copy.deepcopy when
31 | setting values in the output dict. This causes issues with tf.Dataset.
32 | Instead, shallow copy contents.
33 |
34 | Returns:
35 | dict containing contents of dataclass.
36 | """
37 | return {k.name: getattr(self, k.name) for k in dataclasses.fields(self)} # pytype: disable=wrong-arg-types # re-none
38 |
39 |
40 | @dataclasses.dataclass
41 | class Pose3d(NoCopyAsDict):
42 | """Simple container for translation and rotation."""
43 |
44 | rotation: transform.Rotation
45 | translation: np.ndarray
46 |
47 | @property
48 | def vec7(self):
49 | return np.concatenate([self.translation, self.rotation.as_quat()])
50 |
51 | def serialize(self):
52 | return {'rotation': self.rotation.as_quat().tolist(),
53 | 'translation': self.translation.tolist()}
54 |
55 | @staticmethod
56 | def deserialize(data):
57 | return Pose3d(rotation=transform.Rotation.from_quat(data['rotation']),
58 | translation=np.array(data['translation']))
59 |
60 | def __eq__(self, other):
61 | return (np.array_equal(self.rotation.as_quat(),
62 | other.rotation.as_quat()) and
63 | np.array_equal(self.translation,
64 | other.translation))
65 |
66 | def __ne__(self, other):
67 | return not self.__eq__(other)
68 |
--------------------------------------------------------------------------------
/language_table/environments/utils/xarm_sim_robot.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Language Tale Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """XArm Robot Kinematics."""
17 | from language_table.environments.utils import utils_pybullet
18 | from language_table.environments.utils.pose3d import Pose3d
19 | import numpy as np
20 | from scipy.spatial import transform
21 | import pybullet
22 |
23 |
24 | XARM_URDF_PATH = ('third_party/bullet/examples/pybullet/gym/pybullet_data/'
25 | 'xarm/xarm6_robot.urdf')
26 | XARM_WHITE_URDF_PATH = ('third_party/bullet/examples/pybullet/gym/'
27 | 'pybullet_data/xarm/xarm6_robot_white.urdf')
28 | SUCTION_URDF_PATH = (
29 | 'third_party/py/language_table/environments/assets/suction/'
30 | 'suction-head-long.urdf')
31 | CYLINDER_URDF_PATH = (
32 | 'third_party/py/language_table/environments/assets/suction/'
33 | 'cylinder.urdf')
34 | CYLINDER_REAL_URDF_PATH = (
35 | 'third_party/py/language_table/environments/assets/suction/'
36 | 'cylinder_real.urdf')
37 | HOME_JOINT_POSITIONS = np.deg2rad([0, -20, -80, 0, 100, -30])
38 |
39 |
40 | class XArmSimRobot():
41 | """A simulated PyBullet XArm robot, mostly for forward/inverse kinematics."""
42 |
43 | def __init__(self,
44 | pybullet_client,
45 | initial_joint_positions = HOME_JOINT_POSITIONS,
46 | end_effector = 'none',
47 | color = 'default'):
48 | self._pybullet_client = pybullet_client
49 | self.initial_joint_positions = initial_joint_positions
50 |
51 | if color == 'default':
52 | self.xarm = utils_pybullet.load_urdf(pybullet_client, XARM_URDF_PATH,
53 | [0, 0, 0])
54 | elif color == 'white':
55 | self.xarm = utils_pybullet.load_urdf(pybullet_client,
56 | XARM_WHITE_URDF_PATH, [0, 0, 0])
57 | else:
58 | raise ValueError('Unrecognized xarm color %s' % color)
59 |
60 | # Get revolute joints of robot (skip fixed joints).
61 | joints = []
62 | joint_indices = []
63 | for i in range(self._pybullet_client.getNumJoints(self.xarm)):
64 | joint_info = self._pybullet_client.getJointInfo(self.xarm, i)
65 | if joint_info[2] == pybullet.JOINT_REVOLUTE:
66 | joints.append(joint_info[0])
67 | joint_indices.append(i)
68 | # Note examples in pybullet do this, but it is not clear what the
69 | # benefits are.
70 | self._pybullet_client.changeDynamics(
71 | self.xarm, i, linearDamping=0, angularDamping=0)
72 |
73 | self._n_joints = len(joints)
74 | self._joints = tuple(joints)
75 | self._joint_indices = tuple(joint_indices)
76 |
77 | # Move robot to home joint configuration
78 | self.reset_joints(self.initial_joint_positions)
79 | self.set_target_joint_positions(self.initial_joint_positions)
80 | self.effector_link = 6
81 |
82 | if (end_effector == 'suction'
83 | or end_effector == 'cylinder'
84 | or end_effector == 'cylinder_real'):
85 | self.end_effector = self._setup_end_effector(end_effector)
86 | else:
87 | if end_effector != 'none':
88 | raise ValueError('end_effector "%s" is not supported.' % end_effector)
89 | self.end_effector = None
90 |
91 | def _setup_end_effector(self, end_effector):
92 | """Adds a suction or cylinder end effector."""
93 | pose = self.forward_kinematics()
94 | if end_effector == 'suction':
95 | body = utils_pybullet.load_urdf(self._pybullet_client, SUCTION_URDF_PATH,
96 | pose.translation, pose.rotation.as_quat())
97 | elif end_effector == 'cylinder':
98 | body = utils_pybullet.load_urdf(self._pybullet_client, CYLINDER_URDF_PATH,
99 | pose.translation, pose.rotation.as_quat())
100 | elif end_effector == 'cylinder_real':
101 | body = utils_pybullet.load_urdf(self._pybullet_client,
102 | CYLINDER_REAL_URDF_PATH, pose.translation,
103 | pose.rotation.as_quat())
104 | else:
105 | raise ValueError('end_effector "%s" is not supported.' % end_effector)
106 |
107 | constraint_id = self._pybullet_client.createConstraint(
108 | parentBodyUniqueId=self.xarm,
109 | parentLinkIndex=6,
110 | childBodyUniqueId=body,
111 | childLinkIndex=-1,
112 | jointType=pybullet.JOINT_FIXED,
113 | jointAxis=(0, 0, 0),
114 | parentFramePosition=(0, 0, 0),
115 | childFramePosition=(0, 0, 0))
116 | self._pybullet_client.changeConstraint(constraint_id, maxForce=1000)
117 |
118 | return body
119 |
120 | def reset_joints(self, joint_values):
121 | """Sets the position of the Robot's joints.
122 |
123 | *Note*: This should only be used at the start while not running the
124 | simulation resetJointState overrides all physics simulation.
125 |
126 | Args:
127 | joint_values: Iterable with desired joint positions.
128 | """
129 | for i in range(self._n_joints):
130 | self._pybullet_client.resetJointState(self.xarm, self._joints[i],
131 | joint_values[i])
132 |
133 | def get_joints_measured(self):
134 | joint_states = self._pybullet_client.getJointStates(self.xarm,
135 | self._joint_indices)
136 | joint_positions = np.array([state[0] for state in joint_states])
137 | joint_velocities = np.array([state[1] for state in joint_states])
138 | joint_torques = np.array([state[3] for state in joint_states])
139 | return joint_positions, joint_velocities, joint_torques
140 |
141 | def get_joint_positions(self):
142 | joint_states = self._pybullet_client.getJointStates(self.xarm,
143 | self._joint_indices)
144 | joint_positions = np.array([state[0] for state in joint_states])
145 | return joint_positions
146 |
147 | def forward_kinematics(self):
148 | """Forward kinematics."""
149 | effector_state = self._pybullet_client.getLinkState(
150 | self.xarm, self.effector_link, computeForwardKinematics=1)
151 | return Pose3d(translation=np.array(effector_state[0]),
152 | rotation=transform.Rotation.from_quat(effector_state[1]))
153 |
154 | def inverse_kinematics(self,
155 | world_effector_pose,
156 | max_iterations=100,
157 | residual_threshold=1e-10):
158 | """Inverse kinematics.
159 |
160 | Args:
161 | world_effector_pose: Target Pose3d for the robot's end effector.
162 | max_iterations: Refine the IK solution until the distance between target
163 | and actual end effector position is below this threshold, or the
164 | maxNumIterations is reached. Default is 20 iterations.
165 | residual_threshold: Refine the IK solution until the distance between
166 | target and actual end effector position is below this threshold, or the
167 | maxNumIterations is reached.
168 |
169 | Returns:
170 | Numpy array with required joint angles to reach the requested pose.
171 | """
172 | return np.array(
173 | self._pybullet_client.calculateInverseKinematics(
174 | self.xarm,
175 | self.effector_link,
176 | world_effector_pose.translation,
177 | world_effector_pose.rotation.as_quat(), # as_quat returns xyzw.
178 | # TODO(peteflorence): use real limits
179 | lowerLimits=[-17] * 6,
180 | upperLimits=[17] * 6,
181 | jointRanges=[17] * 6,
182 | # TODO(oars): Understand why examples don't use actual positions for
183 | # the first two joints. Taken from
184 | # `pybullet/gym/pybullet_robots/xarm/xarm_sim.py`
185 | restPoses=[0, 0] + self.get_joint_positions()[2:].tolist(),
186 | maxNumIterations=max_iterations,
187 | residualThreshold=residual_threshold))
188 |
189 | def set_target_effector_pose(self, world_effector_pose):
190 | target_joint_positions = self.inverse_kinematics(world_effector_pose)
191 | self.set_target_joint_positions(target_joint_positions)
192 |
193 | def set_target_joint_velocities(self, target_joint_velocities):
194 | self._pybullet_client.setJointMotorControlArray(
195 | self.xarm,
196 | self._joint_indices,
197 | pybullet.VELOCITY_CONTROL,
198 | targetVelocities=target_joint_velocities,
199 | forces=[5 * 240.] * 6)
200 |
201 | def set_target_joint_positions(self, target_joint_positions):
202 | # TODO(oars, peteflorence): Probably should adjust gains to get reasonable
203 | # speeds. It's moving kinda fast.
204 | self._pybullet_client.setJointMotorControlArray(
205 | self.xarm,
206 | self._joint_indices,
207 | pybullet.POSITION_CONTROL,
208 | targetPositions=target_joint_positions,
209 | forces=[5 * 240.] * 6)
210 |
211 | def set_alpha_transparency(self, alpha):
212 | visual_shape_data = self._pybullet_client.getVisualShapeData(self.xarm)
213 |
214 | for i in range(self._pybullet_client.getNumJoints(self.xarm)):
215 | object_id, link_index, _, _, _, _, _, rgba_color = visual_shape_data[i]
216 | assert object_id == self.xarm, 'xarm id mismatch.'
217 | assert link_index == i, 'Link visual data was returned out of order.'
218 | rgba_color = list(rgba_color[0:3]) + [alpha]
219 | self._pybullet_client.changeVisualShape(
220 | self.xarm, linkIndex=i, rgbaColor=rgba_color)
221 |
--------------------------------------------------------------------------------
/language_table/environments/utils/xarm_sim_robot_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Language Tale Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for ibc.environments.xarm_sim_robot."""
17 | import math
18 |
19 | from language_table.environments.utils import xarm_sim_robot
20 | from language_table.environments.utils.pose3d import Pose3d
21 | import numpy as np
22 | from scipy.spatial import transform
23 | import tensorflow.compat.v2 as tf
24 | import pybullet
25 | import pybullet_utils.bullet_client as bullet_client
26 |
27 |
28 | class XArmSimRobotTest(tf.test.TestCase):
29 |
30 | def setUp(self):
31 | super(XArmSimRobotTest, self).setUp()
32 |
33 | # To debug we can use the SHARED_MEMORY connection.
34 | # pybullet.connect(pybullet.SHARED_MEMORY)
35 | connection_mode = pybullet.SHARED_MEMORY
36 | connection_mode = pybullet.DIRECT
37 | self._pybullet_client = bullet_client.BulletClient(connection_mode)
38 | self._pybullet_client.resetSimulation()
39 | self._pybullet_client.configureDebugVisualizer(pybullet.COV_ENABLE_GUI, 0)
40 | self._pybullet_client.setPhysicsEngineParameter(enableFileCaching=0)
41 |
42 | def test_arm_loads(self):
43 | xarm_sim_robot.XArmSimRobot(self._pybullet_client)
44 |
45 | def test_arm_loads_suction(self):
46 | xarm_sim_robot.XArmSimRobot(self._pybullet_client, end_effector='suction')
47 |
48 | def test_forward_kinematics(self):
49 | robot = xarm_sim_robot.XArmSimRobot(self._pybullet_client)
50 |
51 | # Pointing down X Axis
52 | robot.reset_joints([0, math.pi / 2, math.pi, 0, 0, 0])
53 | x, y, _ = robot.forward_kinematics().translation
54 |
55 | self.assertAlmostEqual(0.714479, x, places=3)
56 | self.assertAlmostEqual(-0.0006, y, places=3)
57 |
58 | # Pointing down Y Axis
59 | robot.reset_joints([math.pi / 2, math.pi / 2, math.pi, 0, 0, 0])
60 | x, y, _ = robot.forward_kinematics().translation
61 |
62 | self.assertAlmostEqual(0.0006, x, places=3)
63 | self.assertAlmostEqual(0.714479, y, places=3)
64 |
65 | def test_inverse_kinematics(self):
66 | robot = xarm_sim_robot.XArmSimRobot(self._pybullet_client)
67 | initial_pose = robot.forward_kinematics()
68 |
69 | rotation = transform.Rotation.from_rotvec([0, math.pi / 2, 0])
70 | translation = np.array([0.5, 0.0, 0.10])
71 | target_pose = Pose3d(rotation=rotation, translation=translation)
72 |
73 | robot.reset_joints(robot.inverse_kinematics(target_pose))
74 | pose = robot.forward_kinematics()
75 |
76 | self.assertFalse(np.all(initial_pose.vec7 == pose.vec7))
77 | np.testing.assert_almost_equal(pose.vec7, target_pose.vec7, decimal=2)
78 |
79 |
80 | if __name__ == '__main__':
81 | tf.test.main()
82 |
--------------------------------------------------------------------------------
/language_table/eval/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Language Tale Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
--------------------------------------------------------------------------------
/language_table/eval/main.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Language Tale Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Simple offline evaluation script for language table sim."""
17 |
18 | import collections
19 | from collections.abc import Sequence
20 | import os
21 |
22 | from absl import app
23 | from absl import flags
24 | from absl import logging
25 |
26 | import jax
27 |
28 | from language_table.environments import blocks
29 | from language_table.environments import language_table
30 | from language_table.environments.oracles import push_oracle_rrt_slowdown
31 | from language_table.environments.rewards import block2absolutelocation
32 | from language_table.environments.rewards import block2block
33 | from language_table.environments.rewards import block2block_relative_location
34 | from language_table.environments.rewards import block2relativelocation
35 | from language_table.environments.rewards import separate_blocks
36 | from language_table.eval import wrappers as env_wrappers
37 | from language_table.train import policy as jax_policy
38 | from language_table.train.networks import lava
39 |
40 | import mediapy as mediapy_lib
41 | from ml_collections import config_flags
42 |
43 | import tensorflow as tf
44 | from tf_agents.environments import gym_wrapper
45 | from tf_agents.environments import wrappers as tfa_wrappers
46 |
47 | _CONFIG = config_flags.DEFINE_config_file(
48 | "config", None, "Training configuration.", lock_config=True)
49 | _WORKDIR = flags.DEFINE_string("workdir", None, "Evaluation result directory.")
50 | _CHECKPOINT_PATH = flags.DEFINE_string("checkpoint_path", None,
51 | "FLAX checkpoint path.")
52 |
53 |
54 | def evaluate_checkpoint(checkpoint_path, workdir, config):
55 | """Evaluates the given checkpoint and writes results to workdir."""
56 | video_dir = os.path.join(workdir, "videos")
57 | if not tf.io.gfile.exists(video_dir):
58 | tf.io.gfile.makedirs(video_dir)
59 | rewards = {
60 | "blocktoblock":
61 | block2block.BlockToBlockReward,
62 | "blocktoabsolutelocation":
63 | block2absolutelocation.BlockToAbsoluteLocationReward,
64 | "blocktoblockrelativelocation":
65 | block2block_relative_location.BlockToBlockRelativeLocationReward,
66 | "blocktorelativelocation":
67 | block2relativelocation.BlockToRelativeLocationReward,
68 | "separate":
69 | separate_blocks.SeparateBlocksReward,
70 | }
71 |
72 | num_evals_per_reward = 50
73 | max_episode_steps = 200
74 |
75 | policy = None
76 | model = lava.SequenceLAVMSE(action_size=2, **config.model)
77 |
78 | results = collections.defaultdict(lambda: 0)
79 | for reward_name, reward_factory in rewards.items():
80 | env = language_table.LanguageTable(
81 | block_mode=blocks.LanguageTableBlockVariants.BLOCK_8,
82 | reward_factory=reward_factory,
83 | seed=0)
84 | env = gym_wrapper.GymWrapper(env)
85 | env = env_wrappers.ClipTokenWrapper(env)
86 | env = env_wrappers.CentralCropImageWrapper(
87 | env,
88 | target_width=config.data_target_width,
89 | target_height=config.data_target_height,
90 | random_crop_factor=config.random_crop_factor)
91 | env = tfa_wrappers.HistoryWrapper(
92 | env, history_length=config.sequence_length, tile_first_step_obs=True)
93 |
94 | if policy is None:
95 | policy = jax_policy.BCJaxPyPolicy(
96 | env.time_step_spec(),
97 | env.action_spec(),
98 | model=model,
99 | checkpoint_path=checkpoint_path,
100 | rng=jax.random.PRNGKey(0))
101 |
102 | for ep_num in range(num_evals_per_reward):
103 | # Reset env. Choose new init if oracle cannot find valid motion plan.
104 | # Get an oracle. We use this at the moment to decide whether an
105 | # environment initialization is valid. If oracle can motion plan,
106 | # init is valid.
107 | oracle_policy = push_oracle_rrt_slowdown.ObstacleOrientedPushOracleBoard2dRRT(
108 | env, use_ee_planner=True)
109 | plan_success = False
110 | while not plan_success:
111 | ts = env.reset()
112 | raw_state = env.compute_state()
113 | plan_success = oracle_policy.get_plan(raw_state)
114 | if not plan_success:
115 | logging.info(
116 | "Resetting environment because the "
117 | "initialization was invalid (could not find motion plan).")
118 |
119 | frames = [env.render()]
120 |
121 | episode_steps = 0
122 | while not ts.is_last():
123 | policy_step = policy.action(ts, ())
124 | ts = env.step(policy_step.action)
125 | frames.append(env.render())
126 | episode_steps += 1
127 |
128 | if episode_steps > max_episode_steps:
129 | break
130 |
131 | success_str = ""
132 | if env.succeeded:
133 | results[reward_name] += 1
134 | logging.info("Episode %d: success.", ep_num)
135 | success_str = "success"
136 | else:
137 | logging.info("Episode %d: failure.", ep_num)
138 | success_str = "failure"
139 |
140 | # Write out video of rollout.
141 | video_path = os.path.join(workdir, "videos/",
142 | f"{reward_name}_{ep_num}_{success_str}.mp4")
143 | mediapy_lib.write_video(video_path, frames, fps=10)
144 |
145 | print(results)
146 |
147 |
148 | def main(argv):
149 | if len(argv) > 1:
150 | raise app.UsageError("Too many command-line arguments.")
151 |
152 | evaluate_checkpoint(
153 | checkpoint_path=_CHECKPOINT_PATH.value,
154 | workdir=_WORKDIR.value,
155 | config=_CONFIG.value,
156 | )
157 |
158 |
159 | if __name__ == "__main__":
160 | app.run(main)
161 |
--------------------------------------------------------------------------------
/language_table/eval/wrappers.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Language Tale Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Environment wrappers."""
17 |
18 | from typing import Any, Optional
19 |
20 | from language_table.common import clip_tokenizer
21 | import numpy as np
22 | import tensorflow as tf
23 | from tf_agents.environments import wrappers
24 |
25 |
26 | class ClipTokenWrapper(wrappers.PyEnvironmentBaseWrapper):
27 | """Environment wrapper that adds CLIP tokens to the obs."""
28 |
29 | def __init__(self, env, context_length = 77):
30 | """Centrally crops an image from a dict observation."""
31 | super(ClipTokenWrapper, self).__init__(env)
32 | self._context_length = context_length
33 | vocab_lookup = clip_tokenizer.create_vocab()
34 | self._tokenizer = clip_tokenizer.ClipTokenizer(vocab_lookup)
35 | self._current_tokens = None
36 |
37 | def _reset(self):
38 | time_step = self._env.reset()
39 | self._current_tokens = self._tokenize(time_step.observation['instruction'])
40 | new_obs = time_step.observation
41 | new_obs['instruction_tokenized_clip'] = self._current_tokens
42 | return time_step._replace(observation=new_obs)
43 |
44 | def _step(self, action):
45 | time_step = self._env.step(action)
46 | new_obs = time_step.observation
47 | new_obs['instruction_tokenized_clip'] = self._current_tokens
48 | return time_step._replace(observation=new_obs)
49 |
50 | def _tokenize(self, instruction):
51 | bytes_list = instruction
52 | non_zero = bytes_list[np.where(bytes_list != 0)]
53 | if non_zero.shape[0] == 0:
54 | decoded = ''
55 | else:
56 | bytes_list = bytes(non_zero.tolist())
57 | decoded = bytes_list.decode('utf-8')
58 | tokens = clip_tokenizer.tokenize_text(decoded, self._tokenizer)[0]
59 | return tokens
60 |
61 |
62 | class CentralCropImageWrapper(wrappers.PyEnvironmentBaseWrapper):
63 | """Environment wrapper that crops image observations."""
64 |
65 | def __init__(self,
66 | env,
67 | target_height,
68 | target_width,
69 | random_crop_factor = None):
70 | """Centrally crops an image from a dict observation."""
71 | super(CentralCropImageWrapper, self).__init__(env)
72 | self._target_height = target_height
73 | self._target_width = target_width
74 | self._random_crop_factor = random_crop_factor
75 |
76 | def _reset(self):
77 | time_step = self._env.reset()
78 | new_obs = self._crop_observation(time_step.observation)
79 | return time_step._replace(observation=new_obs)
80 |
81 | def _step(self, action):
82 | time_step = self._env.step(action)
83 | new_obs = self._crop_observation(time_step.observation)
84 | return time_step._replace(observation=new_obs)
85 |
86 | def _crop_observation(self, obs):
87 | new_obs = obs
88 | image = obs['rgb']
89 | image = tf.image.convert_image_dtype(image, dtype=tf.float32)
90 | # Apply average crop augmentation.
91 | image = crop_test_image(image, self._random_crop_factor)
92 | image = resize_images(image, self._target_height, self._target_width)
93 | new_obs['rgb'] = image.numpy()
94 | return new_obs
95 |
96 |
97 | def crop_test_image(images, random_crop_factor):
98 | """Get the average crop applied during crop training augmentation."""
99 |
100 | def take_center_crop_consistent_with_random(im):
101 | im_raw_size = tf.shape(im)
102 | raw_height = tf.cast(im_raw_size[0], tf.float32)
103 | raw_width = tf.cast(im_raw_size[1], tf.float32)
104 | scaled_height = raw_height * random_crop_factor
105 | scaled_width = raw_width * random_crop_factor
106 | offset_height = tf.cast((raw_height - scaled_height) // 2, tf.int32)
107 | offset_width = tf.cast((raw_width - scaled_width) // 2, tf.int32)
108 | target_height = tf.cast(scaled_height, tf.int32)
109 | target_width = tf.cast(scaled_width, tf.int32)
110 | im = tf.image.crop_to_bounding_box(
111 | im,
112 | offset_height=offset_height,
113 | offset_width=offset_width,
114 | target_height=target_height,
115 | target_width=target_width)
116 | return im
117 |
118 | if len(images.shape) == 3:
119 | return take_center_crop_consistent_with_random(images)
120 | images = tf.map_fn(take_center_crop_consistent_with_random, images)
121 | return images
122 |
123 |
124 | def resize_images(images, target_height=None, target_width=None):
125 | """Resizes images to target_height, target_width."""
126 | assert target_height
127 | assert target_width
128 |
129 | # Resize to target height and width.
130 | def _resize(im):
131 | return tf.image.resize(im, [target_height, target_width])
132 |
133 | images = _resize(images)
134 |
135 | return images
136 |
--------------------------------------------------------------------------------
/language_table/examples/dataset_example.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Language Tale Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Example for loading the Language-Table dataset.
17 |
18 | Language-Table data is in the [RLDS](https://github.com/google-research/rlds)
19 | format.
20 | See the [RLDS Tutorial](https://colab.research.google.com/github/
21 | google-research/rlds/blob/main/rlds/examples/rlds_tutorial.ipynb)
22 | for more details on how to use RLDS datasets.
23 | """
24 |
25 | from collections.abc import Sequence
26 |
27 | from absl import app
28 |
29 | import tensorflow_datasets as tfds
30 |
31 | dataset_paths = {
32 | 'language_table': 'gs://gresearch/robotics/language_table/0.0.1/',
33 | 'language_table_sim': 'gs://gresearch/robotics/language_table_sim/0.0.1/',
34 | }
35 |
36 |
37 | def main(argv):
38 | if len(argv) > 1:
39 | raise app.UsageError('Too many command-line arguments.')
40 |
41 | # Iterate through 5 items in language_table.
42 | builder = tfds.builder_from_directory(dataset_paths['language_table'])
43 | ds = builder.as_dataset(split='train')
44 | ds = ds.flat_map(lambda x: x['steps']) # get the dataset as individual steps
45 | for item in iter(ds.take(5)):
46 | print(item)
47 |
48 | # Iterate through 5 items in language_table_sim.
49 | builder = tfds.builder_from_directory(dataset_paths['language_table_sim'])
50 | ds = builder.as_dataset(split='train')
51 | ds = ds.flat_map(lambda x: x['steps']) # get the dataset as individual steps
52 | for item in iter(ds.take(5)):
53 | print(item)
54 |
55 |
56 | if __name__ == '__main__':
57 | app.run(main)
58 |
--------------------------------------------------------------------------------
/language_table/examples/environment_example.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Language Tale Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Example for running the Language-Table environment."""
17 |
18 | from collections.abc import Sequence
19 |
20 | from absl import app
21 |
22 | from language_table.environments import blocks
23 | from language_table.environments import language_table
24 | from language_table.environments.rewards import block2block
25 |
26 | from matplotlib import pyplot as plt
27 |
28 |
29 | def main(argv):
30 | if len(argv) > 1:
31 | raise app.UsageError('Too many command-line arguments.')
32 |
33 | env = language_table.LanguageTable(
34 | block_mode=blocks.LanguageTableBlockVariants.BLOCK_8,
35 | reward_factory=block2block.BlockToBlockReward,
36 | control_frequency=10.0,
37 | )
38 | _ = env.reset()
39 |
40 | # Take a few random actions.
41 | for _ in range(5):
42 | env.step(env.action_space.sample())
43 |
44 | # Save a rendered image.
45 | plt.imsave('/tmp/language_table_render.png', env.render())
46 |
47 |
48 | if __name__ == '__main__':
49 | app.run(main)
50 |
--------------------------------------------------------------------------------
/language_table/train/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Language Tale Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
--------------------------------------------------------------------------------
/language_table/train/bc.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Language Tale Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Behavioral Cloning Agent."""
17 |
18 | from typing import Any, List, Optional
19 |
20 | from absl import logging
21 |
22 | from clu import checkpoint
23 | from clu import metrics
24 | from clu import parameter_overview
25 |
26 | import flax
27 | from flax import linen as nn
28 | import jax
29 | import jax.numpy as jnp
30 | import optax
31 |
32 |
33 | @flax.struct.dataclass
34 | class TrainState:
35 | step: int
36 | params: Any
37 | opt_state: optax.OptState
38 | batch_stats: Any
39 | norm_info: Any
40 |
41 |
42 | @flax.struct.dataclass
43 | class TrainMetrics(metrics.Collection):
44 | """Train metrics for the IBC Agent."""
45 |
46 | learning_rate: metrics.LastValue.from_output("learning_rate")
47 | loss: metrics.Average.from_output("loss")
48 | loss_std: metrics.Std.from_output("loss")
49 |
50 |
51 | class BCAgent(object):
52 | """Behavioral Cloning Agent from ..."""
53 |
54 | def __init__(self,
55 | model,
56 | sequence_length,
57 | learning_rate,
58 | observation_statistics,
59 | action_statistics,
60 | action_min,
61 | action_max,
62 | pretrained_checkpoints = None,
63 | freeze_keys = None):
64 | """Creates the agent."""
65 | self.model = model
66 |
67 | self.sequence_length = sequence_length
68 |
69 | self.pretrained_checkpoints = pretrained_checkpoints or []
70 | self.freeze_keys = freeze_keys or []
71 |
72 | self.optimizer = None
73 |
74 | self.learning_rate = learning_rate
75 |
76 | self.observation_statistics = observation_statistics
77 | self.action_statistics = action_statistics
78 |
79 | self.action_min = action_min
80 | self.action_max = action_max
81 |
82 | logging.info("action_min=%s,action_max=%s", action_min, action_max)
83 |
84 | def create_train_state(self, batch, rng):
85 | """Creates the train state and initial metrics for agent."""
86 | obs_input = batch["observation"]
87 |
88 | rng, encoder_rng = jax.random.split(rng)
89 | variables = self.model.init(encoder_rng, obs_input, train=False)
90 |
91 | # Try to restore.
92 | flat_variables = flax.traverse_util.flatten_dict(variables, sep="/")
93 |
94 | for pretrained_checkpoint in self.pretrained_checkpoints:
95 | checkpoint_path, replacements = pretrained_checkpoint
96 | variable_dict_ckpt = checkpoint.load_state_dict(checkpoint_path)
97 | flat_variable_ckpt = flax.traverse_util.flatten_dict(
98 | variable_dict_ckpt, sep="/")
99 | for ckpt_variable_name, to_variable_name in replacements:
100 | for key in flat_variable_ckpt:
101 | if not key.startswith(ckpt_variable_name):
102 | continue
103 | variable_key = key.replace(ckpt_variable_name, to_variable_name)
104 | if variable_key in flat_variables:
105 | new_value = flat_variable_ckpt[key]
106 | flat_variables[variable_key] = new_value
107 | logging.info("Loading %s into %s: shape %s", key, variable_key,
108 | new_value.shape)
109 |
110 | variables = flax.traverse_util.unflatten_dict(flat_variables, sep="/")
111 |
112 | params = variables["params"]
113 | if variables.get("batch_stats"):
114 | batch_stats = variables["batch_stats"]
115 | else:
116 | batch_stats = {}
117 |
118 | # Optionally freeze variables.
119 | if self.freeze_keys:
120 |
121 | def _should_freeze(path):
122 | for freeze_key in self.freeze_keys:
123 | if freeze_key in path:
124 | logging.info("Freezing param: %s", path)
125 | return True
126 | logging.info("Not freezing param: %s", path)
127 | return False
128 |
129 | label_fn = flattened_traversal(
130 | lambda path, _: "zero" if _should_freeze("/".join(path)) else "adam")
131 |
132 | optimizer = optax.multi_transform(
133 | {
134 | "adam": optax.adam(learning_rate=self.learning_rate, eps=1e-7),
135 | "zero": optax.set_to_zero()
136 | }, label_fn)
137 |
138 | self.optimizer = optimizer
139 | else:
140 | self.optimizer = optax.adam(learning_rate=self.learning_rate, eps=1e-7)
141 |
142 | parameter_overview.log_parameter_overview(params)
143 | train_state = TrainState(
144 | step=0,
145 | params=params,
146 | opt_state=self.optimizer.init(params),
147 | batch_stats=batch_stats,
148 | norm_info={
149 | "observation_statistics":
150 | dict(
151 | jax.tree_util.tree_map(
152 | jnp.asarray,
153 | self.observation_statistics,
154 | is_leaf=lambda x: isinstance(x, list))),
155 | "action_statistics":
156 | dict(
157 | jax.tree_util.tree_map(
158 | jnp.asarray,
159 | self.action_statistics,
160 | is_leaf=lambda x: isinstance(x, list))),
161 | "action_min":
162 | self.action_min,
163 | "action_max":
164 | self.action_max
165 | })
166 | initial_metrics = TrainMetrics.single_from_model_output(
167 | loss=jnp.zeros((1,)),
168 | logits=jnp.zeros((1,)),
169 | negative_logits=jnp.zeros((1,)),
170 | learning_rate=jnp.zeros((1,)))
171 | return (train_state, initial_metrics)
172 |
173 | def train(self, batch, state, rng):
174 | """Performs a single training step."""
175 | logging.info("train_step(batch=%s)", batch)
176 |
177 | rng, loss_rng = jax.random.split(rng)
178 | def loss_fn(params):
179 | variables = {"params": params, "batch_stats": state.batch_stats}
180 | per_example_loss, new_variables = self.bc_loss(
181 | self.model, batch=batch, variables=variables, rng=loss_rng)
182 | loss = jnp.mean(per_example_loss)
183 | return loss, new_variables["batch_stats"]
184 |
185 | grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
186 | (loss, new_batch_stats), grad = grad_fn(state.params)
187 |
188 | # Compute average gradient across multiple workers.
189 | grad = jax.lax.pmean(grad, axis_name="batch")
190 | # Also get the average loss.
191 | loss = jax.lax.pmean(loss, axis_name="batch")
192 | updates, new_opt_state = self.optimizer.update(grad, state.opt_state,
193 | state.params)
194 |
195 | new_params = optax.apply_updates(state.params, updates)
196 | new_state = state.replace( # pytype: disable=attribute-error
197 | step=state.step + 1,
198 | params=flax.core.unfreeze(new_params), # pytype: disable=wrong-arg-types # numpy-scalars
199 | opt_state=flax.core.unfreeze(new_opt_state),
200 | batch_stats=flax.core.unfreeze(new_batch_stats))
201 |
202 | metrics_update = TrainMetrics.gather_from_model_output(
203 | loss=loss, learning_rate=self.learning_rate)
204 | return new_state, metrics_update
205 |
206 | def bc_loss(
207 | self,
208 | model,
209 | batch,
210 | variables,
211 | rng,
212 | ):
213 | """Implements the BC loss."""
214 | # Generate action counter examples.
215 | # Expand actions on dimension 1.
216 | observation = batch["observation"]
217 | action = batch["action"]
218 |
219 | # First, we encode the observations using the model.encode method.
220 | # This will give us an observation encoding (for the entire sequence).
221 | rng, params_rng = jax.random.split(rng)
222 | rng, dropout_rng = jax.random.split(rng)
223 | predicted_actions, new_variables = model.apply(
224 | variables,
225 | observation,
226 | train=True,
227 | mutable=["batch_stats"],
228 | rngs={
229 | "params": params_rng,
230 | "dropout": dropout_rng
231 | })
232 |
233 | per_example_loss = jnp.mean(jnp.square(predicted_actions - action))
234 | return per_example_loss, new_variables
235 |
236 |
237 | def flattened_traversal(fn):
238 | """Returns function that is called with `(path, param)` instead of pytree."""
239 |
240 | def mask(tree):
241 | flat = flax.traverse_util.flatten_dict(tree)
242 | return flax.traverse_util.unflatten_dict(
243 | {k: fn(k, v) for k, v in flat.items()})
244 |
245 | return mask
246 |
--------------------------------------------------------------------------------
/language_table/train/configs/language_table_resnet_sim_local.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Language Tale Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """A config for training sim human 8 block with Pretrained Resnet + BC."""
17 |
18 | import ml_collections
19 |
20 |
21 | dataset_paths = {
22 | "language_table": "gs://gresearch/robotics/language_table/0.0.1/",
23 | "language_table_sim": "gs://gresearch/robotics/language_table_sim/0.0.1/",
24 | }
25 |
26 |
27 | def get_config():
28 | """Config for training sim human 8 block with BC locally."""
29 | config = ml_collections.ConfigDict()
30 | config.binary = "language_table/train/main"
31 |
32 | config.sequence_length = 4
33 |
34 | config.model_name = "sequence_lav_mse"
35 | config.model = ml_collections.ConfigDict()
36 | config.model.dense_resnet_width = 1024
37 | config.model.dense_resnet_num_blocks = 2
38 | config.model.lava_sequence_length = config.sequence_length
39 | config.model.lava_num_layers = 4
40 | config.model.lava_temporal_transformer_num_layers = 2
41 | config.model.lava_d_model = 128
42 | config.model.lava_num_heads = 2
43 | config.model.lava_pyramid_fuse_layers = (2, 3, 4)
44 | config.model.lava_image_encoder = "resnet"
45 | config.model.lava_lang_encoder = "clip"
46 |
47 | config.agent_name = "bc"
48 | config.agent = ml_collections.ConfigDict()
49 | config.agent.learning_rate = 1e-3
50 | config.agent.pretrained_checkpoints = [
51 | # CHANGEME: Change this to a local path by running
52 | # download_clip_flax_ckpt.py.
53 | (
54 | "/tmp/scenic_clip_ckpt/",
55 | [("params/text", "params/encoder/TextEncoder_0")],
56 | )
57 | ]
58 | config.agent.freeze_keys = ["TextEncoder_0"]
59 |
60 | config.dataset_path = dataset_paths["language_table_sim"]
61 | config.data_target_width = 320
62 | config.data_target_height = 180
63 | config.image_photometric_distortions = True
64 | config.image_augment_crop = True
65 | config.random_crop_factor = 0.95
66 | config.data_normalization_num_samples = 32
67 | config.data_skip_normalize_keys = ["rgb", "instruction"]
68 | config.synthetic_data = False
69 |
70 | config.num_train_steps = 1_000_000
71 | config.per_device_batch_size = 4 # 4096 is used for 64 TPUs.
72 | config.replay_capacity = 5_000
73 | config.num_steps_per_train_iter = 1
74 |
75 | config.log_loss_every_steps = 50
76 | config.checkpoint_every_steps = 50
77 |
78 | config.seed = 42
79 |
80 | config.trial = 0 # Dummy for repeated runs.
81 | return config
82 |
83 |
84 | def get_hyper(h):
85 | return h.product(
86 | [
87 | h.sweep("config.trial", range(1)),
88 | ]
89 | )
90 |
--------------------------------------------------------------------------------
/language_table/train/configs/language_table_sim_local.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Language Tale Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """A config for training sim human 8 block with BC."""
17 |
18 | import ml_collections
19 |
20 |
21 | dataset_paths = {
22 | "language_table": "gs://gresearch/robotics/language_table/0.0.1/",
23 | "language_table_sim": "gs://gresearch/robotics/language_table_sim/0.0.1/",
24 | }
25 |
26 |
27 | def get_config():
28 | """Config for training sim human 8 block with BC locally."""
29 | config = ml_collections.ConfigDict()
30 | config.binary = "language_table/train/main"
31 |
32 | config.sequence_length = 4
33 |
34 | config.model_name = "sequence_lav_mse"
35 | config.model = ml_collections.ConfigDict()
36 | config.model.dense_resnet_width = 1024
37 | config.model.dense_resnet_num_blocks = 2
38 | config.model.lava_sequence_length = config.sequence_length
39 | config.model.lava_num_layers = 4
40 | config.model.lava_temporal_transformer_num_layers = 2
41 | config.model.lava_d_model = 128
42 | config.model.lava_num_heads = 2
43 | config.model.lava_pyramid_fuse_layers = (2, 3, 4)
44 | config.model.lava_image_encoder = "conv_maxpool"
45 | config.model.lava_lang_encoder = "clip"
46 |
47 | config.agent_name = "bc"
48 | config.agent = ml_collections.ConfigDict()
49 | config.agent.learning_rate = 1e-3
50 | config.agent.pretrained_checkpoints = [
51 | # CHANGEME: Change this to a local path by running
52 | # download_clip_flax_ckpt.py.
53 | (
54 | "/tmp/scenic_clip_ckpt/",
55 | [("params/text", "params/encoder/TextEncoder_0")],
56 | )
57 | ]
58 | config.agent.freeze_keys = ["TextEncoder_0"]
59 |
60 | config.dataset_path = dataset_paths["language_table_sim"]
61 | config.data_target_width = 320
62 | config.data_target_height = 180
63 | config.image_photometric_distortions = True
64 | config.image_augment_crop = True
65 | config.random_crop_factor = 0.95
66 | config.data_normalization_num_samples = 32
67 | config.data_skip_normalize_keys = ["rgb", "instruction"]
68 | config.synthetic_data = False
69 |
70 | config.num_train_steps = 1_000_000
71 | config.per_device_batch_size = 4 # 4096 is used for 64 TPUs.
72 | config.replay_capacity = 5_000
73 | config.num_steps_per_train_iter = 1
74 |
75 | config.log_loss_every_steps = 50
76 | config.checkpoint_every_steps = 50
77 |
78 | config.seed = 42
79 |
80 | config.trial = 0 # Dummy for repeated runs.
81 | return config
82 |
83 |
84 | def get_hyper(h):
85 | return h.product(
86 | [
87 | h.sweep("config.trial", range(1)),
88 | ]
89 | )
90 |
--------------------------------------------------------------------------------
/language_table/train/download_clip_flax_ckpt.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Language Tale Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Saves a scenic CLIP checkpoint."""
17 |
18 | from collections.abc import Sequence
19 |
20 | from absl import app
21 | from absl import flags
22 |
23 | from clu import checkpoint
24 | from scenic.projects.baselines.clip import model
25 | import tensorflow as tf
26 |
27 | _CHECKPOINT_DIRECTORY = flags.DEFINE_string(
28 | 'checkpoint_directory', None, 'The directory to save the checkpoint.')
29 |
30 |
31 | def main(argv):
32 | if len(argv) > 1:
33 | raise app.UsageError('Too many command-line arguments.')
34 |
35 | model_vars = model.load_model_vars(model_name='vit_b16')
36 |
37 | out_directory = _CHECKPOINT_DIRECTORY.value
38 | if not tf.io.gfile.exists(out_directory):
39 | tf.io.gfile.makedirs(out_directory)
40 | ckpt = checkpoint.Checkpoint(base_directory=out_directory)
41 |
42 | ckpt.save(model_vars)
43 |
44 |
45 | if __name__ == '__main__':
46 | app.run(main)
47 |
--------------------------------------------------------------------------------
/language_table/train/main.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Language Tale Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Main file for running the trainer."""
17 |
18 | from absl import app
19 | from absl import flags
20 | from absl import logging
21 |
22 | from clu import platform
23 | import jax
24 | from language_table.train import train
25 | from ml_collections import config_flags
26 | import tensorflow as tf
27 |
28 |
29 | FLAGS = flags.FLAGS
30 |
31 | config_flags.DEFINE_config_file(
32 | "config", None, "Training configuration.", lock_config=True
33 | )
34 | flags.DEFINE_string("workdir", None, "Work unit directory.")
35 | flags.DEFINE_string("tf_data_service_address", None, "TF Data address.")
36 | flags.mark_flags_as_required(["config", "workdir"])
37 | # Flags --jax_backend_target and --jax_xla_backend are available through JAX.
38 |
39 |
40 | def main(argv):
41 | del argv
42 |
43 | # Hide any GPUs form TensorFlow. Otherwise TF might reserve memory and make
44 | # it unavailable to JAX.
45 | tf.config.experimental.set_visible_devices([], "GPU")
46 |
47 | if FLAGS.jax_backend_target:
48 | logging.info("Using JAX backend target %s", FLAGS.jax_backend_target)
49 | jax_xla_backend = (
50 | "None" if FLAGS.jax_xla_backend is None else FLAGS.jax_xla_backend
51 | )
52 | logging.info("Using JAX XLA backend %s", jax_xla_backend)
53 |
54 | logging.info("JAX process: %d / %d", jax.process_index(), jax.process_count())
55 | logging.info("JAX devices: %r", jax.devices())
56 |
57 | platform.work_unit().set_task_status(
58 | f"process_index: {jax.process_index()}, "
59 | f"process_count: {jax.process_count()}"
60 | )
61 | platform.work_unit().create_artifact(
62 | platform.ArtifactType.DIRECTORY, FLAGS.workdir, "workdir"
63 | )
64 |
65 | train.train(FLAGS.config, FLAGS.workdir, FLAGS.tf_data_service_address)
66 |
67 |
68 | if __name__ == "__main__":
69 | jax.config.config_with_absl()
70 | app.run(main)
71 |
--------------------------------------------------------------------------------
/language_table/train/networks/dense_resnet.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Language Tale Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Residual dense block."""
17 |
18 | from flax import linen as nn
19 | import jax
20 |
21 |
22 | class ResnetDenseBlock(nn.Module):
23 | """Single dense resnet block."""
24 | width: int
25 |
26 | @nn.compact
27 | def __call__(self, x, *, train):
28 | normal_initializer = jax.nn.initializers.normal(stddev=0.05)
29 | y = nn.relu(x)
30 | y = nn.Dense(
31 | self.width // 4,
32 | kernel_init=normal_initializer,
33 | bias_init=normal_initializer)(
34 | y)
35 | y = nn.relu(y)
36 | y = nn.Dense(
37 | self.width // 4,
38 | kernel_init=normal_initializer,
39 | bias_init=normal_initializer)(
40 | y)
41 | y = nn.relu(y)
42 | y = nn.Dense(
43 | self.width,
44 | kernel_init=normal_initializer,
45 | bias_init=normal_initializer)(
46 | y)
47 |
48 | return x + y
49 |
50 |
51 | class DenseResnet(nn.Module):
52 | """Dense Resnet module."""
53 |
54 | width: int
55 | num_blocks: int
56 | value_net: bool
57 |
58 | @nn.compact
59 | def __call__(self, x, *, train):
60 | normal_initializer = jax.nn.initializers.normal(stddev=0.05)
61 | x = nn.Dense(
62 | self.width,
63 | kernel_init=normal_initializer,
64 | bias_init=normal_initializer)(
65 | x)
66 | for _ in range(self.num_blocks):
67 | x = ResnetDenseBlock(self.width)(x, train=train)
68 |
69 | if self.value_net:
70 | x = nn.Dense(
71 | 1, kernel_init=normal_initializer, bias_init=normal_initializer)(
72 | x)
73 | return x
74 |
--------------------------------------------------------------------------------
/language_table/train/networks/pixel.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Language Tale Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Simple Pixel + Language network impelementations."""
17 |
18 | import flax.linen as nn
19 | import jax
20 | import jax.numpy as jnp
21 |
22 | from language_table.train.networks import dense_resnet
23 |
24 |
25 | class LanguageFusion(nn.Module):
26 | """Fuses language information multiplicatively."""
27 |
28 | @nn.compact
29 | def __call__(self, lang, image):
30 | norm_init = jax.nn.initializers.normal(stddev=0.05)
31 | lang = nn.Dense(
32 | jnp.shape(image)[-1], kernel_init=norm_init, bias_init=norm_init)(
33 | lang)
34 |
35 | img_shape = jnp.shape(image)
36 | h = img_shape[1]
37 | w = img_shape[2]
38 | lang = jnp.tile(lang[:, None, None, :], [1, h, w, 1])
39 |
40 | # Fuse.
41 | fused = image * lang
42 | return fused
43 |
44 |
45 | class ConvMaxpoolLanguageEncoder(nn.Module):
46 | """Simple Conv + Maxpool encoder that multiplicatively fuses language."""
47 |
48 | @nn.compact
49 | def __call__(self, rgb, lang_embedding, *, train):
50 |
51 | x = rgb
52 |
53 | fuse_from = 2
54 | conv_channels = [32, 64, 128, 256]
55 | for idx, ch in enumerate(conv_channels):
56 | x = nn.Conv(features=ch, kernel_size=(3, 3), padding="SAME")(x)
57 |
58 | if fuse_from <= idx + 1:
59 | x = LanguageFusion()(lang_embedding, x)
60 |
61 | x = nn.relu(x)
62 | x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2), padding="VALID")
63 |
64 | x = jnp.mean(x, axis=(1, 2), keepdims=False)
65 |
66 | norm_init = jax.nn.initializers.normal(stddev=0.05)
67 | if fuse_from <= len(conv_channels) + 1:
68 | lang_info = nn.Dense(
69 | conv_channels[-1], kernel_init=norm_init, bias_init=norm_init)(
70 | lang_embedding)
71 | x *= lang_info
72 |
73 | x = nn.relu(x)
74 |
75 | x = nn.LayerNorm()(x)
76 |
77 | return x
78 |
79 |
80 | class PixelLangMSE(nn.Module):
81 | """Simple Pixel Language network."""
82 |
83 | action_size: int
84 |
85 | dense_resnet_width: int
86 | dense_resnet_num_blocks: int
87 |
88 | def setup(self):
89 | self.encoder = ConvMaxpoolLanguageEncoder()
90 | self.dense_resnet = dense_resnet.DenseResnet(
91 | width=self.dense_resnet_width,
92 | num_blocks=self.dense_resnet_num_blocks,
93 | value_net=False)
94 | norm_init = jax.nn.initializers.normal(stddev=0.05)
95 | self.action_projection = nn.Dense(
96 | self.action_size, kernel_init=norm_init, bias_init=norm_init)
97 |
98 | def __call__(self, obs, *, train):
99 | rgb = obs["rgb"]
100 | # Reshape to stack images channelwise.
101 | sh = jnp.shape(rgb)
102 | b, n, w, h, c = sh
103 | rgb = jnp.reshape(rgb, (b, w, h, c * n))
104 |
105 | lang = obs["clip_embedding"]
106 | lang = lang[:, -1, Ellipsis]
107 | encoded_obs = self.encoder(rgb, lang, train=train)
108 |
109 | x = self.dense_resnet(encoded_obs, train=train)
110 | x = self.action_projection(x)
111 | return x
112 |
--------------------------------------------------------------------------------
/language_table/train/networks/resnet_v1.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Language Tale Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Implementation of ResNet V1 in Flax.
17 |
18 | "Deep Residual Learning for Image Recognition"
19 | He et al., 2015, [https://arxiv.org/abs/1512.03385]
20 | """
21 |
22 | # See issue #620.
23 | # pytype: disable=wrong-arg-count
24 | # pytype: disable=missing-parameter
25 | # pytype: disable=wrong-keyword-args
26 |
27 | import functools
28 | from typing import Any, Tuple, Type
29 |
30 | import flax.linen as nn
31 | import jax.numpy as jnp
32 |
33 | Conv1x1 = functools.partial(nn.Conv, kernel_size=(1, 1), use_bias=False)
34 | Conv3x3 = functools.partial(nn.Conv, kernel_size=(3, 3), use_bias=False)
35 |
36 |
37 | class ResNetBlock(nn.Module):
38 | """ResNet block without bottleneck used in ResNet-18 and ResNet-34."""
39 |
40 | filters: int
41 | norm: Any
42 | strides: Tuple[int, int] = (1, 1)
43 |
44 | @nn.compact
45 | def __call__(self, x):
46 | residual = x
47 |
48 | x = Conv3x3(self.filters, strides=self.strides, name="conv1")(x)
49 | x = self.norm(name="bn1")(x)
50 | x = nn.relu(x)
51 | x = Conv3x3(self.filters, name="conv2")(x)
52 | # Initializing the scale to 0 has been common practice since "Fixup
53 | # Initialization: Residual Learning Without Normalization" Tengyu et al,
54 | # 2019, [https://openreview.net/forum?id=H1gsz30cKX].
55 | x = self.norm(scale_init=nn.initializers.zeros, name="bn2")(x)
56 |
57 | if residual.shape != x.shape:
58 | residual = Conv1x1(
59 | self.filters, strides=self.strides, name="proj_conv")(
60 | residual)
61 | residual = self.norm(name="proj_bn")(residual)
62 |
63 | x = nn.relu(residual + x)
64 | return x
65 |
66 |
67 | class BottleneckResNetBlock(ResNetBlock):
68 | """Bottleneck ResNet block used in ResNet-50 and larger."""
69 |
70 | @nn.compact
71 | def __call__(self, x):
72 | residual = x
73 |
74 | x = Conv1x1(self.filters, name="conv1")(x)
75 | x = self.norm(name="bn1")(x)
76 | x = nn.relu(x)
77 | x = Conv3x3(self.filters, strides=self.strides, name="conv2")(x)
78 | x = self.norm(name="bn2")(x)
79 | x = nn.relu(x)
80 | x = Conv1x1(4 * self.filters, name="conv3")(x)
81 | # Initializing the scale to 0 has been common practice since "Fixup
82 | # Initialization: Residual Learning Without Normalization" Tengyu et al,
83 | # 2019, [https://openreview.net/forum?id=H1gsz30cKX].
84 | x = self.norm(name="bn3")(x)
85 |
86 | if residual.shape != x.shape:
87 | residual = Conv1x1(
88 | 4 * self.filters, strides=self.strides, name="proj_conv")(
89 | residual)
90 | residual = self.norm(name="proj_bn")(residual)
91 |
92 | x = nn.relu(residual + x)
93 | return x
94 |
95 |
96 | class ResNetStage(nn.Module):
97 | """ResNet stage consistent of multiple ResNet blocks."""
98 |
99 | stage_size: int
100 | filters: int
101 | block_cls: Type[ResNetBlock]
102 | norm: Any
103 | first_block_strides: Tuple[int, int]
104 |
105 | @nn.compact
106 | def __call__(self, x):
107 | for i in range(self.stage_size):
108 | x = self.block_cls(
109 | filters=self.filters,
110 | norm=self.norm,
111 | strides=self.first_block_strides if i == 0 else (1, 1),
112 | name=f"block{i + 1}")(
113 | x)
114 | return x
115 |
116 |
117 | class ResNet(nn.Module):
118 | """Construct ResNet V1 with `num_classes` outputs.
119 |
120 | Attributes:
121 | num_classes: Number of nodes in the final layer.
122 | block_cls: Class for the blocks. ResNet-50 and larger use
123 | `BottleneckResNetBlock` (convolutions: 1x1, 3x3, 1x1), ResNet-18 and
124 | ResNet-34 use `ResNetBlock` without bottleneck (two 3x3 convolutions).
125 | stage_sizes: Tuple with the number of ResNet blocks in each stage. Number of
126 | stages can be varied.
127 | width_factor: Factor applied to the number of filters. The 64 * width_factor
128 | is the number of filters in the first stage, every consecutive stage
129 | doubles the number of filters.
130 | """
131 | num_classes: int
132 | block_cls: Type[ResNetBlock]
133 | stage_sizes: Tuple[int]
134 | width_factor: int = 1
135 |
136 | @nn.compact
137 | def __call__(self, x, *, train):
138 | """Apply the ResNet to the inputs `x`.
139 |
140 | Args:
141 | x: Inputs.
142 | train: Whether to use BatchNorm in training or inference mode.
143 |
144 | Returns:
145 | The output head with `num_classes` entries.
146 | """
147 | width = 64 * self.width_factor
148 | norm = functools.partial(
149 | nn.BatchNorm, use_running_average=not train, momentum=0.9)
150 |
151 | # Root block
152 | x = nn.Conv(
153 | features=width,
154 | kernel_size=(7, 7),
155 | strides=(2, 2),
156 | use_bias=False,
157 | name="init_conv")(
158 | x)
159 | x = norm(name="init_bn")(x)
160 | x = nn.relu(x)
161 | x = nn.max_pool(x, (3, 3), strides=(2, 2), padding="SAME")
162 |
163 | # Stages
164 | for i, stage_size in enumerate(self.stage_sizes):
165 | x = ResNetStage(
166 | stage_size,
167 | filters=width * 2**i,
168 | block_cls=self.block_cls,
169 | norm=norm,
170 | first_block_strides=(1, 1) if i == 0 else (2, 2),
171 | name=f"stage{i + 1}")(
172 | x)
173 |
174 | # Head
175 | x = jnp.mean(x, axis=(1, 2))
176 | x = nn.Dense(
177 | self.num_classes, kernel_init=nn.initializers.zeros, name="head")(
178 | x)
179 | return x
180 |
181 |
182 | ResNet18 = functools.partial(
183 | ResNet, stage_sizes=(2, 2, 2, 2), block_cls=ResNetBlock)
184 | ResNet34 = functools.partial(
185 | ResNet, stage_sizes=(3, 4, 6, 3), block_cls=ResNetBlock)
186 | ResNet50 = functools.partial(
187 | ResNet, stage_sizes=(3, 4, 6, 3), block_cls=BottleneckResNetBlock)
188 | ResNet101 = functools.partial(
189 | ResNet, stage_sizes=(3, 4, 23, 3), block_cls=BottleneckResNetBlock)
190 | ResNet152 = functools.partial(
191 | ResNet, stage_sizes=(3, 8, 36, 3), block_cls=BottleneckResNetBlock)
192 | ResNet200 = functools.partial(
193 | ResNet, stage_sizes=(3, 24, 36, 3), block_cls=BottleneckResNetBlock)
194 |
195 |
196 | class MultiscaleResNet(nn.Module):
197 | """Construct a multiscale ResNet.
198 |
199 | Attributes:
200 | block_cls: Class for the blocks. ResNet-50 and larger use
201 | `BottleneckResNetBlock` (convolutions: 1x1, 3x3, 1x1), ResNet-18 and
202 | ResNet-34 use `ResNetBlock` without bottleneck (two 3x3 convolutions).
203 | stage_sizes: Tuple with the number of ResNet blocks in each stage. Number of
204 | stages can be varied.
205 | width_factor: Factor applied to the number of filters. The 64 * width_factor
206 | is the number of filters in the first stage, every consecutive stage
207 | doubles the number of filters.
208 | """
209 | block_cls: Type[ResNetBlock]
210 | stage_sizes: Tuple[int, Ellipsis]
211 | width_factor: int = 1
212 |
213 | @nn.compact
214 | def __call__(self, x, *, train):
215 | """Apply the ResNet to the inputs `x`.
216 |
217 | Args:
218 | x: Inputs.
219 | train: Whether to use BatchNorm in training or inference mode.
220 |
221 | Returns:
222 | The output head with `num_classes` entries.
223 | """
224 | width = 64 * self.width_factor
225 | norm = functools.partial(
226 | nn.BatchNorm, use_running_average=not train, momentum=0.9)
227 |
228 | all_outputs = []
229 |
230 | # Root block
231 | x = nn.Conv(
232 | features=width,
233 | kernel_size=(7, 7),
234 | strides=(2, 2),
235 | use_bias=False,
236 | name="init_conv")(
237 | x)
238 |
239 | all_outputs.append(x)
240 |
241 | x = norm(name="init_bn")(x)
242 | x = nn.relu(x)
243 | x = nn.max_pool(x, (3, 3), strides=(2, 2), padding="SAME")
244 |
245 | all_outputs.append(x)
246 |
247 | # Stages
248 | for i, stage_size in enumerate(self.stage_sizes):
249 | x = ResNetStage(
250 | stage_size,
251 | filters=width * 2**i,
252 | block_cls=self.block_cls,
253 | norm=norm,
254 | first_block_strides=(1, 1) if i == 0 else (2, 2),
255 | name=f"stage{i + 1}")(
256 | x)
257 | all_outputs.append(x)
258 |
259 | return all_outputs
260 |
--------------------------------------------------------------------------------
/language_table/train/networks/resnet_v1_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Language Tale Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | from absl.testing import parameterized
17 | from clu import parameter_overview
18 | import jax
19 | import jax.numpy as jnp
20 | from language_table.train.networks import resnet_v1
21 | import tensorflow as tf
22 |
23 |
24 | class ResNetV1Test(tf.test.TestCase, parameterized.TestCase):
25 | """Test cases for ResNet V1."""
26 |
27 | @parameterized.named_parameters(
28 | ("ResNet18", resnet_v1.ResNet18, 11_689_512),
29 | ("ResNet34", resnet_v1.ResNet34, 21_797_672),
30 | ("ResNet50", resnet_v1.ResNet50, 25_557_032),
31 | ("ResNet101", resnet_v1.ResNet101, 44_549_160),
32 | ("ResNet152", resnet_v1.ResNet152, 60_192_808),
33 | ("ResNet200", resnet_v1.ResNet200, 64_673_832),
34 | )
35 | def test_architecture(self, cls, param_count):
36 | rng = jax.random.PRNGKey(0)
37 | model = cls(num_classes=1000)
38 | variables = model.init(rng, jnp.ones([2, 224, 224, 3]), train=False)
39 | params = variables["params"]
40 | self.assertEqual(param_count, parameter_overview.count_parameters(params))
41 |
42 |
43 | if __name__ == "__main__":
44 | tf.test.main()
45 |
--------------------------------------------------------------------------------
/language_table/train/policy.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Language Tale Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """PyPolicy for BC Jax."""
17 |
18 | from flax.training import checkpoints
19 | import jax
20 | import jax.numpy as jnp
21 | import numpy as np
22 | from tf_agents.policies import py_policy
23 | from tf_agents.trajectories import policy_step
24 |
25 | EPS = jnp.finfo(jnp.float32).eps
26 |
27 |
28 | class BCJaxPyPolicy(py_policy.PyPolicy):
29 | """Runs inference with a BC policy."""
30 |
31 | def __init__(self, time_step_spec, action_spec, model, checkpoint_path,
32 | rng, params=None, action_statistics=None):
33 | super(BCJaxPyPolicy, self).__init__(time_step_spec, action_spec)
34 | self.model = model
35 | self.rng = rng
36 |
37 | if params is not None and action_statistics is not None:
38 | variables = {
39 | "params": params,
40 | "batch_stats": {}
41 | }
42 | else:
43 | state_dict = checkpoints.restore_checkpoint(checkpoint_path, None)
44 | variables = {
45 | "params": state_dict["params"],
46 | "batch_stats": state_dict["batch_stats"]
47 | }
48 |
49 | if action_statistics is not None:
50 | self.action_mean = np.array(action_statistics["mean"])
51 | self.action_std = np.array(action_statistics["std"])
52 | else:
53 | # We can load the observation and action statistics from the state dict.
54 | self.action_mean = np.array(
55 | state_dict["norm_info"]["action_statistics"]["mean"])
56 | self.action_std = np.array(
57 | state_dict["norm_info"]["action_statistics"]["std"])
58 |
59 | self._rgb_mean = jnp.array(
60 | state_dict["norm_info"]["observation_statistics"]["rgb"]["mean"])
61 | self._rgb_std = jnp.array(
62 | state_dict["norm_info"]["observation_statistics"]["rgb"]["std"])
63 |
64 | self.variables = variables
65 |
66 | self._run_action_inference_jit = jax.jit(self._run_action_inference)
67 |
68 | def _run_action_inference(self, observation):
69 | # Add a batch dim.
70 | observation = jax.tree.map(lambda x: jnp.expand_dims(x, 0), observation)
71 |
72 | normalized_action = self.model.apply(
73 | self.variables, observation, train=False)
74 | action = (
75 | normalized_action * jnp.maximum(self.action_std, EPS) +
76 | self.action_mean)
77 |
78 | # Clip the action to spec.
79 | action = jnp.clip(action, self.action_spec.minimum,
80 | self.action_spec.maximum)
81 |
82 | return action
83 |
84 | def _action(self, time_step, policy_state=(), seed=0):
85 | observation = time_step.observation
86 | action = self._run_action_inference_jit(observation)[0]
87 | return policy_step.PolicyStep(action=action)
88 |
--------------------------------------------------------------------------------
/language_table/train/train.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Language Tale Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """BC trainer for Language-Table."""
17 | import functools
18 | import os
19 | from typing import Optional
20 | from absl import logging
21 | from clu import checkpoint
22 | from clu import metric_writers
23 | from clu import parameter_overview
24 | from clu import periodic_actions
25 | import flax.jax_utils as flax_utils
26 | from flax.training import checkpoints as flax_checkpoints
27 | import jax
28 | from language_table.train import bc
29 | from language_table.train import input_pipeline_rlds
30 | from language_table.train.networks import lava
31 | from language_table.train.networks import pixel
32 | import ml_collections
33 | import tensorflow as tf
34 |
35 |
36 | def multi_train_step(state, batches, agent, initial_metrics, rng):
37 | """Runs multiple iterations of a train step at once."""
38 | num_batches = batches["action"].shape[0]
39 |
40 | def _body_fun(step, state_and_metrics):
41 | state, _ = state_and_metrics
42 | # Important to call fold_in since `rng` is the same on every call.
43 | train_rng = jax.random.fold_in(rng, state.step)
44 | train_rng = jax.random.fold_in(train_rng, jax.lax.axis_index("batch"))
45 | new_state, metrics_update = agent.train(
46 | state=state,
47 | batch=jax.tree.map(lambda x: x[step], batches),
48 | rng=train_rng,
49 | )
50 | return new_state, metrics_update
51 |
52 | return jax.lax.fori_loop(
53 | lower=0,
54 | upper=num_batches,
55 | body_fun=_body_fun,
56 | init_val=(state, initial_metrics),
57 | )
58 |
59 |
60 | def train(
61 | config,
62 | workdir,
63 | tf_data_service_address,
64 | ):
65 | """Runs a training loop.
66 |
67 | Args:
68 | config: Configuration to use.
69 | workdir: Working directory for checkpoints and TF summaries. If this
70 | contains checkpoint training will be resumed from the latest checkpoint.
71 | tf_data_service_address: Address of the TF Data Service. If None, all
72 | dataset computation will run locally.
73 | """
74 | tf.io.gfile.makedirs(workdir)
75 | rng = jax.random.PRNGKey(config.seed)
76 | logging.info(
77 | "Global batch size = %d",
78 | jax.device_count() * config.per_device_batch_size,
79 | )
80 | rng, data_rng = jax.random.split(rng)
81 | data_rng = jax.random.fold_in(data_rng, jax.process_index())
82 | train_ds, obs_statistics, act_statistics, min_actions, max_actions = (
83 | input_pipeline_rlds.create_datasets(
84 | data_rng,
85 | dataset_path=config.dataset_path,
86 | normalization_path=os.path.join(workdir, "norm_info"),
87 | sequence_length=config.sequence_length,
88 | per_device_batch_size=config.per_device_batch_size,
89 | num_steps_per_train_iter=config.num_steps_per_train_iter,
90 | target_width=config.data_target_width,
91 | target_height=config.data_target_height,
92 | random_crop_factor=config.random_crop_factor,
93 | normalization_num_samples=config.data_normalization_num_samples,
94 | skip_normalize_keys=config.data_skip_normalize_keys,
95 | cache=True,
96 | shuffle=True,
97 | shuffle_buffer_size=config.replay_capacity,
98 | prefetch_size=tf.data.AUTOTUNE,
99 | tf_data_service_address=tf_data_service_address,
100 | )
101 | )
102 | train_iter = train_ds.as_numpy_iterator()
103 | rng, agent_rng = jax.random.split(rng)
104 | sample_batch = jax.tree.map(lambda x: x[0][0], next(train_iter))
105 | agent = create_agent(
106 | config.agent_name,
107 | config.model_name,
108 | config.agent,
109 | config.model,
110 | config.sequence_length,
111 | sample_batch,
112 | obs_statistics,
113 | act_statistics,
114 | min_actions,
115 | max_actions,
116 | )
117 | state, initial_metrics = agent.create_train_state(sample_batch, agent_rng)
118 | # Save a file with the agent parameters.
119 | if jax.process_index() == 0:
120 | with tf.io.gfile.GFile(os.path.join(workdir, "parameters.txt"), "w") as f:
121 | f.write(parameter_overview.get_parameter_overview(state.params))
122 | # Set up checkpointing of the model and the input pipeline.
123 | checkpoint_dir = os.path.join(workdir, "checkpoints")
124 | ckpt = checkpoint.MultihostCheckpoint(checkpoint_dir, max_to_keep=1_000_000)
125 | state = ckpt.restore_or_initialize(state)
126 | initial_step = int(state.step)
127 | initial_multi_step = initial_step // config.num_steps_per_train_iter
128 | if jax.process_index() == 0:
129 | flax_checkpoint_dir = os.path.join(workdir, "flax_checkpoints")
130 | tf.io.gfile.makedirs(flax_checkpoint_dir)
131 | flax_checkpoints.save_checkpoint(
132 | flax_checkpoint_dir,
133 | state,
134 | step=int(state.step),
135 | keep=10_000,
136 | keep_every_n_steps=config.checkpoint_every_steps,
137 | overwrite=True,
138 | )
139 | # Distribute training.
140 | state = flax_utils.replicate(state)
141 | rng, train_rng = jax.random.split(rng)
142 | # Create the pmapped multi_train_step.
143 | p_train_step = jax.pmap(
144 | functools.partial(
145 | multi_train_step,
146 | agent=agent,
147 | initial_metrics=initial_metrics,
148 | rng=train_rng,
149 | ),
150 | axis_name="batch",
151 | )
152 | if config.num_train_steps > 0:
153 | num_train_steps = config.num_train_steps
154 | logging.info("num_train_steps=%d", num_train_steps)
155 | writer = metric_writers.create_default_writer(
156 | workdir, just_logging=jax.process_index() > 0
157 | )
158 | if initial_step == 0:
159 | writer.write_hparams(dict(config))
160 | # We use a single thread threadpool for saving checkpoints, so that we're not
161 | # blocked.
162 | hooks = []
163 | report_progress = periodic_actions.ReportProgress(
164 | num_train_steps=num_train_steps, writer=writer
165 | )
166 | if jax.process_index() == 0:
167 | hooks += [
168 | report_progress,
169 | ]
170 | train_metrics = None
171 | logging.info("Starting training loop at step %d.", initial_step)
172 | with metric_writers.ensure_flushes(writer):
173 | for multi_step in range(
174 | initial_multi_step, num_train_steps // config.num_steps_per_train_iter
175 | ):
176 | # Compute the actual step by multiplying the amount of steps we run
177 | # per train iteration.
178 | step = multi_step * config.num_steps_per_train_iter
179 | # `step` is a Python integer. `state.step` is JAX integer on the GPU/TPU
180 | # devices.
181 | is_last_step = step == num_train_steps
182 | with jax.profiler.StepTraceAnnotation("train", step_num=step):
183 | batch = next(train_iter)
184 | state, metrics_update = p_train_step(state=state, batches=batch)
185 | metric_update = flax_utils.unreplicate(metrics_update)
186 | train_metrics = (
187 | metric_update
188 | if train_metrics is None
189 | else train_metrics.merge(metric_update)
190 | )
191 | logging.log_first_n(
192 | logging.INFO, "Finished multi training step %d.", 5, multi_step
193 | )
194 | logging.log_first_n(logging.INFO, "Finished training step %d.", 5, step)
195 | for h in hooks:
196 | h(step)
197 | if step % config.log_loss_every_steps == 0 or is_last_step:
198 | writer.write_scalars(step, train_metrics.compute())
199 | train_metrics = None
200 |
201 | def _save_checkpoints(state, step):
202 | with report_progress.timed("checkpoint"):
203 | unreplicated_state = flax_utils.unreplicate(state)
204 | ckpt.save(unreplicated_state)
205 | if jax.process_index() == 0:
206 | flax_checkpoints.save_checkpoint(
207 | flax_checkpoint_dir,
208 | unreplicated_state,
209 | step=step,
210 | keep=10_000,
211 | keep_every_n_steps=config.checkpoint_every_steps,
212 | overwrite=True,
213 | )
214 |
215 | if step % config.checkpoint_every_steps == 0 or is_last_step:
216 | state = merge_batch_stats(state)
217 | _save_checkpoints(state, step)
218 | logging.info("Finishing training at step %d", num_train_steps)
219 |
220 |
221 | def create_agent(
222 | agent_name,
223 | model_name,
224 | agent_config,
225 | model_config,
226 | sequence_length,
227 | sample_batch,
228 | obs_statistics,
229 | act_statistics,
230 | min_actions,
231 | max_actions,
232 | ):
233 | """Creates an agent using an agent and model config."""
234 | # Automatically infer the action size.
235 | action = sample_batch["action"]
236 | action_size = action.shape[-1]
237 | if model_name == "sequence_lav_mse":
238 | model = lava.SequenceLAVMSE(action_size=action_size, **model_config)
239 | elif model_name == "pixel_lang_mse":
240 | model = pixel.PixelLangMSE(action_size=action_size, **model_config)
241 | else:
242 | raise NotImplementedError(f"{model_name} not implemented.")
243 | if agent_name == "bc":
244 | agent = bc.BCAgent(
245 | model=model,
246 | sequence_length=sequence_length,
247 | observation_statistics=obs_statistics,
248 | action_statistics=act_statistics,
249 | action_min=min_actions,
250 | action_max=max_actions,
251 | **agent_config,
252 | )
253 | else:
254 | raise NotImplementedError(f"{agent_name} not implemented.")
255 | return agent
256 |
257 |
258 | def merge_batch_stats(replicated_state):
259 | """Merge model batch stats."""
260 | if jax.tree.leaves(replicated_state.batch_stats):
261 | cross_replica_mean = jax.pmap(lambda x: jax.lax.pmean(x, "x"), "x")
262 | return replicated_state.replace(
263 | batch_stats=cross_replica_mean(replicated_state.batch_stats)
264 | )
265 | else:
266 | return replicated_state
267 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | clu
2 | dm-reverb-nightly>=0.9.0.dev20221205
3 | gym<=0.23.0
4 | matplotlib
5 | numpy
6 | opencv-python
7 | protobuf
8 | pybullet
9 | rlds>=0.1.7
10 | scipy
11 | six
12 | tf-nightly>=2.12.0.dev20230201
13 | tensorflow_datasets
14 | tensorflow_text
15 | tf_agents>=0.14.0
16 | git+https://github.com/openai/CLIP.git
17 | git+https://github.com/google-research/scenic.git
18 |
--------------------------------------------------------------------------------
/requirements_static.txt:
--------------------------------------------------------------------------------
1 | absl-py==1.2.0
2 | anyio==3.6.2
3 | argon2-cffi==21.3.0
4 | argon2-cffi-bindings==21.2.0
5 | asttokens==2.1.0
6 | astunparse==1.6.3
7 | backcall==0.2.0
8 | bleach==5.0.1
9 | cached-property==1.5.2
10 | cachetools==5.2.0
11 | certifi==2022.12.7
12 | cffi==1.15.1
13 | charset-normalizer==3.0.1
14 | chex==0.1.5
15 | clip @ git+https://github.com/openai/CLIP.git@3702849800aa56e2223035bccd1c6ef91c704ca8
16 | cloudpickle==2.2.0
17 | clu==0.0.8
18 | colorama==0.4.5
19 | commonmark==0.9.1
20 | contextlib2==21.6.0
21 | contourpy==1.0.6
22 | cycler==0.11.0
23 | debugpy==1.6.4
24 | decorator==5.1.1
25 | defusedxml==0.7.1
26 | dill==0.3.6
27 | dm-reverb-nightly==0.11.0.dev20230205
28 | dm-tree==0.1.7
29 | entrypoints==0.4
30 | etils==0.6.0
31 | executing==1.2.0
32 | fastjsonschema==2.16.2
33 | flatbuffers==22.10.26
34 | flax==0.5.3
35 | fonttools==4.38.0
36 | ftfy==6.1.1
37 | gast==0.4.0
38 | gin-config==0.5.0
39 | google-api-core==2.10.1
40 | google-auth==2.10.0
41 | google-auth-oauthlib==0.4.6
42 | google-cloud-speech==2.16.0
43 | google-pasta==0.2.0
44 | googleapis-common-protos==1.56.4
45 | grain==0.1.4
46 | grpcio==1.49.1
47 | grpcio-status==1.49.1
48 | gym==0.23.0
49 | gym-notices==0.0.8
50 | h5py==3.7.0
51 | idna==3.4
52 | imageio==2.28.0
53 | immutabledict==2.2.3
54 | importlib-resources==5.9.0
55 | install==1.3.5
56 | ipykernel==6.17.1
57 | ipython==8.7.0
58 | ipython-genutils==0.2.0
59 | ipywidgets==8.0.2
60 | jax==0.3.15
61 | jaxlib==0.3.15
62 | jedi==0.18.2
63 | Jinja2==3.1.2
64 | jsonschema==4.17.1
65 | jupyter==1.0.0
66 | jupyter-console==6.4.4
67 | jupyter-server==1.23.3
68 | jupyter_client==7.4.7
69 | jupyter_core==5.1.0
70 | jupyterlab-pygments==0.2.2
71 | jupyterlab-widgets==3.0.3
72 | keras-nightly==2.13.0.dev2023020508
73 | Keras-Preprocessing==1.1.2
74 | kiwisolver==1.4.4
75 | libclang==14.0.6
76 | Markdown==3.4.1
77 | MarkupSafe==2.1.1
78 | matplotlib==3.6.2
79 | matplotlib-inline==0.1.6
80 | mediapy==1.1.6
81 | mistune==2.0.4
82 | ml-collections==0.1.1
83 | msgpack==1.0.4
84 | mypy-protobuf==3.1.0
85 | nbclassic==0.4.8
86 | nbclient==0.7.0
87 | nbconvert==7.2.5
88 | nbformat==5.7.0
89 | nest-asyncio==1.5.6
90 | notebook==6.5.2
91 | notebook_shim==0.2.2
92 | numpy==1.23.5
93 | nvidia-cublas-cu11==11.10.3.66
94 | nvidia-cuda-nvrtc-cu11==11.7.99
95 | nvidia-cuda-runtime-cu11==11.7.99
96 | nvidia-cudnn-cu11==8.5.0.96
97 | oauthlib==3.2.2
98 | opencv-python==4.6.0.66
99 | opt-einsum==3.3.0
100 | optax @ git+https://github.com/deepmind/optax.git@d9157132b8661e1453d156f591784ff27d6d1141
101 | packaging==23.0
102 | pandocfilters==1.5.0
103 | parso==0.8.3
104 | pexpect==4.8.0
105 | pickleshare==0.7.5
106 | Pillow==9.4.0
107 | platformdirs==2.5.4
108 | portpicker==1.5.2
109 | prometheus-client==0.15.0
110 | promise==2.3
111 | prompt-toolkit==3.0.33
112 | proto-plus==1.22.1
113 | protobuf==4.21.12
114 | psutil==5.9.4
115 | ptyprocess==0.7.0
116 | pure-eval==0.2.2
117 | pyasn1==0.4.8
118 | pyasn1-modules==0.2.8
119 | PyAudio==0.2.12
120 | pybullet==3.2.5
121 | pycparser==2.21
122 | pygame==2.1.0
123 | Pygments==2.14.0
124 | pyparsing==3.0.9
125 | pyrsistent==0.19.2
126 | python-dateutil==2.8.2
127 | PyYAML==6.0
128 | pyzmq==24.0.1
129 | qtconsole==5.4.0
130 | QtPy==2.3.0
131 | regex==2022.10.31
132 | requests==2.28.2
133 | requests-oauthlib==1.3.1
134 | rich==11.2.0
135 | rlds==0.1.7
136 | rsa==4.9
137 | scenic @ git+https://github.com/google-research/scenic.git@ae21d9e884015aa7bc7cf1d489af53d16c249726
138 | scipy==1.9.3
139 | Send2Trash==1.8.0
140 | six==1.16.0
141 | sniffio==1.3.0
142 | stack-data==0.6.2
143 | tb-nightly==2.12.0a20230205
144 | tensorboard==2.11.0
145 | tensorboard-data-server==0.7.0
146 | tensorboard-plugin-wit==1.8.1
147 | tensorflow-addons==0.19.0
148 | tensorflow-datasets==4.7.0
149 | tensorflow-io-gcs-filesystem==0.26.0
150 | tensorflow-metadata==1.10.0
151 | tensorflow-probability==0.17.0
152 | tensorflow-text-nightly==2.12.0.dev20230203
153 | tensorstore==0.1.22
154 | termcolor==1.1.0
155 | terminado==0.17.0
156 | tf-agents==0.14.0
157 | tf-estimator-nightly==2.13.0.dev2023020509
158 | tf-hub-nightly==0.13.0.dev202302070818
159 | tf-nightly==2.12.0.dev20230203
160 | toml==0.10.2
161 | toolz==0.12.0
162 | torch==1.13.1
163 | torchvision==0.14.1
164 | tornado==6.2
165 | tqdm==4.64.1
166 | traitlets==5.5.0
167 | typeguard==2.13.3
168 | types-protobuf==3.20.4
169 | typing_extensions==4.3.0
170 | urllib3==1.26.14
171 | wcwidth==0.2.5
172 | websocket-client==1.4.2
173 | Werkzeug==2.2.2
174 | widgetsnbextension==4.0.3
175 | wrapt==1.14.1
176 | zipp==3.8.1
177 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Language Tale Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Setup."""
17 |
18 | from distutils import core
19 | import os
20 |
21 | from setuptools import find_namespace_packages
22 |
23 |
24 | here = os.path.abspath(os.path.dirname(__file__))
25 | try:
26 | README = open(os.path.join(here, 'README.md'), encoding='utf-8').read()
27 | except IOError:
28 | README = ''
29 |
30 |
31 | install_requires = [
32 | 'clu',
33 | 'dm-reverb-nightly>=0.9.0.dev20221205',
34 | 'gym<=0.23.0',
35 | 'matplotlib',
36 | 'numpy',
37 | 'opencv-python',
38 | 'protobuf',
39 | 'pybullet',
40 | 'rlds>=0.1.7',
41 | 'scipy',
42 | 'six',
43 | 'tf-nightly>=2.12.0.dev20230201',
44 | 'tensorflow_datasets>=4.7.0',
45 | 'tf_agents>=0.14.0',
46 | ]
47 |
48 |
49 | core.setup(
50 | name='language_table',
51 | version='0.1',
52 | description=(
53 | 'Language-Table is a suite of human-collected datasets and a multi-task'
54 | ' continuous control benchmark for open vocabulary visuolinguomotor'
55 | ' learning.'
56 | ),
57 | long_description='\n\n'.join([README]),
58 | long_description_content_type='text/markdown',
59 | author='Language Table Team',
60 | author_email='language-table-team@google.com',
61 | url='https://github.com/google-research/language-table',
62 | packages=find_namespace_packages(),
63 | install_requires=install_requires,
64 | include_package_data=True,
65 | )
66 |
--------------------------------------------------------------------------------