├── README.md ├── __init__.py ├── bin ├── __init__.py ├── decode_seq2act.sh ├── eval_seq2act.sh ├── seq2act_decode.py ├── seq2act_train_eval.py ├── setup_test.py └── train_seq2act.sh ├── ckpt_hparams ├── grounding │ └── hparams.json └── tuple_extract │ └── hparams.json ├── data_generation ├── README.md ├── common.py ├── config.py ├── create_android_synthetic_dataset.py ├── create_rico_sca.sh ├── create_token_vocab.py ├── filter.txt ├── proto_utils.py ├── requirements.txt ├── resources.py ├── string_utils.py ├── synthetic_action_generator.py └── view_hierarchy.py ├── layers ├── __init__.py ├── area_utils.py ├── common_embed.py └── encode_screen.py ├── models ├── __init__.py ├── input.py ├── seq2act_estimator.py ├── seq2act_grounding.py ├── seq2act_model.py └── seq2act_reference.py ├── requirements.txt ├── run.sh └── utils ├── __init__.py └── decode_utils.py /README.md: -------------------------------------------------------------------------------- 1 | # Seq2act: Mapping Natural Language Instructions to Mobile UI Action Sequences 2 | This repository contains the code for the models and the experimental framework for "Mapping Natural Language Instructions to Mobile UI Action Sequences" by Yang Li, Jiacong He, Xin Zhou, Yuan Zhang, and Jason Baldridge, which is accepted in 2020 Annual Conference of the Association for Computational Linguistics (ACL 2020). 3 | 4 | ## Datasets 5 | 6 | The data pipelines will be available in future updates. 7 | 8 | ## Setup 9 | 10 | Install the packages that required by our codebase, and perform a test over the setup by running a minimal verion of the model and the experimental framework. 11 | 12 | ``` 13 | sh seq2act/run.sh 14 | ``` 15 | 16 | ## Run Experiments. 17 | 18 | * Train (and continuously evaluate) seq2act Phrase Tuple Extraction models. 19 | 20 | ``` 21 | sh seq2act/bin/train_seq2act.sh --experiment_dir=your_exp_dir --train=parse --hparam_file=./seq2act/ckpt_hparams/tuple_extract 22 | ``` 23 | 24 | * Train (and continuously evaluate) seq2act Grounding models. 25 | 26 | ``` 27 | sh seq2act/bin/train_seq2act.sh --experiment_dir=your_exp_dir --train=grou nd --hparam_file=./seq2act/ckpt_hparams/grounding 28 | ``` 29 | 30 | * Test the grounding model or only the phrase extraction model by running the decoder. 31 | 32 | ``` 33 | sh seq2act/bin/decode_seq2act.sh 34 | ``` 35 | 36 | If you use any of the materials, please cite the following paper. 37 | 38 | ``` 39 | @inproceedings{seq2act, 40 | title = {Mapping Natural Language Instructions to Mobile UI Action Sequences}, 41 | author = {Yang Li and Jiacong He and Xin Zhou and Yuan Zhang and Jason Baldridge}, 42 | booktitle = {Annual Conference of the Association for Computational Linguistics (ACL 2020)}, 43 | year = {2020}, 44 | url = {https://arxiv.org/pdf/tbd.pdf}, 45 | } 46 | ``` 47 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /bin/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /bin/decode_seq2act.sh: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The Google Research Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | #!/bin/bash 16 | source gbash.sh || exit 17 | 18 | DEFINE_string output_dir --required "" "Specify the output directory" 19 | DEFINE_string data_files "./seq2act/data/pixel_help/*.tfrecord" \ 20 | "Specify the test data files" 21 | DEFINE_string checkpoint_path "seq2act/ckpt_hparams/grounding" \ 22 | "Specify the checkpoint file" 23 | DEFINE_string problem "pixel_help" "Specify the dataset to decode" 24 | 25 | gbash::init_google "$@" 26 | 27 | set -e 28 | set -x 29 | 30 | virtualenv -p python3 . 31 | source ./bin/activate 32 | 33 | pip install tensorflow 34 | pip install -r seq2act/requirements.txt 35 | 36 | python -m seq2act.bin.seq2act_decode --problem ${FLAGS_problem} \ 37 | --data_files "${FLAGS_data_files}" \ 38 | --checkpoint_path "${FLAGS_checkpoint_path}" \ 39 | --output_dir "${FLAGS_output_dir}" 40 | -------------------------------------------------------------------------------- /bin/eval_seq2act.sh: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The Google Research Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | #!/bin/bash 16 | source gbash.sh || exit 17 | 18 | DEFINE_string experiment_dir --required "" "Specify the experimental directory" 19 | DEFINE_string eval_files "./seq2act/data/rico_sca/*0.tfrecord" \ 20 | "Specify the path to the eval dataset" 21 | DEFINE_string eval_data_source "rico_sca" "Specify eval data source" 22 | DEFINE_string eval_name "rico_sca" "Specify eval job name" 23 | DEFINE_string metric_types "final_accuracy,ref_accuracy,basic_accuracy" \ 24 | "Specify the eval metric types" 25 | DEFINE_int eval_steps 200 "Specify the eval steps" 26 | DEFINE_int eval_batch_size 16 "Specify the eval batch size" 27 | DEFINE_int decode_length 20 "Specify the decode length" 28 | 29 | gbash::init_google "$@" 30 | 31 | set -e 32 | set -x 33 | 34 | virtualenv -p python3 . 35 | source ./bin/activate 36 | 37 | pip install tensorflow 38 | pip install -r seq2act/requirements.txt 39 | 40 | python -m seq2act.bin.seq2act_train_eval --exp_mode "eval" \ 41 | --experiment_dir "${FLAGS_experiment_dir}" \ 42 | --eval_files "${FLAGS_eval_files}" \ 43 | --metric_types "${FLAGS_metric_types}" \ 44 | --decode_length "${FLAGS_decode_length}" \ 45 | --eval_name "${FLAGS_eval_name}" \ 46 | --eval_steps "${FLAGS_eval_steps}" \ 47 | --eval_data_source "${FLAGS_eval_data_source}" \ 48 | --eval_batch_size ${FLAGS_eval_batch_size} 49 | -------------------------------------------------------------------------------- /bin/seq2act_decode.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """seq2act decoder.""" 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import os 22 | import numpy as np 23 | import tensorflow.compat.v1 as tf 24 | from seq2act.models import input as input_utils 25 | from seq2act.models import seq2act_estimator 26 | from seq2act.models import seq2act_model 27 | from seq2act.utils import decode_utils 28 | 29 | flags = tf.flags 30 | FLAGS = flags.FLAGS 31 | 32 | flags.DEFINE_integer("beam_size", 1, "beam size") 33 | flags.DEFINE_string("problem", "android_howto", "problem") 34 | flags.DEFINE_string("data_files", "", "data_files") 35 | flags.DEFINE_string("checkpoint_path", "", "checkpoint_path") 36 | flags.DEFINE_string("output_dir", "", "output_dir") 37 | flags.DEFINE_integer("decode_batch_size", 1, "decode_batch_size") 38 | 39 | 40 | def get_input(hparams, data_files): 41 | """Get the input.""" 42 | if FLAGS.problem == "pixel_help": 43 | data_source = input_utils.DataSource.PIXEL_HELP 44 | elif FLAGS.problem == "android_howto": 45 | data_source = input_utils.DataSource.ANDROID_HOWTO 46 | elif FLAGS.problem == "rico_sca": 47 | data_source = input_utils.DataSource.RICO_SCA 48 | else: 49 | raise ValueError("Unrecognized test: %s" % FLAGS.problem) 50 | tf.logging.info("Testing data_source=%s data_files=%s" % ( 51 | FLAGS.problem, data_files)) 52 | dataset = input_utils.input_fn( 53 | data_files, 54 | FLAGS.decode_batch_size, 55 | repeat=1, 56 | data_source=data_source, 57 | max_range=hparams.max_span, 58 | max_dom_pos=hparams.max_dom_pos, 59 | max_pixel_pos=( 60 | hparams.max_pixel_pos), 61 | load_extra=True, 62 | load_dom_dist=(hparams.screen_encoder == "gcn")) 63 | iterator = tf.data.make_one_shot_iterator(dataset) 64 | features = iterator.get_next() 65 | return features 66 | 67 | 68 | def generate_action_mask(features): 69 | """Computes the decode mask from "task" and "verb_refs".""" 70 | eos_positions = tf.to_int32(tf.expand_dims( 71 | tf.where(tf.equal(features["task"], 1))[:, 1], 1)) 72 | decode_mask = tf.cumsum(tf.to_int32( 73 | tf.logical_and( 74 | tf.equal(features["verb_refs"][:, :, 0], eos_positions), 75 | tf.equal(features["verb_refs"][:, :, 1], eos_positions + 1))), 76 | axis=-1) 77 | decode_mask = tf.sequence_mask( 78 | tf.reduce_sum(tf.to_int32(tf.less(decode_mask, 1)), -1), 79 | maxlen=tf.shape(decode_mask)[1]) 80 | return decode_mask 81 | 82 | 83 | def _decode_common(hparams): 84 | """Common graph for decoding.""" 85 | features = get_input(hparams, FLAGS.data_files) 86 | decode_features = {} 87 | for key in features: 88 | if key.endswith("_refs"): 89 | continue 90 | decode_features[key] = features[key] 91 | _, _, _, references = seq2act_model.compute_logits( 92 | features, hparams, mode=tf.estimator.ModeKeys.EVAL) 93 | decode_utils.decode_n_step(seq2act_model.compute_logits, 94 | decode_features, references["areas"], 95 | hparams, n=20, 96 | beam_size=FLAGS.beam_size) 97 | decode_mask = generate_action_mask(decode_features) 98 | return decode_features, decode_mask, features 99 | 100 | 101 | def to_string(name, seq): 102 | steps = [] 103 | for step in seq: 104 | steps.append(",".join(map(str, step))) 105 | return name + " - ".join(steps) 106 | 107 | 108 | def ref_acc_to_string_list(task_seqs, ref_seqs, masks): 109 | """Convert a seqs of refs to strings.""" 110 | cra = 0. 111 | pra = 0. 112 | string_list = [] 113 | for task, seq, mask in zip(task_seqs, ref_seqs, masks): 114 | # Assuming batch_size = 1 115 | string_list.append(task) 116 | string_list.append(to_string("gt_seq", seq["gt_seq"][0])) 117 | string_list.append(to_string("pred_seq", seq["pred_seq"][0][mask[0]])) 118 | string_list.append( 119 | "complete_seq_acc: " + str( 120 | seq["complete_seq_acc"]) + " partial_seq_acc: " + str( 121 | seq["partial_seq_acc"])) 122 | cra += seq["complete_seq_acc"] 123 | pra += seq["partial_seq_acc"] 124 | mcra = cra / len(ref_seqs) 125 | mpra = pra / len(ref_seqs) 126 | string_list.append("mean_complete_seq_acc: " + str(mcra) +( 127 | "mean_partial_seq_acc: " + str(mpra))) 128 | return string_list 129 | 130 | 131 | def save(task_seqs, seqs, masks, tag): 132 | string_list = ref_acc_to_string_list(task_seqs, seqs, masks) 133 | if not tf.gfile.Exists(FLAGS.output_dir): 134 | tf.gfile.MakeDirs(FLAGS.output_dir) 135 | with tf.gfile.GFile(os.path.join(FLAGS.output_dir, "decodes." + tag), 136 | mode="w") as f: 137 | for item in string_list: 138 | print(item) 139 | f.write(str(item)) 140 | f.write("\n") 141 | 142 | 143 | def decode_fn(hparams): 144 | """The main function.""" 145 | decode_dict, decode_mask, label_dict = _decode_common(hparams) 146 | if FLAGS.problem != "android_howto": 147 | decode_dict["input_refs"] = decode_utils.unify_input_ref( 148 | decode_dict["verbs"], decode_dict["input_refs"]) 149 | print_ops = [] 150 | for key in ["raw_task", "verbs", "objects", 151 | "verb_refs", "obj_refs", "input_refs"]: 152 | print_ops.append(tf.print(key, tf.shape(decode_dict[key]), decode_dict[key], 153 | label_dict[key], "decode_mask", decode_mask, 154 | summarize=100)) 155 | acc_metrics = decode_utils.compute_seq_metrics( 156 | label_dict, decode_dict, mask=None) 157 | saver = tf.train.Saver() 158 | with tf.Session() as session: 159 | session.run(tf.global_variables_initializer()) 160 | latest_checkpoint = tf.train.latest_checkpoint(FLAGS.checkpoint_path) 161 | tf.logging.info("Restoring from the latest checkpoint: %s" % 162 | (latest_checkpoint)) 163 | saver.restore(session, latest_checkpoint) 164 | task_seqs = [] 165 | ref_seqs = [] 166 | act_seqs = [] 167 | mask_seqs = [] 168 | try: 169 | i = 0 170 | while True: 171 | tf.logging.info("Example %d" % i) 172 | task, acc, mask, label, decode = session.run([ 173 | decode_dict["raw_task"], acc_metrics, decode_mask, 174 | label_dict, decode_dict 175 | ]) 176 | ref_seq = {} 177 | ref_seq["gt_seq"] = np.concatenate([ 178 | label["verb_refs"], label["obj_refs"], label["input_refs"]], 179 | axis=-1) 180 | ref_seq["pred_seq"] = np.concatenate([ 181 | decode["verb_refs"], decode["obj_refs"], decode["input_refs"]], 182 | axis=-1) 183 | ref_seq["complete_seq_acc"] = acc["complete_refs_acc"] 184 | ref_seq["partial_seq_acc"] = acc["partial_refs_acc"] 185 | act_seq = {} 186 | act_seq["gt_seq"] = np.concatenate([ 187 | np.expand_dims(label["verbs"], 2), 188 | np.expand_dims(label["objects"], 2), 189 | label["input_refs"]], axis=-1) 190 | act_seq["pred_seq"] = np.concatenate([ 191 | np.expand_dims(decode["verbs"], 2), 192 | np.expand_dims(decode["objects"], 2), 193 | decode["input_refs"]], axis=-1) 194 | act_seq["complete_seq_acc"] = acc["complete_acts_acc"] 195 | act_seq["partial_seq_acc"] = acc["partial_acts_acc"] 196 | print("task", task) 197 | print("ref_seq", ref_seq) 198 | print("act_seq", act_seq) 199 | print("mask", mask) 200 | task_seqs.append(task) 201 | ref_seqs.append(ref_seq) 202 | act_seqs.append(act_seq) 203 | mask_seqs.append(mask) 204 | i += 1 205 | except tf.errors.OutOfRangeError: 206 | pass 207 | save(task_seqs, ref_seqs, mask_seqs, "joint_refs") 208 | save(task_seqs, act_seqs, mask_seqs, "joint_act") 209 | 210 | 211 | def main(_): 212 | hparams = seq2act_estimator.load_hparams(FLAGS.checkpoint_path) 213 | hparams.set_hparam("batch_size", FLAGS.decode_batch_size) 214 | decode_fn(hparams) 215 | 216 | if __name__ == "__main__": 217 | tf.app.run() 218 | -------------------------------------------------------------------------------- /bin/seq2act_train_eval.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Train and eval for the seq2act estimator.""" 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | from tensor2tensor.utils import trainer_lib 22 | import tensorflow.compat.v1 as tf 23 | 24 | from seq2act.models import input as input_utils 25 | from seq2act.models import seq2act_estimator 26 | 27 | flags = tf.flags 28 | FLAGS = flags.FLAGS 29 | 30 | flags.DEFINE_string("exp_mode", "train", "the running mode") 31 | flags.DEFINE_string("eval_name", "", "the eval name") 32 | flags.DEFINE_string("train_file_list", None, "the list of train files") 33 | flags.DEFINE_string("train_source_list", None, "the list of train sources") 34 | flags.DEFINE_string("train_batch_sizes", None, "the list of batch sizes") 35 | flags.DEFINE_string("eval_data_source", "android_howto", "the data source") 36 | flags.DEFINE_string("reference_checkpoint", "", 37 | "the reference_checkpoint") 38 | flags.DEFINE_string("hparam_file", "", "the hyper parameter file") 39 | flags.DEFINE_string("experiment_dir", "/tmp", 40 | "the directory for output checkpoints") 41 | flags.DEFINE_integer("eval_steps", 150, "eval_steps") 42 | flags.DEFINE_integer("decode_length", 20, "decode_length") 43 | flags.DEFINE_integer("eval_batch_size", 2, "eval_batch_size") 44 | flags.DEFINE_integer("shuffle_size", 2, "shuffle_size") 45 | flags.DEFINE_boolean("boost_input", False, "boost_input") 46 | 47 | 48 | def continuous_eval(experiment_dir): 49 | """Evaluate until checkpoints stop being produced.""" 50 | for ckpt_path in trainer_lib.next_checkpoint(experiment_dir, 51 | timeout_mins=-1): 52 | hparams = seq2act_estimator.load_hparams(experiment_dir) 53 | hparams.set_hparam("batch_size", FLAGS.eval_batch_size) 54 | eval_input_fn = seq2act_estimator.create_input_fn( 55 | FLAGS.eval_files, hparams.batch_size, 56 | -1, 2, 57 | input_utils.DataSource.from_str(FLAGS.eval_data_source), 58 | max_range=hparams.max_span, 59 | max_dom_pos=hparams.max_dom_pos, 60 | max_pixel_pos=hparams.max_pixel_pos, 61 | mean_synthetic_length=hparams.mean_synthetic_length, 62 | stddev_synthetic_length=hparams.stddev_synthetic_length, 63 | load_extra=True, 64 | load_screen=hparams.load_screen, 65 | load_dom_dist=(hparams.screen_encoder == "gcn")) 66 | estimator = create_estimator(experiment_dir, hparams, 67 | decode_length=FLAGS.decode_length) 68 | estimator.evaluate(input_fn=eval_input_fn, 69 | steps=FLAGS.eval_steps, 70 | checkpoint_path=ckpt_path, 71 | name=FLAGS.eval_name) 72 | 73 | 74 | def create_estimator(experiment_dir, hparams, decode_length=20): 75 | """Creates an estimator with given hyper parameters.""" 76 | if FLAGS.worker_gpu > 1: 77 | strategy = tf.distribute.MirroredStrategy() 78 | else: 79 | strategy = None 80 | config = tf.estimator.RunConfig( 81 | save_checkpoints_steps=1000, save_summary_steps=300, 82 | train_distribute=strategy) 83 | model_fn = seq2act_estimator.create_model_fn( 84 | hparams, 85 | seq2act_estimator.compute_additional_loss\ 86 | if hparams.use_additional_loss else None, 87 | seq2act_estimator.compute_additional_metric\ 88 | if hparams.use_additional_loss else None, 89 | compute_seq_accuracy=True, 90 | decode_length=decode_length) 91 | if FLAGS.reference_checkpoint: 92 | latest_checkpoint = tf.train.latest_checkpoint( 93 | FLAGS.reference_checkpoint) 94 | ws = tf.estimator.WarmStartSettings( 95 | ckpt_to_initialize_from=latest_checkpoint, 96 | vars_to_warm_start=["embed_tokens/task_embed_w", "encode_decode/.*", 97 | "output_layer/.*"]) 98 | else: 99 | ws = None 100 | estimator = tf.estimator.Estimator( 101 | model_fn=model_fn, model_dir=experiment_dir, config=config, 102 | warm_start_from=ws) 103 | return estimator 104 | 105 | 106 | def train(experiment_dir): 107 | """Trains the model.""" 108 | if FLAGS.hparam_file: 109 | hparams = seq2act_estimator.load_hparams(FLAGS.hparam_file) 110 | else: 111 | hparams = seq2act_estimator.create_hparams() 112 | 113 | estimator = create_estimator(experiment_dir, hparams) 114 | seq2act_estimator.save_hyperparams(hparams, experiment_dir) 115 | train_file_list = FLAGS.train_file_list.split(",") 116 | train_source_list = FLAGS.train_source_list.split(",") 117 | train_batch_sizes = FLAGS.train_batch_sizes.split(",") 118 | print("* xm_train", train_file_list, train_source_list, train_batch_sizes) 119 | if len(train_file_list) > 1: 120 | train_input_fn = seq2act_estimator.create_hybrid_input_fn( 121 | train_file_list, 122 | [input_utils.DataSource.from_str(s) for s in train_source_list], 123 | map(int, train_batch_sizes), 124 | max_range=hparams.max_span, 125 | max_dom_pos=hparams.max_dom_pos, 126 | max_pixel_pos=hparams.max_pixel_pos, 127 | mean_synthetic_length=hparams.mean_synthetic_length, 128 | stddev_synthetic_length=hparams.stddev_synthetic_length, 129 | batch_size=hparams.batch_size, 130 | boost_input=FLAGS.boost_input, 131 | load_screen=hparams.load_screen, 132 | buffer_size=FLAGS.shuffle_size, 133 | shuffle_size=FLAGS.shuffle_size, 134 | load_dom_dist=(hparams.screen_encoder == "gcn")) 135 | else: 136 | train_input_fn = seq2act_estimator.create_input_fn( 137 | train_file_list[0], 138 | hparams.batch_size, 139 | -1, -1, input_utils.DataSource.from_str(train_source_list[0]), 140 | max_range=hparams.max_span, 141 | max_dom_pos=hparams.max_dom_pos, 142 | max_pixel_pos=hparams.max_pixel_pos, 143 | mean_synthetic_length=hparams.mean_synthetic_length, 144 | stddev_synthetic_length=hparams.stddev_synthetic_length, 145 | load_extra=False, 146 | load_screen=hparams.load_screen, 147 | buffer_size=FLAGS.shuffle_size, 148 | shuffle_size=FLAGS.shuffle_size, 149 | load_dom_dist=(hparams.screen_encoder == "gcn")) 150 | estimator.train(input_fn=train_input_fn, steps=FLAGS.train_steps) 151 | 152 | 153 | def main(_): 154 | """The main function.""" 155 | if FLAGS.exp_mode == "train": 156 | train(FLAGS.experiment_dir) 157 | elif FLAGS.exp_mode == "eval": 158 | continuous_eval(FLAGS.experiment_dir) 159 | 160 | if __name__ == "__main__": 161 | tf.app.run() 162 | -------------------------------------------------------------------------------- /bin/setup_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Test seq2act train.""" 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow.compat.v1 as tf 22 | from seq2act.bin import seq2act_train_eval 23 | 24 | flags = tf.flags 25 | FLAGS = flags.FLAGS 26 | 27 | 28 | def main(_): 29 | """The main function.""" 30 | seq2act_train_eval.train(FLAGS.experiment_dir) 31 | 32 | if __name__ == "__main__": 33 | tf.app.run() 34 | -------------------------------------------------------------------------------- /bin/train_seq2act.sh: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The Google Research Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | #!/bin/bash 16 | source gbash.sh || exit 17 | 18 | DEFINE_string experiment_dir --required "" "Specify the experimental directory" 19 | DEFINE_string train "ground" "Specify the training type: [parse, ground]" 20 | DEFINE_string hparam_file "./seq2act/ckpt_hparams/grounding/" \ 21 | "Specify the hyper-parameter file" 22 | DEFINE_string parser_checkpoint "seq2act/ckpt_hparams/tuple_extract" \ 23 | "Specify the checkpoint of tuple extraction" 24 | DEFINE_string rico_sca_train "./seq2act/data/rico_sca/*[!0].tfrecord" \ 25 | "Specify the path to rico_sca dataset" 26 | DEFINE_string android_howto_train "./seq2act/data/android_howto/*[!0].tfrecord" \ 27 | "Specify the path to android_howto dataset" 28 | 29 | gbash::init_google "$@" 30 | 31 | set -e 32 | set -x 33 | 34 | virtualenv -p python3 . 35 | source ./bin/activate 36 | 37 | pip install tensorflow 38 | pip install -r seq2act/requirements.txt 39 | 40 | if (( "${FLAGS_train}" == "parse" )); then 41 | train_file_list="${FLAGS_android_howto_train},${FLAGS_rico_sca_train}" 42 | train_batch_sizes="128,128" 43 | train_source_list="android_howto,rico_sca" 44 | train_steps=1000000 45 | python -m seq2act.bin.seq2act_train_eval --exp_mode "train" \ 46 | --experiment_dir "${FLAGS_experiment_dir}" \ 47 | --hparam_file "${FLAGS_hparam_file}" \ 48 | --train_steps "${train_steps}" \ 49 | --train_file_list "${train_file_list}" \ 50 | --train_batch_sizes "${train_batch_sizes}" \ 51 | --train_source_list "${train_source_list}" \ 52 | else 53 | train_file_list="${FLAGS_rico_sca_train}" 54 | train_batch_sizes="64" 55 | train_source_list="rico_sca" 56 | train_steps=250000 57 | python -m seq2act.bin.seq2act_train_eval --exp_mode "train" \ 58 | --experiment_dir "${FLAGS_experiment_dir}" \ 59 | --hparam_file "${FLAGS_hparam_file}" \ 60 | --train_steps "${train_steps}" \ 61 | --train_file_list "${train_file_list}" \ 62 | --train_batch_sizes "${train_batch_sizes}" \ 63 | --train_source_list "${train_source_list}" \ 64 | --reference_checkpoint "${FLAGS_parser_checkpoint}" 65 | fi 66 | 67 | 68 | 69 | -------------------------------------------------------------------------------- /ckpt_hparams/grounding/hparams.json: -------------------------------------------------------------------------------- 1 | {"learning_rate_constant": 0.1, "multiproblem_per_task_threshold": "", "conv_first_kernel": 3, "eval_drop_long_sequences": false, "action_vocab_size": 5, "use_fixed_batch_size": false, "use_additional_loss": false, "split_targets_strided_training": false, "no_data_parallelism": false, "freeze_reference_model": true, "weights_fn": {}, "force_full_predict": false, "learning_rate_decay_steps": 5000, "optimizer_adam_epsilon": 1e-09, "shared_embedding_and_softmax_weights": true, "unidirectional_encoder": false, "multiproblem_max_input_length": -1, "moe_hidden_sizes": "2048", "multiproblem_target_eval_only": false, "proximity_bias": false, "layer_prepostprocess_dropout_broadcast_dims": "", "area_key_mode": "none", "span_aggregation": "sum", "moe_loss_coef": 0.001, "batch_shuffle_size": 512, "max_length": 256, "norm_epsilon": 1e-06, "memory_height": 1, "screen_encoder_layers": 6, "factored_logits": false, "split_targets_max_chunks": 100, "daisy_chain_variables": true, "optimizer_adafactor_multiply_by_parameter_scale": true, "scheduled_sampling_method": "parallel", "learning_rate_decay_staircase": false, "multiply_embedding_mode": "sqrt_depth", "use_custom_ops": true, "summarize_vars": false, "attention_dropout_broadcast_dims": "", "split_targets_chunk_length": 0, "use_target_space_embedding": false, "batch_size": 64, "weight_dtype": "float32", "filter_size": 512, "label_smoothing": 0.1, "multiproblem_fixed_train_length": -1, "clip_grad_norm": 0.0, "multiproblem_reweight_label_loss": false, "optimizer_multistep_accumulate_steps": 0, "moe_num_experts": 16, "sampling_keep_top_k": -1, "synthetic_screen_noise": 0.0, "mixed_precision_optimizer_loss_scaler": "exponential", "min_length": 0, "multiproblem_vocab_size": -1, "warm_start_from_second": "", "learning_rate_decay_scheme": "noam", "optimizer_adam_beta1": 0.9, "optimizer_adam_beta2": 0.997, "scheduled_sampling_warmup_steps": 50000, "grad_noise_scale": 0.0, "vocab_divisor": 1, "screen_encoder": "transformer", "pretrained_model_dir": "", "layer_preprocess_sequence": "n", "multiproblem_mixing_schedule": "constant", "optimizer_adafactor_beta1": 0.0, "mean_synthetic_length": 1.0, "num_joint_layers": 2, "parameter_attention_key_channels": 0, "causal_decoder_self_attention": true, "bottom": {}, "max_pixel_pos": 100, "heads_share_relative_embedding": false, "learning_rate_warmup_steps": 8000, "ffn_layer": "dense_relu_dense", "learning_rate": 0.2, "prepend_mode": "none", "multiproblem_max_target_length": -1, "eval_run_autoregressive": false, "learning_rate_decay_rate": 1.0, "multiproblem_label_weight": 0.5, "kernel_width": 1, "num_area_layers": 0, "attention_dropout": 0.4, "symbol_dropout": 0.0, "max_dom_pos": 500, "top": {}, "compress_steps": 0, "stddev_synthetic_length": 2.0, "learning_rate_minimum": null, "gpu_automatic_mixed_precision": false, "summarize_grads": false, "scheduled_sampling_prob": 0.0, "use_pad_remover": true, "reference_warmup_steps": 0, "scheduled_sampling_gold_mixin_prob": 0.5, "optimizer_adafactor_beta2": 0.999, "dis_loss_ratio": 0.01, "max_area_height": 1, "moe_k": 2, "task_vocab_size": 59429, "mixed_precision_optimizer_init_loss_scale": 32768, "weight_noise": 0.0, "initializer_gain": 1.0, "learning_rate_schedule": "constant*linear_warmup*rsqrt_decay", "layer_postprocess_sequence": "da", "min_length_bucket": 8, "attention_variables_3d": false, "split_to_length": 0, "obj_text_aggregation": "sum", "sampling_method": "argmax", "optimizer_momentum_momentum": 0.9, "load_screen": true, "pad_batch": false, "optimizer_adafactor_factored": true, "screen_embedding_feature": "text_pos_type_dom_click", "optimizer_zero_grads": false, "video_num_target_frames": 1, "activation_dtype": "float32", "max_target_seq_length": 0, "max_span": 10, "pack_dataset": false, "optimizer_adafactor_decay_type": "pow", "max_area_width": 1, "norm_type": "layer", "hard_attention_k": 0, "add_relative_to_values": false, "parameter_attention_value_channels": 0, "alignment": "dot_product_attention", "optimizer_momentum_nesterov": false, "multiproblem_schedule_max_examples": 10000000.0, "area_value_mode": "none", "num_hidden_layers": 6, "mlperf_mode": false, "relu_dropout": 0.3, "tpu_enable_host_call": false, "scheduled_sampling_warmup_schedule": "exp", "scheduled_sampling_num_passes": 1, "weight_decay": 0.0, "moe_overhead_train": 1.0, "initializer": "uniform_unit_scaling", "name": {}, "gan_update": "center", "compute_verb_obj_separately": true, "num_encoder_layers": 0, "optimizer": "adam", "num_decoder_layers": 0, "kernel_height": 3, "sampling_temp": 1.0, "self_attention_type": "dot_product", "loss": {}, "video_num_input_frames": 1, "dropout": 0.2, "max_relative_position": 0, "gumbel_noise_weight": 0.0, "optimizer_adafactor_memory_exponent": 0.8, "shared_embedding": false, "learning_rate_cosine_cycle_steps": 250000, "gen_loss_ratio": 0.01, "max_input_seq_length": 0, "layer_prepostprocess_dropout": 0.2, "moe_overhead_eval": 2.0, "overload_eval_metric_name": "", "attention_key_channels": 0, "optimizer_adafactor_clipping_threshold": 1.0, "hidden_size": 128, "relu_dropout_broadcast_dims": "", "length_bucket_step": 1.1, "multiproblem_schedule_threshold": 0.5, "symbol_modality_num_shards": 16, "attention_value_channels": 0, "num_heads": 8, "nbr_decoder_problems": 1, "pos": "timing"} 2 | -------------------------------------------------------------------------------- /ckpt_hparams/tuple_extract/hparams.json: -------------------------------------------------------------------------------- 1 | {"moe_overhead_eval": 2.0, "dis_loss_ratio": 0.1, "max_dom_pos": 500, "optimizer_momentum_momentum": 0.9, "activation_dtype": "float32", "num_joint_layers": 2, "moe_k": 2, "moe_num_experts": 16, "attention_dropout_broadcast_dims": "", "max_span": 20, "learning_rate": 0.2, "max_area_width": 1, "learning_rate_decay_scheme": "noam", "freeze_reference_model": false, "max_target_seq_length": 0, "weight_decay": 0.0, "attention_key_channels": 0, "learning_rate_warmup_steps": 8000, "max_input_seq_length": 0, "multiproblem_mixing_schedule": "constant", "multiproblem_target_eval_only": false, "tpu_enable_host_call": false, "kernel_height": 3, "max_area_height": 1, "optimizer": "adam", "moe_hidden_sizes": "2048", "split_to_length": 0, "ffn_layer": "dense_relu_dense", "span_aggregation": "sum", "layer_postprocess_sequence": "da", "multiproblem_fixed_train_length": -1, "layer_prepostprocess_dropout": 0.1, "sampling_method": "argmax", "optimizer_adam_epsilon": 1e-09, "scheduled_sampling_num_passes": 1, "shared_embedding_and_softmax_weights": true, "dropout": 0.2, "learning_rate_minimum": null, "use_fixed_batch_size": false, "optimizer_adafactor_factored": true, "multiply_embedding_mode": "sqrt_depth", "reference_warmup_steps": 0, "gen_loss_ratio": 0.1, "relu_dropout": 0.1, "mixed_precision_optimizer_init_loss_scale": 32768, "optimizer_adam_beta2": 0.997, "optimizer_adafactor_memory_exponent": 0.8, "vocab_divisor": 1, "optimizer_zero_grads": false, "multiproblem_max_target_length": -1, "length_bucket_step": 1.1, "hidden_size": 128, "optimizer_adafactor_beta2": 0.999, "optimizer_adafactor_beta1": 0.0, "action_vocab_size": 6, "max_relative_position": 0, "initializer_gain": 1.0, "sampling_temp": 1.0, "area_key_mode": "none", "learning_rate_cosine_cycle_steps": 250000, "area_value_mode": "none", "attention_value_channels": 0, "screen_encoder_layers": 2, "task_vocab_size": 59429, "learning_rate_schedule": "constant*linear_warmup*rsqrt_decay", "attention_dropout": 0.1, "optimizer_multistep_accumulate_steps": 0, "stddev_synthetic_length": 2.0, "scheduled_sampling_gold_mixin_prob": 0.5, "min_length_bucket": 8, "scheduled_sampling_prob": 0.0, "causal_decoder_self_attention": true, "shared_embedding": false, "norm_type": "layer", "weight_noise": 0.0, "optimizer_adam_beta1": 0.9, "batch_size": 128, "sampling_keep_top_k": -1, "learning_rate_constant": 0.1, "num_hidden_layers": 6, "use_additional_loss": false, "clip_grad_norm": 0.0, "factored_logits": false, "multiproblem_schedule_threshold": 0.5, "multiproblem_label_weight": 0.5, "moe_overhead_train": 1.0, "num_area_layers": 0, "conv_first_kernel": 3, "norm_epsilon": 1e-06, "pad_batch": false, "symbol_dropout": 0.0, "add_relative_to_values": false, "mlperf_mode": false, "hard_attention_k": 0, "weight_dtype": "float32", "pretrained_model_dir": "", "min_length": 0, "use_custom_ops": true, "multiproblem_max_input_length": -1, "scheduled_sampling_warmup_steps": 50000, "moe_loss_coef": 0.001, "weights_fn": {}, "name": {}, "filter_size": 512, "eval_run_autoregressive": false, "learning_rate_decay_staircase": false, "daisy_chain_variables": true, "kernel_width": 1, "gan_update": "flip", "force_full_predict": false, "obj_text_aggregation": "sum", "num_decoder_layers": 0, "multiproblem_vocab_size": -1, "initializer": "uniform_unit_scaling", "eval_drop_long_sequences": false, "learning_rate_decay_rate": 1.0, "layer_prepostprocess_dropout_broadcast_dims": "", "split_targets_max_chunks": 100, "warm_start_from_second": "", "mean_synthetic_length": 4.0, "video_num_input_frames": 1, "split_targets_chunk_length": 0, "synthetic_screen_noise": 0.0, "gumbel_noise_weight": 0.0, "pack_dataset": false, "parameter_attention_key_channels": 0, "symbol_modality_num_shards": 16, "heads_share_relative_embedding": false, "unidirectional_encoder": false, "optimizer_momentum_nesterov": false, "batch_shuffle_size": 512, "split_targets_strided_training": false, "multiproblem_schedule_max_examples": 10000000.0, "mixed_precision_optimizer_loss_scaler": "exponential", "optimizer_adafactor_multiply_by_parameter_scale": true, "compute_verb_obj_separately": true, "pos": "timing", "video_num_target_frames": 1, "optimizer_adafactor_clipping_threshold": 1.0, "prepend_mode": "none", "screen_encoder": "mlp", "scheduled_sampling_warmup_schedule": "exp", "compress_steps": 0, "self_attention_type": "dot_product", "no_data_parallelism": false, "max_pixel_pos": 100, "grad_noise_scale": 0.0, "proximity_bias": false, "loss": {}, "summarize_vars": false, "relu_dropout_broadcast_dims": "", "overload_eval_metric_name": "", "summarize_grads": false, "screen_embedding_feature": "text_type_click", "top": {}, "use_pad_remover": true, "multiproblem_reweight_label_loss": false, "num_heads": 8, "load_screen": false, "num_encoder_layers": 0, "bottom": {}, "use_target_space_embedding": false, "nbr_decoder_problems": 1, "attention_variables_3d": false, "alignment": "dot_product_attention", "label_smoothing": 0.1, "scheduled_sampling_method": "parallel", "gpu_automatic_mixed_precision": false, "memory_height": 1, "learning_rate_decay_steps": 5000, "optimizer_adafactor_decay_type": "pow", "parameter_attention_value_channels": 0, "max_length": 256, "layer_preprocess_sequence": "n", "multiproblem_per_task_threshold": ""} 2 | -------------------------------------------------------------------------------- /data_generation/README.md: -------------------------------------------------------------------------------- 1 | ## Generate RicoSCA Datasets 2 | 3 | **Download Rico Public Dataset** 4 | ``` 5 | # Download dataset from http://interactionmining.org/rico#quick-downloads 6 | # Choose 1 UI Screenshots and View Hierarchies (6 GB) 7 | # Place the downloaded Rico .json data under folder seq2act/data/rico_sca/raw 8 | 9 | Create a folder named "output" under "seq2act/data/rico_sca" 10 | ``` 11 | **Generate Rico SCA tfrecord** 12 | ``` 13 | sh seq2act/data_generation/create_rico_sca.sh 14 | ``` 15 | -------------------------------------------------------------------------------- /data_generation/common.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Functions shared among files under word2act/data_generation.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import collections 23 | import os 24 | 25 | import attr 26 | from enum import Enum 27 | import numpy as np 28 | import tensorflow.compat.v1 as tf # tf 29 | 30 | from seq2act.data_generation import config 31 | from seq2act.data_generation import view_hierarchy 32 | 33 | 34 | gfile = tf.gfile 35 | 36 | 37 | @attr.s 38 | class MaxValues(object): 39 | """Represents max values for a task and UI.""" 40 | 41 | # For instrction 42 | max_word_num = attr.ib(default=None) 43 | max_word_length = attr.ib(default=None) 44 | 45 | # For UI objects 46 | max_ui_object_num = attr.ib(default=None) 47 | max_ui_object_word_num = attr.ib(default=None) 48 | max_ui_object_word_length = attr.ib(default=None) 49 | 50 | def update(self, other): 51 | """Update max value from another MaxValues instance. 52 | 53 | This will be used when want to merge several MaxValues instances: 54 | 55 | max_values_list = ... 56 | result = MaxValues() 57 | for v in max_values_list: 58 | result.update(v) 59 | 60 | Then `result` contains merged max values in each field. 61 | 62 | Args: 63 | other: another MaxValues instance, contains updated data. 64 | """ 65 | self.max_word_num = max(self.max_word_num, other.max_word_num) 66 | self.max_word_length = max(self.max_word_length, other.max_word_length) 67 | self.max_ui_object_num = max(self.max_ui_object_num, 68 | other.max_ui_object_num) 69 | self.max_ui_object_word_num = max(self.max_ui_object_word_num, 70 | other.max_ui_object_word_num) 71 | self.max_ui_object_word_length = max(self.max_ui_object_word_length, 72 | other.max_ui_object_word_length) 73 | 74 | 75 | class ActionRules(Enum): 76 | """The rule_id to generate synthetic action.""" 77 | SINGLE_OBJECT_RULE = 0 78 | GRID_CONTEXT_RULE = 1 79 | NEIGHBOR_CONTEXT_RULE = 2 80 | SWIPE_TO_OBJECT_RULE = 3 81 | SWIPE_TO_DIRECTION_RULE = 4 82 | REAL = 5 # The action is not generated, but a real user action. 83 | CROWD_COMPUTE = 6 84 | DIRECTION_VERB_RULE = 7 # For win, "click button under some tab/combobox 85 | CONSUMED_MULTI_STEP = 8 # For win, if the target verb is not direction_verb 86 | UNCONSUMED_MULTI_STEP = 9 87 | NO_VERB_RULE = 10 88 | 89 | 90 | class ActionTypes(Enum): 91 | """The action types and ids of Android actions.""" 92 | CLICK = 2 93 | INPUT = 3 94 | SWIPE = 4 95 | CHECK = 5 96 | UNCHECK = 6 97 | LONG_CLICK = 7 98 | OTHERS = 8 99 | GO_HOME = 9 100 | GO_BACK = 10 101 | 102 | 103 | VERB_ID_MAP = { 104 | 'check': ActionTypes.CHECK, 105 | 'find': ActionTypes.SWIPE, 106 | 'navigate': ActionTypes.SWIPE, 107 | 'uncheck': ActionTypes.UNCHECK, 108 | 'head to': ActionTypes.SWIPE, 109 | 'enable': ActionTypes.CHECK, 110 | 'turn on': ActionTypes.CHECK, 111 | 'locate': ActionTypes.SWIPE, 112 | 'disable': ActionTypes.UNCHECK, 113 | 'tap and hold': ActionTypes.LONG_CLICK, 114 | 'long press': ActionTypes.LONG_CLICK, 115 | 'look': ActionTypes.SWIPE, 116 | 'press and hold': ActionTypes.LONG_CLICK, 117 | 'turn it on': ActionTypes.CHECK, 118 | 'turn off': ActionTypes.UNCHECK, 119 | 'switch on': ActionTypes.CHECK, 120 | 'visit': ActionTypes.SWIPE, 121 | 'hold': ActionTypes.LONG_CLICK, 122 | 'switch off': ActionTypes.UNCHECK, 123 | 'head': ActionTypes.SWIPE, 124 | 'head over': ActionTypes.SWIPE, 125 | 'long-press': ActionTypes.LONG_CLICK, 126 | 'un-click': ActionTypes.UNCHECK, 127 | 'tap': ActionTypes.CLICK, 128 | 'check off': ActionTypes.UNCHECK, 129 | # 'power on': 21 130 | } 131 | 132 | 133 | class WinActionTypes(Enum): 134 | """The action types and ids of windows actions.""" 135 | LEFT_CLICK = 2 136 | RIGHT_CLICK = 3 137 | DOUBLE_CLICK = 4 138 | INPUT = 5 139 | 140 | 141 | @attr.s 142 | class Action(object): 143 | """The class for a word2act action.""" 144 | instruction_str = attr.ib(default=None) 145 | verb_str = attr.ib(default=None) 146 | obj_desc_str = attr.ib(default=None) 147 | input_content_str = attr.ib(default=None) 148 | action_type = attr.ib(default=None) 149 | action_rule = attr.ib(default=None) 150 | target_obj_idx = attr.ib(default=None) 151 | obj_str_pos = attr.ib(default=None) 152 | input_str_pos = attr.ib(default=None) 153 | verb_str_pos = attr.ib(default=None) 154 | # start/end position of one whole step 155 | step_str_pos = attr.ib(default=[0, 0]) 156 | # Defalt action is 1-step consumed action 157 | is_consumed = attr.ib(default=True) 158 | 159 | def __eq__(self, other): 160 | if not isinstance(other, Action): 161 | return NotImplemented 162 | return self.instruction_str == other.instruction_str 163 | 164 | def is_valid(self): 165 | """Does valid check for action instance. 166 | 167 | Returns true when any component is None or obj_desc_str is all spaces. 168 | 169 | Returns: 170 | a boolean 171 | """ 172 | invalid_obj_pos = (np.array(self.obj_str_pos) == 0).all() 173 | if (not self.instruction_str or invalid_obj_pos or 174 | not self.obj_desc_str.strip()): 175 | return False 176 | 177 | return True 178 | 179 | def has_valid_input(self): 180 | """Does valid check for input positions. 181 | 182 | Returns true when input_str_pos is not all default value. 183 | 184 | Returns: 185 | a boolean 186 | """ 187 | return (self.input_str_pos != np.array([ 188 | config.LABEL_DEFAULT_VALUE_INT, config.LABEL_DEFAULT_VALUE_INT 189 | ])).any() 190 | 191 | def regularize_strs(self): 192 | """Trims action instance's obj_desc_str, input_content_str, verb_str.""" 193 | self.obj_desc_str = self.obj_desc_str.strip() 194 | self.input_content_str = self.input_content_str.strip() 195 | self.verb_str = self.verb_str.strip() 196 | 197 | def convert_to_lower_case(self): 198 | self.instruction_str = self.instruction_str.lower() 199 | self.obj_desc_str = self.obj_desc_str.lower() 200 | self.input_content_str = self.input_content_str.lower() 201 | self.verb_str = self.verb_str.lower() 202 | 203 | 204 | @attr.s 205 | class ActionEvent(object): 206 | """This class defines ActionEvent class. 207 | 208 | ActionEvent is high level event summarized from low level android event logs. 209 | This example shows the android event logs and the extracted ActionEvent 210 | object: 211 | 212 | Android Event Logs: 213 | [ 42.407808] EV_ABS ABS_MT_TRACKING_ID 00000000 214 | [ 42.407808] EV_ABS ABS_MT_TOUCH_MAJOR 00000004 215 | [ 42.407808] EV_ABS ABS_MT_PRESSURE 00000081 216 | [ 42.407808] EV_ABS ABS_MT_POSITION_X 00004289 217 | [ 42.407808] EV_ABS ABS_MT_POSITION_Y 00007758 218 | [ 42.407808] EV_SYN SYN_REPORT 00000000 219 | [ 42.453256] EV_ABS ABS_MT_PRESSURE 00000000 220 | [ 42.453256] EV_ABS ABS_MT_TRACKING_ID ffffffff 221 | [ 42.453256] EV_SYN SYN_REPORT 00000000 222 | 223 | This log can be generated from this command during runing android emulator: 224 | adb shell getevent -lt /dev/input/event1 225 | 226 | If screen pixel size is [480,800], this is the extracted ActionEvent Object: 227 | ActionEvent( 228 | event_time = 42.407808 229 | action_type = ActionTypes.CLICK 230 | action_object_id = -1 231 | coordinates_x = [17033,] 232 | coordinates_y = [30552,] 233 | coordinates_x_pixel = [249,] 234 | coordinates_y_pixel = [747,] 235 | action_params = [] 236 | ) 237 | """ 238 | 239 | event_time = attr.ib() 240 | action_type = attr.ib() 241 | coordinates_x = attr.ib() 242 | coordinates_y = attr.ib() 243 | action_params = attr.ib() 244 | # These fields will be generated by public method update_info_from_screen() 245 | coordinates_x_pixel = None 246 | coordinates_y_pixel = None 247 | object_id = config.LABEL_DEFAULT_INVALID_INT 248 | leaf_nodes = None # If dedup, the nodes here will be less than XML 249 | debug_target_object_word_sequence = None 250 | 251 | def update_info_from_screen(self, screen_info, dedup=False): 252 | """Updates action event attributes from screen_info. 253 | 254 | Updates coordinates_x(y)_pixel and object_id from the screen_info proto. 255 | 256 | Args: 257 | screen_info: ScreenInfo protobuf 258 | dedup: whether dedup the UI objs with same text or content desc. 259 | Raises: 260 | ValueError when fail to find object id. 261 | """ 262 | self.update_norm_coordinates((config.SCREEN_WIDTH, config.SCREEN_HEIGHT)) 263 | vh = view_hierarchy.ViewHierarchy() 264 | vh.load_xml(screen_info.view_hierarchy.xml.encode('utf-8')) 265 | if dedup: 266 | vh.dedup((self.coordinates_x_pixel[0], self.coordinates_y_pixel[0])) 267 | self.leaf_nodes = vh.get_leaf_nodes() 268 | ui_object_list = vh.get_ui_objects() 269 | self._update_object_id(ui_object_list) 270 | 271 | def _update_object_id(self, ui_object_list): 272 | """Updates ui object index from view_hierarchy. 273 | 274 | If point(X,Y) surrounded by multiple UI objects, select the one with 275 | smallest area. 276 | 277 | Args: 278 | ui_object_list: . 279 | Raises: 280 | ValueError when fail to find object id. 281 | """ 282 | smallest_area = -1 283 | for index, ui_obj in enumerate(ui_object_list): 284 | box = ui_obj.bounding_box 285 | if (box.x1 <= self.coordinates_x_pixel[0] <= box.x2 and 286 | box.y1 <= self.coordinates_y_pixel[0] <= box.y2): 287 | area = (box.x2 - box.x1) * (box.y2 - box.y1) 288 | if smallest_area == -1 or area < smallest_area: 289 | self.object_id = index 290 | self.debug_target_object_word_sequence = ui_obj.word_sequence 291 | smallest_area = area 292 | 293 | if smallest_area == -1: 294 | raise ValueError(('Object id not found: x,y=%d,%d coordinates fail to ' 295 | 'match every UI bounding box') % 296 | (self.coordinates_x_pixel[0], 297 | self.coordinates_y_pixel[0])) 298 | 299 | def update_norm_coordinates(self, screen_size): 300 | """Update coordinates_x(y)_norm according to screen_size. 301 | 302 | self.coordinate_x is scaled between [0, ANDROID_LOG_MAX_ABS_X] 303 | self.coordinate_y is scaled between [0, ANDROID_LOG_MAX_ABS_Y] 304 | This function recovers coordinate of android event logs back to coordinate 305 | in real screen's pixel level. 306 | 307 | coordinates_x_pixel = coordinates_x/ANDROID_LOG_MAX_ABS_X*horizontal_pixel 308 | coordinates_y_pixel = coordinates_y/ANDROID_LOG_MAX_ABS_Y*vertical_pixel 309 | 310 | For example, 311 | ANDROID_LOG_MAX_ABS_X = ANDROID_LOG_MAX_ABS_Y = 32676 312 | coordinate_x = [17033, ] 313 | object_cords_y = [30552, ] 314 | screen_size = (480, 800) 315 | Then the updated pixel coordinates are as follow: 316 | coordinates_x_pixel = [250, ] 317 | coordinates_y_pixel = [747, ] 318 | 319 | Args: 320 | screen_size: a tuple of screen pixel size. 321 | """ 322 | (horizontal_pixel, vertical_pixel) = screen_size 323 | self.coordinates_x_pixel = [ 324 | int(cord * horizontal_pixel / config.ANDROID_LOG_MAX_ABS_X) 325 | for cord in self.coordinates_x 326 | ] 327 | self.coordinates_y_pixel = [ 328 | int(cord * vertical_pixel / config.ANDROID_LOG_MAX_ABS_Y) 329 | for cord in self.coordinates_y 330 | ] 331 | 332 | 333 | # For Debug: Get distribution info for each cases 334 | word_num_distribution_dict = collections.defaultdict(int) 335 | word_length_distribution_dict = collections.defaultdict(int) 336 | 337 | 338 | def get_word_statistics(file_path): 339 | """Calculates maximum word number/length from ui objects in one xml/json file. 340 | 341 | Args: 342 | file_path: The full path of a xml/json file. 343 | 344 | Returns: 345 | A tuple (max_word_num, max_word_length) 346 | ui_object_num: UI object num. 347 | max_word_num: The maximum number of words contained in all ui objects. 348 | max_word_length: The maximum length of words contained in all ui objects. 349 | """ 350 | max_word_num = 0 351 | max_word_length = 0 352 | 353 | leaf_nodes = get_view_hierarchy_list(file_path) 354 | for view_hierarchy_object in leaf_nodes: 355 | word_sequence = view_hierarchy_object.uiobject.word_sequence 356 | max_word_num = max(max_word_num, len(word_sequence)) 357 | word_num_distribution_dict[len(word_sequence)] += 1 358 | 359 | for word in word_sequence: 360 | max_word_length = max(max_word_length, len(word)) 361 | word_length_distribution_dict[len(word)] += 1 362 | return len(leaf_nodes), max_word_num, max_word_length 363 | 364 | 365 | def get_ui_max_values(file_paths): 366 | """Calculates max values from ui objects in multi xml/json files. 367 | 368 | Args: 369 | file_paths: The full paths of multi xml/json files. 370 | Returns: 371 | max_values: instrance of MaxValues. 372 | """ 373 | max_values = MaxValues() 374 | for file_path in file_paths: 375 | (ui_object_num, 376 | max_ui_object_word_num, 377 | max_ui_object_word_length) = get_word_statistics(file_path) 378 | 379 | max_values.max_ui_object_num = max( 380 | max_values.max_ui_object_num, ui_object_num) 381 | max_values.max_ui_object_word_num = max( 382 | max_values.max_ui_object_word_num, max_ui_object_word_num) 383 | max_values.max_ui_object_word_length = max( 384 | max_values.max_ui_object_word_length, max_ui_object_word_length) 385 | return max_values 386 | 387 | 388 | def get_ui_object_list(file_path): 389 | """Gets ui object list from view hierarchy leaf nodes. 390 | 391 | Args: 392 | file_path: file path of xml or json 393 | Returns: 394 | A list of ui objects according to view hierarchy leaf nodes. 395 | """ 396 | 397 | vh = _get_view_hierachy(file_path) 398 | return vh.get_ui_objects() 399 | 400 | 401 | def get_view_hierarchy_list(file_path): 402 | """Gets view hierarchy leaf node list. 403 | 404 | Args: 405 | file_path: file path of xml or json 406 | Returns: 407 | A list of view hierarchy leaf nodes. 408 | """ 409 | vh = _get_view_hierachy(file_path) 410 | return vh.get_leaf_nodes() 411 | 412 | 413 | def _get_view_hierachy(file_path): 414 | """Gets leaf nodes view hierarchy lists. 415 | 416 | Args: 417 | file_path: The full path of an input xml/json file. 418 | Returns: 419 | A ViewHierarchy object. 420 | Raises: 421 | ValueError: unsupported file format. 422 | """ 423 | with gfile.GFile(file_path, 'r') as f: 424 | data = f.read() 425 | 426 | _, file_extension = os.path.splitext(file_path) 427 | if file_extension == '.xml': 428 | vh = view_hierarchy.ViewHierarchy( 429 | screen_width=config.SCREEN_WIDTH, screen_height=config.SCREEN_HEIGHT) 430 | vh.load_xml(data) 431 | elif file_extension == '.json': 432 | vh = view_hierarchy.ViewHierarchy( 433 | screen_width=config.RICO_SCREEN_WIDTH, 434 | screen_height=config.RICO_SCREEN_HEIGHT) 435 | vh.load_json(data) 436 | else: 437 | raise ValueError('unsupported file format %s' % file_extension) 438 | return vh 439 | -------------------------------------------------------------------------------- /data_generation/config.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Configurations for all word2act data generation global configs.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | 23 | # Android Emulator Config 24 | ANDROID_LOG_MAX_ABS_X = 32676 25 | ANDROID_LOG_MAX_ABS_Y = 32676 26 | SCREEN_WIDTH = 540 27 | SCREEN_HEIGHT = 960 28 | SCREEN_CHANNEL = 4 29 | # Rico dataset screen config 30 | RICO_SCREEN_WIDTH = 1440 31 | RICO_SCREEN_HEIGHT = 2560 32 | 33 | # Data Generation Config 34 | LABEL_DEFAULT_VALUE_INT = 0 35 | LABEL_DEFAULT_VALUE_STRING = '' 36 | LABEL_DEFAULT_INVALID_INT = -1 37 | LABEL_DEFAULT_INVALID_STRING = '' 38 | 39 | FEATURE_ANCHOR_PADDING_INT = -1 40 | FEATURE_DEFAULT_PADDING_INT = 0 41 | FEATURE_DEFAULT_PADDING_FLOAT = -0.0 42 | FEATURE_DEFAULT_PADDING_STR = '' 43 | TOKEN_DEFAULT_PADDING_INT = 0 44 | 45 | MAX_WORD_NUM_UPPER_BOUND = 30 46 | MAX_WORD_LENGTH_UPPER_BOUND = 50 47 | SHARD_NUM = 10 48 | MAX_INPUT_WORD_NUMBER = 5 49 | 50 | # synthetic action config 51 | MAX_OBJ_NAME_WORD_NUM = 3 52 | MAX_WIN_OBJ_NAME_WORD_NUM = 10 53 | MAX_INPUT_STR_LENGTH = 10 54 | NORM_VERTICAL_NEIGHBOR_MARGIN = 0.01 55 | NORM_HORIZONTAL_NEIGHBOR_MARGIN = 0.01 56 | INPUT_ACTION_UPSAMPLE_RATIO = 1 57 | 58 | # Windows data dimension Config. The numbers are set based on real data 59 | # dimension distribution. 60 | MAX_UI_OBJ_WORD_NUM_UPPER_BOUND = 20 61 | MAX_UI_OBJ_WORD_LENGTH_UPPER_BOUND = 21 62 | 63 | # view hierarchy config 64 | MAX_PER_OBJECT_WORD_NUM = 10 65 | MAX_WORD_LENGTH = 100 66 | TRAINING_BATCH_SIZE = 2 67 | UI_OBJECT_TYPE_NUM = 15 68 | ADJACENT_BOUNDING_BOX_THRESHOLD = 3 69 | -------------------------------------------------------------------------------- /data_generation/create_rico_sca.sh: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The Google Research Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | #!/bin/bash 16 | set -e 17 | set -x 18 | 19 | virtualenv -p python3 . 20 | source ./bin/activate 21 | 22 | 23 | pip install -r seq2act/data_generation/requirements.txt 24 | python -m seq2act.data_generation.create_android_synthetic_dataset \ 25 | --input_dir=${PWD}"seq2act/data/rico_sca/raw" \ 26 | --output_dir=${PWD}"seq2act/data/rico_sca/output" \ 27 | --filter_file=${PWD}"/seq2act/data_generation/filter.txt" \ 28 | --thread_num=10 \ 29 | --shard_num=10 \ 30 | --vocab_file=${PWD}"/seq2act/data_generation/lower_case_vocab" \ 31 | --input_candiate_file=${PWD}"/seq2act/data_generation/input_candidate_words.txt" \ 32 | --logtostderr 33 | 34 | -------------------------------------------------------------------------------- /data_generation/create_token_vocab.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Creates token vocabulary using tensor2tensor tokenizer.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import collections 23 | import csv 24 | import operator 25 | import os 26 | 27 | from tensor2tensor.data_generators import tokenizer 28 | import tensorflow.compat.v1 as tf # tf 29 | 30 | _INPUT_DIR = "/tmp" 31 | _OUTPUT_DIR = "/tmp" 32 | 33 | flags = tf.flags 34 | FLAGS = flags.FLAGS 35 | gfile = tf.gfile 36 | 37 | flags.DEFINE_string( 38 | "corpus_dir", _INPUT_DIR, 39 | "Full path to the directory containing the data files for a set of tasks.") 40 | flags.DEFINE_string( 41 | "vocab_dir", _OUTPUT_DIR, 42 | "Full path to the directory for saving the tf record file.") 43 | flags.DEFINE_string("mode", "write", 44 | "Flag to indicate read vocab csv or write token csv.") 45 | 46 | 47 | word_count = collections.Counter() 48 | freq_count = collections.Counter() 49 | 50 | 51 | def create_token_id_files(corpus_dir, output_vocab_dir): 52 | """Creates token id csv files. 53 | 54 | Args: 55 | corpus_dir: input corpus directory 56 | output_vocab_dir: output token vocabulary csv file directory 57 | """ 58 | walking_iter = gfile.Walk(corpus_dir) 59 | for iter_rst in walking_iter: 60 | valid_filenames = [ 61 | filename for filename in iter_rst[2] 62 | if ".txt" in filename or "wadata" in filename 63 | ] 64 | if not valid_filenames: 65 | continue 66 | input_file_dir = iter_rst[0] 67 | for filename in valid_filenames: 68 | path = os.path.join(input_file_dir, filename) 69 | with gfile.Open(path, "r") as f: 70 | for line in f.read().lower().split("\n"): 71 | tokens = tokenizer.encode(line) 72 | for token in tokens: 73 | word_count[token] += 1 74 | 75 | sorted_vocab = sorted(word_count.items(), key=operator.itemgetter(1)) 76 | tf.logging.info("%d items in vocb", sum(word_count.values())) 77 | 78 | csv_file = gfile.Open(os.path.join(output_vocab_dir, "vocab.csv"), "w+") 79 | csv_writter = csv.writer(csv_file) 80 | 81 | rows = [["", 0, 0], ["", 0, 1], ["", 0, 2], ["", 0, 3]] 82 | for row in rows: 83 | csv_writter.writerow(row) 84 | start_index = len(rows) 85 | for word_freq in reversed(sorted_vocab): 86 | row = [word_freq[0], word_freq[1], start_index] 87 | freq_count[word_freq[1]] += 1 88 | start_index += 1 89 | csv_writter.writerow(row) 90 | tf.logging.info("vocab_size=%d", start_index) 91 | tf.logging.info("token frequency count") 92 | tf.logging.info(sorted(freq_count.items(), key=operator.itemgetter(1))) 93 | csv_file.close() 94 | 95 | 96 | def read_vocab(vocab_path): 97 | """Reads vocabulary csv file. 98 | 99 | Args: 100 | vocab_path: full path of the vocabulary csv file 101 | 102 | Returns: 103 | tokens: list of token strings 104 | freqs: list of token frequencies 105 | ids: list of token ids 106 | """ 107 | csv_file = gfile.Open(vocab_path, "r") 108 | csv_reader = csv.reader(csv_file) 109 | tokens, freqs, ids = [], [], [] 110 | 111 | for row in csv_reader: 112 | tokens.append(row[0]) 113 | freqs.append(int(row[1])) 114 | ids.append(int(row[2])) 115 | tf.logging.info("Totally %d vocabs", len(tokens)) 116 | return tokens, freqs, ids 117 | -------------------------------------------------------------------------------- /data_generation/proto_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Proto utility module containing helper functions. 17 | 18 | The module handles tasks related to protobufs in word2act: 19 | 1. encodes word2act action and time_step into tf.train.Example proto2. 20 | 2. parses screeninfo protobuf into feature dictionary. 21 | """ 22 | 23 | from __future__ import absolute_import 24 | from __future__ import division 25 | from __future__ import print_function 26 | 27 | import io 28 | 29 | import numpy as np 30 | from PIL import Image 31 | import tensorflow.compat.v1 as tf 32 | 33 | from seq2act.data_generation import string_utils 34 | from seq2act.data_generation import view_hierarchy 35 | from tensorflow.contrib import framework as contrib_framework 36 | nest = contrib_framework.nest 37 | 38 | 39 | def get_feature_dict(screen_info_proto, padding_shape=None, lower_case=False): 40 | """Gets screen feature dictionary from screen_info protobuf. 41 | 42 | Args: 43 | screen_info_proto: protobuf defined in word2act/proto/rehearsal.proto. 44 | Contains screenshot and xml 45 | padding_shape: The shape of padding size for final feature list. shape = 46 | (max_object_num, max_word_num, max_word_length) If the shape is not given, 47 | then returns the original list without padding. 48 | lower_case: lower case all the ui texts. 49 | 50 | Returns: 51 | A feature dictionary. If padding_shape is not None, all values of the 52 | dictionary are padded. The shape after padding is shown as 'shape = ...'. 53 | Otherwise, shapes of values are not a fixed value. 54 | screenshot: numpy array of screen_info_proto.screenshot 55 | 'ui_obj_str_seq': uiobject's name/content_descriotion/resource_id, numpy 56 | array of strings. 57 | 'ui_obj_word_id_seq': encoded word sequence, np int array, shape = 58 | (max_object_num, max_word_num) 59 | 'ui_obj_char_id_seq': encoded char sequence, np int array, shape = 60 | (max_object_num, max_word_num, max_word_length) 61 | 'ui_obj_type_seq': type sequence, np int array, shape = (max_object_num,) 62 | 'ui_obj_clickable_seq': clickable sequence, np int array, shape = 63 | (max_object_num,) 64 | 'ui_obj_cord_x_seq': x cordinate sequence, np int array, shape = 65 | (max_object_num*2,) 66 | 'ui_obj_cord_y_seq': y cordinate sequence, np int array, shape = 67 | (max_object_num*2,) 68 | 'ui_obj_v_distance': vertical relation matrix, np float array, 69 | shape = (max_object_num, max_object_num) 70 | 'ui_obj_h_distance': horizontal relation matrix, np float array, shape = 71 | (max_object_num, max_object_num) 72 | 'ui_obj_dom_distance': dom relation matrix, np int array, shape = 73 | (max_object_num, max_object_num) 74 | 'ui_obj_dom_location_seq': dom index from tree traversal, np int array, 75 | shape = (max_object_num*3,) 76 | 77 | 78 | """ 79 | screenshot = Image.open(io.BytesIO(screen_info_proto.screenshot.content)) 80 | screenshot = np.asarray(screenshot, np.float32) 81 | vh = view_hierarchy.ViewHierarchy() 82 | vh.load_xml(screen_info_proto.view_hierarchy.xml.encode('utf-8')) 83 | view_hierarchy_leaf_nodes = vh.get_leaf_nodes() 84 | 85 | ui_object_features_dict = get_ui_objects_feature_dict( 86 | view_hierarchy_leaf_nodes, padding_shape, lower_case) 87 | ui_object_features_dict['screenshot'] = screenshot 88 | 89 | return ui_object_features_dict 90 | 91 | 92 | def get_ui_objects_feature_dict(view_hierarchy_leaf_nodes, 93 | padding_shape=None, 94 | lower_case=False): 95 | """Gets ui object features dictionary from view hierarchy leaf nodes list. 96 | 97 | Args: 98 | view_hierarchy_leaf_nodes: A list of view hierarchy leaf node objects. 99 | padding_shape: The shape of padding size for final feature list. shape = 100 | (max_object_num, max_word_num, max_word_length) If the shape is not given, 101 | then returns the original list without padding. 102 | lower_case: lower case all the ui texts. 103 | 104 | Returns: 105 | A feature dictionary. If padding_shape is not None, all values of the 106 | dictionary are padded. The shape after padding is shown as 'shape = ...'. 107 | Otherwise, shapes of values are not a fixed value. 108 | 'ui_obj_type_seq': type sequence, np int array, shape = (max_object_num,) 109 | 'ui_obj_word_id_seq': encoded word sequence, np int array, shape = 110 | (max_object_num, max_word_num) 111 | 'ui_obj_char_id_seq': encoded char sequence, np int array, shape = 112 | (max_object_num, max_word_num, max_word_length) 113 | 'ui_obj_clickable_seq': clickable sequence, np int array, shape = 114 | (max_object_num,) 115 | 'ui_obj_cord_x_seq': x cordinate sequence, np int array, shape = 116 | (max_object_num*2,) 117 | 'ui_obj_cord_y_seq': y cordinate sequence, np int array, shape = 118 | (max_object_num*2,) 119 | 'ui_obj_v_distance': vertical relation matrix, np float array, shape = 120 | (max_object_num, max_object_num) 121 | 'ui_obj_h_distance': horizontal relation matrix, np float array, shape = 122 | (max_object_num, max_object_num) 123 | 'ui_obj_dom_distance': dom relation matrix, np int array, shape = 124 | (max_object_num, max_object_num) 125 | 'ui_obj_dom_location_seq': dom index from tree traversal, np int array, 126 | shape = (max_object_num*3,) 127 | 'ui_obj_str_seq': uiobject's name/content_descriotion/resource_id, 128 | numpy array of strings. 129 | """ 130 | ui_object_attributes = _get_ui_object_attributes(view_hierarchy_leaf_nodes, 131 | lower_case) 132 | vh_relations = get_view_hierarchy_leaf_relation(view_hierarchy_leaf_nodes) 133 | if padding_shape is None: 134 | merged_features = {} 135 | for key in ui_object_attributes: 136 | if key == 'obj_str_seq': 137 | merged_features['ui_obj_str_seq'] = ui_object_attributes[key].copy() 138 | else: 139 | merged_features['ui_obj_' + key] = ui_object_attributes[key].copy() 140 | for key in vh_relations: 141 | merged_features['ui_obj_' + key] = vh_relations[key].copy() 142 | return merged_features 143 | else: 144 | if not isinstance(padding_shape, tuple): 145 | assert False, 'padding_shape %s is not a tuple.' % (str(padding_shape)) 146 | if len(padding_shape) != 3: 147 | assert False, 'padding_shape %s contains not exactly 3 elements.' % ( 148 | str(padding_shape)) 149 | 150 | (max_object_num, max_word_num, _) = padding_shape 151 | obj_feature_dict = { 152 | 'ui_obj_type_id_seq': 153 | padding_array(ui_object_attributes['type_id_seq'], (max_object_num,), 154 | -1), 155 | 'ui_obj_str_seq': 156 | padding_array( 157 | ui_object_attributes['obj_str_seq'], (max_object_num,), 158 | padding_type=np.string_, 159 | padding_value=''), 160 | 'ui_obj_word_id_seq': 161 | padding_array( 162 | ui_object_attributes['word_id_seq'], 163 | (max_object_num, max_word_num), 164 | padding_value=0), 165 | 'ui_obj_clickable_seq': 166 | padding_array(ui_object_attributes['clickable_seq'], 167 | (max_object_num,)), 168 | 'ui_obj_cord_x_seq': 169 | padding_array(ui_object_attributes['cord_x_seq'], 170 | (max_object_num * 2,)), 171 | 'ui_obj_cord_y_seq': 172 | padding_array(ui_object_attributes['cord_y_seq'], 173 | (max_object_num * 2,)), 174 | 'ui_obj_v_distance': 175 | padding_array(vh_relations['v_distance'], 176 | (max_object_num, max_object_num), 0, np.float32), 177 | 'ui_obj_h_distance': 178 | padding_array(vh_relations['h_distance'], 179 | (max_object_num, max_object_num), 0, np.float32), 180 | 'ui_obj_dom_distance': 181 | padding_array(vh_relations['dom_distance'], 182 | (max_object_num, max_object_num)), 183 | 'ui_obj_dom_location_seq': 184 | padding_array(ui_object_attributes['dom_location_seq'], 185 | (max_object_num * 3,)), 186 | } 187 | return obj_feature_dict 188 | 189 | 190 | def _get_ui_object_attributes(view_hierarchy_leaf_nodes, lower_case=False): 191 | """Parses ui object informationn from a view hierachy leaf node list. 192 | 193 | Args: 194 | view_hierarchy_leaf_nodes: a list of view hierachy leaf nodes. 195 | lower_case: lower case all the ui texts. 196 | 197 | Returns: 198 | An un-padded attribute dictionary as follow: 199 | 'type_id_seq': numpy array of ui object types from view hierarchy. 200 | 'word_id_seq': numpy array of encoding for words in ui object. 201 | 'char_id_seq': numpy array of encoding for words in ui object. 202 | 'clickable_seq': numpy array of ui object clickable status. 203 | 'cord_x_seq': numpy array of ui object x coordination. 204 | 'cord_y_seq': numpy array of ui object y coordination. 205 | 'dom_location_seq': numpy array of ui object depth, pre-order-traversal 206 | index, post-order-traversal index. 207 | 'word_str_sequence': numpy array of ui object name strings. 208 | """ 209 | type_sequence = [] 210 | word_id_sequence = [] 211 | char_id_sequence = [] 212 | clickable_sequence = [] 213 | cord_x_sequence = [] 214 | cord_y_sequence = [] 215 | dom_location_sequence = [] 216 | obj_str_sequence = [] 217 | 218 | def _is_ascii(s): 219 | return all(ord(c) < 128 for c in s) 220 | 221 | for vh_node in view_hierarchy_leaf_nodes: 222 | ui_obj = vh_node.uiobject 223 | type_sequence.append(ui_obj.obj_type.value) 224 | cord_x_sequence.append(ui_obj.bounding_box.x1) 225 | cord_x_sequence.append(ui_obj.bounding_box.x2) 226 | cord_y_sequence.append(ui_obj.bounding_box.y1) 227 | cord_y_sequence.append(ui_obj.bounding_box.y2) 228 | clickable_sequence.append(ui_obj.clickable) 229 | dom_location_sequence.extend(ui_obj.dom_location) 230 | 231 | valid_words = [w for w in ui_obj.word_sequence if _is_ascii(w)] 232 | word_sequence = ' '.join(valid_words) 233 | 234 | if lower_case: 235 | word_sequence = word_sequence.lower() 236 | obj_str_sequence.append(word_sequence) 237 | 238 | word_ids, char_ids = string_utils.tokenize_to_ids(word_sequence) 239 | word_id_sequence.append(word_ids) 240 | char_id_sequence.append(char_ids) 241 | ui_feature = { 242 | 'type_id_seq': np.array(type_sequence), 243 | 'word_id_seq': np.array(word_id_sequence), 244 | 'clickable_seq': np.array(clickable_sequence), 245 | 'cord_x_seq': np.array(cord_x_sequence), 246 | 'cord_y_seq': np.array(cord_y_sequence), 247 | 'dom_location_seq': np.array(dom_location_sequence), 248 | 'obj_str_seq': np.array(obj_str_sequence, dtype=np.str), 249 | } 250 | return ui_feature 251 | 252 | 253 | def get_view_hierarchy_leaf_relation(view_hierarchy_leaf_nodes): 254 | """Calculates adjacency relation from list of view hierarchy leaf nodes. 255 | 256 | Args: 257 | view_hierarchy_leaf_nodes: a list of view hierachy leaf nodes. 258 | 259 | Returns: 260 | An un-padded feature dictionary as follow: 261 | 'v_distance': 2d numpy array of ui object vertical adjacency relation. 262 | 'h_distance': 2d numpy array of ui object horizontal adjacency relation. 263 | 'dom_distance': 2d numpy array of ui object dom adjacency relation. 264 | """ 265 | vh_node_num = len(view_hierarchy_leaf_nodes) 266 | vertical_adjacency = np.zeros((vh_node_num, vh_node_num), dtype=np.float32) 267 | horizontal_adjacency = np.zeros((vh_node_num, vh_node_num), dtype=np.float32) 268 | dom_adjacency = np.zeros((vh_node_num, vh_node_num), dtype=np.int64) 269 | for row in range(len(view_hierarchy_leaf_nodes)): 270 | for column in range(len(view_hierarchy_leaf_nodes)): 271 | if row == column: 272 | h_dist = v_dist = dom_dist = 0 273 | else: 274 | node1 = view_hierarchy_leaf_nodes[row] 275 | node2 = view_hierarchy_leaf_nodes[column] 276 | h_dist, v_dist = node1.normalized_pixel_distance(node2) 277 | dom_dist = node1.dom_distance(node2) 278 | vertical_adjacency[row][column] = v_dist 279 | horizontal_adjacency[row][column] = h_dist 280 | dom_adjacency[row][column] = dom_dist 281 | return { 282 | 'v_distance': vertical_adjacency, 283 | 'h_distance': horizontal_adjacency, 284 | 'dom_distance': dom_adjacency 285 | } 286 | 287 | 288 | def padding_dictionary(orig_dict, padding_shape_dict, padding_type_dict, 289 | padding_value_dict): 290 | """Does padding for dictionary of array or numpy array. 291 | 292 | Args: 293 | orig_dict: Original dictionary. 294 | padding_shape_dict: Dictionary of padding shape, keys are field names, 295 | values are shape tuple 296 | padding_type_dict: Dictionary of padding shape, keys are field names, values 297 | are padded numpy type 298 | padding_value_dict: Dictionary of padding shape, keys are field names, 299 | values are shape tuple 300 | 301 | Returns: 302 | A padded dictionary. 303 | """ 304 | # Asserting the keys of the four dictionaries are exactly same. 305 | assert (set(orig_dict.keys()) == set(padding_shape_dict.keys()) == set( 306 | padding_type_dict.keys()) == set(padding_value_dict.keys())) 307 | padded_dict = {} 308 | for key in orig_dict: 309 | if padding_shape_dict[key]: 310 | padded_dict[key] = padding_array(orig_dict[key], padding_shape_dict[key], 311 | padding_value_dict[key], 312 | padding_type_dict[key]) 313 | else: 314 | padded_dict[key] = np.array(orig_dict[key], dtype=padding_type_dict[key]) 315 | return padded_dict 316 | 317 | 318 | def padding_array(orig_array, 319 | padding_shape, 320 | padding_value=0, 321 | padding_type=np.int64): 322 | """Pads orig_array according to padding shape, number and type. 323 | 324 | The dimension of final result is the smaller dimension between 325 | orig_array.shape and padding_shape. 326 | 327 | For example: 328 | a = [[1,2],[3,4]] 329 | padding_array(a, (3,3), 0, np.int64) = [[1, 2, 0], [3, 4, 0], [0, 0, 0]] 330 | 331 | a = [[1,2,3,4],[5,6,7,8]] 332 | padding_array(a, (3,3), 0, np.int64) = [[1, 2, 3], [5, 6, 7], [0, 0, 0]] 333 | 334 | Args: 335 | orig_array: The original array before padding. 336 | padding_shape: The shape of padding. 337 | padding_value: The number to be padded into new array. 338 | padding_type: The data type to be padded into new array. 339 | 340 | Returns: 341 | A padded numpy array. 342 | """ 343 | # When padding type is string, we need to initialize target_array with object 344 | # type first. And convert it back to np.string_ after _fill_array. Because 345 | # after initialized, numpy string array cannot hold longer string. 346 | # For example: 347 | # >>> a = np.array([''], dtype = np.string_) 348 | # >>> a 349 | # array([''], dtype='|S1') 350 | # >>> a[0] = 'foo' 351 | # >>> a 352 | # array(['f'], dtype='|S1') 353 | if padding_type == np.string_: 354 | used_pad_type = object 355 | else: 356 | used_pad_type = padding_type 357 | target_array = np.full( 358 | shape=padding_shape, fill_value=padding_value, dtype=used_pad_type) 359 | _fill_array(orig_array, target_array) 360 | if padding_type == np.string_: 361 | target_array = target_array.astype(np.string_) 362 | return target_array 363 | 364 | 365 | def _fill_array(orig_array, target_array): 366 | """Fills elements from orig_array to target_array. 367 | 368 | If any dimension of orig_array is larger than target_array, only fills the 369 | array of their shared dimensions. 370 | 371 | Args: 372 | orig_array: original array that contains the filling numbers, could be numpy 373 | array or python list. 374 | target_array: target array that will be filled with original array numbers, 375 | numpy array 376 | 377 | Raises: 378 | TypeError: if the target_array is not a numpy array 379 | """ 380 | if not isinstance(target_array, np.ndarray): 381 | raise TypeError('target array is not numpy array') 382 | if target_array.ndim == 1: 383 | try: 384 | orig_length = len(orig_array) 385 | except TypeError: 386 | tf.logging.exception( 387 | 'orig_array %s and target_array %s dimension not fit', 388 | orig_array, target_array) 389 | orig_length = 0 390 | if len(target_array) < orig_length: 391 | target_array[:] = orig_array[:len(target_array)] 392 | else: 393 | target_array[:orig_length] = orig_array 394 | return 395 | else: 396 | for sub_orig, sub_target in zip(orig_array, target_array): 397 | _fill_array(sub_orig, sub_target) 398 | 399 | 400 | def features_to_tf_example(features): 401 | """Converts feature dictionary into tf.Example protobuf. 402 | 403 | This function only supports to convert np.int and np.float array. 404 | 405 | Args: 406 | features: A feature dictionary. Keys are field names, values are np array. 407 | 408 | Returns: 409 | A tf.Example protobuf. 410 | 411 | Raises: 412 | ValueError: Feature dictionary's value field is not supported type. 413 | 414 | """ 415 | new_features = {} 416 | for k, v in features.items(): 417 | if not isinstance(v, np.ndarray): 418 | raise ValueError('Value field: %s is not numpy array' % str((k, v))) 419 | v = v.flatten() 420 | if np.issubdtype(v.dtype.type, np.string_): 421 | new_features[k] = tf.train.Feature(bytes_list=tf.train.BytesList(value=v)) 422 | elif np.issubdtype(v.dtype.type, np.integer): 423 | new_features[k] = tf.train.Feature(int64_list=tf.train.Int64List(value=v)) 424 | elif np.issubdtype(v.dtype.type, np.floating): 425 | new_features[k] = tf.train.Feature(float_list=tf.train.FloatList(value=v)) 426 | else: 427 | raise ValueError('Value for %s is not a recognized type; v: %s type: %s' % 428 | (k, str(v[0]), str(type(v[0])))) 429 | return tf.train.Example(features=tf.train.Features(feature=new_features)) 430 | -------------------------------------------------------------------------------- /data_generation/requirements.txt: -------------------------------------------------------------------------------- 1 | nltk>=3.5 2 | lxml>=4.5.0 3 | attr>=0.3.1 4 | absl-py>=0.6.0 5 | numpy>=1.15.4 6 | six>=1.12.0 7 | tensorflow==1.15 # change to 'tensorflow-gpu' for gpu support 8 | tensor2tensor 9 | -------------------------------------------------------------------------------- /data_generation/resources.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Util functions for loading word embedding data.""" 17 | 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | from __future__ import print_function 22 | 23 | import threading 24 | import numpy as np 25 | import tensorflow.compat.v1 as tf 26 | gfile = tf.gfile 27 | flags = tf.flags 28 | FLAGS = flags.FLAGS 29 | 30 | flags.DEFINE_string( 31 | "vocab_file", "", 32 | "Full path to the directory containing the data files for a set of tasks.") 33 | flags.DEFINE_string( 34 | "input_candiate_file", "", 35 | "Full path to the directory for saving the tf record file.") 36 | 37 | 38 | _candidate_words = None 39 | lock = threading.Lock() 40 | 41 | token_id_map = None 42 | token_id_map_lock = threading.Lock() 43 | 44 | id_token_map = None 45 | id_token_map_lock = threading.Lock() 46 | 47 | 48 | # This func is ONLY for converting subtoken to id by reading vocab file directly 49 | # Do not use this func as tokenizer for raw string, use string_utils.py instead 50 | def tokens_to_ids(token_list): 51 | """Gets line numbers as ids for tokens accroding to vocab file.""" 52 | with token_id_map_lock: 53 | global token_id_map 54 | if not token_id_map: 55 | token_id_map = {} 56 | with gfile.Open(get_vocab_file(), "r") as f: 57 | for idx, token in enumerate(f.read().split("\n")): 58 | # Remove head and tail apostrophes of token 59 | if token[1:-1]: 60 | token_id_map[token[1:-1]] = idx 61 | return [token_id_map[token] for token in token_list] 62 | 63 | 64 | def ids_to_tokens(id_list): 65 | """Gets tokens from id list accroding to vocab file.""" 66 | with id_token_map_lock: 67 | global id_token_map 68 | if not id_token_map: 69 | id_token_map = {} 70 | with gfile.Open(get_vocab_file(), "r") as f: 71 | for idx, token in enumerate(f.read().split("\n")): 72 | # Remove head and tail apostrophes of token 73 | if token[1:-1]: 74 | id_token_map[idx] = token[1:-1] 75 | return [id_token_map[the_id] for the_id in id_list] 76 | 77 | 78 | def _get_candidate_words(): 79 | with lock: 80 | global _candidate_words 81 | if not _candidate_words: 82 | candidate_file = FLAGS.input_candidate_file 83 | with gfile.Open(candidate_file, "r") as f: 84 | _candidate_words = f.read().split("\n") 85 | return _candidate_words 86 | 87 | 88 | def get_random_words(sample_size): 89 | candidate_words = _get_candidate_words() 90 | return np.random.choice(candidate_words, sample_size, replace=False) 91 | 92 | 93 | def get_vocab_file(): 94 | return FLAGS.vocab_file 95 | -------------------------------------------------------------------------------- /data_generation/string_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Utility to handle tasks related to string encoding. 17 | """ 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import collections 23 | import re 24 | import threading 25 | 26 | import nltk 27 | from tensor2tensor.data_generators import text_encoder 28 | from tensor2tensor.data_generators import tokenizer as t2t_tokenizer 29 | import tensorflow.compat.v1 as tf # tf 30 | 31 | from seq2act.data_generation import create_token_vocab 32 | from seq2act.data_generation import resources 33 | 34 | flags = tf.flags 35 | FLAGS = flags.FLAGS 36 | flags.DEFINE_enum( 37 | 'token_type', 't2t_subtoken', 38 | ['simple', 'nltk_token', 't2t_subtoken', 't2t_token'], 39 | 'The way to represent words: by using token and char or by subtoken') 40 | 41 | 42 | embed_dict = {} 43 | 44 | # Singleton encoder to do subtokenize, which loads vocab file only once. 45 | # Please use _get_subtoken_encoder() to get this singleton instance. 46 | _subtoken_encoder = None 47 | _token_vocab = None 48 | 49 | lock = threading.Lock() 50 | 51 | 52 | class EmptyTextError(ValueError): 53 | pass 54 | 55 | 56 | class CharPosError(ValueError): 57 | pass 58 | 59 | 60 | class UnknownTokenError(ValueError): 61 | pass 62 | 63 | 64 | def _get_subtoken_encoder(): 65 | with lock: 66 | global _subtoken_encoder 67 | if not _subtoken_encoder: 68 | _subtoken_encoder = text_encoder.SubwordTextEncoder( 69 | resources.get_vocab_file()) 70 | return _subtoken_encoder 71 | 72 | 73 | def _get_token_vocab(): 74 | with lock: 75 | global _token_vocab 76 | if not _token_vocab: 77 | _token_vocab = {} 78 | tokens, _, _ = create_token_vocab.read_vocab(resources.get_vocab_file()) 79 | _token_vocab = dict(zip(tokens, range(len(tokens)))) 80 | return _token_vocab 81 | 82 | 83 | def subtokenize_to_ids(text): 84 | """Subtokenizes text string to subtoken ids according to vocabulary.""" 85 | return _get_subtoken_encoder().encode(text) 86 | 87 | 88 | def t2t_tokenize_to_ids(text): 89 | """Tokenize text string with tensor2tensor tokenizer.""" 90 | token_vocab = _get_token_vocab() 91 | tokens = t2t_tokenizer.encode(text) 92 | token_ids = [] 93 | for token in tokens: 94 | if token not in token_vocab: 95 | raise UnknownTokenError('Unknown token %s' % token) 96 | else: 97 | token_ids.append(token_vocab[token]) 98 | return token_ids, tokens 99 | 100 | 101 | stat_fix_dict = collections.defaultdict(int) 102 | 103 | 104 | def _fix_char_position(text, start, end): 105 | """Fixes char position by extending the substring. 106 | 107 | In text_encoder.SubwordTextEncoder, alphanumeric chars vs non-alphanumeric 108 | will be splited as 2 different categories in token level, like: 109 | abc "settings" def -> 110 | 0) abc 111 | 1) space" 112 | 2) settings 113 | 3) "space 114 | 4) def 115 | So if the substring specified by start/end is <"settings">, then its tokens: 116 | 0) " 117 | 1) settings 118 | 2) " 119 | will mismatch the tokens of whole text, because <"> != 120 | Solution is extenting the substring: if the first char is non-alphanumeric and 121 | the previous char is also non-alphanumeric, then move start backforward. Do 122 | same on the end position. 123 | 124 | Args: 125 | text: whole text. 126 | start: char level start position. 127 | end: char level end position (exclusive). 128 | Returns: 129 | start: fixed start position. 130 | end: fixed end position (exclusive). 131 | """ 132 | original_start, original_end = start, end 133 | if text[start: end].strip(): # Do trim if the subtext is more than spaces 134 | while text[start] == ' ': 135 | start += 1 136 | while text[end-1] == ' ': 137 | end -= 1 138 | 139 | def same_category(a, b): 140 | return a.isalnum() and b.isalnum() or not a.isalnum() and not b.isalnum() 141 | 142 | while start > 0 and same_category(text[start-1], text[start]): 143 | start -= 1 144 | while end < len(text) and same_category(text[end-1], text[end]): 145 | end += 1 146 | 147 | edit_distance = abs(start - original_start) + abs(end - original_end) 148 | stat_fix_dict[edit_distance] += 1 149 | return start, end 150 | 151 | 152 | def get_t2t_token_pos_from_char_pos(text, start, end): 153 | """Converts char level position to t2t token/subtoken level position. 154 | 155 | Example: please click "settings" app. 156 | | | 157 | char-level: start end 158 | 159 | Tokens: [u'please', u'click', u' "', u'settings', u'app', u'"', u'.'] 160 | |____________________| | 161 | prev tokens curr tokens 162 | 163 | The start/end position of curr tokens should be (3, 4). 164 | '3' is calculated by counting the tokens of prev tokens. 165 | 166 | Args: 167 | text: whole text. 168 | start: char level start position. 169 | end: char level end position (exclusive). 170 | Returns: 171 | token_start, token_end: token level start/end position. 172 | Raises: 173 | ValueError: Empty token or wrong index to search in text. 174 | """ 175 | if start < 0 or end > len(text): 176 | raise CharPosError('Position annotation out of the boundaries of text.') 177 | 178 | start, end = _fix_char_position(text, start, end) 179 | tokens, _ = tokenize_to_ids(text) 180 | prev, _ = tokenize_to_ids(text[0:start]) 181 | curr, _ = tokenize_to_ids(text[start:end]) 182 | 183 | if curr == tokens[len(prev): len(prev) + len(curr)]: 184 | return len(prev), len(prev) + len(curr) 185 | 186 | space = 1535 187 | 188 | # try ignore the last token(' ') of prev tokens. 189 | if prev[-1] == space and curr == tokens[len(prev)-1: len(prev) + len(curr)-1]: 190 | return len(prev)-1, len(prev) + len(curr)-1 191 | 192 | if text[start: end] == ' ': 193 | raise EmptyTextError('Single space between words will be ignored.') 194 | 195 | assert False, 'Fail to locate start/end positions in text' 196 | 197 | 198 | def text_sequence_to_ids(text_seq, vocab_idx_dict): 199 | """Encodes list of words into word id sequence and character id sequence. 200 | 201 | Retrieves words' index and char's ascii code as encoding. If word is not 202 | contained in vocab_idx_dict, len(vocab_idx_dict) is the word's encoding 203 | number. 204 | 205 | For Example: 206 | vocab_idx_dict = {'hi':0, 'hello':1, 'apple':2} 207 | text_sequence_to_ids(['hello', 'world'], vocab_idx_dict) returns: 208 | word_ids = [1, 3] 209 | char_ids = [[104, 101, 108, 108, 111], [119, 111, 114, 108, 100]] 210 | 211 | Args: 212 | text_seq: list of words to be encoded 213 | vocab_idx_dict: a dictionary, keys are vocabulary, values are words' index 214 | 215 | Returns: 216 | word_ids: A 1d list of intergers, encoded word id sequence 217 | char_ids: A 2d list of integers, encoded char id sequence 218 | """ 219 | word_ids = [ 220 | vocab_idx_dict[word.lower()] 221 | if word.lower() in vocab_idx_dict else len(vocab_idx_dict) 222 | for word in text_seq 223 | ] 224 | char_ids = [] 225 | for word in text_seq: 226 | char_ids.append([ord(ch) for ch in word.lower()]) 227 | return word_ids, char_ids 228 | 229 | 230 | def tokenizer_with_punctuation(origin_string): 231 | """Extracts tokens including punctuation from origial string.""" 232 | tokens = nltk.word_tokenize(origin_string) 233 | 234 | # Note: nltk changes: left double quote to `` and right double quote to ''. 235 | # As we don't need this feature, so change them back to origial quotes 236 | tokens = ['"' if token == '``' or token == '\'\'' else token 237 | for token in tokens] 238 | 239 | result = [] 240 | for token in tokens: 241 | # nltk will separate " alone, which is good. But: 242 | # nltk will keep ' together with neightbor word, we need split the ' in head 243 | # tai. If ' is in middle of a word, leave it unchanged, like n't. 244 | # Example: 245 | # doesn't -> 2 tokens: does, n't. 246 | # 'settings' -> 3 tokens: ', setting, '. 247 | if token == '\'': 248 | result.append(token) 249 | elif token.startswith('\'') and token.endswith('\''): 250 | result.extend(['\'', token[1:-1], '\'']) 251 | elif token.startswith('\''): 252 | result.extend(['\'', token[1:]]) 253 | elif token.endswith('\''): 254 | result.extend([token[:-1], '\'']) 255 | 256 | # nltk keeps abbreviation like 'ok.' as single word, so split tailing dot. 257 | elif len(token) > 1 and token.endswith('.'): 258 | result.extend([token[:-1], '.']) 259 | else: 260 | result.append(token) 261 | 262 | # Now nltk will split https://caldav.calendar.yahoo.com to 263 | # 'https', ':', '//caldav.calendar.yahoo.com' 264 | # Combine them together: 265 | tokens = result 266 | result = [] 267 | i = 0 268 | while i < len(tokens): 269 | if (i < len(tokens) -2 and 270 | tokens[i] in ['http', 'https'] and 271 | tokens[i+1] == ':' and 272 | tokens[i+2].startswith('//')): 273 | result.append(tokens[i] + tokens[i+1] + tokens[i+2]) 274 | i += 3 275 | else: 276 | result.append(tokens[i]) 277 | i += 1 278 | 279 | return result 280 | 281 | 282 | def tokenizer(action_str): 283 | """Extracts token from action string. 284 | 285 | Removes punctuation, extra space and changes all words to lower cases. 286 | 287 | Args: 288 | action_str: the action string. 289 | 290 | Returns: 291 | action_str_tokens: A list of clean tokens. 292 | 293 | """ 294 | action_str_no_punc = re.sub(r'[^\w\s]|\n', ' ', action_str).strip() 295 | tokens = action_str_no_punc.split(' ') 296 | action_str_tokens = [token for token in tokens if token] 297 | return action_str_tokens 298 | 299 | 300 | def is_ascii_str(token_str): 301 | """Checks if the given token string is construced with all ascii chars. 302 | 303 | Args: 304 | token_str: A token string. 305 | 306 | Returns: 307 | A boolean to indicate if the token_str is ascii string or not. 308 | """ 309 | return all(ord(token_char) < 128 for token_char in token_str) 310 | 311 | 312 | def replace_non_ascii(text, replace_with=' '): 313 | """Replaces all non-ASCII chars in strinng.""" 314 | return ''.join([i if ord(i) < 128 else replace_with for i in text]) 315 | 316 | 317 | def get_index_of_list_in_list(base_list, the_sublist, 318 | start_pos=0, lookback_pivot=None): 319 | """Gets the start and end(exclusive) indexes of a sublist in base list. 320 | 321 | Examples: 322 | call with (['00', '.', '22', '33', '44'. '.' '66'], ['22', '33'], 3) 323 | raise ValueError # Search from 3rd and never lookback. 324 | call with (['00', '.', '22', '33', '44'. '.' '66'], ['22', '33'], 3, '.') 325 | return (2, 4) # Search from 3rd and lookback until previous dot('.') 326 | Args: 327 | base_list: list of str (or any other type), the base list. 328 | the_sublist: list of str (or any other type), the sublist search for. 329 | start_pos: the index to start search. 330 | lookback_pivot: string. If not None, the start_pos will be moved backforward 331 | until an item equal to lookback_pivot. If no previous item matchs 332 | lookback_pivot, start_pos will be set at the beginning of base_list. 333 | Returns: 334 | int, int: the start and end indexes(exclusive) of the sublist in base list. 335 | Raises: 336 | ValueError: when sublist not found in base list. 337 | """ 338 | if lookback_pivot is not None: 339 | current = start_pos -1 340 | while current >= 0: 341 | if base_list[current] == lookback_pivot: 342 | break 343 | current -= 1 344 | start_pos = current + 1 345 | 346 | if not base_list or not the_sublist: 347 | return ValueError('Empty base_list or sublist.') 348 | for i in range(start_pos, len(base_list) - len(the_sublist) + 1): 349 | if the_sublist == base_list[i: i + len(the_sublist)]: 350 | return i, i + len(the_sublist) 351 | raise ValueError('Sublist not found in list') 352 | 353 | 354 | def tokenize(text): 355 | """Totenizes text to subtext with specific granularity.""" 356 | global embed_dict 357 | if FLAGS.token_type == 't2t_subtoken': 358 | ids = _get_subtoken_encoder().encode(text) 359 | return [_get_subtoken_encoder().decode([the_id]) for the_id in ids] 360 | else: 361 | assert False, 'Unknown tokenize mode' 362 | 363 | 364 | def tokenize_to_ids(text): 365 | """Totenizes text to ids of subtext with specific granularity.""" 366 | if FLAGS.token_type == 't2t_subtoken': 367 | ids = _get_subtoken_encoder().encode(text) 368 | subtokens = [_get_subtoken_encoder().decode([the_id]) for the_id in ids] 369 | char_ids = [] 370 | for subtoken in subtokens: 371 | char_ids.append([ord(ch) for ch in subtoken.lower()]) 372 | return ids, char_ids 373 | else: 374 | assert False, 'Unknown tokenize mode' 375 | 376 | 377 | def get_token_pos_from_char_pos(text, start, end): 378 | if FLAGS.token_type == 'simple': 379 | raise NotImplementedError() 380 | elif FLAGS.token_type == 'nltk_token': 381 | raise NotImplementedError() 382 | elif FLAGS.token_type == 't2t_subtoken' or FLAGS.token_type == 't2t_token': 383 | return get_t2t_token_pos_from_char_pos(text, start, end) 384 | else: 385 | assert False, 'Unknown tokenize mode' 386 | -------------------------------------------------------------------------------- /layers/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /layers/area_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Utils for area computation.""" 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | from tensor2tensor.layers import common_layers 21 | import tensorflow.compat.v1 as tf 22 | 23 | 24 | def area_bounds(length, max_area_width): 25 | """Compute the area bounds.""" 26 | with tf.name_scope("compute_area_bounds"): 27 | start_list = [] 28 | end_list = [] 29 | for area_size in range(max_area_width): 30 | starts = tf.range(tf.maximum(length - area_size, 0)) 31 | ends = starts + area_size + 1 32 | start_list.append(starts) 33 | end_list.append(ends) 34 | area_starts = tf.concat(start_list, axis=0) 35 | area_ends = tf.concat(end_list, axis=0) 36 | return area_starts, area_ends 37 | 38 | 39 | def compute_sum_image(features, max_area_width): 40 | """Computes the vector sums of possible areas. (TODO: liyang) use t2t. 41 | 42 | Args: 43 | features: a tensor in shape of [batch_size, length, depth] 44 | max_area_width: a constant scalar. 45 | Returns: 46 | sum_image: vector sums of all the area combination. 47 | area_starts: the start position of each area. 48 | area_ends: the end position of each area. 49 | """ 50 | with tf.name_scope("compute_sum_image", values=[features]): 51 | integral_image = tf.cumsum(features, axis=1, name="compute_integral_image") 52 | padded_integral_image = tf.pad( 53 | integral_image, [[0, 0], [1, 0], [0, 0]], constant_values=0) 54 | start_list = [] 55 | end_list = [] 56 | dst_images = [] 57 | src_images = [] 58 | shape = common_layers.shape_list(padded_integral_image) 59 | batch_size = shape[0] 60 | length = shape[1] 61 | for area_size in range(max_area_width): 62 | dst_images.append(padded_integral_image[:, area_size + 1:, :]) 63 | src_images.append(padded_integral_image[:, :-area_size - 1, :]) 64 | starts = tf.tile(tf.expand_dims(tf.range( 65 | tf.maximum(length - area_size - 1, 0)), 0), [batch_size, 1]) 66 | ends = starts + area_size + 1 67 | start_list.append(starts) 68 | end_list.append(ends) 69 | sum_image = tf.subtract(tf.concat(dst_images, axis=1), 70 | tf.concat(src_images, axis=1)) 71 | area_starts = tf.concat(start_list, axis=1) 72 | area_ends = tf.concat(end_list, axis=1) 73 | return sum_image, area_starts, area_ends 74 | 75 | 76 | def compute_alternative_span_rep(hiddens, features, max_area_width, 77 | hidden_size, advanced=False): 78 | """Computes the vector sums of possible areas. (TODO: liyang) use t2t. 79 | 80 | Args: 81 | hiddens: the hidden representation of features. 82 | features: a tensor in shape of [batch_size, length, depth]. 83 | max_area_width: a constant scalar. 84 | hidden_size: the target hidden_size. 85 | advanced: whether to use advanced representations that includes start-end 86 | encoding, the weighted sum encoding and the size encoding. 87 | Returns: 88 | summary_features: representations for all the area combination in the shape 89 | of [batch_size, length, hidden_size]. 90 | """ 91 | with tf.name_scope("compute_start_end_image", values=[hiddens, features]): 92 | dst_images = [hiddens] 93 | src_images = [hiddens] 94 | # Starting from effective area size 2 95 | for area_size in range(max_area_width - 1): 96 | dst_images.append(hiddens[:, (area_size + 1):, :]) 97 | src_images.append(hiddens[:, :-(area_size + 1), :]) 98 | end_image = tf.concat(dst_images, axis=1) 99 | start_image = tf.concat(src_images, axis=1) 100 | start_end_image = tf.concat([end_image, start_image], axis=-1) 101 | if advanced: 102 | weights = tf.exp(tf.nn.sigmoid(tf.layers.dense( 103 | tf.layers.dense(hiddens, units=hidden_size, 104 | activation=tf.nn.relu), units=1))) 105 | features = weights * features 106 | normalizers, area_starts, area_ends = compute_sum_image( 107 | weights, max_area_width) 108 | sum_images, _, _ = compute_sum_image(features, max_area_width) 109 | final_images = tf.math.divide_no_nan(sum_images, normalizers) 110 | sizes = area_ends - area_starts 111 | size_embeddings = tf.nn.embedding_lookup( 112 | params=tf.get_variable( 113 | name="span_len_w", 114 | shape=[max_area_width, max_area_width]), 115 | ids=sizes, name="embed_span_len") 116 | summary_features = tf.layers.dense( 117 | tf.concat([start_end_image, final_images, size_embeddings], axis=-1), 118 | units=hidden_size) 119 | else: 120 | summary_features = tf.layers.dense( 121 | start_end_image, 122 | units=hidden_size) 123 | return summary_features 124 | 125 | 126 | def area_range_to_index(area_range, length, max_area_width): 127 | """Computes the indices of each area in the area expansion. 128 | 129 | Args: 130 | area_range: tensor in shape of [batch_size, 2] 131 | length: a scalar tensor gives the length of the original feature space. 132 | max_area_width: a constant scalar. 133 | Returns: 134 | indices: area indices tensor in shape of [batch_size] 135 | """ 136 | with tf.control_dependencies([tf.assert_equal(tf.rank(area_range), 2), 137 | tf.assert_equal(tf.shape(area_range)[1], 2)]): 138 | area_range = tf.cast(area_range, tf.int32) 139 | target_size = area_range[:, 1] - area_range[:, 0] 140 | with tf.control_dependencies([ 141 | tf.assert_less(target_size, max_area_width + 1, summarize=100000)]): 142 | sizes = target_size - 1 143 | start_length = length 144 | pre_end_length = length - sizes + 1 145 | base = (start_length + pre_end_length) *\ 146 | (start_length - pre_end_length + 1) // 2 147 | base = tf.where( 148 | tf.less_equal(target_size, 1), 149 | tf.zeros_like(target_size), 150 | base) 151 | offset = area_range[:, 0] 152 | return base + offset 153 | 154 | 155 | def batch_gather(values, indices): 156 | """Gather slices from values. 157 | 158 | Args: 159 | values: a tensor in the shape of [batch_size, length, depth]. 160 | indices: a tensor in the shape of [batch_size, slice_count] where 161 | slice_count < length. 162 | Returns: 163 | a tensor in the shape of [batch_size, slice_count, depth]. 164 | """ 165 | with tf.control_dependencies([ 166 | tf.assert_equal(tf.rank(values), 3, message="values"), 167 | tf.assert_equal(tf.rank(indices), 2, message="indices"), 168 | tf.assert_equal(tf.shape(values)[0], tf.shape(indices)[0], 169 | message="batch"), 170 | ]): 171 | shape = common_layers.shape_list(indices) 172 | depth = common_layers.shape_list(values)[-1] 173 | batch_indices = tf.reshape(tf.tile( 174 | tf.expand_dims(tf.range(shape[0]), [1]), 175 | [1, shape[1]]), [-1, 1]) 176 | indices = tf.concat([batch_indices, tf.cast( 177 | tf.reshape(indices, [-1, 1]), tf.int32)], axis=-1) 178 | slices = tf.gather_nd(params=values, indices=indices) 179 | return tf.reshape(slices, [shape[0], shape[1], depth]) 180 | 181 | 182 | def query_area(query, area_encodings, area_bias): 183 | """Predicts a range of tokens based on the query. 184 | 185 | Args: 186 | query: a Tensor of shape [batch_size, length, depth] 187 | area_encodings: a tensor in shape of [batch_size, num_areas, depth] 188 | area_bias: a tensor in shape of [batch_size, num_areas]. 189 | Returns: 190 | the logits to each area. 191 | """ 192 | with tf.control_dependencies([tf.assert_equal(tf.rank(query), 3), 193 | tf.assert_equal(tf.rank(area_encodings), 3), 194 | tf.assert_equal(tf.shape(query)[-1], 195 | tf.shape(area_encodings)[-1]), 196 | tf.assert_equal(tf.rank(area_bias), 2)]): 197 | dot_products = tf.matmul(query, tf.transpose(area_encodings, [0, 2, 1])) 198 | area_logits = dot_products + tf.expand_dims(area_bias, 1) 199 | return area_logits 200 | 201 | 202 | def area_loss(logits, ranges, length, max_area_width, allow_empty=False): 203 | """Computes the loss regarding areas. 204 | 205 | Args: 206 | logits: the predictions of each area [batch_size, query_length, num_areas]. 207 | ranges: the groundtruth [batch_size, query_length, 2]. 208 | length: the length of the original tensor. 209 | max_area_width: the maximum area width. 210 | allow_empty: whether to allow empty refs. 211 | Returns: 212 | the loss. 213 | """ 214 | num_areas = common_layers.shape_list(logits)[-1] 215 | ranges = tf.reshape(ranges, [-1, 2]) 216 | indices = area_range_to_index(area_range=ranges, 217 | length=length, 218 | max_area_width=max_area_width) 219 | if allow_empty: 220 | indices = tf.where( 221 | tf.greater(ranges[:, 1], ranges[:, 0]), indices + 1, 222 | tf.zeros_like(indices)) 223 | logits = tf.reshape(logits, [-1, num_areas]) 224 | losses = tf.losses.sparse_softmax_cross_entropy( 225 | labels=indices, logits=logits, 226 | reduction=tf.losses.Reduction.NONE) 227 | with tf.control_dependencies([tf.assert_greater_equal(ranges[:, 1], 228 | ranges[:, 0])]): 229 | if not allow_empty: 230 | mask = tf.greater(ranges[:, 1], ranges[:, 0]) 231 | losses = losses * tf.cast(mask, tf.float32) 232 | return tf.reduce_mean(losses) 233 | 234 | 235 | def area_to_refs(starts, ends, areas): 236 | return tf.concat([ 237 | batch_gather(tf.expand_dims(starts, 2), areas), 238 | batch_gather(tf.expand_dims(ends, 2), areas)], axis=-1) 239 | -------------------------------------------------------------------------------- /layers/common_embed.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Functions for embedding tokens.""" 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | from tensor2tensor.layers import common_layers 22 | import tensorflow.compat.v1 as tf 23 | 24 | 25 | def embed_tokens(tokens, task_vocab_size, hidden_size, hparams, 26 | embed_scope=None): 27 | """Embeds tokens.""" 28 | with tf.variable_scope("embed_tokens" if embed_scope is None else embed_scope, 29 | reuse=tf.AUTO_REUSE) as scope: 30 | input_embeddings = tf.nn.embedding_lookup( 31 | params=tf.get_variable( 32 | name="task_embed_w", 33 | shape=[task_vocab_size, hidden_size]), 34 | ids=tokens, name="embed_tokens") 35 | if hparams.get("freeze_reference_model", False): 36 | input_embeddings = tf.stop_gradient(input_embeddings) 37 | return input_embeddings, scope 38 | 39 | 40 | def average_bag_of_embeds(embeddings, mask, use_bigrams=False, 41 | bigram_embed_scope=None, append_start_end=False): 42 | """Averages a bag of embeds. 43 | 44 | Args: 45 | embeddings: a float Tensor of shape [None, length, depth] 46 | mask: a boolean Tensor of shape [None, length] 47 | use_bigrams: whether to use bigrams. 48 | bigram_embed_scope: the variable scope. 49 | append_start_end: whether to append start and end tokens. 50 | Returns: 51 | word_embed: a Tensor of shape [None, embed_size] 52 | """ 53 | if bigram_embed_scope is None: 54 | var_scope = "average_bow" 55 | else: 56 | var_scope = bigram_embed_scope 57 | with tf.variable_scope(var_scope, reuse=tf.AUTO_REUSE): 58 | with tf.control_dependencies([ 59 | tf.assert_equal(tf.rank(embeddings), 3, summarize=100), 60 | tf.assert_equal(tf.rank(mask), 2, summarize=100), 61 | ]): 62 | lengths = tf.cast( 63 | tf.reduce_sum(tf.cast(mask, tf.int32), -1, keepdims=True), tf.float32) 64 | batch_size = common_layers.shape_list(embeddings)[0] 65 | length = common_layers.shape_list(embeddings)[1] 66 | depth = common_layers.shape_list(embeddings)[2] 67 | embeddings = tf.where( 68 | tf.tile(tf.expand_dims(mask, 2), [1, 1, depth]), embeddings, 69 | tf.zeros_like(embeddings)) 70 | if use_bigrams: 71 | if append_start_end: 72 | span_start_embed = tf.get_variable(name="span_start_embed", 73 | shape=[depth]) 74 | span_end_embed = tf.get_variable(name="span_end_embed", 75 | shape=[depth]) 76 | span_end_embed = tf.expand_dims(tf.expand_dims(span_end_embed, 0), 0) 77 | start = tf.expand_dims( 78 | tf.tile(tf.expand_dims(span_start_embed, 0), [batch_size, 1]), 1) 79 | # Prefix the start 80 | embeddings = tf.concat([start, embeddings], axis=1) 81 | # Pad for the end slot 82 | embeddings = tf.pad(embeddings, [[0, 0], [0, 1], [0, 0]]) 83 | span_end_embed = tf.tile(span_end_embed, [batch_size, length + 2, 1]) 84 | mask_with_start = tf.pad( 85 | tf.pad(tf.to_int32(mask), [[0, 0], [1, 0]], 86 | constant_values=1), [[0, 0], [0, 1]], 87 | constant_values=0) 88 | mask_with_end = tf.pad(mask_with_start, [[0, 0], [1, 0]], 89 | constant_values=1)[:, :-1] 90 | mask = tf.cast(mask_with_end, tf.bool) 91 | mask_of_end = tf.expand_dims(mask_with_end - mask_with_start, 2) 92 | embeddings = embeddings + span_end_embed * tf.to_float(mask_of_end) 93 | bigram_embeddings = tf.layers.dense( 94 | tf.concat([embeddings[:, :-1, :], embeddings[:, 1:, :]], axis=-1), 95 | units=depth) 96 | bigram_mask = tf.to_float(tf.expand_dims(mask[:, 1:], 2)) 97 | masked_bigram_embeddings = bigram_embeddings * bigram_mask 98 | embeddings = tf.concat( 99 | [embeddings, masked_bigram_embeddings], axis=1) 100 | lengths = lengths + lengths - 1 101 | avg_embeddings = tf.div(tf.reduce_sum(embeddings, axis=1), 102 | tf.maximum(lengths, 1.0)) 103 | return avg_embeddings 104 | -------------------------------------------------------------------------------- /layers/encode_screen.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Utils for encoding a UI screen.""" 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | 20 | from __future__ import print_function 21 | 22 | from tensor2tensor.layers import common_attention 23 | from tensor2tensor.layers import common_layers 24 | from tensor2tensor.models import transformer 25 | import tensorflow.compat.v1 as tf 26 | from seq2act.layers import common_embed 27 | 28 | 29 | def prepare_encoder_input(features, hparams, embed_scope=None, 30 | embed_token_fn=common_embed.embed_tokens): 31 | """Prepares the input for the screen encoder. 32 | 33 | Args: 34 | features: the feature dict. 35 | hparams: the hyperparameter. 36 | embed_scope: the embedding variable scope. 37 | embed_token_fn: the function for embedding tokens. 38 | Returns: 39 | object_embedding: a Tensor of shape 40 | [batch_size, num_steps, max_object_count, embed_depth] 41 | object_mask: a binary tensor of shape 42 | [batch_size, num_steps, max_object_count] 43 | nonpadding_bias: a Tensor of shape 44 | [batch_size, num_steps, max_object_count] 45 | """ 46 | with tf.control_dependencies([ 47 | tf.assert_equal(tf.rank(features["obj_text"]), 4)]): 48 | if hparams.get("synthetic_screen_noise", 0.) > 0.: 49 | num_objects = tf.shape(features["obj_text"])[2] 50 | # [batch, length, num_objects] 51 | target_obj_mask = tf.cast( 52 | tf.one_hot(features["objects"], depth=num_objects), tf.bool) 53 | num_tokens = tf.shape(features["obj_text"])[-1] 54 | target_obj_mask = tf.tile( 55 | tf.expand_dims(target_obj_mask, 3), 56 | [1, 1, 1, num_tokens]) 57 | # Randomly keep tokens 58 | keep_mask = tf.greater_equal( 59 | tf.random_uniform(shape=tf.shape(features["obj_text"])), 60 | hparams.synthetic_screen_noise) 61 | # Keep paddings 62 | keep_mask = tf.logical_or(tf.equal(features["obj_text"], 0), 63 | keep_mask) 64 | # Keep targets 65 | target_obj_mask = tf.logical_or(target_obj_mask, keep_mask) 66 | features["obj_text"] = tf.where( 67 | target_obj_mask, features["obj_text"], 68 | tf.random_uniform(shape=tf.shape(features["obj_text"]), 69 | maxval=50000, dtype=tf.int32)) 70 | text_embeddings, _ = embed_token_fn( 71 | features["obj_text"], 72 | hparams.task_vocab_size, 73 | hparams.hidden_size, hparams, 74 | embed_scope=embed_scope) 75 | with tf.variable_scope("obj_text_embed", reuse=tf.AUTO_REUSE): 76 | if hparams.obj_text_aggregation == "max": 77 | embed_bias = tf.cast(tf.less(features["obj_text"], 2), 78 | tf.float32) * -1e7 79 | with tf.control_dependencies([tf.assert_equal(tf.rank(embed_bias), 4)]): 80 | text_embeddings = tf.reduce_max( 81 | text_embeddings + tf.expand_dims(embed_bias, 4), -2) 82 | no_txt_embed = tf.get_variable( 83 | name="no_txt_embed", shape=[hparams.hidden_size]) 84 | shape = common_layers.shape_list(text_embeddings) 85 | no_txt_embed = tf.tile( 86 | tf.reshape(no_txt_embed, [1, 1, 1, hparams.hidden_size]), 87 | [shape[0], shape[1], shape[2], 1]) 88 | text_embeddings = tf.maximum(text_embeddings, no_txt_embed) 89 | elif hparams.obj_text_aggregation == "sum": 90 | # [batch, step, #max_obj, #max_token] 0 for padded tokens 91 | real_objects = tf.cast( 92 | tf.greater_equal(features["obj_text"], 2), tf.float32) 93 | # [batch, step, #max_obj, hidden] 0s for padded objects 94 | text_embeddings = tf.reduce_sum( 95 | text_embeddings * tf.expand_dims(real_objects, 4), -2) 96 | elif hparams.obj_text_aggregation == "mean": 97 | shape_list = common_layers.shape_list(text_embeddings) 98 | embeddings = tf.reshape(text_embeddings, [-1] + shape_list[3:]) 99 | emb_sum = tf.reduce_sum(tf.abs(embeddings), axis=-1) 100 | non_paddings = tf.not_equal(emb_sum, 0.0) 101 | embeddings = common_embed.average_bag_of_embeds( 102 | embeddings, non_paddings, use_bigrams=True, 103 | bigram_embed_scope=embed_scope, append_start_end=True) 104 | text_embeddings = tf.reshape( 105 | embeddings, shape_list[:3] + [hparams.hidden_size]) 106 | else: 107 | raise ValueError("Unrecognized token aggregation %s" % ( 108 | hparams.obj_text_aggregation)) 109 | with tf.control_dependencies([ 110 | tf.assert_equal(tf.rank(features["obj_type"]), 3), 111 | tf.assert_equal(tf.rank(features["obj_clickable"]), 3)]): 112 | with tf.variable_scope("encode_object_attr", reuse=tf.AUTO_REUSE): 113 | type_embedding = tf.nn.embedding_lookup( 114 | params=tf.get_variable( 115 | name="embed_type_w", shape=[hparams.get("num_types", 100), 116 | hparams.hidden_size]), 117 | ids=tf.maximum(features["obj_type"], 0)) 118 | clickable_embedding = tf.nn.embedding_lookup( 119 | params=tf.get_variable( 120 | name="embed_clickable_w", shape=[2, hparams.hidden_size]), 121 | ids=features["obj_clickable"]) 122 | with tf.control_dependencies([ 123 | tf.assert_equal(tf.rank(features["obj_screen_pos"]), 4)]): 124 | def _create_embed(feature_name, vocab_size, depth): 125 | """Embed a position feature.""" 126 | pos_embedding_list = [] 127 | with tf.variable_scope("encode_object_" + feature_name, 128 | reuse=tf.AUTO_REUSE): 129 | num_featues = common_layers.shape_list(features[feature_name])[-1] 130 | for i in range(num_featues): 131 | pos_embedding_list.append(tf.nn.embedding_lookup( 132 | params=tf.get_variable( 133 | name=feature_name + "_embed_w_%d" % i, 134 | shape=[vocab_size, depth]), 135 | ids=features[feature_name][:, :, :, i])) 136 | pos_embedding = tf.add_n(pos_embedding_list) 137 | return pos_embedding 138 | pos_embedding = _create_embed("obj_screen_pos", 139 | hparams.max_pixel_pos, 140 | hparams.hidden_size) 141 | if "all" == hparams.screen_embedding_feature or ( 142 | "dom" in hparams.screen_embedding_feature): 143 | dom_embedding = _create_embed("obj_dom_pos", 144 | hparams.max_dom_pos, 145 | hparams.hidden_size) 146 | object_embed = tf.zeros_like(text_embeddings, dtype=tf.float32) 147 | if hparams.screen_embedding_feature == "all": 148 | object_embed = ( 149 | text_embeddings + type_embedding + pos_embedding + dom_embedding) 150 | elif "text" in hparams.screen_embedding_feature: 151 | object_embed += text_embeddings 152 | elif "type" in hparams.screen_embedding_feature: 153 | object_embed += type_embedding 154 | elif "pos" in hparams.screen_embedding_feature: 155 | object_embed += pos_embedding 156 | elif "dom" in hparams.screen_embedding_feature: 157 | object_embed += dom_embedding 158 | elif "click" in hparams.screen_embedding_feature: 159 | object_embed += clickable_embedding 160 | object_mask = tf.cast(tf.not_equal(features["obj_type"], -1), tf.float32) 161 | object_embed = object_embed * tf.expand_dims(object_mask, 3) 162 | att_bias = (1. - object_mask) * common_attention.large_compatible_negative( 163 | object_embed.dtype) 164 | return object_embed, object_mask, att_bias 165 | 166 | 167 | def transformer_encoder(features, hparams, 168 | embed_scope=None, 169 | embed_token_fn=common_embed.embed_tokens, 170 | attention_weights=None): 171 | """Encodes a screen using Transformer. 172 | 173 | Args: 174 | features: the feature dict. 175 | hparams: the hyperparameter. 176 | embed_scope: the scope for token embedding. 177 | embed_token_fn: the embed function. 178 | attention_weights: the attention_weights dict. 179 | Returns: 180 | encoder_outputs: a Tensor of shape 181 | [batch_size, num_steps, max_object_count, hidden_size] 182 | encoder_attn_bias: A tensor of shape 183 | [batch_size, num_steps, max_object_count] 184 | """ 185 | tf.logging.info("Using Transformer screen encoder") 186 | # Remove the default positional encoding in Transformer 187 | object_embed, object_mask, encoder_attn_bias = prepare_encoder_input( 188 | features=features, hparams=hparams, embed_scope=embed_scope, 189 | embed_token_fn=embed_token_fn) 190 | with tf.variable_scope("encode_screen", reuse=tf.AUTO_REUSE): 191 | shape = tf.shape(object_embed) 192 | with tf.control_dependencies([ 193 | tf.assert_equal(shape[3], hparams.hidden_size)]): 194 | object_embed = tf.reshape(object_embed, 195 | [shape[0] * shape[1], shape[2], 196 | hparams.hidden_size]) 197 | encoder_input = tf.nn.dropout( 198 | object_embed, 199 | keep_prob=1.0 - hparams.layer_prepostprocess_dropout) 200 | self_attention_bias = tf.expand_dims(tf.expand_dims( 201 | tf.reshape(encoder_attn_bias, [shape[0] * shape[1], shape[2]]), 202 | axis=1), axis=1) 203 | encoder_output = transformer.transformer_encoder( 204 | encoder_input=encoder_input, 205 | encoder_self_attention_bias=self_attention_bias, 206 | hparams=hparams, 207 | save_weights_to=attention_weights, 208 | make_image_summary=not common_layers.is_xla_compiled()) 209 | encoder_output = tf.reshape(encoder_output, 210 | [shape[0], shape[1], shape[2], shape[3]]) 211 | return encoder_output, object_mask, encoder_attn_bias 212 | 213 | 214 | def gcn_encoder(features, hparams, embed_scope, 215 | embed_token_fn=common_embed.embed_tokens, 216 | adjcency_feature="obj_dom_dist", 217 | discretize=True): 218 | """Encodes a screen using Graph Convolution Networks. 219 | 220 | Args: 221 | features: the feature dict. 222 | hparams: the hyperparameter. 223 | embed_scope: the variable scope for token embedding. 224 | embed_token_fn: the embed function. 225 | adjcency_feature: the feature name for the adjacency matrix. 226 | discretize: whether to discretize the matrix. 227 | Returns: 228 | encoder_outputs: a Tensor of shape 229 | [batch_size, num_steps, max_object_count, hidden_size] 230 | encoder_attn_bias: A tensor of shape 231 | [batch_size, num_steps, max_object_count] 232 | """ 233 | tf.logging.info("Using GCN screen encoder") 234 | # [batch_size, num_steps, max_num_objects, depth] 235 | inputs, object_mask, encoder_attn_bias = prepare_encoder_input( 236 | features=features, hparams=hparams, embed_scope=embed_scope, 237 | embed_token_fn=embed_token_fn) 238 | # [batch_size, num_steps, max_num_objects, max_num_objects] 239 | if discretize: 240 | adjacency_matrix = tf.cast(tf.where( 241 | tf.greater(features[adjcency_feature], 1), 242 | tf.zeros_like(features[adjcency_feature]), 243 | tf.ones_like(features[adjcency_feature])), tf.float32) 244 | else: 245 | adjacency_matrix = tf.cast(features[adjcency_feature], tf.float32) 246 | dom_dist_variance = 0.1 247 | numerator = tf.exp( 248 | adjacency_matrix * adjacency_matrix / (-2.0 * dom_dist_variance)) 249 | denominator = tf.sqrt(2.0 * 3.141 * dom_dist_variance) 250 | adjacency_matrix = numerator / denominator 251 | encoder_outputs = graph_cnn(inputs, object_mask, hparams.num_hidden_layers, 252 | hparams.hidden_size, 253 | dropout=hparams.layer_prepostprocess_dropout, 254 | adjacency_matrix=adjacency_matrix, 255 | norm_type=hparams.norm_type, 256 | norm_epsilon=hparams.norm_epsilon) 257 | return encoder_outputs, object_mask, encoder_attn_bias 258 | 259 | 260 | def graph_cnn(inputs, object_mask, num_layers, hidden_size, dropout, 261 | adjacency_matrix, norm_type="layer", norm_epsilon=0.001, 262 | test=False): 263 | """Encodes a screen using Graph Convolution Networks. 264 | 265 | Args: 266 | inputs: [batch_size, num_steps, max_object_count, depth]. 267 | object_mask: [batch_size, num_steps, max_object_count]. 268 | num_layers: the number of layers. 269 | hidden_size: the hidden layer size. 270 | dropout: dropout ratio. 271 | adjacency_matrix: the adjacency matrix 272 | [batch_size, num_steps, max_object_count, max_object_count]. 273 | norm_type: the norm_type. 274 | norm_epsilon: norm_epsilon. 275 | test: whether it's in the test mode. 276 | Returns: 277 | hidden: a Tensor of shape 278 | [batch_size, num_steps, max_object_count, depth] 279 | """ 280 | # [batch_size, num_steps, max_num_objects, max_num_objects] 281 | normalizer = tf.div(1., tf.sqrt(tf.reduce_sum( 282 | adjacency_matrix, -1, keepdims=True))) 283 | normalizer = normalizer * tf.expand_dims(tf.expand_dims( 284 | tf.eye(tf.shape(normalizer)[-2]), 0), 0) 285 | adjacency_matrix = tf.matmul( 286 | tf.matmul(normalizer, adjacency_matrix), normalizer) 287 | hidden = inputs 288 | for layer in range(num_layers): 289 | with tf.variable_scope("gcn_layer_" + str(layer), reuse=tf.AUTO_REUSE): 290 | hidden = tf.matmul(adjacency_matrix, hidden) 291 | # [batch_size, num_steps, max_num_objects, depth] 292 | if not test: 293 | hidden = tf.layers.dense(inputs=hidden, units=hidden_size) 294 | hidden = common_layers.apply_norm( 295 | hidden, norm_type, hidden_size, 296 | epsilon=norm_epsilon) 297 | hidden = tf.nn.relu(hidden) 298 | hidden = tf.nn.dropout(hidden, keep_prob=1.0 - dropout) 299 | # zero out padding objects 300 | hidden = hidden * tf.expand_dims(object_mask, 3) 301 | return hidden 302 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /models/seq2act_estimator.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """seq2act estimator.""" 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import os 22 | from tensor2tensor.layers import common_layers 23 | from tensor2tensor.models import transformer 24 | from tensor2tensor.utils import learning_rate 25 | from tensor2tensor.utils import optimize 26 | import tensorflow.compat.v1 as tf 27 | from seq2act.models import input as input_utils 28 | from seq2act.models import seq2act_model 29 | from seq2act.utils import decode_utils 30 | 31 | flags = tf.flags 32 | FLAGS = flags.FLAGS 33 | 34 | flags.DEFINE_string("train_files", None, "the file names for training") 35 | flags.DEFINE_string("eval_files", None, "the file names for eval") 36 | flags.DEFINE_integer("worker_gpu", 1, "worker_gpu") 37 | flags.DEFINE_integer("worker_replicas", 1, "num_workers") 38 | flags.DEFINE_integer("train_steps", 100000, "train_steps") 39 | flags.DEFINE_string("hparams", "", "the hyper parameters") 40 | flags.DEFINE_string("domain_hparams", "", 41 | "the domain-specific hyper parameters") 42 | flags.DEFINE_string("domain_train_files", None, "the file names for training") 43 | flags.DEFINE_string("domain_eval_files", None, "the file names for eval") 44 | flags.DEFINE_string("metric_types", 45 | "final_accuracy,ref_accuracy,basic_accuracy", 46 | "metric types") 47 | flags.DEFINE_boolean("post_processing", True, 48 | "post processing the predictions") 49 | 50 | 51 | def _ref_accuracy(features, pred_dict, nonpadding, name, metrics, 52 | decode_refs=None, 53 | measure_beginning_eos=False, 54 | debug=False): 55 | """Computes the accuracy of reference prediction. 56 | 57 | Args: 58 | features: the feature dict. 59 | pred_dict: the dictionary to hold the prediction results. 60 | nonpadding: a 2D boolean tensor for masking out paddings. 61 | name: the name of the feature to be predicted. 62 | metrics: the eval metrics. 63 | decode_refs: decoded references. 64 | measure_beginning_eos: whether to measure the beginning and the end. 65 | debug: whether to output mismatches. 66 | """ 67 | if decode_refs is not None: 68 | gt_seq_lengths = decode_utils.verb_refs_to_lengths(features["task"], 69 | features["verb_refs"]) 70 | pr_seq_lengths = decode_utils.verb_refs_to_lengths(decode_refs["task"], 71 | decode_refs["verb_refs"]) 72 | full_acc, partial_acc = decode_utils.sequence_accuracy( 73 | features[name], decode_refs[name], gt_seq_lengths, pr_seq_lengths, 74 | debug=debug, name=name) 75 | metrics[name + "_full_accuracy"] = tf.metrics.mean(full_acc) 76 | metrics[name + "_partial_accuracy"] = tf.metrics.mean(partial_acc) 77 | if measure_beginning_eos: 78 | nonpadding = tf.reshape(nonpadding, [-1]) 79 | refs = tf.reshape(features[name], [-1, 2]) 80 | predict_refs = tf.reshape(pred_dict[name], [-1, 2]) 81 | metrics[name + "_start"] = tf.metrics.accuracy( 82 | labels=tf.boolean_mask(refs[:, 0], nonpadding), 83 | predictions=tf.boolean_mask(predict_refs[:, 0], nonpadding), 84 | name=name + "_start_accuracy") 85 | metrics[name + "_end"] = tf.metrics.accuracy( 86 | labels=tf.boolean_mask(refs[:, 1], nonpadding), 87 | predictions=tf.boolean_mask(predict_refs[:, 1], nonpadding), 88 | name=name + "_end_accuracy") 89 | 90 | 91 | def _eval(metrics, pred_dict, loss_dict, features, areas, compute_seq_accuracy, 92 | hparams, metric_types, decode_length=20): 93 | """Internal eval function.""" 94 | # Assume data sources are not mixed within each batch 95 | if compute_seq_accuracy: 96 | decode_features = {} 97 | for key in features: 98 | if not key.endswith("_refs"): 99 | decode_features[key] = features[key] 100 | decode_utils.decode_n_step(seq2act_model.compute_logits, 101 | decode_features, areas, 102 | hparams, n=decode_length, beam_size=1) 103 | decode_features["input_refs"] = decode_utils.unify_input_ref( 104 | decode_features["verbs"], decode_features["input_refs"]) 105 | acc_metrics = decode_utils.compute_seq_metrics( 106 | features, decode_features) 107 | metrics["seq_full_acc"] = tf.metrics.mean(acc_metrics["complete_refs_acc"]) 108 | metrics["seq_partial_acc"] = tf.metrics.mean( 109 | acc_metrics["partial_refs_acc"]) 110 | if "final_accuracy" in metric_types: 111 | metrics["complet_act_accuracy"] = tf.metrics.mean( 112 | acc_metrics["complete_acts_acc"]) 113 | metrics["partial_seq_acc"] = tf.metrics.mean( 114 | acc_metrics["partial_acts_acc"]) 115 | print0 = tf.print("*** lang", features["raw_task"], summarize=100) 116 | with tf.control_dependencies([print0]): 117 | loss_dict["total_loss"] = tf.identity(loss_dict["total_loss"]) 118 | else: 119 | decode_features = None 120 | if "ref_accuracy" in metric_types: 121 | with tf.control_dependencies([ 122 | tf.assert_equal(tf.rank(features["verb_refs"]), 3), 123 | tf.assert_equal(tf.shape(features["verb_refs"])[-1], 2)]): 124 | _ref_accuracy(features, pred_dict, 125 | tf.less(features["verb_refs"][:, :, 0], 126 | features["verb_refs"][:, :, 1]), 127 | "verb_refs", metrics, decode_features, 128 | measure_beginning_eos=True) 129 | _ref_accuracy(features, pred_dict, 130 | tf.less(features["obj_refs"][:, :, 0], 131 | features["obj_refs"][:, :, 1]), 132 | "obj_refs", metrics, decode_features, 133 | measure_beginning_eos=True) 134 | _ref_accuracy(features, pred_dict, 135 | tf.less(features["input_refs"][:, :, 0], 136 | features["input_refs"][:, :, 1]), 137 | "input_refs", metrics, decode_features, 138 | measure_beginning_eos=True) 139 | if "basic_accuracy" in metric_types: 140 | target_verbs = tf.reshape(features["verbs"], [-1]) 141 | verb_nonpadding = tf.greater(target_verbs, 1) 142 | target_verbs = tf.boolean_mask(target_verbs, verb_nonpadding) 143 | predict_verbs = tf.boolean_mask(tf.reshape(pred_dict["verbs"], [-1]), 144 | verb_nonpadding) 145 | metrics["verb"] = tf.metrics.accuracy( 146 | labels=target_verbs, 147 | predictions=predict_verbs, 148 | name="verb_accuracy") 149 | input_mask = tf.reshape( 150 | tf.less(features["verb_refs"][:, :, 0], 151 | features["verb_refs"][:, :, 1]), [-1]) 152 | metrics["input"] = tf.metrics.accuracy( 153 | labels=tf.boolean_mask( 154 | tf.reshape(tf.to_int32( 155 | tf.less(features["input_refs"][:, :, 0], 156 | features["input_refs"][:, :, 1])), [-1]), input_mask), 157 | predictions=tf.boolean_mask( 158 | tf.reshape(pred_dict["input"], [-1]), input_mask), 159 | name="input_accuracy") 160 | metrics["object"] = tf.metrics.accuracy( 161 | labels=tf.boolean_mask(tf.reshape(features["objects"], [-1]), 162 | verb_nonpadding), 163 | predictions=tf.boolean_mask(tf.reshape(pred_dict["objects"], [-1]), 164 | verb_nonpadding), 165 | name="object_accuracy") 166 | metrics["eval_object_loss"] = tf.metrics.mean( 167 | tf.reduce_mean( 168 | tf.boolean_mask(tf.reshape(loss_dict["object_losses"], [-1]), 169 | verb_nonpadding))) 170 | metrics["eval_verb_loss"] = tf.metrics.mean( 171 | tf.reduce_mean( 172 | tf.boolean_mask(tf.reshape(loss_dict["verbs_losses"], [-1]), 173 | verb_nonpadding))) 174 | 175 | 176 | def decode_sequence(features, areas, hparams, decode_length, 177 | post_processing=True): 178 | """Decodes the entire sequence in an auto-regressive way.""" 179 | decode_utils.decode_n_step(seq2act_model.compute_logits, 180 | features, areas, 181 | hparams, n=decode_length, beam_size=1) 182 | if post_processing: 183 | features["input_refs"] = decode_utils.unify_input_ref( 184 | features["verbs"], features["input_refs"]) 185 | pred_lengths = decode_utils.verb_refs_to_lengths(features["task"], 186 | features["verb_refs"], 187 | include_eos=False) 188 | predicted_actions = tf.concat([ 189 | features["verb_refs"], 190 | features["obj_refs"], 191 | features["input_refs"], 192 | tf.to_int32(tf.expand_dims(features["verbs"], 2)), 193 | tf.to_int32(tf.expand_dims(features["objects"], 2))], axis=-1) 194 | if post_processing: 195 | predicted_actions = tf.where( 196 | tf.tile(tf.expand_dims( 197 | tf.sequence_mask(pred_lengths, 198 | maxlen=tf.shape(predicted_actions)[1]), 199 | 2), [1, 1, tf.shape(predicted_actions)[-1]]), predicted_actions, 200 | tf.zeros_like(predicted_actions)) 201 | return predicted_actions 202 | 203 | 204 | def create_model_fn(hparams, compute_additional_loss_fn=None, 205 | compute_additional_metric_fn=None, 206 | compute_seq_accuracy=False, 207 | decode_length=20): 208 | """Creates the model function. 209 | 210 | Args: 211 | hparams: the hyper parameters. 212 | compute_additional_loss_fn: the optional callback for calculating 213 | additional loss. 214 | compute_additional_metric_fn: the optional callback for computing 215 | additional metrics. 216 | compute_seq_accuracy: whether to compute seq accuracy. 217 | decode_length: the maximum decoding length. 218 | Returns: 219 | the model function for estimator. 220 | """ 221 | def model_fn(features, labels, mode): 222 | """The model function for creating an Estimtator.""" 223 | del labels 224 | input_count = tf.reduce_sum( 225 | tf.to_int32(tf.greater(features["input_refs"][:, :, 1], 226 | features["input_refs"][:, :, 0]))) 227 | tf.summary.scalar("input_count", input_count) 228 | loss_dict, pred_dict, areas = seq2act_model.core_graph( 229 | features, hparams, mode, compute_additional_loss_fn) 230 | if mode == tf.estimator.ModeKeys.PREDICT: 231 | pred_dict["sequences"] = decode_sequence( 232 | features, areas, hparams, decode_length, 233 | post_processing=FLAGS.post_processing) 234 | return tf.estimator.EstimatorSpec(mode, predictions=pred_dict) 235 | elif mode == tf.estimator.ModeKeys.EVAL: 236 | metrics = {} 237 | _eval(metrics, pred_dict, loss_dict, features, 238 | areas, compute_seq_accuracy, 239 | hparams, 240 | metric_types=FLAGS.metric_types.split(","), 241 | decode_length=decode_length) 242 | if compute_additional_metric_fn: 243 | compute_additional_metric_fn(metrics, pred_dict, features) 244 | return tf.estimator.EstimatorSpec( 245 | mode, loss=loss_dict["total_loss"], eval_metric_ops=metrics) 246 | else: 247 | assert mode == tf.estimator.ModeKeys.TRAIN 248 | loss = loss_dict["total_loss"] 249 | for loss_name in loss_dict: 250 | if loss_name == "total_loss": 251 | continue 252 | if loss_name.endswith("losses"): 253 | continue 254 | tf.summary.scalar(loss_name, loss_dict[loss_name]) 255 | step_num = tf.to_float(tf.train.get_global_step()) 256 | schedule_string = hparams.learning_rate_schedule 257 | names = schedule_string.split("*") 258 | names = [name.strip() for name in names if name.strip()] 259 | ret = tf.constant(1.0) 260 | for name in names: 261 | ret *= learning_rate.learning_rate_factor(name, step_num, hparams) 262 | train_op = optimize.optimize(loss, ret, hparams) 263 | return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op) 264 | return model_fn 265 | 266 | 267 | def create_input_fn(files, 268 | batch_size, 269 | repeat, 270 | required_agreement, 271 | data_source, 272 | max_range, 273 | max_dom_pos, 274 | max_pixel_pos, 275 | mean_synthetic_length, 276 | stddev_synthetic_length, 277 | load_extra=False, 278 | load_screen=True, 279 | buffer_size=8 * 1024, 280 | shuffle_size=8 * 1024, 281 | load_dom_dist=False): 282 | """Creats the input function.""" 283 | def input_fn(): 284 | return input_utils.input_fn(data_files=files, 285 | batch_size=batch_size, 286 | repeat=repeat, 287 | required_agreement=required_agreement, 288 | data_source=data_source, 289 | max_range=max_range, 290 | max_dom_pos=max_dom_pos, 291 | max_pixel_pos=max_pixel_pos, 292 | mean_synthetic_length=mean_synthetic_length, 293 | stddev_synthetic_length=stddev_synthetic_length, 294 | load_extra=load_extra, 295 | buffer_size=buffer_size, 296 | load_screen=load_screen, 297 | shuffle_size=shuffle_size, 298 | load_dom_dist=load_dom_dist) 299 | return input_fn 300 | 301 | 302 | def create_hybrid_input_fn(data_files_list, 303 | data_source_list, 304 | batch_size_list, 305 | max_range, 306 | max_dom_pos, 307 | max_pixel_pos, 308 | mean_synthetic_length, 309 | stddev_synthetic_length, 310 | batch_size, 311 | boost_input=False, 312 | load_screen=True, 313 | buffer_size=1024 * 8, 314 | shuffle_size=1024, 315 | load_dom_dist=False): 316 | """Creats the input function.""" 317 | def input_fn(): 318 | return input_utils.hybrid_input_fn( 319 | data_files_list, 320 | data_source_list, 321 | batch_size_list, 322 | max_range=max_range, 323 | max_dom_pos=max_dom_pos, 324 | max_pixel_pos=max_pixel_pos, 325 | mean_synthetic_length=mean_synthetic_length, 326 | stddev_synthetic_length=stddev_synthetic_length, 327 | hybrid_batch_size=batch_size, 328 | boost_input=boost_input, 329 | load_screen=load_screen, 330 | buffer_size=buffer_size, 331 | shuffle_size=shuffle_size, 332 | load_dom_dist=load_dom_dist) 333 | return input_fn 334 | 335 | 336 | def create_hparams(): 337 | """Creates hyper parameters.""" 338 | hparams = getattr(transformer, "transformer_base")() 339 | hparams.add_hparam("reference_warmup_steps", 0) 340 | hparams.add_hparam("max_span", 20) 341 | hparams.add_hparam("task_vocab_size", 59429) 342 | hparams.add_hparam("load_screen", True) 343 | 344 | hparams.set_hparam("hidden_size", 16) 345 | hparams.set_hparam("num_hidden_layers", 2) 346 | hparams.add_hparam("freeze_reference_model", False) 347 | hparams.add_hparam("mean_synthetic_length", 1.0) 348 | hparams.add_hparam("stddev_synthetic_length", .0) 349 | hparams.add_hparam("instruction_encoder", "transformer") 350 | hparams.add_hparam("instruction_decoder", "transformer") 351 | hparams.add_hparam("clip_norm", 0.) 352 | hparams.add_hparam("span_rep", "area") 353 | hparams.add_hparam("dom_dist_variance", 1.0) 354 | 355 | hparams.add_hparam("attention_mechanism", "luong") # "bahdanau" 356 | hparams.add_hparam("output_attention", True) 357 | hparams.add_hparam("attention_layer_size", 128) 358 | 359 | # GAN-related hyper params 360 | hparams.add_hparam("dis_loss_ratio", 0.01) 361 | hparams.add_hparam("gen_loss_ratio", 0.01) 362 | hparams.add_hparam("gan_update", "center") 363 | hparams.add_hparam("num_joint_layers", 2) 364 | hparams.add_hparam("use_additional_loss", False) 365 | 366 | hparams.add_hparam("compute_verb_obj_separately", True) 367 | hparams.add_hparam("synthetic_screen_noise", 0.) 368 | hparams.add_hparam("screen_encoder", "mlp") 369 | hparams.add_hparam("screen_encoder_layers", 2) 370 | hparams.add_hparam("action_vocab_size", 6) 371 | hparams.add_hparam("max_pixel_pos", 100) 372 | hparams.add_hparam("max_dom_pos", 500) 373 | hparams.add_hparam("span_aggregation", "sum") 374 | hparams.add_hparam("obj_text_aggregation", "sum") 375 | hparams.add_hparam("screen_embedding_feature", "text_pos_type") 376 | hparams.add_hparam("alignment", "dot_product_attention") 377 | hparams.parse(FLAGS.hparams) 378 | 379 | hparams.set_hparam("use_target_space_embedding", False) 380 | hparams.set_hparam("filter_size", hparams.hidden_size * 4) 381 | hparams.set_hparam("attention_layer_size", hparams.hidden_size) 382 | hparams.set_hparam("dropout", hparams.layer_prepostprocess_dropout) 383 | return hparams 384 | 385 | 386 | def save_hyperparams(hparams, output_dir): 387 | """Save the model hyperparameters.""" 388 | if not tf.gfile.Exists(output_dir): 389 | tf.gfile.MakeDirs(output_dir) 390 | if not tf.gfile.Exists(os.path.join(output_dir, "hparams.json")): 391 | with tf.gfile.GFile( 392 | os.path.join(output_dir, "hparams.json"), mode="w") as f: 393 | f.write(hparams.to_json()) 394 | 395 | 396 | def load_hparams(checkpoint_path): 397 | """Prepares the hyper-parameters.""" 398 | hparams = create_hparams() 399 | with tf.gfile.Open(os.path.join(checkpoint_path, "hparams.json"), 400 | "r") as hparams_file: 401 | hparams_string = " ".join(hparams_file.readlines()) 402 | hparams.parse_json(hparams_string) 403 | tf.logging.info("hparams: %s" % hparams) 404 | return hparams 405 | 406 | 407 | def _pad_to_max(x, y, constant_values=0): 408 | """Pad x and y to their maximum shape.""" 409 | shape_x = common_layers.shape_list(x) 410 | shape_y = common_layers.shape_list(y) 411 | assert len(shape_x) == len(shape_y) 412 | pad_x = [[0, 0]] 413 | pad_y = [[0, 0]] 414 | for dim in range(len(shape_x) - 1): 415 | add_y = shape_x[dim + 1] - shape_y[dim + 1] 416 | add_x = -add_y 417 | pad_x.append([0, tf.maximum(add_x, 0)]) 418 | pad_y.append([0, tf.maximum(add_y, 0)]) 419 | x = tf.pad(x, pad_x, constant_values=constant_values) 420 | y = tf.pad(y, pad_y, constant_values=constant_values) 421 | return x, y 422 | -------------------------------------------------------------------------------- /models/seq2act_grounding.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """The grounding models.""" 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | from tensor2tensor.layers import common_layers 21 | import tensorflow.compat.v1 as tf 22 | from seq2act.layers import area_utils 23 | from seq2act.layers import common_embed 24 | from seq2act.layers import encode_screen 25 | from seq2act.models import seq2act_reference 26 | 27 | 28 | def encode_screen_ffn(features, hparams, embed_scope): 29 | """Encodes a screen with feed forward neural network. 30 | 31 | Args: 32 | features: the feature dict. 33 | hparams: the hyperparameter. 34 | embed_scope: the name scope. 35 | Returns: 36 | encoder_outputs: a Tensor of shape 37 | [batch_size, num_steps, max_object_count, hidden_size] 38 | obj_mask: A tensor of shape 39 | [batch_size, num_steps, max_object_count] 40 | """ 41 | object_embed, obj_mask, obj_bias = encode_screen.prepare_encoder_input( 42 | features=features, hparams=hparams, 43 | embed_scope=embed_scope) 44 | for layer in range(hparams.num_hidden_layers): 45 | with tf.variable_scope( 46 | "encode_screen_ff_layer_%d" % layer, reuse=tf.AUTO_REUSE): 47 | object_embed = tf.layers.dense(object_embed, units=hparams.hidden_size) 48 | object_embed = common_layers.apply_norm( 49 | object_embed, hparams.norm_type, hparams.hidden_size, 50 | epsilon=hparams.norm_epsilon) 51 | object_embed = tf.nn.relu(object_embed) 52 | object_embed = tf.nn.dropout( 53 | object_embed, 54 | keep_prob=1.0 - hparams.layer_prepostprocess_dropout) 55 | object_embed = object_embed * tf.expand_dims(obj_mask, 3) 56 | return object_embed, obj_bias 57 | 58 | 59 | def compute_logits(features, references, hparams): 60 | """Grounds using the predicted references. 61 | 62 | Args: 63 | features: the feature dict. 64 | references: the dict that keeps the reference results. 65 | hparams: the hyper-parameters. 66 | Returns: 67 | action_logits: [batch_size, num_steps, num_actions] 68 | object_logits: [batch_size, num_steps, max_num_objects] 69 | """ 70 | lang_hidden_layers = hparams.num_hidden_layers 71 | pos_embed = hparams.pos 72 | hparams.set_hparam("num_hidden_layers", hparams.screen_encoder_layers) 73 | hparams.set_hparam("pos", "none") 74 | with tf.variable_scope("compute_grounding_logits", reuse=tf.AUTO_REUSE): 75 | # Encode objects 76 | if hparams.screen_encoder == "gcn": 77 | screen_encoding, _, screen_encoding_bias = ( 78 | encode_screen.gcn_encoder( 79 | features, hparams, references["embed_scope"], 80 | discretize=False)) 81 | elif hparams.screen_encoder == "transformer": 82 | screen_encoding, _, screen_encoding_bias = ( 83 | encode_screen.transformer_encoder( 84 | features, hparams, references["embed_scope"])) 85 | elif hparams.screen_encoder == "mlp": 86 | screen_encoding, screen_encoding_bias = encode_screen_ffn( 87 | features, hparams, references["embed_scope"]) 88 | else: 89 | raise ValueError( 90 | "Unsupported encoder: %s" % hparams.screen_encoder) 91 | # Compute query 92 | if hparams.compute_verb_obj_separately: 93 | verb_hidden, object_hidden = _compute_query_embedding( 94 | features, references, hparams, references["embed_scope"]) 95 | else: 96 | verb_hidden = references["verb_hidden"] 97 | object_hidden = references["object_hidden"] 98 | # Predict actions 99 | with tf.variable_scope("compute_action_logits", reuse=tf.AUTO_REUSE): 100 | action_logits = tf.layers.dense( 101 | verb_hidden, units=hparams.action_vocab_size) 102 | # Predict objects 103 | obj_logits, consumed_logits = _compute_object_logits( 104 | hparams, 105 | object_hidden, 106 | screen_encoding, 107 | screen_encoding_bias) 108 | hparams.set_hparam("num_hidden_layers", lang_hidden_layers) 109 | hparams.set_hparam("pos", pos_embed) 110 | return action_logits, obj_logits, consumed_logits 111 | 112 | 113 | def _compute_object_logits(hparams, object_hidden, 114 | screen_encoding, screen_encoding_bias): 115 | """The output layer for a specific domain.""" 116 | with tf.variable_scope("compute_object_logits", reuse=tf.AUTO_REUSE): 117 | if hparams.alignment == "cosine_similarity": 118 | object_hidden = tf.layers.dense( 119 | object_hidden, units=hparams.hidden_size) 120 | screen_encoding = tf.layers.dense( 121 | screen_encoding, units=hparams.hidden_size) 122 | norm_screen_encoding = tf.math.l2_normalize(screen_encoding, axis=-1) 123 | norm_obj_hidden = tf.math.l2_normalize(object_hidden, axis=-1) 124 | align_logits = tf.matmul(norm_screen_encoding, 125 | tf.expand_dims(norm_obj_hidden, 3)) 126 | elif hparams.alignment == "scaled_cosine_similarity": 127 | object_hidden = tf.layers.dense( 128 | object_hidden, units=hparams.hidden_size) 129 | screen_encoding = tf.reshape( 130 | screen_encoding, 131 | common_layers.shape_list( 132 | screen_encoding)[:-1] + [hparams.hidden_size]) 133 | screen_encoding = tf.layers.dense( 134 | screen_encoding, units=hparams.hidden_size) 135 | norm_screen_encoding = tf.math.l2_normalize(screen_encoding, axis=-1) 136 | norm_obj_hidden = tf.math.l2_normalize(object_hidden, axis=-1) 137 | dot_products = tf.matmul(norm_screen_encoding, 138 | tf.expand_dims(norm_obj_hidden, 3)) 139 | align_logits = tf.layers.dense(dot_products, units=1) 140 | elif hparams.alignment == "dot_product_attention": 141 | object_hidden = tf.layers.dense( 142 | object_hidden, units=hparams.hidden_size) 143 | align_logits = tf.matmul(screen_encoding, 144 | tf.expand_dims(object_hidden, 3)) 145 | elif hparams.alignment == "mlp_attention": 146 | batch_size = tf.shape(screen_encoding)[0] 147 | num_steps = tf.shape(screen_encoding)[1] 148 | num_objects = tf.shape(screen_encoding)[2] 149 | tiled_object_hidden = tf.tile(tf.expand_dims(object_hidden, 2), 150 | [1, 1, num_objects, 1]) 151 | align_feature = tf.concat([tiled_object_hidden, screen_encoding], axis=-1) 152 | align_feature = tf.reshape( 153 | align_feature, 154 | [batch_size, num_steps, num_objects, hparams.hidden_size * 2]) 155 | with tf.variable_scope("align", reuse=tf.AUTO_REUSE): 156 | align_hidden = tf.layers.dense(align_feature, units=hparams.hidden_size) 157 | align_hidden = common_layers.apply_norm( 158 | align_hidden, hparams.norm_type, hparams.hidden_size, 159 | epsilon=hparams.norm_epsilon) 160 | align_hidden = tf.nn.tanh(align_hidden) 161 | align_logits = tf.layers.dense(align_hidden, units=1) 162 | else: 163 | raise ValueError("Unsupported alignment: %s" % hparams.alignment) 164 | 165 | obj_logits = tf.squeeze(align_logits, [3]) + screen_encoding_bias 166 | # [batch_size, num_steps] 167 | batch_size = common_layers.shape_list(obj_logits)[0] 168 | num_steps = common_layers.shape_list(obj_logits)[1] 169 | # [batch_size * num_steps, 1] 170 | batch_indices = tf.to_int64(tf.reshape( 171 | tf.tile(tf.expand_dims(tf.range(batch_size), 1), [1, num_steps]), 172 | [-1, 1])) 173 | step_indices = tf.to_int64(tf.reshape( 174 | tf.tile(tf.expand_dims(tf.range(num_steps), 0), [batch_size, 1]), 175 | [-1, 1])) 176 | object_indices = tf.reshape(tf.argmax(obj_logits, -1), [-1, 1]) 177 | indices = tf.concat([batch_indices, step_indices, object_indices], -1) 178 | # [batch_size, num_steps, depth] 179 | depth = tf.shape(screen_encoding)[-1] 180 | best_logits = tf.reshape( 181 | tf.gather_nd(screen_encoding, indices=indices), 182 | [batch_size, num_steps, depth]) 183 | consumed_logits = tf.layers.dense( 184 | tf.reshape(tf.concat([object_hidden, best_logits], -1), 185 | [batch_size, num_steps, hparams.hidden_size * 2]), 186 | 2) 187 | with tf.control_dependencies([tf.assert_equal( 188 | tf.reduce_all(tf.math.is_nan(consumed_logits)), False, 189 | data=[tf.shape(best_logits), best_logits, 190 | tf.constant("screen_encoding"), screen_encoding, 191 | tf.constant("indices"), indices], 192 | summarize=10000, message="consumed_logits_nan")]): 193 | consumed_logits = tf.identity(consumed_logits) 194 | return obj_logits, consumed_logits 195 | 196 | 197 | def _compute_query_embedding(features, references, hparams, embed_scope=None): 198 | """Computes lang embeds for verb and object from predictions. 199 | 200 | Args: 201 | features: a dictionary contains "inputs" that is a tensor in shape of 202 | [batch_size, num_tokens], "verb_id_seq" that is in shape of 203 | [batch_size, num_actions], "object_spans" and "param_span" tensor 204 | in shape of [batch_size, num_actions, 2]. 0 is used as padding or 205 | non-existent values. 206 | references: the dict that keeps the reference results. 207 | hparams: the general hyperparameters for the model. 208 | embed_scope: the embedding variable scope. 209 | Returns: 210 | verb_embeds: a Tensor of shape 211 | [batch_size, num_steps, depth] 212 | object_embeds: 213 | [batch_size, num_steps, depth] 214 | """ 215 | pred_verb_refs = seq2act_reference.predict_refs( 216 | references["verb_area_logits"], 217 | references["areas"]["starts"], references["areas"]["ends"]) 218 | pred_obj_refs = seq2act_reference.predict_refs( 219 | references["obj_area_logits"], 220 | references["areas"]["starts"], references["areas"]["ends"]) 221 | input_embeddings, _ = common_embed.embed_tokens( 222 | features["task"], hparams.task_vocab_size, hparams.hidden_size, hparams, 223 | embed_scope=references["embed_scope"]) 224 | if hparams.obj_text_aggregation == "sum": 225 | area_encodings, _, _ = area_utils.compute_sum_image( 226 | input_embeddings, max_area_width=hparams.max_span) 227 | shape = common_layers.shape_list(features["task"]) 228 | encoder_input_length = shape[1] 229 | verb_embeds = seq2act_reference.span_embedding( 230 | encoder_input_length, area_encodings, pred_verb_refs, hparams) 231 | object_embeds = seq2act_reference.span_embedding( 232 | encoder_input_length, area_encodings, pred_obj_refs, hparams) 233 | elif hparams.obj_text_aggregation == "mean": 234 | verb_embeds = seq2act_reference.span_average_embed( 235 | input_embeddings, pred_verb_refs, embed_scope, hparams) 236 | object_embeds = seq2act_reference.span_average_embed( 237 | input_embeddings, pred_obj_refs, embed_scope, hparams) 238 | else: 239 | raise ValueError("Unrecognized query aggreggation %s" % ( 240 | hparams.span_aggregation)) 241 | return verb_embeds, object_embeds 242 | 243 | 244 | def compute_losses(loss_dict, features, action_logits, obj_logits, 245 | consumed_logits): 246 | """Compute the loss based on the logits and labels.""" 247 | valid_obj_mask = tf.to_float(tf.greater(features["verbs"], 1)) 248 | action_losses = tf.losses.sparse_softmax_cross_entropy( 249 | labels=features["verbs"], 250 | logits=action_logits, 251 | reduction=tf.losses.Reduction.NONE) * valid_obj_mask 252 | action_loss = tf.reduce_mean(action_losses) 253 | object_losses = tf.losses.sparse_softmax_cross_entropy( 254 | labels=features["objects"], 255 | logits=obj_logits, 256 | reduction=tf.losses.Reduction.NONE) * valid_obj_mask 257 | object_loss = tf.reduce_mean(object_losses) 258 | if "consumed" in features: 259 | consumed_loss = tf.reduce_mean( 260 | tf.losses.sparse_softmax_cross_entropy( 261 | labels=features["consumed"], 262 | logits=consumed_logits, 263 | reduction=tf.losses.Reduction.NONE) * valid_obj_mask) 264 | else: 265 | consumed_loss = 0.0 266 | loss_dict["grounding_loss"] = action_loss + object_loss + consumed_loss 267 | loss_dict["verbs_loss"] = action_loss 268 | loss_dict["objects_loss"] = object_loss 269 | loss_dict["verbs_losses"] = action_losses 270 | loss_dict["object_losses"] = object_losses 271 | loss_dict["consumed_loss"] = consumed_loss 272 | return loss_dict["grounding_loss"] 273 | 274 | 275 | def compute_predictions(prediction_dict, action_logits, obj_logits, 276 | consumed_logits): 277 | """Predict the action tuple based on the logits.""" 278 | prediction_dict["verbs"] = tf.argmax(action_logits, -1) 279 | prediction_dict["objects"] = tf.argmax(obj_logits, -1) 280 | prediction_dict["consumed"] = tf.argmax(consumed_logits, -1) 281 | -------------------------------------------------------------------------------- /models/seq2act_model.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Seq2act model.""" 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | import tensorflow.compat.v1 as tf 21 | from seq2act.models import seq2act_grounding 22 | from seq2act.models import seq2act_reference 23 | 24 | 25 | def compute_logits(features, hparams, mode, 26 | use_cache=None, cache=None): 27 | """Computes the logits.""" 28 | if mode != tf.estimator.ModeKeys.TRAIN: 29 | for key in hparams.values(): 30 | if key.endswith("dropout"): 31 | setattr(hparams, key, 0.0) 32 | setattr(hparams, "synthetic_screen_noise", 0.0) 33 | tf.logging.info(hparams) 34 | references = seq2act_reference.compute_logits( 35 | features, hparams, 36 | train=(mode == tf.estimator.ModeKeys.TRAIN)) 37 | if use_cache is not None and cache is not None: 38 | for key in cache: 39 | references[key] = tf.where( 40 | tf.equal(use_cache, 1), tf.concat([ 41 | cache[key], cache[key][:, -1:, :]], axis=1), references[key]) 42 | action_logits, obj_logits, consumed_logits = seq2act_grounding.compute_logits( 43 | features, references, hparams) 44 | return action_logits, obj_logits, consumed_logits, references 45 | 46 | 47 | def compute_loss(loss_dict, features, action_logits, obj_logits, 48 | consumed_logits, references, hparams): 49 | """Computes the loss.""" 50 | total_loss = seq2act_reference.compute_losses( 51 | loss_dict, features, references, hparams) 52 | grounding_loss = seq2act_grounding.compute_losses( 53 | loss_dict, features, action_logits, obj_logits, consumed_logits) 54 | global_step = tf.train.get_global_step() 55 | if global_step: 56 | total_loss += tf.cond( 57 | tf.greater(global_step, hparams.reference_warmup_steps), 58 | lambda: grounding_loss, 59 | lambda: tf.constant(0.)) 60 | else: 61 | total_loss += grounding_loss 62 | loss_dict["total_loss"] = total_loss 63 | 64 | 65 | def predict(prediction_dict, action_logits, obj_logits, consumed_logits, 66 | references): 67 | """Compute predictions.""" 68 | seq2act_reference.compute_predictions(prediction_dict, references) 69 | seq2act_grounding.compute_predictions(prediction_dict, 70 | action_logits, obj_logits, 71 | consumed_logits) 72 | 73 | 74 | def core_graph(features, hparams, mode, 75 | compute_additional_loss=None): 76 | """The core TF graph for the estimator.""" 77 | action_logits, obj_logits, consumed_logits, references = ( 78 | compute_logits(features, hparams, mode)) 79 | prediction_dict = {} 80 | loss_dict = {} 81 | if mode != tf.estimator.ModeKeys.PREDICT: 82 | compute_loss(loss_dict, features, 83 | action_logits, obj_logits, consumed_logits, references, 84 | hparams) 85 | if compute_additional_loss: 86 | compute_additional_loss(hparams, features, references["decoder_output"], 87 | loss_dict, prediction_dict, mode) 88 | if mode != tf.estimator.ModeKeys.TRAIN: 89 | if mode == tf.estimator.ModeKeys.PREDICT: 90 | prediction_dict["task"] = features["task"] 91 | prediction_dict["raw_task"] = features["raw_task"] 92 | prediction_dict["data_source"] = features["data_source"] 93 | predict(prediction_dict, action_logits, obj_logits, consumed_logits, 94 | references) 95 | return loss_dict, prediction_dict, references["areas"] 96 | -------------------------------------------------------------------------------- /models/seq2act_reference.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """The reference models.""" 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | import copy 21 | from tensor2tensor.layers import common_attention 22 | from tensor2tensor.layers import common_layers 23 | from tensor2tensor.models import transformer 24 | import tensorflow.compat.v1 as tf 25 | from seq2act.layers import area_utils 26 | from seq2act.layers import common_embed 27 | 28 | 29 | def span_embedding(encoder_input_length, area_encodings, spans, hparams): 30 | """Computes the embedding for each span. (TODO: liyang): comment shapes.""" 31 | with tf.control_dependencies([tf.assert_equal(tf.rank(area_encodings), 3)]): 32 | area_indices = area_utils.area_range_to_index( 33 | area_range=tf.reshape(spans, [-1, 2]), length=encoder_input_length, 34 | max_area_width=hparams.max_span) 35 | return area_utils.batch_gather( 36 | area_encodings, tf.reshape(area_indices, 37 | [tf.shape(spans)[0], tf.shape(spans)[1]])) 38 | 39 | 40 | def span_average_embed(area_encodings, spans, embed_scope): 41 | """Embeds a span of tokens using averaging. 42 | 43 | Args: 44 | area_encodings: [batch_size, length, depth]. 45 | spans: [batch_size, ref_len, 2]. 46 | embed_scope: the variable scope for embedding. 47 | Returns: 48 | the average embeddings in the shape of [batch_size, ref_lengths, depth]. 49 | """ 50 | ref_len = common_layers.shape_list(spans)[1] 51 | starts_ends = tf.reshape(spans, [-1, 2]) 52 | depth = common_layers.shape_list(area_encodings)[-1] 53 | length = common_layers.shape_list(area_encodings)[1] 54 | area_encodings = tf.reshape( 55 | tf.tile(tf.expand_dims(area_encodings, 1), [1, ref_len, 1, 1]), 56 | [-1, length, depth]) 57 | area_ranges = starts_ends[:, 1] - starts_ends[:, 0] 58 | max_num_tokens = tf.reduce_max(area_ranges) 59 | def _fetch_embeddings(area_encoding_and_range): 60 | """Fetches embeddings for the range.""" 61 | area_encoding = area_encoding_and_range[0] 62 | area_range = area_encoding_and_range[1] 63 | start = area_range[0] 64 | end = area_range[1] 65 | embeddings = area_encoding[start:end, :] 66 | em_len = area_range[1] - area_range[0] 67 | embeddings = tf.pad(embeddings, [[0, max_num_tokens - em_len], [0, 0]], 68 | constant_values=0.0) 69 | return embeddings 70 | # [batch_size * ref_len, max_num_tokens, depth] 71 | area_embeddings = tf.map_fn(_fetch_embeddings, 72 | [area_encodings, starts_ends], 73 | dtype=tf.float32, infer_shape=False) 74 | # To give a fixed dimension 75 | area_embeddings = tf.reshape(area_embeddings, 76 | [-1, max_num_tokens, depth]) 77 | emb_sum = tf.reduce_sum(tf.abs(area_embeddings), axis=-1) 78 | non_paddings = tf.not_equal(emb_sum, 0.0) 79 | # [batch_size * ref_len, depth] 80 | area_embeddings = common_embed.average_bag_of_embeds( 81 | area_embeddings, non_paddings, use_bigrams=True, 82 | bigram_embed_scope=embed_scope, append_start_end=True) 83 | area_embeddings = tf.reshape(area_embeddings, [-1, ref_len, depth]) 84 | return area_embeddings 85 | 86 | 87 | def _prepare_decoder_input(area_encoding, decoder_nonpadding, 88 | features, hparams, 89 | embed_scope=None): 90 | """Prepare the input for the action decoding. 91 | 92 | Args: 93 | area_encoding: the encoder output in shape of [batch_size, area_len, depth]. 94 | decoder_nonpadding: the nonpadding mask for the decoding seq. 95 | features: a dictionary of tensors in the shape of [batch_size, seq_length]. 96 | hparams: the hyperparameters. 97 | embed_scope: the embedding scope. 98 | Returns: 99 | decoder_input: decoder input in shape of 100 | [batch_size, num_steps, latent_depth] 101 | decoder_self_attention_bias: decoder attention bias. 102 | """ 103 | with tf.variable_scope("prepare_decoder_input", reuse=tf.AUTO_REUSE): 104 | shape = common_layers.shape_list(features["task"]) 105 | batch_size = shape[0] 106 | encoder_input_length = shape[1] 107 | depth = common_layers.shape_list(area_encoding)[-1] 108 | if hparams.span_aggregation == "sum": 109 | verb_embeds = span_embedding(encoder_input_length, 110 | area_encoding, features["verb_refs"], 111 | hparams) 112 | object_embeds = span_embedding(encoder_input_length, 113 | area_encoding, features["obj_refs"], 114 | hparams) 115 | input_embeds = span_embedding(encoder_input_length, 116 | area_encoding, features["input_refs"], 117 | hparams) 118 | non_input_embeds = tf.tile(tf.expand_dims(tf.expand_dims( 119 | tf.get_variable(name="non_input_embeds", 120 | shape=[depth]), 121 | 0), 0), [batch_size, tf.shape(features["input_refs"])[1], 1]) 122 | input_embeds = tf.where( 123 | tf.tile( 124 | tf.expand_dims(tf.equal(features["input_refs"][:, :, 1], 125 | features["input_refs"][:, :, 0]), 2), 126 | [1, 1, tf.shape(input_embeds)[-1]]), 127 | non_input_embeds, 128 | input_embeds) 129 | elif hparams.span_aggregation == "mean": 130 | area_encoding = area_encoding[:, :encoder_input_length, :] 131 | verb_embeds = span_average_embed(area_encoding, features["verb_refs"], 132 | embed_scope) 133 | object_embeds = span_average_embed(area_encoding, features["obj_refs"], 134 | embed_scope) 135 | input_embeds = span_average_embed(area_encoding, features["input_refs"], 136 | embed_scope) 137 | else: 138 | raise ValueError("Unrecognized span aggregation method %s" % ( 139 | hparams.span_aggregation)) 140 | embeds = verb_embeds + object_embeds + input_embeds 141 | embeds = tf.multiply( 142 | tf.expand_dims(decoder_nonpadding, 2), embeds) 143 | start_embed = tf.tile(tf.expand_dims(tf.expand_dims( 144 | tf.get_variable(name="start_step_embed", 145 | shape=[depth]), 0), 0), 146 | [batch_size, 1, 1]) 147 | embeds = tf.concat([start_embed, embeds], axis=1) 148 | embeds = embeds[:, :-1, :] 149 | decoder_self_attention_bias = ( 150 | common_attention.attention_bias_lower_triangle( 151 | common_layers.shape_list(features["verb_refs"])[1])) 152 | if hparams.pos == "timing": 153 | decoder_input = common_attention.add_timing_signal_1d(embeds) 154 | elif hparams.pos == "emb": 155 | decoder_input = common_attention.add_positional_embedding( 156 | embeds, hparams.max_length, "targets_positional_embedding", 157 | None) 158 | else: 159 | decoder_input = embeds 160 | return decoder_input, decoder_self_attention_bias 161 | 162 | 163 | def encode_decode_task(features, hparams, train, attention_weights=None): 164 | """Model core graph for the one-shot action. 165 | 166 | Args: 167 | features: a dictionary contains "inputs" that is a tensor in shape of 168 | [batch_size, num_tokens], "verb_id_seq" that is in shape of 169 | [batch_size, num_actions], "object_spans" and "param_span" tensor 170 | in shape of [batch_size, num_actions, 2]. 0 is used as padding or 171 | non-existent values. 172 | hparams: the general hyperparameters for the model. 173 | train: the train mode. 174 | attention_weights: the dict to keep attention weights for analysis. 175 | Returns: 176 | loss_dict: the losses for training. 177 | prediction_dict: the predictions for action tuples. 178 | areas: the area encodings of the task. 179 | scope: the embedding scope. 180 | """ 181 | del train 182 | input_embeddings, scope = common_embed.embed_tokens( 183 | features["task"], 184 | hparams.task_vocab_size, 185 | hparams.hidden_size, hparams) 186 | with tf.variable_scope("encode_decode", reuse=tf.AUTO_REUSE): 187 | encoder_nonpadding = tf.minimum(tf.to_float(features["task"]), 1.0) 188 | input_embeddings = tf.multiply( 189 | tf.expand_dims(encoder_nonpadding, 2), 190 | input_embeddings) 191 | encoder_input, self_attention_bias, encoder_decoder_attention_bias = ( 192 | transformer.transformer_prepare_encoder( 193 | input_embeddings, None, hparams, features=None)) 194 | encoder_input = tf.nn.dropout( 195 | encoder_input, 196 | keep_prob=1.0 - hparams.layer_prepostprocess_dropout) 197 | if hparams.instruction_encoder == "transformer": 198 | encoder_output = transformer.transformer_encoder( 199 | encoder_input, 200 | self_attention_bias, 201 | hparams, 202 | save_weights_to=attention_weights, 203 | make_image_summary=not common_layers.is_xla_compiled()) 204 | else: 205 | raise ValueError("Unsupported instruction encoder %s" % ( 206 | hparams.instruction_encoder)) 207 | span_rep = hparams.get("span_rep", "area") 208 | area_encodings, area_starts, area_ends = area_utils.compute_sum_image( 209 | encoder_output, max_area_width=hparams.max_span) 210 | current_shape = tf.shape(area_encodings) 211 | if span_rep == "area": 212 | area_encodings, _, _ = area_utils.compute_sum_image( 213 | encoder_output, max_area_width=hparams.max_span) 214 | elif span_rep == "basic": 215 | area_encodings = area_utils.compute_alternative_span_rep( 216 | encoder_output, input_embeddings, max_area_width=hparams.max_span, 217 | hidden_size=hparams.hidden_size, advanced=False) 218 | elif span_rep == "coref": 219 | area_encodings = area_utils.compute_alternative_span_rep( 220 | encoder_output, input_embeddings, max_area_width=hparams.max_span, 221 | hidden_size=hparams.hidden_size, advanced=True) 222 | else: 223 | raise ValueError("xyz") 224 | areas = {} 225 | areas["encodings"] = area_encodings 226 | areas["starts"] = area_starts 227 | areas["ends"] = area_ends 228 | with tf.control_dependencies([tf.print("encoder_output", 229 | tf.shape(encoder_output)), 230 | tf.assert_equal(current_shape, 231 | tf.shape(area_encodings), 232 | summarize=100)]): 233 | paddings = tf.cast(tf.less(self_attention_bias, -1), tf.int32) 234 | padding_sum, _, _ = area_utils.compute_sum_image( 235 | tf.expand_dims(tf.squeeze(paddings, [1, 2]), 2), 236 | max_area_width=hparams.max_span) 237 | num_areas = common_layers.shape_list(area_encodings)[1] 238 | area_paddings = tf.reshape(tf.minimum(tf.to_float(padding_sum), 1.0), 239 | [-1, num_areas]) 240 | areas["bias"] = area_paddings 241 | decoder_nonpadding = tf.to_float( 242 | tf.greater(features["verb_refs"][:, :, 1], 243 | features["verb_refs"][:, :, 0])) 244 | if hparams.instruction_encoder == "lstm": 245 | hparams_decoder = copy.copy(hparams) 246 | hparams_decoder.set_hparam("pos", "none") 247 | else: 248 | hparams_decoder = hparams 249 | decoder_input, decoder_self_attention_bias = _prepare_decoder_input( 250 | area_encodings, decoder_nonpadding, features, hparams_decoder, 251 | embed_scope=scope) 252 | decoder_input = tf.nn.dropout( 253 | decoder_input, keep_prob=1.0 - hparams.layer_prepostprocess_dropout) 254 | if hparams.instruction_decoder == "transformer": 255 | decoder_output = transformer.transformer_decoder( 256 | decoder_input=decoder_input, 257 | encoder_output=encoder_output, 258 | decoder_self_attention_bias=decoder_self_attention_bias, 259 | encoder_decoder_attention_bias=encoder_decoder_attention_bias, 260 | hparams=hparams_decoder) 261 | else: 262 | raise ValueError("Unsupported instruction encoder %s" % ( 263 | hparams.instruction_encoder)) 264 | return decoder_output, decoder_nonpadding, areas, scope 265 | 266 | 267 | def predict_refs(logits, starts, ends): 268 | """Outputs the refs based on area predictions.""" 269 | with tf.control_dependencies([ 270 | tf.assert_equal(tf.rank(logits), 3), 271 | tf.assert_equal(tf.rank(starts), 2), 272 | tf.assert_equal(tf.rank(ends), 2)]): 273 | predicted_areas = tf.argmax(logits, -1) 274 | return area_utils.area_to_refs(starts, ends, predicted_areas) 275 | 276 | 277 | def compute_logits(features, hparams, train): 278 | """Computes reference logits and auxiliary information. 279 | 280 | Args: 281 | features: the feature dict. 282 | hparams: the hyper-parameters. 283 | train: whether it is in the train mode. 284 | Returns: 285 | a dict that contains: 286 | input_logits: [batch_size, num_steps, 2] 287 | verb_area_logits: [batch_size, num_steps, num_areas] 288 | obj_area_logits: [batch_size, num_steps, num_areas] 289 | input_area_logits: [batch_size, num_steps, num_areas] 290 | verb_hidden: [batch_size, num_steps, hidden_size] 291 | obj_hidden: [batch_size, num_steps, hidden_size] 292 | areas: a dict that contains area representation of the source sentence. 293 | """ 294 | latent_state, _, areas, embed_scope = encode_decode_task( 295 | features, hparams, train) 296 | task_encoding = areas["encodings"] 297 | task_encoding_bias = areas["bias"] 298 | def _output(latent_state, hparams, name): 299 | """Output layer.""" 300 | with tf.variable_scope("latent_to_" + name, reuse=tf.AUTO_REUSE): 301 | hidden = tf.layers.dense(latent_state, units=hparams.hidden_size) 302 | hidden = common_layers.apply_norm( 303 | hidden, hparams.norm_type, hparams.hidden_size, 304 | epsilon=hparams.norm_epsilon) 305 | return tf.nn.relu(hidden) 306 | with tf.variable_scope("output_layer", values=[latent_state, task_encoding, 307 | task_encoding_bias], 308 | reuse=tf.AUTO_REUSE): 309 | with tf.control_dependencies([tf.assert_equal(tf.rank(latent_state), 3)]): 310 | verb_hidden = _output(latent_state, hparams, "verb") 311 | object_hidden = _output(latent_state, hparams, "object") 312 | verb_hidden = tf.nn.dropout( 313 | verb_hidden, 314 | keep_prob=1.0 - hparams.layer_prepostprocess_dropout) 315 | object_hidden = tf.nn.dropout( 316 | object_hidden, 317 | keep_prob=1.0 - hparams.layer_prepostprocess_dropout) 318 | with tf.variable_scope("verb_refs", reuse=tf.AUTO_REUSE): 319 | verb_area_logits = area_utils.query_area( 320 | tf.layers.dense(verb_hidden, units=hparams.hidden_size, 321 | name="verb_query"), 322 | task_encoding, task_encoding_bias) 323 | with tf.variable_scope("object_refs", reuse=tf.AUTO_REUSE): 324 | obj_area_logits = area_utils.query_area( 325 | tf.layers.dense(object_hidden, units=hparams.hidden_size, 326 | name="obj_query"), 327 | task_encoding, task_encoding_bias) 328 | with tf.variable_scope("input_refs", reuse=tf.AUTO_REUSE): 329 | input_logits = tf.layers.dense( 330 | _output(latent_state, hparams, "input"), units=2) 331 | input_area_logits = area_utils.query_area( 332 | _output(latent_state, hparams, "input_refs"), 333 | task_encoding, task_encoding_bias) 334 | references = {} 335 | references["input_logits"] = input_logits 336 | references["verb_area_logits"] = verb_area_logits 337 | references["obj_area_logits"] = obj_area_logits 338 | references["input_area_logits"] = input_area_logits 339 | references["verb_hidden"] = verb_hidden 340 | references["object_hidden"] = object_hidden 341 | references["areas"] = areas 342 | references["decoder_output"] = latent_state 343 | references["embed_scope"] = embed_scope 344 | if hparams.freeze_reference_model: 345 | for key in ["input_logits", "verb_area_logits", "obj_area_logits", 346 | "input_area_logits", "verb_hidden", "object_hidden", 347 | "decoder_output"]: 348 | references[key] = tf.stop_gradient(references[key]) 349 | for key in references["areas"]: 350 | references["areas"][key] = tf.stop_gradient(references["areas"][key]) 351 | return references 352 | 353 | 354 | def compute_losses(loss_dict, features, references, hparams): 355 | """Compute the loss based on the logits and labels.""" 356 | # Commented code can be useful for examining seq lengths distribution 357 | # srcs, _, counts = tf.unique_with_counts(features["data_source"]) 358 | # lengths = tf.reduce_sum(tf.to_int32( 359 | # tf.greater(features["verb_refs"][:, :, 1], 360 | # features["verb_refs"][:, :, 0])), -1) - 1 361 | # lengths, _, len_counts = tf.unique_with_counts(lengths) 362 | # with tf.control_dependencies([ 363 | # tf.print("sources", srcs, counts, lengths, len_counts, summarize=1000)]): 364 | input_mask = tf.to_float( 365 | tf.greater(features["verb_refs"][:, :, 1], 366 | features["verb_refs"][:, :, 0])) 367 | input_loss = tf.reduce_mean( 368 | tf.losses.sparse_softmax_cross_entropy( 369 | labels=tf.to_int32( 370 | tf.greater(features["input_refs"][:, :, 1], 371 | features["input_refs"][:, :, 0])), 372 | logits=references["input_logits"], 373 | reduction=tf.losses.Reduction.NONE) * input_mask) 374 | encoder_input_length = common_layers.shape_list(features["task"])[1] 375 | verb_area_loss = area_utils.area_loss( 376 | logits=references["verb_area_logits"], ranges=features["verb_refs"], 377 | length=encoder_input_length, 378 | max_area_width=hparams.max_span) 379 | object_area_loss = area_utils.area_loss( 380 | logits=references["obj_area_logits"], 381 | ranges=features["obj_refs"], 382 | length=encoder_input_length, 383 | max_area_width=hparams.max_span) 384 | input_area_loss = area_utils.area_loss( 385 | logits=references["input_area_logits"], ranges=features["input_refs"], 386 | length=encoder_input_length, 387 | max_area_width=hparams.max_span) 388 | loss_dict["reference_loss"] = ( 389 | input_loss + verb_area_loss + object_area_loss + input_area_loss) 390 | loss_dict["input_loss"] = input_loss 391 | loss_dict["verb_refs_loss"] = verb_area_loss 392 | loss_dict["obj_refs_loss"] = object_area_loss 393 | loss_dict["input_refs_loss"] = input_area_loss 394 | return loss_dict["reference_loss"] 395 | 396 | 397 | def compute_predictions(prediction_dict, references): 398 | """Predict the action tuple based on the logits.""" 399 | prediction_dict["input"] = tf.argmax(references["input_logits"], -1) 400 | prediction_dict["verb_refs"] = predict_refs(references["verb_area_logits"], 401 | references["areas"]["starts"], 402 | references["areas"]["ends"]) 403 | prediction_dict["obj_refs"] = predict_refs(references["obj_area_logits"], 404 | references["areas"]["starts"], 405 | references["areas"]["ends"]) 406 | prediction_dict["input_refs"] = predict_refs( 407 | references["input_area_logits"], 408 | references["areas"]["starts"], 409 | references["areas"]["ends"]) * tf.to_int32(tf.expand_dims( 410 | prediction_dict["input"], 2)) 411 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py>=0.6.0 2 | numpy>=1.15.4 3 | six>=1.12.0 4 | tensorflow==1.15 # change to 'tensorflow-gpu' for gpu support 5 | tensor2tensor 6 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The Google Research Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | #!/bin/bash 16 | set -e 17 | set -x 18 | 19 | virtualenv -p python3 . 20 | source ./bin/activate 21 | 22 | pip install tensorflow 23 | pip install -r seq2act/requirements.txt 24 | python -m seq2act.bin.setup_test --train_file_list "seq2act/data/android_howto/*.tfrecord,seq2act/data/rico_sca/*.tfrecord" \ 25 | --train_source_list "android_howto,rico_sca" \ 26 | --train_batch_sizes "2,2" \ 27 | --train_steps 2 \ 28 | --batch_size 2 \ 29 | --experiment_dir "/tmp/seq2act" \ 30 | --logtostderr 31 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /utils/decode_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """decode_utils.""" 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | from enum import Enum 22 | from tensor2tensor.layers import common_layers 23 | import tensorflow.compat.v1 as tf 24 | from seq2act.layers import area_utils 25 | 26 | 27 | class ActionTypes(Enum): 28 | """The action types and ids of Android actions.""" 29 | CLICK = 2 30 | INPUT = 3 31 | SWIPE = 4 32 | CHECK = 5 33 | UNCHECK = 6 34 | LONG_CLICK = 7 35 | OTHERS = 8 36 | 37 | 38 | def verb_refs_to_lengths(task, verb_refs, include_eos=True): 39 | """Computes the length of a sequence.""" 40 | eos_positions = tf.to_int32(tf.expand_dims( 41 | tf.where(tf.equal(task, 1))[:, 1], 1)) 42 | seq_mask = tf.logical_not(tf.cast(tf.cumsum(tf.to_int32( 43 | tf.logical_and( 44 | tf.equal(verb_refs[:, :, 0], eos_positions), 45 | tf.equal(verb_refs[:, :, 1], eos_positions + 1))), axis=-1), tf.bool)) 46 | lengths = tf.reduce_sum(tf.to_float(seq_mask), axis=-1) 47 | if include_eos: 48 | lengths = lengths + 1 49 | return lengths 50 | 51 | 52 | def compute_seq_metrics(label_dict, feature_dict, debug=False, mask=None): 53 | """Compute the reference accuracy.""" 54 | gt_lengths = verb_refs_to_lengths(label_dict["task"], 55 | label_dict["verb_refs"], include_eos=False) 56 | pred_lengths = verb_refs_to_lengths(feature_dict["task"], 57 | feature_dict["verb_refs"], 58 | include_eos=False) 59 | gt_actions = tf.concat([ 60 | tf.expand_dims(label_dict["verbs"], 2), 61 | tf.expand_dims(label_dict["objects"], 2), 62 | label_dict["input_refs"]], axis=-1) 63 | pr_actions = tf.concat([ 64 | tf.expand_dims(feature_dict["verbs"], 2), 65 | tf.expand_dims(feature_dict["objects"], 2), 66 | feature_dict["input_refs"]], axis=-1) 67 | complete_act_acc, partial_act_acc = sequence_accuracy( 68 | gt_actions, pr_actions, gt_lengths, pred_lengths, 69 | debug=debug, name="act") 70 | gt_refs = tf.concat([ 71 | label_dict["verb_refs"], 72 | label_dict["obj_refs"], 73 | label_dict["input_refs"]], axis=-1) 74 | pr_refs = tf.concat([ 75 | feature_dict["verb_refs"], 76 | feature_dict["obj_refs"], 77 | feature_dict["input_refs"]], axis=-1) 78 | if mask is not None: 79 | mask = tf.expand_dims(tf.expand_dims(mask, 0), 0) 80 | gt_refs = gt_refs * mask 81 | pr_refs = pr_refs * mask 82 | pred_lengths = gt_lengths 83 | with tf.control_dependencies([tf.print( 84 | "mask", gt_refs, pr_refs, summarize=100)]): 85 | complete_refs_acc, partial_refs_acc = sequence_accuracy( 86 | gt_refs, pr_refs, gt_lengths, pred_lengths, 87 | debug=debug, name="ref") 88 | refs_metrics = {} 89 | refs_metrics["complete_acts_acc"] = complete_act_acc 90 | refs_metrics["partial_acts_acc"] = partial_act_acc 91 | refs_metrics["complete_refs_acc"] = complete_refs_acc 92 | refs_metrics["partial_refs_acc"] = partial_refs_acc 93 | refs_metrics["gt_seq"] = gt_actions 94 | refs_metrics["pred_seq"] = pr_actions 95 | return refs_metrics 96 | 97 | 98 | def unify_input_ref(pred_verbs, pred_input_ref): 99 | """Changes the input ref to zero according if pred_verbs are not input.""" 100 | pred_verbs = tf.expand_dims(pred_verbs, axis=-1) 101 | same_dim_verbs = tf.concat([pred_verbs, pred_verbs], axis=-1) 102 | zero_refs = tf.zeros_like(pred_input_ref) 103 | return tf.where( 104 | tf.equal(same_dim_verbs, ActionTypes.INPUT.value), pred_input_ref, 105 | zero_refs) 106 | 107 | 108 | def sequence_accuracy(gt_seqs, decode_seqs, gt_seq_lengths, pr_seq_lengths, 109 | debug=False, name=""): 110 | """Computes the complete and the partial sequence accuracy.""" 111 | gt_shape = common_layers.shape_list(gt_seqs) 112 | pr_shape = common_layers.shape_list(decode_seqs) 113 | batch_size = gt_shape[0] 114 | depth = gt_shape[-1] 115 | gt_len = gt_shape[1] 116 | pr_len = pr_shape[1] 117 | max_len = tf.maximum(gt_len, pr_len) 118 | gt_seqs = tf.pad(gt_seqs, 119 | [[0, 0], [0, max_len - gt_len], [0, 0]]) 120 | decode_seqs = tf.pad(decode_seqs, 121 | [[0, 0], [0, max_len - pr_len], [0, 0]]) 122 | gt_seqs = tf.where( 123 | tf.tile( 124 | tf.expand_dims(tf.sequence_mask(gt_seq_lengths, maxlen=max_len), 2), 125 | [1, 1, depth]), 126 | gt_seqs, 127 | tf.fill(tf.shape(gt_seqs), -1)) 128 | decode_seqs = tf.where( 129 | tf.tile( 130 | tf.expand_dims(tf.sequence_mask(pr_seq_lengths, maxlen=max_len), 2), 131 | [1, 1, depth]), 132 | decode_seqs, 133 | tf.fill(tf.shape(decode_seqs), -1)) 134 | # [batch_size, decode_length] 135 | corrects = tf.reduce_all(tf.equal(gt_seqs, decode_seqs), -1) 136 | correct_mask = tf.reduce_all(corrects, -1) 137 | # [batch_size] 138 | if debug: 139 | incorrect_mask = tf.logical_not(correct_mask) 140 | incorrect_gt = tf.boolean_mask(gt_seqs, incorrect_mask) 141 | incorrect_pr = tf.boolean_mask(decode_seqs, incorrect_mask) 142 | with tf.control_dependencies([tf.print(name + "_mismatch", 143 | incorrect_gt, 144 | incorrect_pr, 145 | summarize=1000)]): 146 | correct_mask = tf.identity(correct_mask) 147 | correct_seqs = tf.to_float(correct_mask) 148 | total_correct_seqs = tf.reduce_sum(correct_seqs) 149 | mean_complete_accuracy = total_correct_seqs / tf.to_float(batch_size) 150 | # Compute partial accuracy 151 | errors = tf.logical_not(corrects) 152 | errors = tf.cast(tf.cumsum(tf.to_float(errors), axis=-1), tf.bool) 153 | # [batch_size] 154 | correct_steps = tf.reduce_sum(tf.to_float(tf.logical_not(errors)), axis=-1) 155 | mean_partial_accuracy = tf.reduce_mean( 156 | tf.div(tf.minimum(correct_steps, gt_seq_lengths), gt_seq_lengths)) 157 | return mean_complete_accuracy, mean_partial_accuracy 158 | 159 | 160 | def _advance(step, beam_log_probs, previous_refs, 161 | area_logits, areas, batch_size, beam_size, append_refs=True, 162 | condition=None): 163 | """Advance one element in the tuple for a decoding step. 164 | 165 | Args: 166 | step: the current decoding step. 167 | beam_log_probs: [batch_size * beam_size] 168 | previous_refs: [batch_size * beam_size, input_length - 1, 2] 169 | area_logits: [batch_size * beam_size, num_areas] 170 | areas: the areas. 171 | batch_size: the batch size. 172 | beam_size: the beam_size. 173 | append_refs: returning references or ids. 174 | condition: conditional probability mask in shape [batch_size * beam_size]. 175 | Returns: 176 | beam_log_probs: [batch_size * beam_size] 177 | references in shape of [batch_size * beam_size, input_length, 2] or 178 | ids in shape of [batch_size * beam_size] 179 | """ 180 | with tf.control_dependencies([ 181 | tf.equal(tf.shape(beam_log_probs), (batch_size * beam_size,))]): 182 | num_expansions = tf.minimum(beam_size, tf.shape(area_logits)[-1]) 183 | # [batch_size * beam_size, num_expansions] 184 | area_log_probs = common_layers.log_prob_from_logits(area_logits) 185 | if condition is not None: 186 | area_log_probs = area_log_probs * tf.to_float( 187 | tf.expand_dims(condition, 1)) 188 | top_area_log_probs, top_area_ids = tf.nn.top_k( 189 | area_log_probs, k=num_expansions) 190 | if append_refs: 191 | # [batch_size * beam_size, num_expansions, 2] 192 | refs = area_utils.area_to_refs(areas["starts"], areas["ends"], 193 | top_area_ids) 194 | if condition is not None: 195 | refs = refs * tf.expand_dims(tf.expand_dims(condition, 1), 2) 196 | refs = tf.reshape(refs, [batch_size, beam_size, num_expansions, 1, 2]) 197 | if step > 0: 198 | previous_refs = tf.reshape( 199 | previous_refs, [batch_size, beam_size, 1, step, 2]) 200 | previous_refs = tf.tile(previous_refs, [1, 1, num_expansions, 1, 1]) 201 | new_refs = tf.concat([previous_refs, refs], axis=3) 202 | else: 203 | new_refs = refs 204 | new_refs = tf.reshape( 205 | new_refs, [batch_size * beam_size * num_expansions, step + 1, 2]) 206 | # [batch_size, beam_size * num_expansions] 207 | log_probs = tf.reshape(tf.expand_dims(beam_log_probs, 1) + top_area_log_probs, 208 | [batch_size, beam_size * num_expansions]) 209 | # [batch_size, beam_size] 210 | beam_log_probs, beam_indices = tf.nn.top_k(log_probs, k=beam_size) 211 | beam_indices = tf.reshape(beam_indices, [-1]) 212 | beam_log_probs = tf.reshape(beam_log_probs, [batch_size * beam_size]) 213 | indices = tf.reshape( 214 | tf.tile(tf.expand_dims(tf.range(batch_size) * beam_size * num_expansions, 215 | axis=1), [1, beam_size]), [-1]) + beam_indices 216 | if append_refs: 217 | new_refs = tf.gather(new_refs, indices=indices) 218 | else: 219 | new_refs = tf.gather(tf.reshape(top_area_ids, [-1]), indices=indices) 220 | return beam_log_probs, new_refs 221 | 222 | 223 | def decode_one_step(step, live_beams, eos_positions, 224 | compute_logits, 225 | beam_log_probs, batch_size, beam_size, 226 | features, areas, hparams, 227 | use_cache, cache, 228 | mode=tf.estimator.ModeKeys.EVAL, 229 | always_consumed=True): 230 | """decode one step.""" 231 | # features: [batch_size * beam_size, step + 1, ...] 232 | # [batch_size * beam_size, num_areas] 233 | action_logits, object_logits, consumed_logits, references = compute_logits( 234 | features, hparams, mode, use_cache, cache) 235 | input_logits = references["input_logits"] 236 | verb_area_logits = references["verb_area_logits"] 237 | obj_area_logits = references["obj_area_logits"] 238 | input_area_logits = references["input_area_logits"] 239 | cache = {} 240 | cache["input_logits"] = input_logits 241 | cache["verb_area_logits"] = verb_area_logits 242 | cache["obj_area_logits"] = obj_area_logits 243 | cache["input_area_logits"] = input_area_logits 244 | # step + 1 245 | input_length = tf.shape(features["verb_refs"])[1] 246 | output_length = tf.shape(verb_area_logits)[1] 247 | with tf.control_dependencies([ 248 | tf.assert_equal(input_length, output_length)]): 249 | # Decode consumed 250 | beam_log_probs, is_ref_consumed = _advance( 251 | step, 252 | beam_log_probs, 253 | previous_refs=None, 254 | area_logits=consumed_logits[:, -1, :], areas=None, 255 | batch_size=batch_size, beam_size=beam_size, append_refs=False, 256 | condition=tf.to_int32(live_beams)) 257 | if always_consumed: 258 | use_cache = tf.zeros_like(use_cache) 259 | else: 260 | use_cache = 1 - is_ref_consumed 261 | # Decode actions and objects greedy 262 | _, action = tf.nn.top_k(action_logits[:, -1, :]) 263 | features["verbs"] = tf.concat([ 264 | features["verbs"][:, :step], 265 | action, features["verbs"][:, step + 1:]], axis=1) 266 | features["verbs"] = tf.where(tf.equal(use_cache, 1), 267 | # Emit CLICK (2) if not consumed 268 | tf.fill(tf.shape(features["verbs"]), 2), 269 | features["verbs"]) 270 | _, obj = tf.nn.top_k(object_logits[:, -1, :]) 271 | features["objects"] = tf.concat([ 272 | features["objects"][:, :step], 273 | obj, features["objects"][:, step + 1:]], axis=1) 274 | # Decode verb refs 275 | beam_log_probs, new_refs = _advance( 276 | step, 277 | beam_log_probs, 278 | previous_refs=features["verb_refs"][:, :-1, :], 279 | area_logits=verb_area_logits[:, -1, :], areas=areas, 280 | batch_size=batch_size, beam_size=beam_size, 281 | condition=tf.to_int32(live_beams)) 282 | features["verb_refs"] = new_refs 283 | live_beams = tf.logical_and( 284 | live_beams, 285 | tf.not_equal(new_refs[:, -1, 0], eos_positions)) 286 | # Decode object refs 287 | beam_log_probs, new_refs = _advance( 288 | step, 289 | beam_log_probs, 290 | previous_refs=features["obj_refs"][:, :-1, :], 291 | area_logits=obj_area_logits[:, -1, :], areas=areas, 292 | batch_size=batch_size, beam_size=beam_size, 293 | condition=tf.to_int32(live_beams)) 294 | features["obj_refs"] = new_refs 295 | # Decode input refs 296 | beam_log_probs, need_inputs = _advance( 297 | step, 298 | beam_log_probs, 299 | previous_refs=None, 300 | area_logits=input_logits[:, -1, :], areas=None, 301 | batch_size=batch_size, beam_size=beam_size, append_refs=False, 302 | condition=tf.to_int32(live_beams)) 303 | beam_log_probs, new_refs = _advance( 304 | step, 305 | beam_log_probs, 306 | previous_refs=features["input_refs"][:, :-1, :], 307 | area_logits=input_area_logits[:, -1, :], areas=areas, 308 | batch_size=batch_size, beam_size=beam_size, 309 | condition=tf.to_int32(live_beams) * need_inputs) 310 | features["input_refs"] = new_refs 311 | return beam_log_probs, live_beams, use_cache, cache 312 | 313 | 314 | def _expand_to_beam(features, beam_size): 315 | shape_list = common_layers.shape_list(features) 316 | batch_size = shape_list[0] 317 | features = tf.expand_dims(features, axis=1) 318 | tile_dims = [1] * features.shape.ndims 319 | tile_dims[1] = beam_size 320 | shape_list[0] = batch_size * beam_size 321 | features = tf.reshape(tf.tile(features, tile_dims), shape_list) 322 | return features 323 | 324 | 325 | def _recover_shape(features, beam_size): 326 | shape_list = common_layers.shape_list(features) 327 | batch_size = shape_list.pop(0) // beam_size 328 | shape_list = [batch_size, beam_size] + shape_list 329 | features = tf.reshape(features, shape_list) 330 | return features 331 | 332 | 333 | def decode_n_step(compute_logits, features, areas, 334 | hparams, n=20, beam_size=1, top_beam=True): 335 | """Decode for n steps. 336 | 337 | Args: 338 | compute_logits: the callback function for computing the logits. 339 | features: a dictionary of features. 340 | areas: the dict of area index mapping, with each tensor in the shape of 341 | [batch_size, num_areas]. 342 | hparams: the hyperparameters. 343 | n: the number of steps to decode. 344 | beam_size: the beam size for beach search. 345 | top_beam: whether to return the results from the top beam only. 346 | """ 347 | print(features) 348 | use_obj_dom_dist = ("obj_dom_dist" in features) 349 | batch_size = tf.shape(features["task"])[0] 350 | beam_log_probs = tf.fill([batch_size * beam_size], 0.) 351 | live_beams = tf.fill([batch_size * beam_size], True) 352 | use_cache = tf.fill([batch_size * beam_size], 0) 353 | cache = {} 354 | for step in range(n): 355 | if step == 0: 356 | features["verb_refs"] = tf.zeros([batch_size, 1, 2], tf.int32) 357 | features["obj_refs"] = tf.zeros([batch_size, 1, 2], tf.int32) 358 | features["input_refs"] = tf.zeros([batch_size, 1, 2], tf.int32) 359 | for key in features: 360 | features[key] = _expand_to_beam(features[key], beam_size) 361 | areas["starts"] = _expand_to_beam(areas["starts"], beam_size) 362 | areas["ends"] = _expand_to_beam(areas["ends"], beam_size) 363 | # Backup the screen features 364 | def pad_to_match(feature, target_length, rank, constant_values): 365 | """Pad the feature to the decode length.""" 366 | padding_list = [] 367 | target_length = tf.maximum(target_length, tf.shape(feature)[1]) 368 | for r in range(rank): 369 | if r == 1: 370 | padding_list.append([0, target_length - tf.shape(feature)[1]]) 371 | else: 372 | padding_list.append([0, 0]) 373 | return tf.pad(feature, padding_list, constant_values=constant_values, 374 | name="pad_to_match") 375 | features["backup_obj_text"] = pad_to_match(features["obj_text"], n, 4, 0) 376 | features["backup_obj_type"] = pad_to_match(features["obj_type"], n, 3, -1) 377 | features["backup_obj_clickable"] = pad_to_match( 378 | features["obj_clickable"], n, 3, 0) 379 | features["backup_obj_screen_pos"] = pad_to_match( 380 | features["obj_screen_pos"], n, 4, 0) 381 | features["backup_obj_dom_pos"] = pad_to_match(features["obj_dom_pos"], 382 | n, 4, 0) 383 | if use_obj_dom_dist: 384 | features["backup_obj_dom_dist"] = pad_to_match(features["obj_dom_dist"], 385 | n, 4, 0) 386 | # Set the screen features 387 | features["obj_text"] = features["obj_text"][:, :1] 388 | features["obj_type"] = features["obj_type"][:, :1] 389 | features["obj_clickable"] = features["obj_clickable"][:, :1] 390 | features["obj_screen_pos"] = features["obj_screen_pos"][:, :1] 391 | features["obj_dom_pos"] = features["obj_dom_pos"][:, :1] 392 | if use_obj_dom_dist: 393 | features["obj_dom_dist"] = features["obj_dom_dist"][:, :1] 394 | else: 395 | features["verb_refs"] = tf.pad(features["verb_refs"], 396 | [[0, 0], [0, 1], [0, 0]], 397 | name="pad_verb_refs") 398 | features["obj_refs"] = tf.pad(features["obj_refs"], 399 | [[0, 0], [0, 1], [0, 0]], 400 | name="pad_obj_refs") 401 | features["input_refs"] = tf.pad(features["input_refs"], 402 | [[0, 0], [0, 1], [0, 0]], 403 | name="pad_input_refs") 404 | # Fill in the screen information 405 | features["obj_text"] = features["backup_obj_text"][:, :step + 1] 406 | features["obj_type"] = features["backup_obj_type"][:, :step + 1] 407 | features["obj_clickable"] = features["backup_obj_clickable"][:, :step + 1] 408 | features["obj_screen_pos"] = ( 409 | features["backup_obj_screen_pos"][:, :step + 1]) 410 | features["obj_dom_pos"] = ( 411 | features["backup_obj_dom_pos"][:, :step + 1]) 412 | if use_obj_dom_dist: 413 | features["obj_dom_dist"] = ( 414 | features["backup_obj_dom_dist"][:, :step + 1]) 415 | eos_positions = tf.to_int32(tf.where(tf.equal(features["task"], 1))[:, 1]) 416 | beam_log_probs, live_beams, use_cache, cache = decode_one_step( 417 | step, live_beams, eos_positions, 418 | compute_logits, 419 | beam_log_probs, 420 | batch_size, beam_size, features, 421 | areas, hparams, 422 | use_cache=use_cache, cache=cache, 423 | always_consumed=True) 424 | for key in features: 425 | features[key] = _recover_shape(features[key], beam_size) 426 | if top_beam: 427 | features[key] = features[key][:, 0] 428 | if key in ["obj_type", "obj_clickable"]: 429 | features[key] = tf.pad( 430 | features[key], [[0, 0], 431 | [0, n - tf.shape(features[key])[1]], [0, 0]], 432 | constant_values=-1 if key.endswith("type") else 0, 433 | name="pad_type_clickable") 434 | elif key in ["obj_text", "obj_screen_pos", "obj_dom_pos", "obj_dom_dist"]: 435 | features[key] = tf.pad(features[key], 436 | [[0, 0], [0, n - tf.shape(features[key])[1]], 437 | [0, 0], [0, 0]], 438 | name="pad_rest_screen_features") 439 | --------------------------------------------------------------------------------