├── AUTHORS ├── CONTRIBUTING.md ├── DENSE_TABLE_RETRIEVER.md ├── DOT.md ├── INTERMEDIATE_PRETRAIN_DATA.md ├── LICENSE ├── MANIFEST.in ├── MATE.md ├── PRETRAIN_DATA.md ├── README.md ├── TABLEFORMER.md ├── notebooks ├── retrieval_predictions.ipynb ├── sqa_predictions.ipynb ├── tabfact_predictions.ipynb └── wtq_predictions.ipynb ├── requirements.txt ├── setup.py ├── tapas ├── __init__.py ├── create_intermediate_pretrain_examples_main.py ├── create_pretrain_examples_main.py ├── datasets │ ├── __init__.py │ ├── dataset.py │ ├── dataset_test.py │ ├── table_dataset.py │ ├── table_dataset_test.py │ └── table_dataset_test_utils.py ├── experiments │ ├── __init__.py │ ├── prediction_utils.py │ ├── prediction_utils_test.py │ ├── table_retriever_experiment.py │ ├── tapas_classifier_experiment.py │ └── tapas_pretraining_experiment.py ├── models │ ├── __init__.py │ ├── bert │ │ ├── __init__.py │ │ ├── modeling.py │ │ ├── modeling_test.py │ │ ├── optimization.py │ │ ├── optimization_test.py │ │ └── table_bert.py │ ├── segmented_tensor.py │ ├── segmented_tensor_test.py │ ├── table_retriever_model.py │ ├── table_retriever_model_test.py │ ├── tapas_classifier_model.py │ ├── tapas_classifier_model_test.py │ ├── tapas_classifier_model_utils.py │ ├── tapas_pretraining_model.py │ └── tapas_pretraining_model_test.py ├── protos │ ├── __init__.py │ ├── annotated_text.proto │ ├── interaction.proto │ ├── negative_retrieval_examples.proto │ ├── retriever_info.proto │ ├── table_pruning.proto │ └── table_selection.proto ├── retrieval │ ├── add_negative_tables_to_interactions.py │ ├── add_negative_tables_to_interactions_main.py │ ├── add_negative_tables_to_interactions_test.py │ ├── create_baseline_results.py │ ├── create_e2e_interactions.py │ ├── create_retrieval_data_main.py │ ├── create_retrieval_pretrain_data_main.py │ ├── e2e_eval.py │ ├── e2e_eval_utils.py │ ├── e2e_eval_utils_test.py │ ├── testdata │ │ ├── neural_retrieval_00.jsonl │ │ └── retrieval_interaction.pbtxt │ ├── tf_example_utils.py │ ├── tf_example_utils_test.py │ ├── tfidf_baseline.py │ ├── tfidf_baseline_utils.py │ └── tfidf_baseline_utils_test.py ├── run_task_main.py ├── scripts │ ├── __init__.py │ ├── calc_metrics.py │ ├── calc_metrics_test.py │ ├── calc_metrics_utils.py │ ├── convert_predictions.py │ ├── convert_predictions_utils.py │ ├── convert_predictions_utils_test.py │ ├── eval_table_retriever.py │ ├── eval_table_retriever_test.py │ ├── eval_table_retriever_utils.py │ ├── eval_wikisql.py │ ├── prediction_utils.py │ ├── prediction_utils_test.py │ ├── preprocess_nq.py │ ├── preprocess_nq_test.py │ ├── preprocess_nq_utils.py │ └── testdata │ │ ├── nq_raw_examples.txt │ │ ├── table_00.html │ │ ├── table_01.html │ │ └── table_02.html ├── testdata │ └── classification_examples.tfrecords └── utils │ ├── __init__.py │ ├── attention_utils.py │ ├── attention_utils_test.py │ ├── beam_runner.py │ ├── beam_utils.py │ ├── beam_utils_test.py │ ├── constants.py │ ├── contrastive_statements.py │ ├── contrastive_statements_test.py │ ├── contrastive_statements_test_utils.py │ ├── create_data.py │ ├── create_data_file_io.py │ ├── create_data_test.py │ ├── experiment_utils.py │ ├── experiment_utils_test.py │ ├── file_utils.py │ ├── hparam_utils.py │ ├── hybridqa_rc_utils.py │ ├── hybridqa_rc_utils_test.py │ ├── hybridqa_utils.py │ ├── hybridqa_utils_test.py │ ├── interaction_utils.py │ ├── interaction_utils_parser.py │ ├── interaction_utils_parser_test.py │ ├── interaction_utils_test.py │ ├── intermediate_pretrain_utils.py │ ├── intermediate_pretrain_utils_test.py │ ├── interpretation_utils.py │ ├── interpretation_utils_test.py │ ├── number_annotation_utils.py │ ├── number_annotation_utils_test.py │ ├── number_utils.py │ ├── number_utils_test.py │ ├── pretrain_utils.py │ ├── pruning_utils.py │ ├── pruning_utils_test.py │ ├── sem_tab_fact_utils.py │ ├── sem_tab_fact_utils_test.py │ ├── sentence_tokenizer.py │ ├── sentence_tokenizer_test.py │ ├── span_prediction_utils.py │ ├── span_prediction_utils_test.py │ ├── sqa_utils.py │ ├── synthesize_entablement.py │ ├── synthesize_entablement_test.py │ ├── tabfact_utils.py │ ├── table_pruning.py │ ├── table_pruning_test.py │ ├── tableformer_utils.py │ ├── tableformer_utils_test.py │ ├── task_utils.py │ ├── tasks.py │ ├── testdata │ ├── interaction_00.pbtxt │ ├── interaction_01.pbtxt │ ├── interaction_02.pbtxt │ ├── interaction_03.pbtxt │ ├── interaction_04.pbtxt │ ├── pretrain_interactions.txtpb │ ├── questions.tsv │ ├── questions_aggregation.tsv │ ├── questions_float_answer.tsv │ ├── sem_tab_fact_20502.xml │ ├── sem_tab_fact_20502_interaction.txtpb │ ├── sem_tab_fact_20502_interaction_v2.txtpb │ ├── tf_example_02.pbtxt │ ├── tf_example_02_conv.pbtxt │ ├── tf_example_03.pbtxt │ └── vocab.txt │ ├── text_index.py │ ├── text_index_test.py │ ├── text_utils.py │ ├── text_utils_test.py │ ├── tf_example_utils.py │ ├── tf_example_utils_test.py │ ├── wikisql_utils.py │ ├── wikisql_utils_test.py │ ├── wtq_utils.py │ └── wtq_utils_test.py └── tox.ini /AUTHORS: -------------------------------------------------------------------------------- 1 | # This is the list of the Tapas authors for copyright purposes. 2 | # 3 | # This does not necessarily list everyone who has contributed code, since in 4 | # some cases, their employer may be the copyright holder. To see the full list 5 | # of contributors, see the revision history in source control. 6 | 7 | Google Inc. 8 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | We'd love to accept your patches and contributions to this project. There are 4 | just a few small guidelines you need to follow. 5 | 6 | ## Contributor License Agreement 7 | 8 | Contributions to this project must be accompanied by a Contributor License 9 | Agreement. You (or your employer) retain the copyright to your contribution; 10 | this simply gives us permission to use and redistribute your contributions as 11 | part of the project. Head over to to see 12 | your current agreements on file or to sign a new one. 13 | 14 | You generally only need to submit a CLA once, so if you've already submitted one 15 | (even if it was for a different project), you probably don't need to do it 16 | again. 17 | 18 | ## Code reviews 19 | 20 | All submissions, including submissions by project members, require review. We 21 | use GitHub pull requests for this purpose. Consult 22 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 23 | information on using pull requests. 24 | 25 | ## Community Guidelines 26 | 27 | This project follows [Google's Open Source Community 28 | Guidelines](https://opensource.google.com/conduct/). 29 | -------------------------------------------------------------------------------- /DOT.md: -------------------------------------------------------------------------------- 1 | # DoT: An efficient Double Transformer for NLP tasks with tables 2 | This document contains models and steps to reproduce the results of [DoT: An efficient Double Transformer for NLP tasks with tables](https://arxiv.org/abs/2106.00479) published at ACL Findings 2021. 3 | 4 | ## DoT Models 5 | DoT is a double transformer model composed of a first pruning transformer that selects 256 tokens and passes them to a second transformer, a task specific transformer that solves the main task. 6 | 7 | Best models' results presented in the paper are reported in the following table. 8 | 9 | Dataset | Model | Accuracy | Best | NPE/s | Link 10 | -------- | ------------------------- | ---------- | ----- | ----- | ------ 11 | WikiSQL | HEM-1024->DoT(s-256->l) | 85.3±0.4 | 85.76 | 1250 | [wikisql_hem_1024_dot_small_256_large.zip](https://storage.googleapis.com/tapas_models/2021_08_20/wikisql_hem_1024_dot_small_256_large.zip) 12 | TABFACT | HEM-1024->DoT(s-256->l) | 81.6±0.3 | 81.74 | 1300 |[tabfact_hem_1024_dot_small_256_large.zip](https://storage.googleapis.com/tapas_models/2021_08_20/tabfact_hem_1024_dot_small_256_large.zip) 13 | WikiTQ | CC-1024->C-DoT(m-256->l) | 50.1±0.5 | 50.14 | 950 |[wtq_cc_1024_column_dot_medium_256_large.zip](https://storage.googleapis.com/tapas_models/2021_08_20/wtq_cc_1024_column_dot_medium_256_large.zip) 14 | 15 | ## Learning DoT model 16 | The different steps to learn DoT models: 17 | 1. Pre-train the pruning and task-specific transformers. 18 | 2. Create the data for the fine-tuning. 19 | 3. Create the pruning config. 20 | 4. Fine-tune DoT: Jointly learn the two transformers. 21 | 22 | ### Pre-training 23 | All DoT models are initialized from models pre-trained with a Mask LM, intermediate data and SQA. Both pruning and task-specific transformers are pre-trained. 24 | 25 | ### Generating HEM or CC data 26 | To reduce the input size of the interactions we can use the heuristic exact match (HEM) or the cell concatenation (CC). Then the code extracts the corresponding tf_examples. 27 | ```bash 28 | HEURISTIC="hem" 29 | 30 | python -m tapas/run_task_main \ 31 | --task=${TASK} \ 32 | --input_dir="${task_data_dir}" \ 33 | --output_dir="${output_dir}/${HEURISTIC}" \ 34 | --max_seq_length="1024" \ 35 | --prune_columns="true" \ 36 | --bert_vocab_file="${tapas_data_dir}/vocab.txt" \ 37 | --mode="create_data" 38 | ``` 39 | For CC use HEURISTIC="CC" and prune_columns=false. 40 | We use max_seq_length=1024 as all the DoT models use a heuristic to reduce the input length to 1024 (.-1024->DoT). 41 | 42 | ### Create the pruning transformer config 43 | The pruning config follows tapas/protos/table_pruning.proto 44 | For example to create a config for -DoT(s-256->.) use: 45 | 46 | ```bash 47 | # proto-file: tapas/google/table_pruning.proto 48 | # proto-message: tapas.TablePruningModel 49 | max_num_tokens: 256 50 | tapas{ 51 | selection: TOKENS 52 | loss: { 53 | unsupervised: {regularization:NONE} 54 | train: {selection_fn: TOP_K} 55 | eval: {selection_fn: TOP_K} 56 | } 57 | reset_position_index_per_cell: true 58 | bert_config_file: "${TAPAS_S_CHECK_POINT}/bert_config.json" 59 | bert_init_checkpoint: "${TAPAS_S_CHECK_POINT}/model.ckpt" 60 | } 61 | ``` 62 | Then use the path to the created file: 63 | ```bash 64 | CONGIG_FILE=".textproto" 65 | ``` 66 | 67 | ### Fine-tuning DoT 68 | DoT has been used for three datasets. To select the dataset TASK can be set to WIKISQL, TABFACT, or WTQ. 69 | 70 | ```bash 71 | python -m tapas.run_task_main \ 72 | --task="${TASK}" \ 73 | --max_seq_length=1024 \ 74 | --output_dir="${output_dir}/${HEURISTIC}" \ 75 | --init_checkpoint="${TAPAS_L_CHECK_POINT}/model.ckpt" \ 76 | --bert_config_file="${TAPAS_L_CHECK_POINT}/bert_config.json" \ 77 | --table_pruning_config_file="${CONGIG_FILE}" \ 78 | --reset_position_index_per_cell=true \ 79 | --mode="train" 80 | ``` 81 | 82 | ## Licence 83 | This code and data are licensed under the [Creative Commons Attribution-ShareAlike 3.0 Unported License](https://en.wikipedia.org/wiki/Wikipedia:Text_of_Creative_Commons_Attribution-ShareAlike_3.0_Unported_License).\ 84 | See also the Wikipedia [Copyrights](https://en.wikipedia.org/wiki/Wikipedia:Copyrights) page. 85 | 86 | ## How to cite this data and code? 87 | You can cite the [paper](https://arxiv.org/abs/2106.00479) to appear in ACL Findings 2021. 88 | -------------------------------------------------------------------------------- /INTERMEDIATE_PRETRAIN_DATA.md: -------------------------------------------------------------------------------- 1 | # Intermediate Pre-Training Data 2 | 3 | In our latest models, two training objectives are used after the standard mask 4 | language modeling and before fine-tuning, hence the name intermediate pre-training. Both objectives are binary classification tasks on a sentence table pair, where the tables are real Wikipedia tables based on the ones released [here](https://github.com/google-research/tapas/blob/master/PRETRAIN_DATA.md). To explain each of the two tasks let us use the following example table: 5 | 6 | | Rank | Player | Country | Earnings | Events | Wins | 7 | |------|-----------------|---------------|-----------|--------|------| 8 | | 1 | Greg Norman | Australia | 1,654,959 | 16 | 3 | 9 | | 2 | Billy Mayfair | United States | 1,543,192 | 28 | 2 | 10 | | 3 | Lee Janzen | United States | 1,378,966 | 28 | 3 | 11 | | 4 | Corey Pavin | United States | 1,340,079 | 22 | 2 | 12 | | 5 | Steve Elkington | Australia | 1,254,352 | 21 | 2 | 13 | 14 | ## Synthetic Examples 15 | 16 | We generate synthetic examples using a simple grammar. 17 | The grammar randomly generates two SQL-like phrases that can use aggregations or constants and compares one against each other using equality or numeric comparisons. 18 | We assign a binary label according to the truth value of the generated statement, and the algorithm is adjusted to get the same number of positives and negatives. 19 | The total amount of examples obtained is 3.7 million. 20 | 21 | Below are some examples, and we recommend looking at the section 3.2 in the [paper](https://www.aclweb.org/anthology/2020.findings-emnlp.27/) to see the full details. 22 | 23 | 1. **2** is less than **wins when Player is Lee Janzen**. 24 | The right hand side corresponds to the query 25 | 26 | SELECT wins FROM table WHERE player = "Lee Janzen" 27 | 28 | 2. **The sum of Earnings when Country is Australia** is **2,909,311**. 29 | The right hand side corresponds to the query 30 | 31 | SELECT SUM(earnings) FROM table WHERE country = "Australia" 32 | 33 | Although the language is artificial, these examples can improve the model numerical reasoning skills. 34 | 35 | ## Counterfactual Examples 36 | 37 | Counterfactual examples are created by randomly replacing an entity in a statement for a different but plausible entity. 38 | The original sentences are also obtained from Wikipedia by selecting text in the vicinity of the table or that has a hyperlink to the table from a different page. 39 | For example, if we get a text that reads 40 | 41 | > Greg Norman has the highest earnings 42 | 43 | we replace "Greg Norman" for another entity appearing in the same column of the table, obtaining 44 | 45 | > ~~Greg Norman~~ **Steve Elkington** has the highest earnings. 46 | 47 | The total amount of examples obtained is 4.1 million and the model is then asked to detect whether the sentence has been corrupted or not. 48 | The sentences obtained in this manner sound more natural, but the type of logical inference 49 | that the model has to perform on the table is typically simpler, since often looking at a single row is enough. You may also check section 3.1 in the [paper](https://www.aclweb.org/anthology/2020.findings-emnlp.27/) to see the full details. 50 | 51 | ## Model release 52 | 53 | The models trained on these datasets jointly are released with and without the fine-tuning on the end tasks (TabFact, SQA, etc.). All the links with the expected results are listed in the main [readme](https://github.com/google-research/tapas/blob/master/README.md#models). 54 | 55 | ## Code Release 56 | 57 | Data generation code can be run with: 58 | 59 | ```bash 60 | python3 setup.py sdist 61 | python3 tapas/create_intermediate_pretrain_examples_main.py \ 62 | --input_file="gs://tapas_models/2021_07_22/interactions.txtpb.gz" \ 63 | --vocab_file="gs://tapas_models/2021_07_22/vocab.txt" \ 64 | --output_dir="gs://your_bucket/output" \ 65 | --runner_type="DATAFLOW" \ 66 | --gc_project="you-project" \ 67 | --gc_region="us-west1" \ 68 | --gc_job_name="create-intermediate" \ 69 | --gc_staging_location="gs://your_bucket/staging" \ 70 | --gc_temp_location="gs://your_bucket/tmp" \ 71 | --extra_packages=dist/tapas-0.0.1.dev0.tar.gz 72 | ``` 73 | 74 | You can also run the pipeline locally but that will take a long time: 75 | 76 | ```bash 77 | python3 tapas/create_intermediate_pretrain_examples_main.py \ 78 | --input_file="$data/interactions.txtpb.gz" \ 79 | --output_dir="$data/" \ 80 | --vocab_file="$data/vocab.txt" \ 81 | --runner_type="DIRECT" 82 | ``` 83 | 84 | ## Licence 85 | 86 | This code and data are licensed under the [Creative Commons Attribution-ShareAlike 3.0 Unported License](https://en.wikipedia.org/wiki/Wikipedia:Text_of_Creative_Commons_Attribution-ShareAlike_3.0_Unported_License).\ 87 | See also the Wikipedia [Copyrights](https://en.wikipedia.org/wiki/Wikipedia:Copyrights) page. 88 | 89 | ## How to cite this data and code? 90 | 91 | You can cite the [paper](https://www.aclweb.org/anthology/2020.findings-emnlp.27/) to appear in 92 | EMNLP 2020 Findings. 93 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include requirements.txt 2 | -------------------------------------------------------------------------------- /MATE.md: -------------------------------------------------------------------------------- 1 | # MATE: Multi-view Attention for Table Transformer Efficiency 2 | This document contains models and steps to reproduce the results of [MATE: Multi-view Attention for Table Transformer Efficiency](https://arxiv.org/abs/2109.04312) published at EMNLP 2021. 3 | 4 | ## MATE Model 5 | 6 | Based on the intuition that attention across tokens in different columns and 7 | rows is not needed, MATE uses two types of attention heads that can either only 8 | attend within the same column or within the same row. 9 | 10 | MATE can be (approximately) implemented linearly by adapting an idea from 11 | Reformer (Kitaev et al., 2020): having column heads sort the input according to 12 | a column order and row heads according to the row order. 13 | Then the input is bucketed and attention restricted to adjacent buckets. 14 | 15 | ## Using MATE 16 | 17 | Using for pre-training or fine-tuning a model can be accomplished through the 18 | following configuration flags in `tapas_classifier_experiment.py`: 19 | 20 | * `--restrict_attention_mode=same_colum_or_row` Attention from tokens in 21 | different columns and rows is masked out. 22 | * `--restrict_attention_mode=headwise_same_colum_or_row` Row heads mask 23 | attention between different rows, and columns heads between columns. 24 | The `bucket_size` and `header_size` arguments define below can be optionally 25 | applied to mimic the efficient implementation. 26 | * `--restrict_attention_mode=headwise_efficient` Similar to 27 | `headwise_same_colum_or_row` but uses an log linear implementation by sorting 28 | the input tokens by column or row order depending on the type of attention head. 29 | * `--restrict_attention_bucket_size=` For sparse attention modes, further 30 | restricts attention to consecutive buckets of uniform size. Two tokens may only 31 | attend each other if the fall in consecutive buckets of this size. 32 | Only required for `restrict_attention_mode=headwise_efficient`. 33 | * `--restrict_attention_header_size=` For sparse attention modes, size of 34 | the first section that will attend to/from everything else. Only required for 35 | `restrict_attention_mode=headwise_efficient`. 36 | * `--restrict_attention_row_heads_ratio= For sparse attention modes, 37 | proportion of heads that should focus on rows vs columns. Default is 0.5. 38 | 39 | ## Licence 40 | 41 | This code and data are licensed under the [Creative Commons Attribution-ShareAlike 3.0 Unported License](https://en.wikipedia.org/wiki/Wikipedia:Text_of_Creative_Commons_Attribution-ShareAlike_3.0_Unported_License).\ 42 | See also the Wikipedia [Copyrights](https://en.wikipedia.org/wiki/Wikipedia:Copyrights) page. 43 | 44 | ## How to cite this data and code? 45 | 46 | You can cite the [paper](https://arxiv.org/abs/2109.04312) to appear in 47 | EMNLP 2021. 48 | -------------------------------------------------------------------------------- /PRETRAIN_DATA.md: -------------------------------------------------------------------------------- 1 | # Pre-Training Data 2 | 3 | The pre-training data consists of 6.2 million table-text examples extracted from 4 | the [English Wikipedia](https://en.wikipedia.org/wiki/Wikipedia) on December 2019. 5 | The associated text of a table is the page title and description, table caption as well as the section title and section text. 6 | 7 | 8 | ## Example 9 | 10 | This is an example in proto text format extracted from [this](https://en.wikipedia.org/wiki/ARY_Film_Award_for_Best_Dialogue) page. 11 | 12 | ``` 13 | table: { 14 | columns: { text: "Year" } 15 | columns: { text: "Film" } 16 | columns: { text: "Dialogue-writer(s)" } 17 | rows: { 18 | cells: { text: "2013\n(1st)" } 19 | cells: { text: "" } 20 | cells: { text: "" } 21 | } 22 | rows: { 23 | cells: { text: "2013\n(1st)" } 24 | cells: { text: "Main Hoon Shahid Afridi" } 25 | cells: { text: "Vasay Chaudhry" } 26 | } 27 | table_id: "http://en.wikipedia.org/wiki/ARY_Film_Award_for_Best_Dialogue_1" 28 | } 29 | questions: { 30 | id: "TITLE" 31 | original_text: "ARY Film Award for Best Dialogue" 32 | } 33 | questions: { 34 | id: "DESCRIPTION" 35 | original_text: "The ARY Film Award for Best Dialogue is the ARY Film Award for the best dialogues of the year in film. It is one of three writing awards in the Technical Awarding category." 36 | } 37 | questions: { 38 | id: "SEGMENT_TITLE" 39 | original_text: "2010s" 40 | } 41 | ``` 42 | 43 | ## Data 44 | 45 | You can find the latest version of the data [here](https://storage.googleapis.com/tapas_models/2020_05_11/interactions.txtpb.gz). 46 | We also provide a small [snapshot](https://storage.googleapis.com/tapas_models/2020_05_11/interactions_sample.txtpb.gz) of the first 100 interactions. 47 | 48 | ## Conversion to TF Examples 49 | 50 | `create_pretrain_examples_main.py` converts the data to TF examples. 51 | It can be run locally (that will take a long time on a single machine) or as a [Dataflow](https://cloud.google.com/dataflow) on Google Cloud. 52 | You can find command line snippets [here](https://github.com/google-research/tapas#pre-training). 53 | 54 | ## Parsing Protobuffers in Text Format 55 | 56 | In case you want to work with the data in ways we didn't anticipate you can 57 | simple parse them into proto objects line-by-line. 58 | 59 | Here is a simple example: 60 | 61 | ```python 62 | from google.protobuf import text_format 63 | from tapas.protos import interaction_pb2 64 | 65 | for line in input_file: 66 | interaction = text_format.Parse(line, interaction_pb2.Interaction()) 67 | ``` 68 | 69 | ## Licence 70 | 71 | This data is licensed under the [Creative Commons Attribution-ShareAlike 3.0 Unported License](https://en.wikipedia.org/wiki/Wikipedia:Text_of_Creative_Commons_Attribution-ShareAlike_3.0_Unported_License).\ 72 | See also the Wikipedia [Copyrights](https://en.wikipedia.org/wiki/Wikipedia:Copyrights) page. 73 | 74 | ## How to cite this data? 75 | 76 | You can cite the [ACL 2020 paper](https://www.aclweb.org/anthology/2020.acl-main.398/). 77 | -------------------------------------------------------------------------------- /TABLEFORMER.md: -------------------------------------------------------------------------------- 1 | # TableFormer: Robust Transformer Modeling for Table-Text Encoding 2 | This document contains models and steps to reproduce the results of [TableFormer: Robust Transformer Modeling for Table-Text Encoding](https://arxiv.org/abs/2203.00274) published at ACL 2022. 3 | 4 | ## TableFormer Model 5 | 6 | TableFormer encodes the general table structure along with the associated text 7 | by introducing task-independent relative attention biases for table-text 8 | encoding to facilitate the following: 9 | 10 | * structural inductive bias for better table understanding and table-text 11 | alignment, 12 | * robustness to table row/column perturbation. 13 | 14 | TableFormer is: 15 | * strictly invariant to row and column orders, and, 16 | * could understand tables better due to its tabular inductive biases. 17 | 18 | Our evaluations show that TableFormer outperforms strong baselines in all 19 | settings on SQA, WTQ and TABFACT table reasoning datasets, and achieves 20 | state-of-the-art performance on SQA, especially when facing answer-invariant 21 | row and column order perturbations (6% improvement over the best baseline), 22 | because previous SOTA models’ performance drops by 4% - 6% when facing such 23 | perturbations while TableFormer is not affected. 24 | 25 | ## Using TableFormer 26 | 27 | Using TableFormer for pre-training and fine-tuning can be acomplished through 28 | the following configuration flags in `tapas_pretraining_experiment.py` and 29 | `tapas_classifier_experiment.py`, respectively: 30 | 31 | * `--restrict_attention_mode=table_attention` Uses the 13 relative relational 32 | ids introduced in TableFormer. 33 | * `--attention_bias_use_relative_scalar_only` Whether to just use a scalar bias 34 | or an embedding per relative id per head per layer. 35 | * `--attention_bias_disabled` Which relational id to be disabled. This should 36 | only be used for ablation studies, otherwise defaults to 0. 37 | 38 | 39 | ## Licence 40 | 41 | This code and data are licensed under the [Creative Commons Attribution-ShareAlike 3.0 Unported License](https://en.wikipedia.org/wiki/Wikipedia:Text_of_Creative_Commons_Attribution-ShareAlike_3.0_Unported_License).\ 42 | See also the Wikipedia [Copyrights](https://en.wikipedia.org/wiki/Wikipedia:Copyrights) page. 43 | 44 | ## How to cite this data and code? 45 | 46 | You can cite the [paper](https://arxiv.org/abs/2203.00274) to appear in 47 | ACL 2022. 48 | 49 | ``` 50 | @inproceedings{yang-etal-2022-tableformer, 51 | title="{TableFormer: Robust Transformer Modeling for Table-Text Encoding}", 52 | author="Jingfeng Yang and Aditya Gupta and Shyam Upadhyay and Luheng He and Rahul Goel and Shachi Paul", 53 | booktitle = "Proc. of ACL", 54 | year = "2022" 55 | } 56 | ``` 57 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | apache-beam[gcp]==2.28.0 2 | dataclasses~=0.7; python_version < '3.7' 3 | frozendict==1.2 4 | pandas~=1.0.0 5 | scikit-learn~=0.22.1 6 | tensorflow~=2.2.0 7 | tf-models-official~=2.2.0 8 | # Kaggle required by tf-models-official is incompatible with py36 from slugify 9 | kaggle<1.5.8 10 | tensorflow-probability==0.10.1 11 | tf_slim~=1.1.0 12 | nltk~=3.5 13 | beautifulsoup4==4.9.3 14 | html5lib==1.1 15 | gensim~=3.8.3 16 | lxml~=4.6.0 17 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Install source code for the Tapas paper.""" 16 | from distutils import spawn 17 | import glob 18 | import os 19 | import subprocess 20 | import sys 21 | 22 | from setuptools import find_packages 23 | from setuptools import setup 24 | 25 | 26 | def find_protoc(): 27 | """Find the Protocol Compiler.""" 28 | if "PROTOC" in os.environ and os.path.exists(os.environ["PROTOC"]): 29 | return os.environ["PROTOC"] 30 | elif os.path.exists("../src/protoc"): 31 | return "../src/protoc" 32 | elif os.path.exists("../src/protoc.exe"): 33 | return "../src/protoc.exe" 34 | elif os.path.exists("../vsprojects/Debug/protoc.exe"): 35 | return "../vsprojects/Debug/protoc.exe" 36 | elif os.path.exists("../vsprojects/Release/protoc.exe"): 37 | return "../vsprojects/Release/protoc.exe" 38 | else: 39 | return spawn.find_executable("protoc") 40 | 41 | 42 | def needs_update(source, target): 43 | """Returns wheter target file is old or does not exist.""" 44 | if not os.path.exists(target): 45 | return True 46 | if not os.path.exists(source): 47 | return False 48 | return os.path.getmtime(source) > os.path.getmtime(target) 49 | 50 | 51 | def fail(message): 52 | """Write message to stderr and finish.""" 53 | sys.stderr.write(message + "\n") 54 | sys.exit(-1) 55 | 56 | 57 | def generate_proto(protoc, source): 58 | """Invokes the Protocol Compiler to generate a _pb2.py.""" 59 | 60 | target = source.replace(".proto", "_pb2.py") 61 | 62 | if needs_update(source, target): 63 | print(f"Generating {target}...") 64 | 65 | if not os.path.exists(source): 66 | fail(f"Cannot find required file: {source}") 67 | 68 | if protoc is None: 69 | fail("protoc is not installed nor found in ../src. Please compile it " 70 | "or install the binary package.") 71 | 72 | protoc_command = [protoc, "-I.", "--python_out=.", source] 73 | if subprocess.call(protoc_command) != 0: 74 | fail(f"Command fail: {' '.join(protoc_command)}") 75 | 76 | 77 | def prepare(): 78 | """Find all proto files and generate the pb2 ones.""" 79 | proto_file_patterns = ["./tapas/protos/*.proto"] 80 | protoc = find_protoc() 81 | for file_pattern in proto_file_patterns: 82 | for proto_file in glob.glob(file_pattern, recursive=True): 83 | generate_proto(protoc, proto_file) 84 | 85 | 86 | def read(fname): 87 | return open( 88 | os.path.join(os.path.dirname(__file__), fname), encoding="utf-8").read() 89 | 90 | 91 | prepare() 92 | setup( 93 | name="tapas-table-parsing", 94 | version="0.0.1.dev", 95 | packages=find_packages(), 96 | description="Tapas: Table-based Question Answering.", 97 | long_description_content_type="text/markdown", 98 | long_description=read("README.md"), 99 | author="Google Inc.", 100 | url="https://github.com/google-research/tapas", 101 | license="Apache 2.0", 102 | install_requires=read("requirements.txt").strip().split("\n")) 103 | -------------------------------------------------------------------------------- /tapas/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /tapas/create_intermediate_pretrain_examples_main.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Pipeline for generating synthetic statements for tables. 16 | 17 | This implements the intermediate pre-training generation discussed in 18 | "Understanding tables with intermediate pre-training" 19 | (https://arxiv.org/abs/2010.00571). 20 | 21 | The script will output TF examples (default) or interactions 22 | (if `convert_to_examples` is set to false). 23 | 24 | The outputs will always be deterministically split into a train and a test 25 | set. 26 | 27 | If a `convert_to_examples` is true the intermediate interactions will be 28 | written to an additional interaction file 29 | (with path "f{output_dir}/interaction{output_suffix}"). 30 | """ 31 | 32 | from absl import app 33 | from absl import flags 34 | from tapas.utils import beam_runner 35 | from tapas.utils import intermediate_pretrain_utils 36 | from tapas.utils import synthesize_entablement 37 | from tapas.utils import tf_example_utils 38 | 39 | flags.DEFINE_string( 40 | "input_file", 41 | None, 42 | "'.txtpb' or '.tfrecord' files with interaction protos.", 43 | ) 44 | flags.DEFINE_string( 45 | "output_dir", 46 | None, 47 | "output directory.", 48 | ) 49 | flags.DEFINE_string( 50 | "output_suffix", 51 | ".tfrecord", 52 | "Should be '.tfrecod' or '.txtpb'", 53 | ) 54 | flags.DEFINE_string( 55 | "vocab_file", 56 | None, 57 | "The vocabulary file that the BERT model was trained on.", 58 | ) 59 | flags.DEFINE_bool( 60 | "convert_to_examples", 61 | True, 62 | "If true convert interactions to examples.", 63 | ) 64 | flags.DEFINE_enum_class( 65 | "mode", 66 | intermediate_pretrain_utils.Mode.ALL, 67 | intermediate_pretrain_utils.Mode, 68 | "Mode to run in.", 69 | ) 70 | flags.DEFINE_integer( 71 | "max_seq_length", 72 | 128, 73 | "See tf_example_utils.ClassifierConversionConfig", 74 | ) 75 | flags.DEFINE_boolean( 76 | "use_fake_table", 77 | False, 78 | "Replace table with a constant.", 79 | ) 80 | flags.DEFINE_boolean( 81 | "add_opposite_table", 82 | False, 83 | "If, true add opposite table.", 84 | ) 85 | flags.DEFINE_float( 86 | "prob_count_aggregation", 87 | 0.02, 88 | "See SynthesizationConfig.", 89 | ) 90 | flags.DEFINE_float( 91 | "drop_without_support_rate", 92 | 1.0, 93 | "If true, drop contrastive examples without support.", 94 | ) 95 | 96 | FLAGS = flags.FLAGS 97 | 98 | 99 | def main(unused_argv): 100 | del unused_argv 101 | config = synthesize_entablement.SynthesizationConfig( 102 | prob_count_aggregation=FLAGS.prob_count_aggregation,) 103 | conversion_config = None 104 | if FLAGS.convert_to_examples: 105 | conversion_config = tf_example_utils.ClassifierConversionConfig( 106 | vocab_file=FLAGS.vocab_file, 107 | max_seq_length=FLAGS.max_seq_length, 108 | max_column_id=FLAGS.max_seq_length, 109 | max_row_id=FLAGS.max_seq_length, 110 | strip_column_names=False, 111 | ) 112 | pipeline = intermediate_pretrain_utils.build_pipeline( 113 | mode=FLAGS.mode, 114 | config=config, 115 | use_fake_table=FLAGS.use_fake_table, 116 | add_opposite_table=FLAGS.add_opposite_table, 117 | drop_without_support_rate=FLAGS.drop_without_support_rate, 118 | input_file=FLAGS.input_file, 119 | output_dir=FLAGS.output_dir, 120 | output_suffix=FLAGS.output_suffix, 121 | conversion_config=conversion_config) 122 | beam_runner.run(pipeline) 123 | 124 | 125 | if __name__ == "__main__": 126 | flags.mark_flag_as_required("input_file") 127 | flags.mark_flag_as_required("output_dir") 128 | app.run(main) 129 | -------------------------------------------------------------------------------- /tapas/create_pretrain_examples_main.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | r"""Converts interactions in text format to TF examples. 16 | 17 | # Direct runner. 18 | python3 tapas/create_pretrain_examples_main.py \ 19 | --input_file="/tmp/interactions_sample.txtpb.gz" \ 20 | --output_dir="/tmp/" \ 21 | --vocab_file="/tmp/vocab.txt" \ 22 | --runner_type="DIRECT" 23 | 24 | # Dataflow runner (runs on Google Cloud). 25 | python3 setup.py sdist 26 | python3 tapas/create_pretrain_examples_main.py \ 27 | --input_file="gs://tapas_models/2020_05_11/interactions_sample.txtpb.gz" \ 28 | --output_dir="gs://your_bucket/output" \ 29 | --vocab_file="gs://tapas_models/2020_05_11/vocab.txt" \ 30 | --runner_type=DATAFLOW \ 31 | --gc_project="you-project" \ 32 | --gc_region="us-west1" \ 33 | --gc_job_name="create-pretrain" \ 34 | --gc_staging_location="gs://your_bucket/staging" \ 35 | --gc_temp_location="gs://your_bucket/tmp" \ 36 | --extra_packages=dist/tapas-0.0.1.dev0.tar.gz 37 | 38 | """ 39 | 40 | from absl import app 41 | from absl import flags 42 | 43 | from tapas.utils import beam_runner 44 | from tapas.utils import create_data 45 | from tapas.utils import tf_example_utils 46 | 47 | FLAGS = flags.FLAGS 48 | 49 | flags.DEFINE_string("input_file", None, 50 | "Compressed interaction in text format.") 51 | flags.DEFINE_string("output_dir", None, 52 | "Directory where new data is written to.") 53 | flags.DEFINE_string("vocab_file", None, 54 | "The vocabulary file that the BERT model was trained on.") 55 | 56 | flags.DEFINE_integer("max_seq_length", 128, 57 | "See tf_example_utils.PretrainConversionConfig") 58 | flags.DEFINE_integer("max_predictions_per_seq", 20, 59 | "See tf_example_utils.PretrainConversionConfig") 60 | flags.DEFINE_integer("random_seed", 12345, 61 | "See tf_example_utils.PretrainConversionConfig") 62 | flags.DEFINE_integer("dupe_factor", 10, 63 | "See tf_example_utils.PretrainConversionConfig") 64 | flags.DEFINE_float("masked_lm_prob", 0.15, 65 | "See tf_example_utils.PretrainConversionConfig") 66 | flags.DEFINE_integer("max_column_id", 512, 67 | "See tf_example_utils.PretrainConversionConfig") 68 | flags.DEFINE_integer("max_row_id", 512, 69 | "See tf_example_utils.PretrainConversionConfig") 70 | flags.DEFINE_integer("min_num_rows", 0, 71 | "See tf_example_utils.PretrainConversionConfig") 72 | flags.DEFINE_integer("min_num_columns", 0, 73 | "See tf_example_utils.PretrainConversionConfig") 74 | flags.DEFINE_integer("min_question_length", 8, 75 | "See tf_example_utils.PretrainConversionConfig") 76 | flags.DEFINE_integer("max_question_length", 32, 77 | "See tf_example_utils.PretrainConversionConfig") 78 | flags.DEFINE_bool("always_continue_cells", True, 79 | "See tf_example_utils.PretrainConversionConfig") 80 | 81 | 82 | def main(argv): 83 | if len(argv) > 1: 84 | raise app.UsageError("Too many command-line arguments.") 85 | config = tf_example_utils.PretrainConversionConfig( 86 | vocab_file=FLAGS.vocab_file, 87 | max_seq_length=FLAGS.max_seq_length, 88 | max_predictions_per_seq=FLAGS.max_predictions_per_seq, 89 | random_seed=FLAGS.random_seed, 90 | masked_lm_prob=FLAGS.masked_lm_prob, 91 | max_column_id=FLAGS.max_column_id, 92 | max_row_id=FLAGS.max_row_id, 93 | min_question_length=FLAGS.min_question_length, 94 | max_question_length=FLAGS.max_question_length, 95 | always_continue_cells=FLAGS.always_continue_cells, 96 | strip_column_names=False, 97 | ) 98 | pipeline = create_data.build_pretraining_pipeline( 99 | input_file=FLAGS.input_file, 100 | output_dir=FLAGS.output_dir, 101 | output_suffix=".tfrecord", 102 | config=config, 103 | dupe_factor=FLAGS.dupe_factor, 104 | min_num_rows=FLAGS.min_num_rows, 105 | min_num_columns=FLAGS.min_num_columns, 106 | ) 107 | beam_runner.run(pipeline) 108 | 109 | 110 | if __name__ == "__main__": 111 | app.run(main) 112 | -------------------------------------------------------------------------------- /tapas/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /tapas/datasets/dataset.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Contains dataset utility functions.""" 16 | 17 | from typing import Iterable, Text, Callable, Any, Mapping 18 | import tensorflow.compat.v1 as tf 19 | 20 | 21 | ParseExampleFn = Callable[[Any], Mapping[Text, Any]] 22 | 23 | 24 | def read_dataset( 25 | parse_examples_fn, 26 | name, 27 | file_patterns, 28 | data_format, 29 | compression_type, 30 | is_training, 31 | params, 32 | ): 33 | """Returns an input_fn that can be used with the tf.Estimator API.""" 34 | with tf.variable_scope(name): 35 | batch_size = params["batch_size"] 36 | # This is used mainly by the test to remove any source of randomness. 37 | cycle_length = params.get("cycle_length", 64) 38 | 39 | dataset = tf.data.Dataset.list_files(file_patterns, shuffle=is_training) 40 | 41 | if is_training: 42 | dataset = dataset.repeat() 43 | 44 | def fetch_dataset(filename): 45 | if data_format == "tfrecord": 46 | buffer_size = 8 * 1024 * 1024 # 8 MiB per file 47 | return tf.data.TFRecordDataset( 48 | filename, 49 | buffer_size=buffer_size, 50 | compression_type=compression_type, 51 | ) 52 | raise ValueError("Unsupported data_format: {}".format(data_format)) 53 | 54 | dataset = dataset.apply( 55 | tf.data.experimental.parallel_interleave( 56 | fetch_dataset, sloppy=is_training, cycle_length=cycle_length)) 57 | 58 | if is_training: 59 | dataset = dataset.shuffle(1024) 60 | else: 61 | max_eval_count = params.get("max_eval_count") 62 | if max_eval_count is not None: 63 | dataset = dataset.take(max_eval_count) 64 | 65 | parse_fn = parse_examples_fn 66 | 67 | dataset = dataset.apply( 68 | tf.data.experimental.map_and_batch( 69 | parse_fn, 70 | batch_size=batch_size, 71 | num_parallel_calls=tf.data.experimental.AUTOTUNE, 72 | drop_remainder=params.get("drop_remainder", False) or is_training)) 73 | dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE) 74 | return dataset 75 | 76 | 77 | def build_parser_function(feature_types, 78 | params): 79 | """Returns a parse function that can be used by read_dataset.""" 80 | del params 81 | 82 | def parse_examples(serialized_examples): 83 | features = tf.io.parse_single_example(serialized_examples, feature_types) 84 | # tf.Example only supports tf.int64, but the TPU only supports tf.int32. 85 | # So cast all int64 to int32. 86 | for name in list(features.keys()): 87 | t = features[name] 88 | if t.dtype == tf.int64: 89 | t = tf.cast(t, tf.int32) 90 | features[name] = t 91 | return features 92 | 93 | return parse_examples 94 | -------------------------------------------------------------------------------- /tapas/datasets/dataset_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import os 17 | import tempfile 18 | 19 | from absl.testing import parameterized 20 | from tapas.datasets import dataset 21 | import tensorflow.compat.v1 as tf 22 | 23 | 24 | 25 | tf.disable_v2_behavior() 26 | 27 | 28 | def write_tf_example(filename, data_format, features): 29 | example = tf.train.Example(features=tf.train.Features(feature=features)) 30 | if data_format == "tfrecord": 31 | with tf.io.TFRecordWriter(filename) as writer: 32 | writer.write(example.SerializeToString()) 33 | else: 34 | raise ValueError("Unsupported data_format: {}".format(data_format)) 35 | 36 | 37 | class DatasetTest(parameterized.TestCase, tf.test.TestCase): 38 | 39 | def setUp(self): 40 | super(DatasetTest, self).setUp() 41 | 42 | # We add a prefix because the dataset API matches files and then sort them 43 | # lexicographically. 44 | self._file1 = tempfile.mktemp(prefix="1", suffix="test") 45 | self._file2 = tempfile.mktemp(prefix="2", suffix="test-00010-of-00020") 46 | self._file_patterns = [ 47 | self._file1, 48 | # We use a ? to check that glob mechanism works. 49 | self._file2.replace("00010-of-00020", "000?0-of-00020") 50 | ] 51 | 52 | # Creates empty files to avoid errors in tearDown when self.cached_session() 53 | # is executed. 54 | open(self._file1, "a").close() 55 | open(self._file2, "a").close() 56 | 57 | self._file_patterns = [self._file1, self._file2] 58 | 59 | def tearDown(self): 60 | super(DatasetTest, self).tearDown() 61 | 62 | os.remove(self._file1) 63 | os.remove(self._file2) 64 | 65 | @parameterized.named_parameters( 66 | ("train_f1_f2", "tfrecord", True, dict(batch_size=2), (True, True)), 67 | ("train_f1", "tfrecord", True, dict(batch_size=1), (True, False)), 68 | ("train_f2", "tfrecord", True, dict(batch_size=1), (False, True)), 69 | ("test_f1_f2", "tfrecord", False, dict(batch_size=2, cycle_length=1), 70 | (True, True)), 71 | ("test_f1", "tfrecord", False, dict(batch_size=1, cycle_length=1), 72 | (True, False)), 73 | ("test_f2", "tfrecord", False, dict(batch_size=1, cycle_length=1), 74 | (False, True))) 75 | def test_read_dataset(self, data_format, is_training, params, 76 | include_patterns): 77 | write_tf_example( 78 | self._file1, data_format, { 79 | "name": 80 | tf.train.Feature(bytes_list=tf.train.BytesList(value=[b"one"])), 81 | "number": 82 | tf.train.Feature(int64_list=tf.train.Int64List(value=[1])), 83 | }) 84 | write_tf_example( 85 | self._file2, data_format, { 86 | "name": 87 | tf.train.Feature(bytes_list=tf.train.BytesList(value=[b"two"])), 88 | "number": 89 | tf.train.Feature(int64_list=tf.train.Int64List(value=[2])), 90 | }) 91 | 92 | feature_types = { 93 | "name": tf.io.FixedLenFeature([], tf.string), 94 | "number": tf.io.FixedLenFeature([], tf.int64), 95 | } 96 | 97 | parse_fn = dataset.build_parser_function(feature_types, params) 98 | 99 | def filter_fn(xs): 100 | return [x for (x, include) in zip(xs, include_patterns) if include] 101 | 102 | patterns = filter_fn(self._file_patterns) 103 | ds = dataset.read_dataset( 104 | parse_fn, 105 | "dataset", 106 | patterns, 107 | data_format, 108 | compression_type="", 109 | is_training=is_training, 110 | params=params, 111 | ) 112 | feature_tuple = tf.data.make_one_shot_iterator(ds).get_next() 113 | 114 | with self.cached_session() as sess: 115 | feature_tuple = sess.run(feature_tuple) 116 | 117 | if params["batch_size"] == 1: 118 | self.assertIsInstance(feature_tuple, dict) 119 | else: 120 | self.assertLen(feature_tuple, params["batch_size"]) 121 | 122 | if not is_training: 123 | expected_names = filter_fn([b"one", b"two"]) 124 | expected_numbers = filter_fn([1, 2]) 125 | self.assertSequenceEqual(list(feature_tuple["name"]), expected_names) 126 | self.assertSequenceEqual(list(feature_tuple["number"]), expected_numbers) 127 | 128 | @parameterized.named_parameters( 129 | ("tfrecord", "tfrecord")) 130 | def test_read_dataset_test_shape_is_fully_known(self, data_format): 131 | write_tf_example(self._file1, data_format, { 132 | "number": tf.train.Feature(int64_list=tf.train.Int64List(value=[1])), 133 | }) 134 | feature_types = { 135 | "number": tf.io.FixedLenFeature([], tf.int64), 136 | } 137 | params = {"batch_size": 5} 138 | parse_fn = dataset.build_parser_function(feature_types, params) 139 | ds = dataset.read_dataset( 140 | parse_fn, 141 | "dataset", 142 | file_patterns=[self._file1], 143 | data_format=data_format, 144 | compression_type="", 145 | is_training=True, 146 | params=params, 147 | ) 148 | feature_tuple = tf.data.make_one_shot_iterator(ds).get_next() 149 | feature_tuple["number"].shape.assert_is_fully_defined() 150 | 151 | 152 | if __name__ == "__main__": 153 | tf.test.main() 154 | -------------------------------------------------------------------------------- /tapas/datasets/table_dataset_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from absl import logging 17 | from absl.testing import absltest 18 | from absl.testing import parameterized 19 | import numpy as np 20 | from tapas.datasets import table_dataset 21 | from tapas.datasets import table_dataset_test_utils 22 | import tensorflow.compat.v1 as tf 23 | 24 | tf.disable_v2_behavior() 25 | 26 | 27 | class TableDatasetTest(parameterized.TestCase, tf.test.TestCase): 28 | 29 | @parameterized.named_parameters( 30 | ("train_with_aggregation", 10, 10, table_dataset.TableTask.CLASSIFICATION, 31 | True, False, False, False, False), 32 | ("train_with_weak_supervision", 10, 10, 33 | table_dataset.TableTask.CLASSIFICATION, True, True, False, False, False), 34 | ("train", 10, 10, table_dataset.TableTask.CLASSIFICATION, False, False, 35 | False, False, False), 36 | ("pretrain", 10, 10, table_dataset.TableTask.PRETRAINING, False, False, 37 | False, False, False), 38 | ("predict", 10, 10, table_dataset.TableTask.CLASSIFICATION, False, False, 39 | True, False, False), 40 | ("predict_with_aggregation", 10, 10, 41 | table_dataset.TableTask.CLASSIFICATION, True, False, True, False, False), 42 | ("predict_with_weak_supervision", 10, 10, 43 | table_dataset.TableTask.CLASSIFICATION, True, True, True, False, False), 44 | ("train_with_candidate_answers", 10, 10, 45 | table_dataset.TableTask.CLASSIFICATION, True, True, False, True, False), 46 | ("train_with_classification", 10, 10, 47 | table_dataset.TableTask.CLASSIFICATION, True, False, False, False, True), 48 | ("predict_with_classification", 10, 10, 49 | table_dataset.TableTask.CLASSIFICATION, True, False, True, False, True), 50 | ("train_with_dual_encoder", 10, 10, table_dataset.TableTask.RETRIEVAL, 51 | False, False, False, False, False), 52 | ("predict_with_dual_encoder", 10, 10, table_dataset.TableTask.RETRIEVAL, 53 | False, False, True, False, False), 54 | ) 55 | def test_parse_table_examples(self, max_seq_length, max_predictions_per_seq, 56 | task_type, add_aggregation_function_id, 57 | add_answer, include_id, add_candidate_answers, 58 | add_classification_labels): 59 | logging.info("Setting random seed to 42") 60 | np.random.seed(42) 61 | max_num_candidates = 10 62 | values = table_dataset_test_utils.create_random_example( 63 | max_seq_length, 64 | max_predictions_per_seq, 65 | task_type, 66 | add_aggregation_function_id, 67 | add_classification_labels, 68 | add_answer, 69 | include_id, 70 | vocab_size=10, 71 | segment_vocab_size=3, 72 | num_columns=3, 73 | num_rows=2, 74 | add_candidate_answers=add_candidate_answers, 75 | max_num_candidates=max_num_candidates) 76 | example = table_dataset_test_utils.make_tf_example(values) 77 | 78 | params = {} 79 | parse_fn = table_dataset.parse_table_examples( 80 | max_seq_length=max_seq_length, 81 | max_predictions_per_seq=max_predictions_per_seq, 82 | task_type=task_type, 83 | add_aggregation_function_id=add_aggregation_function_id, 84 | add_classification_labels=add_classification_labels, 85 | add_answer=add_answer, 86 | include_id=include_id, 87 | add_candidate_answers=add_candidate_answers, 88 | max_num_candidates=max_num_candidates, 89 | params=params, 90 | ) 91 | features = parse_fn(example.SerializeToString()) 92 | 93 | with self.cached_session() as sess: 94 | features_vals = sess.run(features) 95 | 96 | for value in values: 97 | if value == "can_indexes": 98 | continue 99 | if values[value].dtype == np.float32 or values[value].dtype == np.int32: 100 | np.testing.assert_almost_equal(features_vals[value], values[value]) 101 | else: # Handle feature as string. 102 | np.testing.assert_equal(features_vals[value], values[value]) 103 | 104 | if add_candidate_answers: 105 | self.assertEqual(features_vals["can_label_ids"].dtype, np.int32) 106 | self.assertAllEqual(features_vals["can_label_ids"].shape, 107 | [max_num_candidates, max_seq_length]) 108 | 109 | # The total number of label_ids set to 1 must match the total number 110 | # of indices. 111 | num_indices = len(values["can_indexes"]) 112 | self.assertEqual(features_vals["can_label_ids"].sum(), num_indices) 113 | 114 | # Check that the correct indices are set to 1. 115 | cand_id = 0 116 | cand_start = 0 117 | for i in range(len(values["can_indexes"])): 118 | while i - cand_start >= values["can_sizes"][cand_id]: 119 | cand_id += 1 120 | cand_start = i 121 | token_id = values["can_indexes"][i] 122 | self.assertEqual(features_vals["can_label_ids"][cand_id, token_id], 1) 123 | 124 | 125 | if __name__ == "__main__": 126 | absltest.main() 127 | -------------------------------------------------------------------------------- /tapas/experiments/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /tapas/models/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /tapas/models/bert/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /tapas/models/bert/optimization_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from tapas.models.bert import optimization 17 | import tensorflow.compat.v1 as tf 18 | 19 | tf.disable_v2_behavior() 20 | 21 | 22 | class OptimizationTest(tf.test.TestCase): 23 | 24 | def test_adam(self): 25 | with self.test_session() as sess: 26 | w = tf.get_variable( 27 | "w", 28 | shape=[3], 29 | initializer=tf.constant_initializer([0.1, -0.2, -0.1])) 30 | x = tf.constant([0.4, 0.2, -0.5]) 31 | loss = tf.reduce_mean(tf.square(x - w)) 32 | tvars = tf.trainable_variables() 33 | grads = tf.gradients(loss, tvars) 34 | global_step = tf.train.get_or_create_global_step() 35 | optimizer = optimization.AdamWeightDecayOptimizer(learning_rate=0.2) 36 | train_op = optimizer.apply_gradients(zip(grads, tvars), global_step) 37 | init_op = tf.group(tf.global_variables_initializer(), 38 | tf.local_variables_initializer()) 39 | sess.run(init_op) 40 | for _ in range(100): 41 | sess.run(train_op) 42 | w_np = sess.run(w) 43 | self.assertAllClose(w_np.flat, [0.4, 0.2, -0.5], rtol=1e-2, atol=1e-2) 44 | 45 | def test_gradient_accumulation_empty_variables(self): 46 | optimizer = optimization.GradientAccumulationOptimizer( 47 | tf.train.RMSPropOptimizer( 48 | learning_rate=.2, decay=.9, momentum=.9, epsilon=1.0), 49 | steps=2, 50 | ) 51 | self.assertEmpty(optimizer.variables()) 52 | 53 | def test_gradient_accumulation(self): 54 | with self.test_session() as sess: 55 | w = tf.get_variable( 56 | "w", 57 | shape=[3], 58 | initializer=tf.constant_initializer([0.1, -0.2, -0.1])) 59 | x = tf.constant([0.4, 0.2, -0.5]) 60 | loss = tf.reduce_mean(tf.square(x - w)) 61 | tvars = tf.trainable_variables() 62 | grads = tf.gradients(loss, tvars) 63 | global_step = tf.train.get_or_create_global_step() 64 | optimizer = optimization.GradientAccumulationOptimizer( 65 | optimization.AdamWeightDecayOptimizer(learning_rate=0.2), 66 | steps=2, 67 | ) 68 | train_op = optimizer.apply_gradients(zip(grads, tvars), global_step) 69 | init_op = tf.group(tf.global_variables_initializer(), 70 | tf.local_variables_initializer()) 71 | sess.run(init_op) 72 | 73 | # After one step weights should be fixed 74 | sess.run(train_op) 75 | w_np = sess.run(w) 76 | self.assertAllClose(w_np.flat, [0.1, -0.2, -0.1]) 77 | 78 | # After two steps weights should have changed 79 | sess.run(train_op) 80 | w_np = sess.run(w) 81 | self.assertNotAllClose(w_np.flat, [0.1, -0.2, -0.1]) 82 | 83 | for _ in range(200): 84 | sess.run(train_op) 85 | w_np = sess.run(w) 86 | self.assertAllClose(w_np.flat, [0.4, 0.2, -0.5], rtol=1e-2, atol=1e-2) 87 | 88 | 89 | if __name__ == "__main__": 90 | tf.test.main() 91 | -------------------------------------------------------------------------------- /tapas/models/tapas_classifier_model_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """TAPAS BERT model utils for classification.""" 16 | 17 | from typing import Dict, Text, Tuple, Optional 18 | from tapas.models import segmented_tensor 19 | import tensorflow.compat.v1 as tf 20 | 21 | EPSILON_ZERO_DIVISION = 1e-10 22 | CLOSE_ENOUGH_TO_LOG_ZERO = -10000.0 23 | 24 | 25 | def classification_initializer(): 26 | """Classification layer initializer.""" 27 | return tf.truncated_normal_initializer(stddev=0.02) 28 | 29 | 30 | def extract_answer_from_features( 31 | features, use_answer_as_supervision 32 | ): 33 | """Extracts the answer, numeric_values, numeric_values_scale.""" 34 | if use_answer_as_supervision: 35 | answer = tf.squeeze(features["answer"], axis=[1]) 36 | numeric_values = features["numeric_values"] 37 | numeric_values_scale = features["numeric_values_scale"] 38 | else: 39 | answer = None 40 | numeric_values = None 41 | numeric_values_scale = None 42 | return answer, numeric_values, numeric_values_scale 43 | 44 | 45 | def compute_token_logits(output_layer, temperature, 46 | init_cell_selection_weights_to_zero): 47 | """Computes logits per token. 48 | 49 | Args: 50 | output_layer: [batch_size, seq_length, hidden_dim] Output of the 51 | encoder layer. 52 | temperature: float Temperature for the Bernoulli distribution. 53 | init_cell_selection_weights_to_zero: Whether the initial weights should be 54 | set to 0. This ensures that all tokens have the same prior probability. 55 | 56 | Returns: 57 | [batch_size, seq_length] Logits per token. 58 | """ 59 | hidden_size = output_layer.shape.as_list()[-1] 60 | output_weights = tf.get_variable( 61 | "output_weights", [hidden_size], 62 | initializer=tf.zeros_initializer() 63 | if init_cell_selection_weights_to_zero else classification_initializer()) 64 | output_bias = tf.get_variable( 65 | "output_bias", shape=(), initializer=tf.zeros_initializer()) 66 | logits = (tf.einsum("bsj,j->bs", output_layer, output_weights) + 67 | output_bias) / temperature 68 | return logits 69 | 70 | 71 | # TODO(eisenjulian): Move more methods from tapas_classifier_model 72 | def compute_column_logits(output_layer, 73 | cell_index, 74 | cell_mask, 75 | init_cell_selection_weights_to_zero, 76 | allow_empty_column_selection): 77 | """Computes logits for each column. 78 | 79 | Args: 80 | output_layer: [batch_size, seq_length, hidden_dim] Output of the 81 | encoder layer. 82 | cell_index: segmented_tensor.IndexMap [batch_size, seq_length] Index that 83 | groups tokens into cells. 84 | cell_mask: [batch_size, max_num_rows * max_num_cols] Input mask per 85 | cell, 1 for cells that exists in the example and 0 for padding. 86 | init_cell_selection_weights_to_zero: Whether the initial weights should be 87 | set to 0. This is also applied to column logits, as they are used to 88 | select the cells. This ensures that all columns have the same prior 89 | probability. 90 | allow_empty_column_selection: Allow to select no column. 91 | 92 | Returns: 93 | [batch_size, max_num_cols] Logits per column. Logits will be set to 94 | a very low value (such that the probability is 0) for the special id 0 95 | (which means "outside the table") or columns that do not apear in the 96 | table. 97 | """ 98 | hidden_size = output_layer.shape.as_list()[-1] 99 | column_output_weights = tf.get_variable( 100 | "column_output_weights", [hidden_size], 101 | initializer=tf.zeros_initializer() 102 | if init_cell_selection_weights_to_zero else classification_initializer()) 103 | column_output_bias = tf.get_variable( 104 | "column_output_bias", shape=(), initializer=tf.zeros_initializer()) 105 | token_logits = ( 106 | tf.einsum("bsj,j->bs", output_layer, column_output_weights) + 107 | column_output_bias) 108 | 109 | # Average the logits per cell and then per column. 110 | # Note that by linearity it doesn't matter if we do the averaging on the 111 | # embeddings or on the logits. For performance we do the projection first. 112 | # [batch_size, max_num_cols * max_num_rows] 113 | cell_logits, cell_logits_index = segmented_tensor.reduce_mean( 114 | token_logits, cell_index) 115 | 116 | column_index = cell_index.project_inner(cell_logits_index) 117 | # [batch_size, max_num_cols] 118 | column_logits, out_index = segmented_tensor.reduce_sum( 119 | cell_logits * cell_mask, column_index) 120 | cell_count, _ = segmented_tensor.reduce_sum(cell_mask, column_index) 121 | column_logits /= cell_count + EPSILON_ZERO_DIVISION 122 | 123 | # Mask columns that do not appear in the example. 124 | is_padding = tf.logical_and(cell_count < 0.5, 125 | tf.not_equal(out_index.indices, 0)) 126 | column_logits += CLOSE_ENOUGH_TO_LOG_ZERO * tf.cast(is_padding, tf.float32) 127 | 128 | if not allow_empty_column_selection: 129 | column_logits += CLOSE_ENOUGH_TO_LOG_ZERO * tf.cast( 130 | tf.equal(out_index.indices, 0), tf.float32) 131 | 132 | return column_logits 133 | -------------------------------------------------------------------------------- /tapas/models/tapas_pretraining_model_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from absl.testing import parameterized 17 | import numpy as np 18 | from tapas.datasets import table_dataset 19 | from tapas.datasets import table_dataset_test_utils 20 | from tapas.models import tapas_pretraining_model 21 | from tapas.models.bert import modeling 22 | import tensorflow.compat.v1 as tf 23 | from tensorflow.compat.v1 import estimator as tf_estimator 24 | 25 | 26 | class TapasPretrainingModelTest(parameterized.TestCase, tf.test.TestCase): 27 | 28 | def _generator_kwargs(self): 29 | return dict( 30 | max_seq_length=10, 31 | max_predictions_per_seq=5, 32 | task_type=table_dataset.TableTask.PRETRAINING, 33 | add_aggregation_function_id=False, 34 | add_classification_labels=False, 35 | add_answer=False, 36 | include_id=False, 37 | vocab_size=10, 38 | segment_vocab_size=3, 39 | num_columns=3, 40 | num_rows=2, 41 | add_candidate_answers=False, 42 | max_num_candidates=10) 43 | 44 | def _create_estimator(self, params): 45 | tf.logging.info("Setting random seed to {}".format(42)) 46 | np.random.seed(42) 47 | 48 | # Small bert model for testing. 49 | bert_config = modeling.BertConfig.from_dict({ 50 | "vocab_size": 10, 51 | "type_vocab_size": [3, 256, 256, 2, 256, 256, 10], 52 | "num_hidden_layers": 2, 53 | "num_attention_heads": 2, 54 | "hidden_size": 128, 55 | "intermediate_size": 512, 56 | }) 57 | model_fn = tapas_pretraining_model.model_fn_builder( 58 | bert_config=bert_config, 59 | init_checkpoint=params["init_checkpoint"], 60 | learning_rate=params["learning_rate"], 61 | num_train_steps=params["num_train_steps"], 62 | num_warmup_steps=params["num_warmup_steps"], 63 | use_tpu=params["use_tpu"]) 64 | 65 | estimator = tf_estimator.tpu.TPUEstimator( 66 | use_tpu=params["use_tpu"], 67 | model_fn=model_fn, 68 | config=tf_estimator.tpu.RunConfig( 69 | model_dir=self.get_temp_dir(), 70 | save_summary_steps=params["num_train_steps"], 71 | save_checkpoints_steps=params["num_train_steps"]), 72 | train_batch_size=params["batch_size"], 73 | predict_batch_size=params["batch_size"], 74 | eval_batch_size=params["batch_size"]) 75 | 76 | return estimator 77 | 78 | @parameterized.named_parameters(("no_checkpoint", False), 79 | ("with_checkpoint", True)) 80 | def test_build_model_train_and_evaluate(self, load_checkpoint): 81 | """Tests that we can train, save, load and evaluate the model.""" 82 | params = { 83 | "batch_size": 2, 84 | "init_checkpoint": None, 85 | "learning_rate": 5e-5, 86 | "num_train_steps": 50, 87 | "num_warmup_steps": 10, 88 | "num_eval_steps": 20, 89 | "use_tpu": False, 90 | } 91 | 92 | estimator = self._create_estimator(params) 93 | generator_kwargs = self._generator_kwargs() 94 | 95 | def _input_fn(params): 96 | return table_dataset_test_utils.create_random_dataset( 97 | num_examples=params["batch_size"], 98 | batch_size=params["batch_size"], 99 | repeat=True, 100 | generator_kwargs=generator_kwargs) 101 | 102 | estimator.train(_input_fn, max_steps=params["num_train_steps"]) 103 | 104 | if load_checkpoint: 105 | params.update({"init_checkpoint": self.get_temp_dir()}) 106 | estimator = self._create_estimator(params) 107 | estimator.train(_input_fn, max_steps=params["num_train_steps"]) 108 | 109 | eval_metrics = estimator.evaluate(_input_fn, steps=params["num_eval_steps"]) 110 | 111 | for metric_name in ("masked_lm_loss", "masked_lm_accuracy", "loss", 112 | "next_sentence_accuracy", "next_sentence_loss"): 113 | self.assertIn(metric_name, eval_metrics) 114 | 115 | def _predict_input_fn(params): 116 | dataset = table_dataset_test_utils.create_random_dataset( 117 | num_examples=params["batch_size"], 118 | batch_size=params["batch_size"], 119 | repeat=True, 120 | generator_kwargs=generator_kwargs) 121 | return dataset.take(2) 122 | 123 | for predictions in estimator.predict(_predict_input_fn): 124 | self.assertIn("masked_lm_predictions", predictions) 125 | 126 | 127 | if __name__ == "__main__": 128 | tf.test.main() 129 | -------------------------------------------------------------------------------- /tapas/protos/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /tapas/protos/annotated_text.proto: -------------------------------------------------------------------------------- 1 | // Copyright 2019 The Google AI Language Team Authors. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | syntax = "proto2"; 15 | 16 | package language.tapas; 17 | 18 | import "tapas/protos/interaction.proto"; 19 | 20 | message AnnotatedText { 21 | extend language.tapas.Question { 22 | optional AnnotatedText annotated_question_ext = 304163798; 23 | } 24 | extend language.tapas.Cell { 25 | optional AnnotatedText annotated_cell_ext = 304163798; 26 | } 27 | repeated Annotation annotations = 1; 28 | } 29 | 30 | message AnnotationDescription { 31 | extend language.tapas.Interaction { 32 | optional AnnotationDescription annotation_descriptions_ext = 319400515; 33 | } 34 | // For each entity that appears in the interaction, the map has a textual 35 | // description of the entity, like the first section of its Wikipedia page. 36 | map descriptions = 1; 37 | } 38 | 39 | message Annotation { 40 | // Indices refer to 'original_text' of a question or 'text' of a cell. 41 | // Inclusive begin byte. 42 | optional int64 begin_byte_index = 1; 43 | // Exclusive end byte. 44 | optional int64 end_byte_index = 2; 45 | // An identifier for example a Wikipedia URL. 46 | optional string identifier = 3; 47 | } 48 | -------------------------------------------------------------------------------- /tapas/protos/interaction.proto: -------------------------------------------------------------------------------- 1 | // Copyright 2019 The Google AI Language Team Authors. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | syntax = "proto2"; 15 | 16 | package language.tapas; 17 | 18 | // An interaction represents a sequences of question answerable from a single 19 | // table. 20 | message Interaction { 21 | optional string id = 1; 22 | optional Table table = 2; 23 | repeated Question questions = 3; 24 | 25 | extensions 10000 to max; 26 | } 27 | 28 | message Question { 29 | optional string id = 1; 30 | // The question string after normalization. 31 | optional string text = 2; 32 | // The original raw question string. 33 | optional string original_text = 3; 34 | // Numeric value spans in 'text'. 35 | optional NumericValueSpans annotations = 4; 36 | optional Answer answer = 5; 37 | repeated Answer alternative_answers = 6; 38 | 39 | extensions 10000 to max; 40 | } 41 | 42 | message AnswerCoordinate { 43 | optional int32 row_index = 1; 44 | optional int32 column_index = 2; 45 | } 46 | 47 | message Answer { 48 | // Coordinates of cells that contain the answers. 49 | repeated AnswerCoordinate answer_coordinates = 1; 50 | 51 | // A function that is applied to the answer cells in order to obtain 52 | // the final answer. 53 | enum AggregationFunction { 54 | NONE = 0; 55 | // Sums all cell values. Numeric cells only. 56 | SUM = 1; 57 | // Averages all cell values. Numeric cells only. 58 | AVERAGE = 2; 59 | // Counts the number of answers. 60 | COUNT = 3; 61 | } 62 | optional AggregationFunction aggregation_function = 2; 63 | 64 | // Answers in text format. 65 | repeated string answer_texts = 3; 66 | 67 | // Present if the answer can be represented as a single float value, for 68 | // example produced by an aggregation ('the average population of all 69 | // countries'). 70 | optional float float_value = 4; 71 | 72 | // If true, this answer can be used to construct training/test examples. If 73 | // false some error were triggered during parsing of this answer. 74 | optional bool is_valid = 5 [default = true]; 75 | 76 | // Present if the answer can be represented as a single integer value, for 77 | // example when it's a classification or entailment task. 78 | optional int32 class_index = 6; 79 | 80 | extensions 10000 to max; 81 | } 82 | 83 | // Represents a simple table with m rows and n columns. 84 | message Table { 85 | // The names of the n columns. 86 | repeated Cell columns = 1; 87 | 88 | // m rows containing n cells each. 89 | repeated Cells rows = 2; 90 | 91 | // Some unique identifier of this table. 92 | optional string table_id = 3; 93 | 94 | // The title of the document the table appears in. 95 | optional string document_title = 4; 96 | 97 | // Title or caption of the table. 98 | optional string caption = 5; 99 | 100 | // The URL the table was found on. 101 | optional string document_url = 6; 102 | 103 | // Other versions of the same document that the table occurs on. 104 | repeated string alternative_document_urls = 7; 105 | 106 | // Other versions of the same table. 107 | repeated string alternative_table_ids = 8; 108 | 109 | // Heading of the table on the document. 110 | optional string context_heading = 9; 111 | 112 | extensions 10000 to max; 113 | } 114 | 115 | message Cell { 116 | optional string text = 1; 117 | optional NumericValue numeric_value = 2; 118 | 119 | extensions 10000 to max; 120 | } 121 | 122 | message Cells { 123 | repeated Cell cells = 1; 124 | } 125 | 126 | message Date { 127 | optional int32 year = 1; 128 | optional int32 month = 2; 129 | optional int32 day = 3; 130 | } 131 | 132 | message NumericValue { 133 | oneof value { 134 | float float_value = 1; 135 | Date date = 2; 136 | } 137 | } 138 | 139 | message NumericValueSpan { 140 | optional int32 begin_index = 1; 141 | optional int32 end_index = 2; 142 | repeated NumericValue values = 3; 143 | } 144 | 145 | message NumericValueSpans { 146 | repeated NumericValueSpan spans = 1; 147 | } 148 | -------------------------------------------------------------------------------- /tapas/protos/negative_retrieval_examples.proto: -------------------------------------------------------------------------------- 1 | // Copyright 2019 The Google AI Language Team Authors. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | syntax = "proto2"; 15 | 16 | package language.tapas; 17 | 18 | import "tapas/protos/interaction.proto"; 19 | 20 | message NegativeRetrievalExamples { 21 | extend language.tapas.Question { 22 | optional NegativeRetrievalExamples negative_retrieval_examples_ext = 23 | 288888272; 24 | } 25 | // The examples correspend to the negative tables. 26 | repeated NegativeRetrievalExample examples = 1; 27 | } 28 | 29 | message NegativeRetrievalExample { 30 | // One negative table 31 | optional Table table = 1; 32 | enum Type { 33 | BASELINE = 1; 34 | DOCUMENT = 2; 35 | CORRUPTED = 3; 36 | } 37 | optional Type type = 2; 38 | // If table was retrieved from Baseline its rank. 39 | optional int32 rank = 3; 40 | // The similarity score between the negative table and the question. 41 | // A positive score represent a high similarity. 42 | optional float score = 4; 43 | } 44 | -------------------------------------------------------------------------------- /tapas/protos/retriever_info.proto: -------------------------------------------------------------------------------- 1 | // Copyright 2019 The Google AI Language Team Authors. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | syntax = "proto2"; 15 | 16 | package language.tapas; 17 | 18 | import "tapas/protos/interaction.proto"; 19 | 20 | message RetrieverInfo { 21 | extend language.tapas.Question { 22 | optional RetrieverInfo question_ext = 337075296; 23 | } 24 | optional int32 rank = 1; 25 | optional double score = 2; 26 | } 27 | -------------------------------------------------------------------------------- /tapas/protos/table_pruning.proto: -------------------------------------------------------------------------------- 1 | // Copyright 2019 The Google AI Language Team Authors. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | syntax = "proto2"; 15 | 16 | package language.tapas; 17 | 18 | // The loss used to learn the weights. 19 | message Loss { 20 | // The hard selection strategy used for train and/or for test. 21 | message HardSelection { 22 | enum SelectionFn { 23 | // No hard selection is used. Select all the tokens. 24 | ALL = 0; 25 | // Selects the best tokens up to max_num_tokens. Returns the TOP_K scores. 26 | TOP_K = 1; 27 | // Selects the best tokens up to max_num_tokens. 28 | // Returns the TOP_K mask values in 0,1. 29 | MASK_TOP_K = 2; 30 | } 31 | optional SelectionFn selection_fn = 1; 32 | } 33 | 34 | 35 | // Uses an unsupervised model to learn the required columns. 36 | // The back probagation is always activated. 37 | message Unsupervised { 38 | enum Regularization { 39 | // No Regularization is used. 40 | NONE = 0; 41 | // Computes L1 over all the tokens scores. 42 | L1 = 1; 43 | // Computes L2 over all the tokens scores. 44 | L2 = 2; 45 | // Computes l1 on tokens sequence then l2 on the batch. 46 | L1_L2 = 3; 47 | } 48 | optional Regularization regularization = 5; 49 | } 50 | 51 | oneof loss { 52 | Unsupervised unsupervised = 200; 53 | } 54 | 55 | optional HardSelection train = 400; 56 | optional HardSelection eval = 500; 57 | // Enables the pruning model to use a loss similar to the tapas model. 58 | optional bool add_classification_loss = 600 [default = false]; 59 | } 60 | 61 | // Uses the average cosine similarity to score the tokens. 62 | message AvgCosSimilarity { 63 | // Enables the use of positional embeddingins to compute the average cosine 64 | // similarity. 65 | optional bool use_positional_embeddings = 2 [default = false]; 66 | // The loss used to learn the weights. 67 | optional Loss loss = 3; 68 | } 69 | 70 | // Uses a TAPAS model to score the columns or the tokens. 71 | message TAPAS { 72 | // Specifies the use of the columns scores or tokens scores. 73 | enum Selection { 74 | COLUMNS = 0; 75 | TOKENS = 1; 76 | } 77 | optional Selection selection = 2; 78 | optional string bert_config_file = 3; 79 | optional string bert_init_checkpoint = 4; 80 | optional bool reset_position_index_per_cell = 6 [default = false]; 81 | 82 | // The loss used to learn the weights. 83 | optional Loss loss = 5; 84 | } 85 | 86 | // Select the k first tokens up to max_num_tokens. 87 | // If max_num_tokens = tapas_max_num_tokens no table pruing is used. 88 | message FirstTokens {} 89 | 90 | // Options for table pruning models. 91 | message TablePruningModel { 92 | optional int32 max_num_tokens = 1; 93 | oneof table_pruning_model { 94 | AvgCosSimilarity avg_cos_similarity = 2; 95 | TAPAS tapas = 3; 96 | FirstTokens first_tokens = 4; 97 | } 98 | } 99 | -------------------------------------------------------------------------------- /tapas/protos/table_selection.proto: -------------------------------------------------------------------------------- 1 | // Copyright 2019 The Google AI Language Team Authors. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | syntax = "proto2"; 15 | 16 | package language.tapas; 17 | 18 | import "tapas/protos/interaction.proto"; 19 | 20 | message TableSelection { 21 | extend language.tapas.Question { 22 | optional TableSelection table_selection_ext = 288888271; 23 | } 24 | 25 | // Tokens that should be added to the TF example. 26 | // Must not be empty! 27 | message TokenCoordinates { 28 | // The header row has index 0, the first data row index 1. 29 | optional int32 row_index = 1; 30 | optional int32 column_index = 2; 31 | optional int32 token_index = 3; 32 | } 33 | repeated TokenCoordinates selected_tokens = 3; 34 | 35 | message ModelPredictionStatsPerModel { 36 | // Identifier of the model that produced this result. 37 | optional string model_id = 1; 38 | // Whether the model predictions is correct. 39 | optional bool is_correct = 2; 40 | } 41 | 42 | // For each column of the table and every model that correctly answered the 43 | // question when run on the whole table, whether it also correctly 44 | // answered the question when this column was removed from the input. 45 | // That is, is_correct == false implies that this column is relevant for 46 | // answering the question. 47 | // For example for the following output: 48 | /* model_prediction_stats: { 49 | model_id: "1" 50 | is_correct: true 51 | } 52 | model_prediction_stats: { 53 | model_id: "5" 54 | is_correct: false 55 | } 56 | column: 0 57 | } 58 | */ 59 | // Model 2, 3 and 4 didn't answer the question correctly even when running 60 | // on the whole table. When column 0 was removed from the input, model 1 61 | // answered the question correctly and model 5 incorrectly. 62 | message ModelPredictionStatsPerColumn { 63 | // Column index. 64 | optional int32 column = 3; 65 | repeated ModelPredictionStatsPerModel model_prediction_stats = 2; 66 | } 67 | 68 | message ModelPredictionStats { 69 | repeated ModelPredictionStatsPerColumn column_prediction_stats = 1; 70 | // Model predictions of the umodified inputs. 71 | repeated ModelPredictionStatsPerModel model_prediction_stats = 2; 72 | } 73 | 74 | optional ModelPredictionStats model_prediction_stats = 2; 75 | 76 | message DebugInfo { 77 | message Column { 78 | // Index of the column. 79 | optional int32 index = 1; 80 | 81 | // The score assigned by some scorer. 82 | optional double score = 2; 83 | 84 | // True if the column was selected as relevant by some column selector. 85 | optional bool is_selected = 3; 86 | 87 | // True if the column is needed to find the final answer (gold data). 88 | optional bool is_required = 4; 89 | } 90 | 91 | repeated Column columns = 1; 92 | } 93 | 94 | optional DebugInfo debug = 100; 95 | } 96 | -------------------------------------------------------------------------------- /tapas/retrieval/add_negative_tables_to_interactions_main.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Adds nearest neigbors as negatives to interactions. 16 | 17 | Given 'interaction_dir' extracts all files and searches for 18 | the nearest neighbor files in json format in 'json_dir'. 19 | 20 | For example, for an interaction file 'test' we expect to find a 21 | file 'test_results.jsonl' in 'json_dir'. 22 | 23 | The table ids in the json files should have a corresponding table in 24 | 'input_table_file'. 25 | 26 | After the processing the interactions are written to 'output_dir' keeping the 27 | original file name. 28 | """ 29 | import os 30 | from typing import List, Text 31 | 32 | from absl import app 33 | from absl import flags 34 | import dataclasses 35 | from tapas.retrieval.add_negative_tables_to_interactions import add_negative_tables_to_interactions 36 | from tapas.utils import beam_runner 37 | import tensorflow.compat.v1 as tf 38 | 39 | 40 | flags.DEFINE_string("interaction_dir", None, 41 | "Directory with interaction tfrecords.") 42 | 43 | flags.DEFINE_string("json_dir", None, "Directory with jsonl files.") 44 | 45 | flags.DEFINE_string("input_tables_file", None, "The tfrecord tables file.") 46 | 47 | flags.DEFINE_string("output_dir", None, 48 | "Directory where new interactions are written to.") 49 | 50 | flags.DEFINE_integer( 51 | "max_num_negatives", 52 | None, 53 | "Max negatives examples to add to interaction.", 54 | ) 55 | 56 | FLAGS = flags.FLAGS 57 | 58 | 59 | @dataclasses.dataclass(frozen=True) 60 | class InputsOutputs: 61 | input_interaction_files: List[Text] 62 | input_json_files: List[Text] 63 | output_interaction_files: List[Text] 64 | 65 | 66 | def _get_inputs_outputs( 67 | interaction_dir, 68 | json_dir, 69 | output_dir, 70 | ): 71 | """Gets input and output files.""" 72 | 73 | interaction_paths = tf.io.gfile.glob( 74 | os.path.join(interaction_dir, "*.tfrecord")) 75 | if not interaction_paths: 76 | raise ValueError(f"No interactions found: {interaction_dir}") 77 | 78 | interaction_names = [os.path.basename(path) for path in interaction_paths] 79 | 80 | json_files = [] 81 | for name in interaction_names: 82 | base_name = os.path.splitext(name)[0] 83 | json_file = os.path.join(json_dir, f"{base_name}_results.jsonl") 84 | json_files.append(json_file) 85 | if not tf.io.gfile.exists(json_file): 86 | raise ValueError(f"Missing file: {json_file}") 87 | 88 | outputs = [os.path.join(output_dir, name) for name in interaction_names] 89 | return InputsOutputs( 90 | input_interaction_files=interaction_paths, 91 | input_json_files=json_files, 92 | output_interaction_files=outputs, 93 | ) 94 | 95 | 96 | def main(unused_argv): 97 | r"""Reads nearest neigbors adds them to the interactions.""" 98 | del unused_argv 99 | 100 | inputs_outputs = _get_inputs_outputs( 101 | FLAGS.interaction_dir, 102 | FLAGS.json_dir, 103 | FLAGS.output_dir, 104 | ) 105 | pipeline = add_negative_tables_to_interactions( 106 | max_num_negatives=FLAGS.max_num_negatives, 107 | input_interactions_files=inputs_outputs.input_interaction_files, 108 | input_tables_file=FLAGS.input_tables_file, 109 | input_json_files=inputs_outputs.input_json_files, 110 | output_files=inputs_outputs.output_interaction_files, 111 | ) 112 | beam_runner.run(pipeline) 113 | 114 | 115 | if __name__ == "__main__": 116 | flags.mark_flag_as_required("interaction_dir") 117 | flags.mark_flag_as_required("input_tables_file") 118 | flags.mark_flag_as_required("json_dir") 119 | flags.mark_flag_as_required("output_dir") 120 | 121 | app.run(main) 122 | -------------------------------------------------------------------------------- /tapas/retrieval/create_baseline_results.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Uses the baseline to create nearest neighbor results.""" 16 | 17 | import collections 18 | import json 19 | import os 20 | 21 | from absl import app 22 | from absl import flags 23 | from tapas.retrieval import tfidf_baseline_utils 24 | from tapas.scripts import prediction_utils 25 | from tapas.scripts import preprocess_nq_utils 26 | import tensorflow.compat.v1 as tf 27 | import tqdm 28 | 29 | FLAGS = flags.FLAGS 30 | 31 | flags.DEFINE_string("input_dir", None, "Interaction protos in tfrecord format.") 32 | 33 | flags.DEFINE_string("table_file", None, "Table protos in tfrecord format.") 34 | 35 | flags.DEFINE_string("output_dir", None, 36 | "Directory where interactions will be written to.") 37 | 38 | flags.DEFINE_integer("max_rank", None, "Max rank to consider.") 39 | 40 | flags.DEFINE_integer("title_multiplicator", 15, "See create_bm25_index.") 41 | 42 | 43 | def main(argv): 44 | if len(argv) > 1: 45 | raise app.UsageError("Too many command-line arguments.") 46 | 47 | print("Creating output dir ...") 48 | tf.io.gfile.makedirs(FLAGS.output_dir) 49 | 50 | interaction_files = [] 51 | for filename in tf.io.gfile.listdir(FLAGS.input_dir): 52 | interaction_files.append(os.path.join(FLAGS.input_dir, filename)) 53 | 54 | tables = {} 55 | if FLAGS.table_file: 56 | print("Reading tables ...") 57 | for table in tqdm.tqdm( 58 | tfidf_baseline_utils.iterate_tables(FLAGS.table_file), total=375_000): 59 | tables[table.table_id] = table 60 | 61 | print("Adding interactions tables ...") 62 | for interaction_file in interaction_files: 63 | interactions = prediction_utils.iterate_interactions(interaction_file) 64 | for interaction in interactions: 65 | tables[interaction.table.table_id] = interaction.table 66 | 67 | print("Creating index ...") 68 | index = tfidf_baseline_utils.create_bm25_index( 69 | tables=tables.values(), 70 | title_multiplicator=FLAGS.title_multiplicator, 71 | num_tables=len(tables), 72 | ) 73 | 74 | print("Processing interactions ...") 75 | for interaction_file in interaction_files: 76 | interactions = list(prediction_utils.iterate_interactions(interaction_file)) 77 | 78 | examples = collections.defaultdict(list) 79 | for interaction in interactions: 80 | example_id, _ = preprocess_nq_utils.parse_interaction_id(interaction.id) 81 | examples[example_id].append(interaction) 82 | 83 | filename = os.path.basename(interaction_file) 84 | filename = os.path.splitext(filename)[0] 85 | output = os.path.join(FLAGS.output_dir, filename + "_results.jsonl") 86 | with tf.io.gfile.GFile(output, "w") as file_writer: 87 | num_correct = 0 88 | with tqdm.tqdm( 89 | examples.items(), 90 | total=len(examples), 91 | desc=filename, 92 | ) as pbar: 93 | for nr, example in enumerate(pbar): 94 | example_id, interaction_list = example 95 | 96 | questions = [] 97 | for interaction in interaction_list: 98 | if len(interaction.questions) != 1: 99 | raise ValueError(f"Unexpected question in {interaction}") 100 | questions.append(interaction.questions[0]) 101 | 102 | if len(set(q.original_text for q in questions)) != 1: 103 | raise ValueError(f"Different questions {questions}") 104 | question_text = questions[0].original_text 105 | scored_hits = index.retrieve(question_text) 106 | scored_hits = scored_hits[:FLAGS.max_rank] 107 | 108 | table_scores = [] 109 | for scored_hit in scored_hits: 110 | table_scores.append({ 111 | "table_id": scored_hit[0], 112 | "score": -scored_hit[1], 113 | }) 114 | 115 | result = { 116 | "query_id": example_id + "_0_0", 117 | "table_scores": table_scores, 118 | } 119 | 120 | file_writer.write(json.dumps(result)) 121 | file_writer.write("\n") 122 | 123 | 124 | if __name__ == "__main__": 125 | app.run(main) 126 | -------------------------------------------------------------------------------- /tapas/retrieval/create_retrieval_data_main.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert interactions to Tensorflow examples.""" 16 | 17 | from absl import app 18 | from absl import flags 19 | from tapas.retrieval import tf_example_utils 20 | from tapas.utils import beam_runner 21 | from tapas.utils import create_data 22 | from tapas.utils import create_data_file_io 23 | 24 | 25 | flags.DEFINE_string("input_interactions_dir", None, "Directory with inputs.") 26 | flags.DEFINE_string("input_tables_dir", None, "Directory with inputs.") 27 | flags.DEFINE_string("output_dir", None, "Directory with outputs.") 28 | flags.DEFINE_string("vocab_file", None, 29 | "The vocabulary file that the BERT model was trained on.") 30 | flags.DEFINE_integer("max_seq_length", None, 31 | "Max length of a sequence in word pieces.") 32 | flags.DEFINE_float("max_column_id", None, "Max column id to extract.") 33 | flags.DEFINE_float("max_row_id", None, "Max row id to extract.") 34 | flags.DEFINE_integer("cell_trim_length", -1, 35 | "If > 0: Trim cells so that the length is <= this value.") 36 | flags.DEFINE_boolean("use_document_title", None, 37 | "Include document title text in the tf example.") 38 | flags.DEFINE_enum_class("converter_impl", create_data.ConverterImplType.PYTHON, 39 | create_data.ConverterImplType, 40 | "Implementation to map interactions to tf examples.") 41 | FLAGS = flags.FLAGS 42 | 43 | 44 | def run(inputs, outputs, input_format): 45 | beam_runner.run( 46 | create_data.build_retrieval_pipeline( 47 | input_files=inputs, 48 | input_format=input_format, 49 | output_files=outputs, 50 | config=tf_example_utils.RetrievalConversionConfig( 51 | vocab_file=FLAGS.vocab_file, 52 | max_seq_length=FLAGS.max_seq_length, 53 | max_column_id=FLAGS.max_column_id, 54 | max_row_id=FLAGS.max_row_id, 55 | strip_column_names=False, 56 | cell_trim_length=FLAGS.cell_trim_length, 57 | use_document_title=FLAGS.use_document_title, 58 | ), 59 | converter_impl=FLAGS.converter_impl, 60 | )).wait_until_finish() 61 | 62 | 63 | def main(_): 64 | inputs, outputs = create_data_file_io.get_inputs_and_outputs( 65 | FLAGS.input_interactions_dir, FLAGS.output_dir) 66 | if not inputs: 67 | raise ValueError(f"Input dir is empty: '{FLAGS.input_interactions_dir}'") 68 | 69 | run(inputs, outputs, create_data.InputFormat.INTERACTION) 70 | 71 | if FLAGS.input_tables_dir is not None: 72 | table_inputs, table_outputs = create_data_file_io.get_inputs_and_outputs( 73 | FLAGS.input_tables_dir, FLAGS.output_dir) 74 | run(table_inputs, table_outputs, create_data.InputFormat.TABLE) 75 | 76 | 77 | if __name__ == "__main__": 78 | flags.mark_flag_as_required("input_interactions_dir") 79 | flags.mark_flag_as_required("max_column_id") 80 | flags.mark_flag_as_required("max_row_id") 81 | flags.mark_flag_as_required("max_seq_length") 82 | flags.mark_flag_as_required("output_dir") 83 | flags.mark_flag_as_required("vocab_file") 84 | app.run(main) 85 | -------------------------------------------------------------------------------- /tapas/retrieval/create_retrieval_pretrain_data_main.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert interactions to Tensorflow examples.""" 16 | 17 | from typing import Iterable, Optional, Text, Tuple 18 | 19 | from absl import app 20 | from absl import flags 21 | import apache_beam as beam 22 | import nltk 23 | from tapas.protos import interaction_pb2 24 | from tapas.utils import beam_runner 25 | from tapas.utils import beam_utils 26 | from tapas.utils import create_data 27 | from tapas.utils import sentence_tokenizer 28 | 29 | 30 | flags.DEFINE_string("inputs", None, "Interaction tables.") 31 | flags.DEFINE_string("output_dir", None, "Directory with outputs.") 32 | 33 | FLAGS = flags.FLAGS 34 | 35 | 36 | _NS = "main" 37 | _TITLE_QUESTION_ID = "TITLE" 38 | 39 | 40 | def get_title(interaction): 41 | for question in interaction.questions: 42 | if question.id == _TITLE_QUESTION_ID: 43 | return question.original_text 44 | return None 45 | 46 | 47 | def _to_retrieval_interaction_fn( 48 | interaction 49 | ): 50 | """Converts pretraining interaction to retrieval interaction.""" 51 | beam.metrics.Metrics.counter(_NS, "Interactions").inc() 52 | title = get_title(interaction) 53 | if title is None or not title: 54 | beam.metrics.Metrics.counter(_NS, "Interactions without title").inc() 55 | return 56 | 57 | interaction = beam_utils.rekey(interaction) 58 | interaction.table.document_title = title 59 | 60 | word_tok = nltk.tokenize.treebank.TreebankWordTokenizer() 61 | 62 | for question in interaction.questions: 63 | if question.id == _TITLE_QUESTION_ID: 64 | continue 65 | 66 | text = question.original_text 67 | 68 | for paragraph in text.split("\n"): 69 | for sentence in sentence_tokenizer.tokenize(paragraph): 70 | sentence = sentence.strip() 71 | if not sentence: 72 | continue 73 | 74 | beam.metrics.Metrics.counter(_NS, "Sentences").inc() 75 | num_tokens = word_tok.tokenize(sentence) 76 | if len(num_tokens) < 4: 77 | beam.metrics.Metrics.counter(_NS, "Sentence too short").inc() 78 | continue 79 | if len(num_tokens) > 32: 80 | beam.metrics.Metrics.counter(_NS, "Sentence too long").inc() 81 | continue 82 | 83 | new_interaction = interaction_pb2.Interaction() 84 | new_interaction.CopyFrom(interaction) 85 | del new_interaction.questions[:] 86 | new_question = new_interaction.questions.add() 87 | new_question.id = hex( 88 | beam_utils.to_numpy_seed(obj=(interaction.id, sentence))) 89 | new_interaction.id = new_question.id 90 | new_question.original_text = sentence 91 | 92 | beam.metrics.Metrics.counter(_NS, "Examples").inc() 93 | yield new_interaction.id, new_interaction 94 | 95 | 96 | def build_pipeline(inputs, output_dir): 97 | """Builds the pipeline.""" 98 | 99 | def _pipeline(root): 100 | 101 | interactions = ( 102 | create_data.read_interactions(root, inputs, name="input") 103 | | "DropKey" >> beam.Map(beam.Values()) 104 | | "ToRetrievalExample" >> beam.FlatMap(_to_retrieval_interaction_fn) 105 | | "Reshuffle" >> beam.transforms.util.Reshuffle()) 106 | 107 | # We expect ~37,568,664 interactions by taking 1 / 5000 for test test we 108 | # get a reasonable test set size of ~7513. 109 | beam_utils.split_by_table_id_and_write( 110 | interactions, 111 | output_dir, 112 | train_suffix="@*", 113 | test_suffix="@*", 114 | num_splits=5000, 115 | ) 116 | 117 | return _pipeline 118 | 119 | 120 | def main(_): 121 | beam_runner.run( 122 | build_pipeline(inputs=FLAGS.inputs, 123 | output_dir=FLAGS.output_dir)).wait_until_finish() 124 | 125 | 126 | if __name__ == "__main__": 127 | flags.mark_flag_as_required("inputs") 128 | flags.mark_flag_as_required("output_dir") 129 | app.run(main) 130 | -------------------------------------------------------------------------------- /tapas/retrieval/e2e_eval.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Runs e2e eval on a pair of interactions / predctions.""" 16 | 17 | from absl import app 18 | from absl import flags 19 | 20 | from tapas.retrieval import e2e_eval_utils 21 | 22 | flags.DEFINE_string("interaction_file", None, "TFRecord of interactions.") 23 | flags.DEFINE_string("prediction_file", None, "Predictions in TSV format.") 24 | 25 | FLAGS = flags.FLAGS 26 | 27 | 28 | def main(argv): 29 | if len(argv) > 1: 30 | raise app.UsageError("Too many command-line arguments.") 31 | result = e2e_eval_utils.evaluate_retrieval_e2e( 32 | FLAGS.interaction_file, 33 | FLAGS.prediction_file, 34 | ) 35 | print(result) 36 | 37 | 38 | if __name__ == "__main__": 39 | flags.mark_flag_as_required("interaction_file") 40 | flags.mark_flag_as_required("prediction_file") 41 | app.run(main) 42 | -------------------------------------------------------------------------------- /tapas/retrieval/testdata/neural_retrieval_00.jsonl: -------------------------------------------------------------------------------- 1 | {"query_id": "q_1_0", "table_scores": [{"table_id": "table_1", "score": -10.0}, {"table_id": "table_3", "score":-8.0}, {"table_id":"table_2", "score": -6.0}]} 2 | {"query_id": "q_2_0", "table_scores": [{"table_id": "table_2", "score": -100.0}, {"table_id": "table_1", "score":-14.0}, {"table_id": "table_3", "score": -1.0}]} 3 | -------------------------------------------------------------------------------- /tapas/retrieval/testdata/retrieval_interaction.pbtxt: -------------------------------------------------------------------------------- 1 | id: "int_id" 2 | table: { 3 | columns: { 4 | text: "Created by" 5 | } 6 | columns: { 7 | text: "Original work" 8 | } 9 | columns: { 10 | text: "Novels" 11 | } 12 | rows: { 13 | cells: { 14 | text: "Thomas Harris" 15 | } 16 | cells: { 17 | text: "Red Dragon" 18 | } 19 | cells: { 20 | text: "Red Dragon" 21 | } 22 | } 23 | table_id: "tab_id_0" 24 | document_title: "Hannibal Lecter" 25 | } 26 | questions: { 27 | id: "q_id_0" 28 | original_text: "what order do the hannibal lecter movies go in" 29 | answer: { 30 | answer_texts: "Hannibal" 31 | } 32 | } 33 | -------------------------------------------------------------------------------- /tapas/retrieval/tf_example_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Conversion code for table retrieval.""" 16 | 17 | from typing import Optional 18 | 19 | from tapas.protos import interaction_pb2 20 | from tapas.protos import negative_retrieval_examples_pb2 21 | from tapas.utils import text_utils 22 | from tapas.utils import tf_example_utils as base 23 | import tensorflow.compat.v1 as tf 24 | 25 | _SEP = base._SEP # pylint:disable=protected-access 26 | _MAX_INT = base._MAX_INT # pylint:disable=protected-access 27 | _NegativeRetrievalExample = negative_retrieval_examples_pb2.NegativeRetrievalExample 28 | 29 | 30 | def _join_features(features, negative_example_features): 31 | """Joins the features of two tables.""" 32 | 33 | def append_feature(values, other_values): 34 | values.extend(other_values) 35 | 36 | for k, n_v in negative_example_features.items(): 37 | v = features[k] 38 | for feature_type in ['bytes_list', 'float_list', 'int64_list']: 39 | if n_v.HasField(feature_type) != v.HasField(feature_type): 40 | raise ValueError(f'feature types are incomapatible: {k}') 41 | if n_v.HasField(feature_type): 42 | append_feature( 43 | getattr(v, feature_type).value, 44 | getattr(n_v, feature_type).value) 45 | break 46 | else: 47 | raise ValueError(f'Unsupported feature type: {k}') 48 | 49 | 50 | RetrievalConversionConfig = base.RetrievalConversionConfig 51 | 52 | 53 | class ToRetrievalTensorflowExample(base.ToTrimmedTensorflowExample): 54 | """Class for converting retrieval examples. 55 | 56 | These examples are used for building a two tower model. 57 | One tower consists of document titlte and table the other tower is solely 58 | made up of the questions. 59 | """ 60 | 61 | def __init__(self, config): 62 | super(ToRetrievalTensorflowExample, self).__init__(config) 63 | self._use_document_title = config.use_document_title 64 | 65 | def convert( 66 | self, 67 | interaction, 68 | index, 69 | negative_example, 70 | ): 71 | """Converts question at 'index' to example.""" 72 | table = interaction.table 73 | 74 | num_rows = len(table.rows) 75 | if num_rows >= self._max_row_id: 76 | num_rows = self._max_row_id - 1 77 | 78 | num_columns = len(table.columns) 79 | if num_columns >= self._max_column_id: 80 | num_columns = self._max_column_id - 1 81 | 82 | title = table.document_title 83 | if not self._use_document_title: 84 | title = '' 85 | title_tokens = self._tokenizer.tokenize(title) 86 | tokenized_table = self._tokenize_table(table) 87 | 88 | while True: 89 | try: 90 | _, features = self._to_trimmed_features( 91 | question=None, 92 | table=table, 93 | question_tokens=title_tokens, 94 | tokenized_table=tokenized_table, 95 | num_columns=num_columns, 96 | num_rows=num_rows) 97 | break 98 | except ValueError: 99 | pass 100 | # Since this is retrieval we might get away with removing some cells of 101 | # the table. 102 | # TODO(thomasmueller) Consider taking the token length into account. 103 | if num_columns >= num_rows: 104 | num_columns -= 1 105 | else: 106 | num_rows -= 1 107 | if num_columns == 0 or num_rows == 0: 108 | raise ValueError('Cannot fit table into sequence.') 109 | 110 | question = interaction.questions[index] 111 | features['question_id'] = base.create_string_feature( 112 | [question.id.encode('utf8')]) 113 | features['question_id_ints'] = base.create_int_feature( 114 | text_utils.str_to_ints( 115 | question.id, length=text_utils.DEFAULT_INTS_LENGTH)) 116 | 117 | q_tokens = self._tokenizer.tokenize(question.text) 118 | q_tokens = self._serialize_text(q_tokens)[0] 119 | q_tokens.append(base.Token(_SEP, _SEP)) 120 | q_input_ids = self._to_token_ids(q_tokens) 121 | self._pad_to_seq_length(q_input_ids) 122 | q_input_mask = [1] * len(q_tokens) 123 | self._pad_to_seq_length(q_input_mask) 124 | features['question_input_ids'] = base.create_int_feature(q_input_ids) 125 | features['question_input_mask'] = base.create_int_feature(q_input_mask) 126 | if question: 127 | features['question_hash'] = base.create_int_feature( 128 | [base.fingerprint(question.text) % _MAX_INT]) 129 | 130 | if negative_example is not None: 131 | n_table = negative_example.table 132 | n_title_tokens = self._tokenizer.tokenize(n_table.document_title) 133 | n_tokenized_table = self._tokenize_table(n_table) 134 | n_num_rows = self._get_num_rows(n_table, drop_rows_to_fit=True) 135 | n_num_columns = self._get_num_columns(n_table) 136 | _, n_example_features = self._to_trimmed_features( 137 | question=None, 138 | table=n_table, 139 | question_tokens=n_title_tokens, 140 | tokenized_table=n_tokenized_table, 141 | num_columns=n_num_columns, 142 | num_rows=n_num_rows, 143 | drop_rows_to_fit=True) 144 | _join_features(features, n_example_features) 145 | return tf.train.Example(features=tf.train.Features(feature=features)) 146 | -------------------------------------------------------------------------------- /tapas/retrieval/tfidf_baseline.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """A simple TF-IDF model for table retrieval.""" 16 | 17 | from typing import Iterable, List, Text 18 | 19 | from absl import app 20 | from absl import flags 21 | from absl import logging 22 | import pandas as pd 23 | from tapas.protos import interaction_pb2 24 | from tapas.retrieval import tfidf_baseline_utils 25 | from tapas.scripts import prediction_utils 26 | 27 | FLAGS = flags.FLAGS 28 | 29 | flags.DEFINE_list("interaction_files", None, 30 | "Interaction protos in tfrecord format.") 31 | 32 | flags.DEFINE_string("table_file", None, "Table protos in tfrecord format.") 33 | 34 | flags.DEFINE_integer("max_table_rank", 50, "Max number of tables to retrieve.") 35 | 36 | flags.DEFINE_integer("min_term_rank", 100, 37 | "Min term frequency rank to consider.") 38 | 39 | flags.DEFINE_boolean("drop_term_frequency", True, 40 | "If True, ignore term frequency term.") 41 | 42 | 43 | def _print(message): 44 | logging.info(message) 45 | print(message) 46 | 47 | 48 | def evaluate(index, max_table_rank, 49 | thresholds, 50 | interactions, 51 | rows): 52 | """Evaluates index against interactions.""" 53 | ranks = [] 54 | for nr, interaction in enumerate(interactions): 55 | for question in interaction.questions: 56 | scored_hits = index.retrieve(question.original_text) 57 | reference_table_id = interaction.table.table_id 58 | for rank, (table_id, _) in enumerate(scored_hits[:max_table_rank]): 59 | if table_id == reference_table_id: 60 | ranks.append(rank) 61 | break 62 | if nr % (len(interactions) // 10) == 0: 63 | _print(f"Processed {nr:5d} / {len(interactions):5d}.") 64 | 65 | def precision_at_th(threshold): 66 | return sum(1 for rank in ranks if rank < threshold) / len(interactions) 67 | 68 | values = [f"{precision_at_th(threshold):.4}" for threshold in thresholds] 69 | rows.append(values) 70 | 71 | 72 | def create_index(tables, 73 | title_multiplicator, use_bm25): 74 | if use_bm25: 75 | return tfidf_baseline_utils.create_bm25_index( 76 | tables, 77 | title_multiplicator=title_multiplicator, 78 | ) 79 | return tfidf_baseline_utils.create_inverted_index( 80 | tables=tables, 81 | min_rank=FLAGS.min_term_rank, 82 | drop_term_frequency=FLAGS.drop_term_frequency, 83 | title_multiplicator=title_multiplicator, 84 | ) 85 | 86 | 87 | def get_hparams(): 88 | hparams = [] 89 | for multiplier in [1, 2]: 90 | hparams.append({"multiplier": multiplier, "use_bm25": False}) 91 | for multiplier in [10, 15]: 92 | hparams.append({"multiplier": multiplier, "use_bm25": True}) 93 | return hparams 94 | 95 | 96 | def main(_): 97 | 98 | max_table_rank = FLAGS.max_table_rank 99 | thresholds = [1, 5, 10, 15, max_table_rank] 100 | 101 | for interaction_file in FLAGS.interaction_files: 102 | _print(f"Test set: {interaction_file}") 103 | interactions = list(prediction_utils.iterate_interactions(interaction_file)) 104 | 105 | for use_local_index in [True, False]: 106 | 107 | rows = [] 108 | row_names = [] 109 | 110 | for hparams in get_hparams(): 111 | 112 | name = "local" if use_local_index else "global" 113 | name += "_bm25" if hparams["use_bm25"] else "_tfidf" 114 | name += f'_tm{hparams["multiplier"]}' 115 | 116 | _print(name) 117 | if use_local_index: 118 | index = create_index( 119 | tables=(i.table for i in interactions), 120 | title_multiplicator=hparams["multiplier"], 121 | use_bm25=hparams["use_bm25"], 122 | ) 123 | else: 124 | index = create_index( 125 | tables=tfidf_baseline_utils.iterate_tables(FLAGS.table_file), 126 | title_multiplicator=hparams["multiplier"], 127 | use_bm25=hparams["use_bm25"], 128 | ) 129 | _print("... index created.") 130 | evaluate(index, max_table_rank, thresholds, interactions, rows) 131 | row_names.append(name) 132 | 133 | df = pd.DataFrame(rows, columns=thresholds, index=row_names) 134 | _print(df.to_string()) 135 | 136 | 137 | if __name__ == "__main__": 138 | flags.mark_flag_as_required("interaction_files") 139 | flags.mark_flag_as_required("table_file") 140 | app.run(main) 141 | -------------------------------------------------------------------------------- /tapas/retrieval/tfidf_baseline_utils_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from absl.testing import absltest 17 | from absl.testing import parameterized 18 | from tapas.protos import interaction_pb2 19 | from tapas.retrieval import tfidf_baseline_utils 20 | 21 | 22 | class TfIdfBaselineUtilsTest(parameterized.TestCase): 23 | 24 | @parameterized.parameters((False, [("A", [("table_0", 1.0)]), 25 | ("B", [("table_1", 1.0)]), 26 | ("A C", [("table_0", 0.75), 27 | ("table_1", 0.25)])]), 28 | (True, [("A", [("table_0", 0.5)]), 29 | ("B", [("table_1", 1.0)])])) 30 | def test_simple(self, drop_term_frequency, expected): 31 | index = tfidf_baseline_utils.create_inverted_index( 32 | [ 33 | interaction_pb2.Table(table_id="table_0", document_title="a a c"), 34 | interaction_pb2.Table(table_id="table_1", document_title="b c") 35 | ], 36 | drop_term_frequency=drop_term_frequency) 37 | for query, results in expected: 38 | self.assertEqual(index.retrieve(query), results) 39 | 40 | def test_simple_bm25(self): 41 | expected = [("AA", [("table_0", 1.5475852968796064)]), 42 | ("BB", [("table_1", 1.2426585328757855)]), 43 | ("AA CC", [("table_0", 2.0749815245480145), 44 | ("table_1", 0.668184203698534)])] 45 | index = tfidf_baseline_utils.create_bm25_index([ 46 | interaction_pb2.Table(table_id="table_0", document_title="aa aa cc"), 47 | interaction_pb2.Table(table_id="table_1", document_title="bb cc"), 48 | interaction_pb2.Table(table_id="table_2", document_title="dd"), 49 | interaction_pb2.Table(table_id="table_3", document_title="ee"), 50 | interaction_pb2.Table(table_id="table_4", document_title="ff"), 51 | interaction_pb2.Table(table_id="table_5", document_title="gg"), 52 | interaction_pb2.Table(table_id="table_6", document_title="hh"), 53 | ]) 54 | for query, results in expected: 55 | self.assertEqual(index.retrieve(query), results) 56 | 57 | def test_min_rank(self): 58 | index = tfidf_baseline_utils.create_inverted_index([ 59 | interaction_pb2.Table(table_id="table_0", document_title="Table A"), 60 | interaction_pb2.Table(table_id="table_1", document_title="Table B") 61 | ], 62 | min_rank=1) 63 | self.assertEqual(index.retrieve("A"), [("table_0", 1.0)]) 64 | self.assertEqual(index.retrieve("B"), [("table_1", 1.0)]) 65 | 66 | 67 | if __name__ == "__main__": 68 | absltest.main() 69 | -------------------------------------------------------------------------------- /tapas/scripts/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /tapas/scripts/calc_metrics.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Denotation accuracy calculation for TAPAS predictions over WikiSQL.""" 16 | 17 | import os 18 | 19 | from absl import app 20 | from absl import flags 21 | from tapas.scripts import calc_metrics_utils 22 | 23 | FLAGS = flags.FLAGS 24 | 25 | flags.DEFINE_string('interactions_file', None, 26 | 'The file that contains interactions protos.') 27 | 28 | flags.DEFINE_string('prediction_files', None, 29 | 'A list of files that contain model prediction.') 30 | 31 | flags.DEFINE_string('denotation_errors_path', None, 32 | 'If not None, denotation errors are written there.') 33 | 34 | flags.DEFINE_bool('is_strong_supervision_available', False, 35 | 'Whether to store all tables compactly in one file.') 36 | 37 | 38 | def main(_): 39 | examples = calc_metrics_utils.read_data_examples_from_interactions( 40 | FLAGS.interactions_file) 41 | 42 | prediction_file_name = os.path.basename(FLAGS.prediction_files) 43 | calc_metrics_utils.read_predictions(FLAGS.prediction_files, examples) 44 | if FLAGS.is_strong_supervision_available: 45 | results = calc_metrics_utils.calc_structure_metrics( 46 | examples, FLAGS.denotation_errors_path) 47 | print('%s: joint_accuracy=%s' % (FLAGS.prediction_files, results.joint_acc)) 48 | 49 | denotation_accuracy = calc_metrics_utils.calc_denotation_accuracy( 50 | examples, FLAGS.denotation_errors_path, prediction_file_name) 51 | print('%s: denotation_accuracy=%s' % 52 | (FLAGS.prediction_files, denotation_accuracy)) 53 | 54 | 55 | if __name__ == '__main__': 56 | flags.mark_flag_as_required('prediction_files') 57 | flags.mark_flag_as_required('interactions_file') 58 | app.run(main) 59 | -------------------------------------------------------------------------------- /tapas/scripts/convert_predictions.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Scripts to convert predictions file to other formats.""" 16 | 17 | from absl import app 18 | from absl import flags 19 | from tapas.scripts import convert_predictions_utils 20 | 21 | FLAGS = flags.FLAGS 22 | 23 | flags.DEFINE_list('interaction_files', None, 24 | 'A list of files contain interactions protos.') 25 | flags.DEFINE_list('prediction_files', None, 26 | 'A list of files that contain model prediction.') 27 | flags.DEFINE_string('output_directory', None, 28 | 'Output directory where converted files will be stored.') 29 | flags.DEFINE_enum_class('dataset_format', None, 30 | convert_predictions_utils.DatasetFormat, 31 | 'Dataset format.') 32 | 33 | 34 | def main(_): 35 | convert_predictions_utils.convert(FLAGS.interaction_files, 36 | FLAGS.prediction_files, 37 | FLAGS.output_directory, 38 | FLAGS.dataset_format) 39 | 40 | 41 | if __name__ == '__main__': 42 | flags.mark_flag_as_required('interaction_files') 43 | flags.mark_flag_as_required('prediction_files') 44 | flags.mark_flag_as_required('output_directory') 45 | flags.mark_flag_as_required('dataset_format') 46 | app.run(main) 47 | -------------------------------------------------------------------------------- /tapas/scripts/convert_predictions_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Utilities to convert predictions file to other formats.""" 16 | 17 | import enum 18 | import os 19 | from typing import Text, List 20 | 21 | from tapas.scripts import calc_metrics_utils 22 | from tapas.scripts import prediction_utils 23 | import tensorflow.compat.v1 as tf 24 | 25 | 26 | class DatasetFormat(enum.Enum): 27 | WIKITABLEQUESTIONS = 0 28 | 29 | 30 | def _convert_single_wtq(interaction_file, prediction_file, 31 | output_file): 32 | """Convert predictions to WikiTablequestions format.""" 33 | 34 | interactions = dict( 35 | (prediction_utils.parse_interaction_id(i.id), i) 36 | for i in prediction_utils.iterate_interactions(interaction_file)) 37 | missing_interaction_ids = set(interactions.keys()) 38 | 39 | with tf.io.gfile.GFile(output_file, 'w') as output_file: 40 | for prediction in prediction_utils.iterate_predictions(prediction_file): 41 | interaction_id = prediction['id'] 42 | if interaction_id in missing_interaction_ids: 43 | missing_interaction_ids.remove(interaction_id) 44 | else: 45 | continue 46 | 47 | coordinates = prediction_utils.parse_coordinates( 48 | prediction['answer_coordinates']) 49 | 50 | denot_pred, _ = calc_metrics_utils.execute( 51 | int(prediction.get('pred_aggr', '0')), coordinates, 52 | prediction_utils.table_to_panda_frame( 53 | interactions[interaction_id].table)) 54 | 55 | answers = '\t'.join(sorted(map(str, denot_pred))) 56 | output_file.write('{}\t{}\n'.format(interaction_id, answers)) 57 | 58 | for interaction_id in missing_interaction_ids: 59 | output_file.write('{}\n'.format(interaction_id)) 60 | 61 | 62 | def _convert_single(interaction_file, prediction_file, 63 | output_file, dataset_format): 64 | if dataset_format == DatasetFormat.WIKITABLEQUESTIONS: 65 | return _convert_single_wtq(interaction_file, prediction_file, output_file) 66 | else: 67 | raise ValueError('Unknown dataset format {}'.format(dataset_format)) 68 | 69 | 70 | def convert(interactions, predictions, 71 | output_directory, dataset_format): 72 | assert len(interactions) == len(predictions) 73 | for interaction_file, prediction_file in zip(interactions, predictions): 74 | output_file = os.path.join(output_directory, 75 | os.path.basename(prediction_file)) 76 | _convert_single(interaction_file, prediction_file, output_file, 77 | dataset_format) 78 | -------------------------------------------------------------------------------- /tapas/scripts/convert_predictions_utils_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import tempfile 17 | 18 | from absl.testing import absltest 19 | import pandas as pd 20 | from tapas.protos import interaction_pb2 21 | from tapas.scripts import convert_predictions_utils 22 | import tensorflow.compat.v1 as tf 23 | 24 | Cell = interaction_pb2.Cell 25 | Cells = interaction_pb2.Cells 26 | Table = interaction_pb2.Table 27 | 28 | 29 | class ConvertPredictionsUtilsTest(absltest.TestCase): 30 | 31 | def test_convert_single_no_pred_aggr(self): 32 | interactions_path = tempfile.mktemp(suffix='.tfrecord') 33 | with tf.python_io.TFRecordWriter(interactions_path) as writer: 34 | writer.write( 35 | interaction_pb2.Interaction( 36 | id='dev-1-2_3', 37 | table=Table( 38 | columns=[Cell(text='A')], 39 | rows=[Cells(cells=[Cell(text='answer')])], 40 | )).SerializeToString()) 41 | writer.write( 42 | interaction_pb2.Interaction(id='dev-2-1_3').SerializeToString()) 43 | 44 | predictions_path = tempfile.mktemp() 45 | predictions_df = pd.DataFrame( 46 | columns=['id', 'annotator', 'position', 'answer_coordinates'], 47 | data=[['dev-1', '2', '3', '["(0,0)"]']]) 48 | predictions_df.to_csv(predictions_path, sep='\t', index=False) 49 | 50 | output_path = tempfile.mktemp() 51 | convert_predictions_utils._convert_single( 52 | interactions_path, 53 | predictions_path, 54 | output_path, 55 | convert_predictions_utils.DatasetFormat.WIKITABLEQUESTIONS, 56 | ) 57 | 58 | with open(output_path, 'rt') as file_handle: 59 | self.assertEqual(file_handle.read(), 'dev-1\tanswer\ndev-2\n') 60 | 61 | 62 | if __name__ == '__main__': 63 | absltest.main() 64 | -------------------------------------------------------------------------------- /tapas/scripts/eval_table_retriever.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Evaluates precision@k scores for table retriever predictions.""" 16 | from absl import app 17 | from absl import flags 18 | 19 | from tapas.scripts import eval_table_retriever_utils 20 | 21 | FLAGS = flags.FLAGS 22 | 23 | flags.DEFINE_string( 24 | 'prediction_files_local', None, 25 | 'A list of files that contain model predictions as a TSV' 26 | 'file with headers [table_id, query_rep, table_rep].') 27 | flags.DEFINE_string( 28 | 'prediction_files_global', None, 29 | 'A list of files that contain model predictions for all' 30 | 'of that tables in the corpous. Used as the index to' 31 | 'retrieve tables from.') 32 | 33 | flags.DEFINE_string( 34 | 'retrieval_results_file_path', None, 35 | 'A path to file where the best tables candidates and their scores, for each' 36 | 'query will be written.') 37 | 38 | 39 | def main(argv): 40 | 41 | if len(argv) > 1: 42 | raise app.UsageError('Too many command-line arguments.') 43 | 44 | if FLAGS.prediction_files_global: 45 | eval_table_retriever_utils.eval_precision_at_k( 46 | FLAGS.prediction_files_local, 47 | FLAGS.prediction_files_global, 48 | make_tables_unique=True, 49 | retrieval_results_file_path=FLAGS.retrieval_results_file_path) 50 | else: 51 | eval_table_retriever_utils.eval_precision_at_k( 52 | FLAGS.prediction_files_local, 53 | FLAGS.prediction_files_local, 54 | make_tables_unique=True, 55 | retrieval_results_file_path=FLAGS.retrieval_results_file_path) 56 | 57 | 58 | if __name__ == '__main__': 59 | flags.mark_flag_as_required('prediction_files_local') 60 | app.run(main) 61 | -------------------------------------------------------------------------------- /tapas/scripts/eval_wikisql.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Evaluates WikiSQL predictions against NSM gold json files.""" 16 | 17 | import json 18 | 19 | from absl import app 20 | from absl import flags 21 | from tapas.utils import text_utils 22 | import tensorflow.compat.v1 as tf 23 | 24 | FLAGS = flags.FLAGS 25 | 26 | flags.DEFINE_list('reference_files', None, 'NSM json gold file.') 27 | flags.DEFINE_list('prediction_files', None, 28 | 'Produced by "convert_predictions".') 29 | flags.DEFINE_string('dataset', None, '"dev" or "test".') 30 | 31 | 32 | def main(argv): 33 | if len(argv) > 1: 34 | raise app.UsageError('Too many command-line arguments.') 35 | 36 | assert FLAGS.dataset 37 | assert len(FLAGS.reference_files) == len(FLAGS.prediction_files) 38 | for reference_file, prediction_file in zip(FLAGS.reference_files, 39 | FLAGS.prediction_files): 40 | 41 | with tf.io.gfile.GFile(prediction_file, 'r') as input_file: 42 | predictions = {} 43 | for line in input_file: 44 | line = line.strip() 45 | if line: 46 | segments = line.split('\t') 47 | key = segments[0] 48 | value = segments[1:] 49 | predictions[key] = text_utils.normalize_answers(value) 50 | 51 | with tf.io.gfile.GFile(reference_file, 'r') as input_file: 52 | reference = json.load(input_file) 53 | 54 | references = {} 55 | for index, data in enumerate(reference): 56 | key = '%s-%d' % (FLAGS.dataset, index) 57 | references[key] = text_utils.normalize_answers(data) 58 | 59 | num_correct = 0 60 | for key, gold_answer in references.items(): 61 | pred_answer = predictions[key] 62 | is_correct = gold_answer == pred_answer 63 | if is_correct: 64 | num_correct += 1 65 | 66 | print('Correct: ', num_correct, len(references), 67 | num_correct / float(len(references))) 68 | 69 | 70 | if __name__ == '__main__': 71 | app.run(main) 72 | -------------------------------------------------------------------------------- /tapas/scripts/prediction_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Utilitity functions to deal with predictions.""" 16 | 17 | import ast 18 | import csv 19 | import os 20 | from typing import Any, Iterable, Set, Text, Tuple, MutableMapping 21 | 22 | import numpy as np 23 | import pandas as pd 24 | 25 | from tapas.protos import interaction_pb2 26 | import tensorflow.compat.v1 as tf 27 | 28 | 29 | 30 | def parse_coordinates(raw_coordinates): 31 | """Parses cell coordinates from text.""" 32 | return {ast.literal_eval(x) for x in ast.literal_eval(raw_coordinates)} 33 | 34 | 35 | # TODO(thomasmueller) Return a dataclass here. 36 | def iterate_predictions( 37 | prediction_file): 38 | """Iterates through a TSV prediction file.""" 39 | with tf.io.gfile.GFile(prediction_file, 'r') as f: 40 | reader = csv.DictReader(f, delimiter='\t') 41 | for row in reader: 42 | if 'logits_cls' in row: 43 | # Only for binary problems the logit will be a float scalar. 44 | if row['logits_cls'].startswith('['): 45 | row['logits_cls'] = np.fromstring( 46 | row['logits_cls'][1:-1], sep=' ').tolist() 47 | else: 48 | row['logits_cls'] = float(row['logits_cls']) 49 | yield row 50 | 51 | 52 | def is_tfrecord(filename): 53 | extension = os.path.splitext(filename)[1] 54 | return extension in ['.tfrecord', '.tfrecords'] 55 | 56 | 57 | def iterate_interactions( 58 | interactions_file): 59 | """Get interactions from file.""" 60 | for value in tf.python_io.tf_record_iterator(interactions_file): 61 | interaction = interaction_pb2.Interaction() 62 | interaction.ParseFromString(value) 63 | yield interaction 64 | 65 | 66 | def parse_interaction_id(text): 67 | return text[:text.rindex('-')] 68 | 69 | 70 | def table_to_panda_frame(table): 71 | contents = [[cell.text for cell in row.cells] for row in table.rows] 72 | headers = [ 73 | f'{column.text}_{index}' for index, column in enumerate(table.columns) 74 | ] 75 | return pd.DataFrame(contents, columns=headers) 76 | -------------------------------------------------------------------------------- /tapas/scripts/prediction_utils_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import tempfile 17 | from absl.testing import absltest 18 | from tapas.protos import interaction_pb2 19 | from tapas.scripts import calc_metrics_utils 20 | from tapas.scripts import prediction_utils 21 | import tensorflow.compat.v1 as tf 22 | 23 | 24 | Cell = interaction_pb2.Cell 25 | Cells = interaction_pb2.Cells 26 | Table = interaction_pb2.Table 27 | 28 | 29 | class PredictionUtilsTest(absltest.TestCase): 30 | 31 | def test_iterate_interactions(self): 32 | filepath = tempfile.mktemp(suffix='.tfrecord') 33 | interactions = [ 34 | interaction_pb2.Interaction(id='dev_723'), 35 | interaction_pb2.Interaction(id='dev_456'), 36 | interaction_pb2.Interaction(id='dev_123'), 37 | ] 38 | with tf.io.TFRecordWriter(filepath) as writer: 39 | for interaction in interactions: 40 | writer.write(interaction.SerializeToString()) 41 | actual_interactions = list(prediction_utils.iterate_interactions(filepath)) 42 | self.assertEqual(interactions, actual_interactions) 43 | 44 | def table_to_panda_frame(self): 45 | frame = prediction_utils.table_to_panda_frame( 46 | Table( 47 | columns=[Cell(text='a'), Cell(text='a')], 48 | rows=[Cells(cells=[Cell(text='0'), Cell(text='1')])])) 49 | self.assertEqual(['0'], 50 | calc_metrics_utils._collect_cells_from_table({(0, 0)}, 51 | frame)) 52 | 53 | def test_iterate_predictions(self): 54 | filepath = tempfile.mktemp(suffix='.tsv') 55 | predictions = [ 56 | { 57 | 'logits_cls': 0.1 58 | }, 59 | { 60 | 'logits_cls': [3.0, 4.0] 61 | }, 62 | ] 63 | with tf.io.gfile.GFile(filepath, mode='w') as writer: 64 | writer.write('logits_cls\n') 65 | writer.write('0.1\n') 66 | writer.write('[3 4]\n') 67 | actual_predictions = list(prediction_utils.iterate_predictions(filepath)) 68 | self.assertEqual(predictions, actual_predictions) 69 | 70 | 71 | if __name__ == '__main__': 72 | absltest.main() 73 | -------------------------------------------------------------------------------- /tapas/scripts/testdata/table_00.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 9 | 10 | 11 | 12 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 61 | 62 | 63 | 64 | 65 |
Grey's Anatomy (season 14)
Grey's Anatomy season 14 poster.jpg 7 |
Promotional poster
8 |
Starring 13 | 31 |
Country of originUnited States
No. of episodes24
Release
Original networkABC
Original releaseSeptember 28, 2017 (2017-09-28) – May 17, 2018 (2018-05-17)
Season chronology
57 |
← Previous
58 | Season 13
59 |
60 |
List of Grey's Anatomy episodes
66 | -------------------------------------------------------------------------------- /tapas/scripts/testdata/table_01.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 |
Metric diameterUS Knitting needle numberCrochet hook size
2.25 mm1B-1
2.75 mm2C-2
3.25 mm3D-3
3.5 mm4E-4
3.75 mm5F-5
4 mm6G-6
4.5 mm77
5 mm8H-8
5.5 mm9I-9
6 mm10J-10
6.5 mm10.5K-10.5
8 mm11L-11
9 mm13M/N-13
10 mm15N/P-15
12.75 mm17
15 mm19P/Q
16 mmQ
19 mm35S
25 mm50U
103 | -------------------------------------------------------------------------------- /tapas/testdata/classification_examples.tfrecords: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-research/tapas/569a3c31451d941165bd10783f73f494406b3906/tapas/testdata/classification_examples.tfrecords -------------------------------------------------------------------------------- /tapas/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /tapas/utils/beam_runner.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Utilities for running beam pipelines.""" 16 | 17 | import enum 18 | 19 | from absl import flags 20 | from apache_beam import runners 21 | from apache_beam.options import pipeline_options 22 | from apache_beam.runners.direct import direct_runner 23 | 24 | 25 | class RunnerType(enum.Enum): 26 | DIRECT = 1 27 | DATAFLOW = 2 28 | 29 | 30 | flags.DEFINE_enum_class("runner_type", RunnerType.DIRECT, RunnerType, 31 | "Runner type to use.") 32 | # Google Cloud options. 33 | # See https://beam.apache.org/get-started/wordcount-example/ 34 | flags.DEFINE_string("gc_project", None, "e.g. my-project-id") 35 | # GC regions: https://cloud.google.com/compute/docs/regions-zones 36 | flags.DEFINE_string("gc_region", None, "e.g. us-central1") 37 | flags.DEFINE_string("gc_job_name", None, "e.g. myjob") 38 | flags.DEFINE_string("gc_staging_location", None, 39 | "e.g. gs://your-bucket/staging") 40 | flags.DEFINE_string("gc_temp_location", None, "e.g. gs://your-bucket/temp") 41 | flags.DEFINE_boolean("save_main_session", False, 42 | "Useful when getting NameErrors from global imports.") 43 | # Pass Tapas sources to GC. 44 | # See https://beam.apache.org/documentation/sdks/python-pipeline-dependencies/ 45 | flags.DEFINE_list( 46 | "extra_packages", 47 | None, 48 | "Packed Tapas sources (python3 setup.py sdist).", 49 | ) 50 | 51 | FLAGS = flags.FLAGS 52 | 53 | 54 | def run_type(pipeline, runner_type): 55 | """Executes pipeline with certain runner type.""" 56 | if runner_type == RunnerType.DIRECT: 57 | print("Running pipeline with direct runner this might take a long time!") 58 | return direct_runner.DirectRunner().run(pipeline) 59 | if runner_type == RunnerType.DATAFLOW: 60 | options = pipeline_options.PipelineOptions() 61 | gc_options = options.view_as(pipeline_options.GoogleCloudOptions) 62 | gc_options.project = FLAGS.gc_project 63 | gc_options.region = FLAGS.gc_region 64 | gc_options.job_name = FLAGS.gc_job_name 65 | gc_options.staging_location = FLAGS.gc_staging_location 66 | gc_options.temp_location = FLAGS.gc_temp_location 67 | setup = options.view_as(pipeline_options.SetupOptions) 68 | setup.extra_packages = FLAGS.extra_packages 69 | setup.save_main_session = FLAGS.save_main_session 70 | return runners.DataflowRunner().run(pipeline, options=options) 71 | raise ValueError(f"Unsupported runner type: {runner_type}") 72 | 73 | 74 | def run(pipeline): 75 | return run_type(pipeline, FLAGS.runner_type) 76 | -------------------------------------------------------------------------------- /tapas/utils/beam_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Utilities around apache beams.""" 16 | 17 | from typing import Iterable, List, Tuple 18 | 19 | from tapas.protos import interaction_pb2 20 | from tapas.utils import pretrain_utils 21 | 22 | to_numpy_seed = pretrain_utils.to_numpy_seed 23 | split_by_table_id_and_write = pretrain_utils.split_by_table_id_and_write 24 | 25 | 26 | def rekey( 27 | interaction): 28 | new_interaction = interaction_pb2.Interaction() 29 | new_interaction.CopyFrom(interaction) 30 | iid = interaction.table.table_id 31 | iid = hex(to_numpy_seed(iid)) 32 | new_interaction.id = iid 33 | new_interaction.table.table_id = iid 34 | return new_interaction 35 | 36 | 37 | def _get_sharded_ranges( 38 | begin, 39 | end, 40 | max_length, 41 | ): 42 | """Recursively cuts ranges in half to satisfy 'max_length'.""" 43 | if max_length <= 0: 44 | raise ValueError("max_length <= 0.") 45 | length = end - begin 46 | if length <= max_length: 47 | return [(begin, end)] 48 | pivot = begin + length // 2 49 | return (_get_sharded_ranges(begin, pivot, max_length) + 50 | _get_sharded_ranges(pivot, end, max_length)) 51 | 52 | 53 | def get_row_sharded_interactions( 54 | interaction, 55 | max_num_cells, 56 | ): 57 | """Equally shards the interaction row-wise to satisfy 'max_num_cells'.""" 58 | num_columns = len(interaction.table.columns) 59 | max_num_rows = max_num_cells // num_columns 60 | if max_num_rows == 0: 61 | return 62 | for begin, end in _get_sharded_ranges( 63 | begin=0, 64 | end=len(interaction.table.rows), 65 | max_length=max_num_rows, 66 | ): 67 | new_interaction = interaction_pb2.Interaction() 68 | new_interaction.CopyFrom(interaction) 69 | del new_interaction.table.rows[:] 70 | for row in interaction.table.rows[begin:end]: 71 | new_interaction.table.rows.add().CopyFrom(row) 72 | yield new_interaction 73 | -------------------------------------------------------------------------------- /tapas/utils/beam_utils_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from absl import logging 17 | from absl.testing import absltest 18 | from absl.testing import parameterized 19 | from tapas.protos import interaction_pb2 20 | from tapas.utils import beam_utils 21 | 22 | 23 | class BeamUtilsTest(parameterized.TestCase): 24 | 25 | @parameterized.parameters( 26 | (5, 10, [(0, 5)]), 27 | (5, 2, [(0, 2), (2, 3), (3, 5)]), 28 | (10, 2, [(0, 2), (2, 3), (3, 5), (5, 7), (7, 8), (8, 10)]), 29 | ) 30 | def test_get_sharded_ranges( 31 | self, 32 | end, 33 | max_length, 34 | expected, 35 | ): 36 | self.assertEqual( 37 | beam_utils._get_sharded_ranges(0, end, max_length), expected) 38 | 39 | @parameterized.parameters( 40 | (5, 10, 50), 41 | (5, 10, 5), 42 | (5, 2, 5), 43 | (3, 7, 13), 44 | ) 45 | def test_get_row_sharded_interactions( 46 | self, 47 | num_columns, 48 | num_rows, 49 | max_num_cells, 50 | ): 51 | interaction = interaction_pb2.Interaction() 52 | for i in range(num_columns): 53 | interaction.table.columns.add().text = f'{i}' 54 | for j in range(num_rows): 55 | row = interaction.table.rows.add() 56 | for i in range(num_columns): 57 | row.cells.add().text = f'{j}_{i}' 58 | interactions = list( 59 | beam_utils.get_row_sharded_interactions(interaction, max_num_cells)) 60 | restorted_interaction = interaction_pb2.Interaction() 61 | restorted_interaction.CopyFrom(interaction) 62 | del restorted_interaction.table.rows[:] 63 | for shard in interactions: 64 | self.assertEqual(shard.table.columns, interaction.table.columns) 65 | self.assertLessEqual(len(shard.table.rows) * num_columns, max_num_cells) 66 | for row in shard.table.rows: 67 | restorted_interaction.table.rows.add().CopyFrom(row) 68 | logging.info(restorted_interaction) 69 | self.assertEqual(interaction, restorted_interaction) 70 | 71 | 72 | if __name__ == '__main__': 73 | absltest.main() 74 | -------------------------------------------------------------------------------- /tapas/utils/constants.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Constants used by Dopa tables project.""" 16 | 17 | import enum 18 | 19 | EMPTY_TEXT = 'EMPTY' 20 | 21 | NUMBER_TYPE = 'number' 22 | DATE_TYPE = 'date' 23 | 24 | 25 | class Relation(enum.Enum): 26 | HEADER_TO_CELL = 1 # Connects header to cell. 27 | CELL_TO_HEADER = 2 # Connects cell to header. 28 | QUERY_TO_HEADER = 3 # Connects query to headers. 29 | QUERY_TO_CELL = 4 # Connects query to cells. 30 | ROW_TO_CELL = 5 # Connects row to cells. 31 | CELL_TO_ROW = 6 # Connects cells to row. 32 | EQ = 7 # Annotation value is same as cell value 33 | LT = 8 # Annotation value is less than cell value 34 | GT = 9 # Annotation value is greater than cell value 35 | -------------------------------------------------------------------------------- /tapas/utils/contrastive_statements_test_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Helpers for contrastive statements tests.""" 16 | 17 | import re 18 | from typing import List, Mapping 19 | 20 | from tapas.protos import annotated_text_pb2 21 | from tapas.protos import interaction_pb2 22 | 23 | _AnnotatedText = annotated_text_pb2.AnnotatedText 24 | 25 | 26 | def get_test_interaction(): 27 | """Creates example interaction with annotations.""" 28 | interaction = create_interaction( 29 | [ 30 | ['Name', 'Age', 'Birthday'], 31 | ['Bob', '1.7', '24 April 1950'], 32 | ['Julia', '1.5', '24 April 1951'], 33 | ['Peter', '1.9', '24 March 1950'], 34 | ], 'Robert was born on 24 April 1950.', { 35 | 'bob': 'http://en.wikipedia.org/wiki/Bob', 36 | 'robert': 'http://en.wikipedia.org/wiki/Bob', 37 | 'peter': 'http://en.wikipedia.org/wiki/Peter', 38 | 'julia': 'http://en.wikipedia.org/wiki/Julia', 39 | }) 40 | return interaction 41 | 42 | 43 | def create_interaction( 44 | table, 45 | statement, 46 | mentions, 47 | ): 48 | """Creates interaction proto with annotations by matching entity mentions.""" 49 | interaction = interaction_pb2.Interaction() 50 | for index, row in enumerate(table): 51 | new_row = interaction.table.columns 52 | if index > 0: 53 | new_row = interaction.table.rows.add().cells 54 | for cell in row: 55 | new_row.add().text = cell 56 | for row in interaction.table.rows: 57 | for cell in row.cells: 58 | cell_text = cell.text.lower() 59 | if cell_text not in mentions: 60 | continue 61 | annotated_text = cell.Extensions[_AnnotatedText.annotated_cell_ext] 62 | annotation = annotated_text.annotations.add() 63 | annotation.begin_byte_index = 0 64 | annotation.end_byte_index = len(cell_text) 65 | annotation.identifier = mentions[cell_text] 66 | 67 | question = interaction.questions.add() 68 | question.original_text = statement 69 | q_annotated_text = question.Extensions[_AnnotatedText.annotated_question_ext] 70 | question_text = question.original_text.lower() 71 | for phrase, identifier in mentions.items(): 72 | for match in re.finditer(phrase, question_text): 73 | annotation = q_annotated_text.annotations.add() 74 | begin, end = match.span() 75 | annotation.begin_byte_index = begin 76 | annotation.end_byte_index = end 77 | annotation.identifier = identifier 78 | return interaction 79 | -------------------------------------------------------------------------------- /tapas/utils/create_data_file_io.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Helpers for mapping input to output filenames.""" 16 | 17 | import os 18 | from typing import Text, Tuple, List, Set 19 | 20 | import tensorflow.compat.v1 as tf 21 | 22 | 23 | 24 | def _check_basename( 25 | basenames, 26 | basename, 27 | input_dir, 28 | ): 29 | if basename in basenames: 30 | raise ValueError("Basename should be unique:" 31 | f"basename: {basename}, input_dir:{input_dir}") 32 | basenames.add(basename) 33 | 34 | 35 | def _is_supported(filename): 36 | extension = os.path.splitext(filename)[1] 37 | return extension in [ 38 | ".txtpb.gz", 39 | ".txtpb", 40 | ".tfrecord", 41 | ".tfrecords", 42 | ] 43 | 44 | 45 | def get_inputs_and_outputs(input_dir, 46 | output_dir): 47 | """Reads files from 'input_dir' and creates corresponding paired outputs. 48 | 49 | Args: 50 | input_dir: Where to read inputs from. 51 | output_dir: Where to read outputs from. 52 | 53 | Returns: 54 | inputs and outputs. 55 | """ 56 | input_files = tf.io.gfile.listdir(input_dir) 57 | 58 | basenames = set() 59 | 60 | inputs = [] 61 | outputs = [] 62 | 63 | for filename in input_files: 64 | if not _is_supported(filename): 65 | print(f"Skipping unsupported file: {filename}") 66 | continue 67 | basename, _ = os.path.splitext(filename) 68 | _check_basename(basenames, basename, input_dir) 69 | inputs.append(filename) 70 | output = f"{basename}.tfrecord" 71 | outputs.append(output) 72 | 73 | inputs = [os.path.join(input_dir, i) for i in inputs] 74 | outputs = [os.path.join(output_dir, o) for o in outputs] 75 | return inputs, outputs 76 | -------------------------------------------------------------------------------- /tapas/utils/experiment_utils_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | import os 16 | import tempfile 17 | 18 | from absl.testing import absltest 19 | from tapas.utils import experiment_utils 20 | 21 | import tensorflow.compat.v1 as tf 22 | from google.protobuf import text_format 23 | 24 | 25 | class ExperimentUtilsTest(absltest.TestCase): 26 | 27 | def test_iterate_checkpoints_single_step(self): 28 | results = list( 29 | experiment_utils.iterate_checkpoints( 30 | model_dir='path', 31 | single_step=100, 32 | marker_file_prefix='path', 33 | total_steps=None)) 34 | self.assertEqual(results, [(100, 'path/model.ckpt-100')]) 35 | 36 | def test_iterate_checkpoints_multi_step(self): 37 | test_tmpdir = tempfile.mkdtemp() 38 | checkpoints = [ 39 | os.path.join(test_tmpdir, checkpoint) for checkpoint in 40 | ['model.ckpt-00001', 'model.ckpt-00002', 'model.ckpt-00003'] 41 | ] 42 | # Write fake checkpoint file to tmpdir. 43 | state = tf.train.generate_checkpoint_state_proto( 44 | test_tmpdir, 45 | model_checkpoint_path=checkpoints[-1], 46 | all_model_checkpoint_paths=checkpoints) 47 | with open(os.path.join(test_tmpdir, 'checkpoint'), 'w') as f: 48 | f.write(text_format.MessageToString(state)) 49 | for checkpoint in checkpoints: 50 | with open(f'{checkpoint}.index', 'w') as f: 51 | f.write('\n') 52 | 53 | marker_file_prefix = os.path.join(test_tmpdir, 'marker') 54 | results = list( 55 | experiment_utils.iterate_checkpoints( 56 | model_dir=test_tmpdir, 57 | total_steps=3, 58 | marker_file_prefix=marker_file_prefix)) 59 | 60 | expected_steps = [1, 2, 3] 61 | self.assertEqual(results, list(zip(expected_steps, checkpoints))) 62 | for step in expected_steps: 63 | self.assertTrue(tf.gfile.Exists(f'{marker_file_prefix}-{step}.done')) 64 | 65 | results = list( 66 | experiment_utils.iterate_checkpoints( 67 | model_dir=test_tmpdir, 68 | total_steps=3, 69 | marker_file_prefix=marker_file_prefix)) 70 | self.assertEmpty(results) 71 | 72 | 73 | if __name__ == '__main__': 74 | absltest.main() 75 | -------------------------------------------------------------------------------- /tapas/utils/file_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Helper function for dealing with local files.""" 16 | 17 | from typing import List, Text 18 | import tensorflow.compat.v1 as tf 19 | 20 | 21 | def make_directories(path): 22 | """Create directory recursively. Don't do anything if directory exits.""" 23 | tf.io.gfile.makedirs(path) 24 | 25 | 26 | def list_directory(path): 27 | """List directory contents.""" 28 | return tf.io.gfile.listdir(path) 29 | -------------------------------------------------------------------------------- /tapas/utils/interaction_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Utility functions for creating interaction protos.""" 16 | 17 | import ast 18 | import csv 19 | from typing import Text, Iterable, List 20 | from tapas.protos import interaction_pb2 21 | from tapas.utils import text_utils 22 | 23 | _AggregationFunction = interaction_pb2.Answer.AggregationFunction 24 | 25 | # TSV headers. 26 | _ID = 'id' 27 | _ANNOTATOR = 'annotator' 28 | _POSITION = 'position' 29 | _QUESTION = 'question' 30 | _TABLE_FILE = 'table_file' 31 | _ANSWER_TEXT = 'answer_text' 32 | _ANSWER_COORDINATES = 'answer_coordinates' 33 | _AGGREGATION = 'aggregation' 34 | _ANSWER_FLOAT_VALUE = 'float_answer' 35 | _ANSWER_CLASS_INDEX = 'class_index' 36 | 37 | 38 | def _parse_answer_coordinates(answer_coordinate_str, 39 | answer): 40 | """Populates the answer_coordinates field of `answer` by parsing `answer_coordinate_str`. 41 | 42 | Args: 43 | answer_coordinate_str: A string representation of a Python list of tuple 44 | strings. 45 | For example: "['(1, 4)','(1, 3)', ...]" 46 | answer: an Answer object. 47 | """ 48 | 49 | try: 50 | coords = ast.literal_eval(answer_coordinate_str) 51 | for row_index, column_index in sorted( 52 | ast.literal_eval(coord) for coord in coords): 53 | answer.answer_coordinates.add( 54 | row_index=row_index, column_index=column_index) 55 | except SyntaxError: 56 | raise ValueError('Unable to evaluate %s' % answer_coordinate_str) 57 | 58 | 59 | def _parse_answer_text(answer_text, answer): 60 | """Populates the answer_texts field of `answer` by parsing `answer_text`. 61 | 62 | Args: 63 | answer_text: A string representation of a Python list of strings. 64 | For example: "[u'test', u'hello', ...]" 65 | answer: an Answer object. 66 | """ 67 | try: 68 | for value in ast.literal_eval(answer_text): 69 | answer.answer_texts.append(value) 70 | except SyntaxError: 71 | raise ValueError('Unable to evaluate %s' % answer_text) 72 | 73 | 74 | def read_from_tsv_file( 75 | file_handle): 76 | """Parses a TSV file in SQA format into a list of interactions. 77 | 78 | Args: 79 | file_handle: File handle of a TSV file in SQA format. 80 | 81 | Returns: 82 | Questions grouped into interactions. 83 | """ 84 | questions = {} 85 | for row in csv.DictReader(file_handle, delimiter='\t'): 86 | sequence_id = text_utils.get_sequence_id(row[_ID], row[_ANNOTATOR]) 87 | key = sequence_id, row[_TABLE_FILE] 88 | if key not in questions: 89 | questions[key] = {} 90 | 91 | position = int(row[_POSITION]) 92 | 93 | answer = interaction_pb2.Answer() 94 | _parse_answer_coordinates(row[_ANSWER_COORDINATES], answer) 95 | _parse_answer_text(row[_ANSWER_TEXT], answer) 96 | 97 | if _AGGREGATION in row: 98 | agg_func = row[_AGGREGATION].upper().strip() 99 | if agg_func: 100 | answer.aggregation_function = _AggregationFunction.Value(agg_func) 101 | if _ANSWER_FLOAT_VALUE in row: 102 | float_value = row[_ANSWER_FLOAT_VALUE] 103 | if float_value: 104 | answer.float_value = float(float_value) 105 | if _ANSWER_CLASS_INDEX in row: 106 | class_index = row[_ANSWER_CLASS_INDEX] 107 | if class_index: 108 | answer.class_index = int(class_index) 109 | 110 | questions[key][position] = interaction_pb2.Question( 111 | id=text_utils.get_question_id(sequence_id, position), 112 | original_text=row[_QUESTION], 113 | answer=answer) 114 | 115 | interactions = [] 116 | for (sequence_id, table_file), question_dict in sorted( 117 | questions.items(), key=lambda sid: sid[0]): 118 | question_list = [ 119 | question for _, question in sorted( 120 | question_dict.items(), key=lambda pos: pos[0]) 121 | ] 122 | interactions.append( 123 | interaction_pb2.Interaction( 124 | id=sequence_id, 125 | questions=question_list, 126 | table=interaction_pb2.Table(table_id=table_file))) 127 | return interactions 128 | -------------------------------------------------------------------------------- /tapas/utils/intermediate_pretrain_utils_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import os 17 | import tempfile 18 | 19 | from absl import flags 20 | from absl.testing import absltest 21 | from absl.testing import parameterized 22 | 23 | from tapas.protos import interaction_pb2 24 | from tapas.utils import beam_runner 25 | from tapas.utils import contrastive_statements_test_utils 26 | from tapas.utils import intermediate_pretrain_utils 27 | from tapas.utils import synthesize_entablement 28 | from tapas.utils import tf_example_utils 29 | 30 | import tensorflow.compat.v1 as tf 31 | from google.protobuf import text_format 32 | 33 | 34 | FLAGS = flags.FLAGS 35 | TEST_PATH = "tapas/utils/testdata/" 36 | _RESERVED_SYMBOLS = ("[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]", "[EMPTY]") 37 | 38 | 39 | def _read_record(filepath, message): 40 | for value in tf.python_io.tf_record_iterator(filepath): 41 | element = message() 42 | element.ParseFromString(value) 43 | yield element 44 | 45 | 46 | def _create_vocab(vocab, vocab_path): 47 | with tf.gfile.Open(vocab_path, "w") as input_handle: 48 | input_handle.write("\n".join(vocab)) 49 | 50 | 51 | class CreatePretrainingDataTest(parameterized.TestCase): 52 | 53 | def setUp(self): 54 | super(CreatePretrainingDataTest, self).setUp() 55 | 56 | self._test_dir = TEST_PATH 57 | 58 | 59 | @parameterized.parameters( 60 | (beam_runner.RunnerType.DIRECT, False), 61 | (beam_runner.RunnerType.DIRECT, True), 62 | ) 63 | def test_end_to_end(self, runner_type, add_example_conversion): 64 | mode = intermediate_pretrain_utils.Mode.ALL 65 | prob_count_aggregation = 0.2 66 | use_fake_table = False 67 | add_opposite_table = False 68 | drop_without_support_rate = 0.0 69 | 70 | with tempfile.TemporaryDirectory() as temp_dir: 71 | config = None 72 | if add_example_conversion: 73 | vocab_path = os.path.join(temp_dir, "vocab.txt") 74 | _create_vocab(list(_RESERVED_SYMBOLS) + ["released"], vocab_path) 75 | config = tf_example_utils.ClassifierConversionConfig( 76 | vocab_file=vocab_path, 77 | max_seq_length=32, 78 | max_column_id=32, 79 | max_row_id=32, 80 | strip_column_names=False, 81 | ) 82 | 83 | pipeline = intermediate_pretrain_utils.build_pipeline( 84 | mode=mode, 85 | config=synthesize_entablement.SynthesizationConfig( 86 | prob_count_aggregation=prob_count_aggregation), 87 | use_fake_table=use_fake_table, 88 | add_opposite_table=add_opposite_table, 89 | drop_without_support_rate=drop_without_support_rate, 90 | input_file=os.path.join(self._test_dir, 91 | "pretrain_interactions.txtpb"), 92 | output_dir=temp_dir, 93 | output_suffix=".tfrecord", 94 | num_splits=3, 95 | conversion_config=config, 96 | ) 97 | 98 | beam_runner.run_type(pipeline, runner_type).wait_until_finish() 99 | 100 | message_type = interaction_pb2.Interaction 101 | if add_example_conversion: 102 | message_type = tf.train.Example 103 | 104 | for name in [("train"), ("test")]: 105 | self.assertNotEmpty( 106 | list( 107 | _read_record( 108 | os.path.join(temp_dir, f"{name}.tfrecord"), 109 | message_type, 110 | ))) 111 | 112 | if add_example_conversion: 113 | self.assertNotEmpty( 114 | list( 115 | _read_record( 116 | os.path.join(temp_dir, "interactions.tfrecord"), 117 | interaction_pb2.Interaction, 118 | ),)) 119 | 120 | 121 | if __name__ == "__main__": 122 | absltest.main() 123 | -------------------------------------------------------------------------------- /tapas/utils/interpretation_utils_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # coding=utf8 16 | 17 | import random 18 | 19 | from absl.testing import absltest 20 | from absl.testing import parameterized 21 | from tapas.protos import interaction_pb2 22 | from tapas.utils import interpretation_utils 23 | 24 | _Candidate = interpretation_utils.Candidate 25 | _AggFun = interaction_pb2.Answer.AggregationFunction 26 | 27 | 28 | def _to_interaction(table, float_value): 29 | interaction = interaction_pb2.Interaction() 30 | for row in table: 31 | new_row = interaction.table.rows.add() 32 | for value in row: 33 | new_cell = new_row.cells.add() 34 | if value is not None: 35 | new_cell.numeric_value.float_value = value 36 | new_question = interaction.questions.add() 37 | new_question.answer.float_value = float_value 38 | return interaction 39 | 40 | 41 | class InterpretationUtilsTest(parameterized.TestCase, absltest.TestCase): 42 | 43 | @parameterized.parameters((0,), (1,), (2,), (3,), (4.), (5,)) 44 | def test_float_conversion(self, seed): 45 | interaction = _to_interaction( 46 | [[1.0, 3.0, 6.0], [2.0, 0.0, None], [1.0, None, 0.0]], 3.0) 47 | rng = random.Random(seed) 48 | actual = interpretation_utils.find_candidates(rng, interaction.table, 49 | interaction.questions[0]) 50 | expected = [ 51 | _Candidate(_AggFun.COUNT, 0, (0, 1, 2)), 52 | _Candidate(_AggFun.COUNT, 1, (0, 1, 2)), 53 | _Candidate(_AggFun.COUNT, 2, (0, 1, 2)), 54 | _Candidate(_AggFun.SUM, 0, (0, 1)), 55 | _Candidate(_AggFun.SUM, 0, (1, 2)), 56 | _Candidate(_AggFun.SUM, 1, (0,)), 57 | _Candidate(_AggFun.SUM, 1, (0, 1)), 58 | _Candidate(_AggFun.AVERAGE, 1, (0,)), 59 | _Candidate(_AggFun.AVERAGE, 2, (0, 2)), 60 | ] 61 | self.assertEqual(expected, actual) 62 | 63 | @parameterized.parameters((0,), (1,), (2,), (3,), (4.), (5,)) 64 | def test_random_sampling(self, seed): 65 | interaction = _to_interaction( 66 | [[1.0, 3.0, 6.0], [2.0, 0.0, None], [1.0, None, 0.0]], 3.0) 67 | rng = random.Random(seed) 68 | 69 | interpretation_utils._MAX_NUM_CANDIDATES = 1 70 | 71 | actual = interpretation_utils.find_candidates(rng, interaction.table, 72 | interaction.questions[0]) 73 | self.assertLen(actual, 3) 74 | 75 | @parameterized.parameters((0,), (1,), (2,), (3,), (4.), (5,)) 76 | def test_random_exploration(self, seed): 77 | interaction = _to_interaction( 78 | [[1.0, 3.0, 6.0], [2.0, 0.0, None], [1.0, None, 0.0]], 3.0) 79 | rng = random.Random(seed) 80 | 81 | interpretation_utils._MAX_INDICES_TO_EXPLORE = 1 82 | 83 | actual = interpretation_utils.find_candidates(rng, interaction.table, 84 | interaction.questions[0]) 85 | expected = [ 86 | _Candidate(_AggFun.COUNT, 0, (0, 1, 2)), 87 | _Candidate(_AggFun.COUNT, 1, (0, 1, 2)), 88 | _Candidate(_AggFun.COUNT, 2, (0, 1, 2)), 89 | _Candidate(_AggFun.SUM, 1, (0,)), 90 | _Candidate(_AggFun.AVERAGE, 1, (0,)), 91 | ] 92 | self.assertEqual(expected, actual) 93 | 94 | @parameterized.parameters((0,), (1,), (2,), (3,), (4.), (5,)) 95 | def test_selection_answer(self, seed): 96 | interaction = _to_interaction( 97 | [[1.0, 3.0, 6.0], [2.0, 0.0, None], [1.0, None, 0.0]], 100.0) 98 | coords = interaction.questions[0].answer.answer_coordinates.add() 99 | coords.row_index = 1 100 | coords.column_index = 2 101 | rng = random.Random(seed) 102 | 103 | actual = interpretation_utils.find_candidates(rng, interaction.table, 104 | interaction.questions[0]) 105 | expected = [_Candidate(_AggFun.NONE, 2, (1,))] 106 | self.assertEqual(expected, actual) 107 | 108 | 109 | if __name__ == "__main__": 110 | absltest.main() 111 | -------------------------------------------------------------------------------- /tapas/utils/sem_tab_fact_utils_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import os 17 | from absl import flags 18 | from absl import logging 19 | from absl.testing import absltest 20 | from absl.testing import parameterized 21 | from tapas.protos import interaction_pb2 22 | from tapas.utils import sem_tab_fact_utils 23 | 24 | from google.protobuf import text_format 25 | 26 | FLAGS = flags.FLAGS 27 | TEST_PATH = 'tapas/utils/testdata/' 28 | 29 | 30 | class GetTableDimensionsTest(parameterized.TestCase): 31 | 32 | def setUp(self): 33 | super().setUp() 34 | self.test_data_dir = TEST_PATH 35 | 36 | @parameterized.parameters( 37 | (sem_tab_fact_utils.Version.V1), 38 | (sem_tab_fact_utils.Version.V2), 39 | ) 40 | def test_process_doc(self, version): 41 | interactions = list( 42 | sem_tab_fact_utils._process_doc( 43 | os.path.join( 44 | self.test_data_dir, 45 | 'sem_tab_fact_20502.xml', 46 | ), 47 | version, 48 | )) 49 | 50 | if version == sem_tab_fact_utils.Version.V1: 51 | name = 'sem_tab_fact_20502_interaction.txtpb' 52 | elif version == sem_tab_fact_utils.Version.V2: 53 | name = 'sem_tab_fact_20502_interaction_v2.txtpb' 54 | else: 55 | raise ValueError(f'Unsupported version: {version.name}') 56 | interaction_file = os.path.join(self.test_data_dir, name) 57 | with open(interaction_file) as input_file: 58 | interaction = text_format.ParseLines(input_file, 59 | interaction_pb2.Interaction()) 60 | self.assertLen(interactions, 4) 61 | logging.info(interactions[0]) 62 | self.assertEqual(interactions[0], interaction) 63 | questions = [ 64 | ( # pylint: disable=g-complex-comprehension 65 | i.questions[0].id, 66 | i.questions[0].original_text, 67 | i.questions[0].answer.class_index, 68 | ) for i in interactions 69 | ] 70 | self.assertEqual(questions, [ 71 | ( 72 | 'sem_tab_fact_20502_Table_2_2_0', 73 | 'At the same time, these networks often occur in tandem at the firm level.', 74 | 1, 75 | ), 76 | ( 77 | 'sem_tab_fact_20502_Table_2_3_0', 78 | 'For each network interaction, there is considerable variation both across and within countries.', 79 | 1, 80 | ), 81 | ( 82 | 'sem_tab_fact_20502_Table_2_5_0', 83 | 'The n value is same for Hong Kong and Malaysia.', 84 | 0, 85 | ), 86 | ( 87 | 'sem_tab_fact_20502_Table_2_8_0', 88 | 'There are 9 different types country in the given table.', 89 | 1, 90 | ), 91 | ]) 92 | 93 | 94 | if __name__ == '__main__': 95 | absltest.main() 96 | -------------------------------------------------------------------------------- /tapas/utils/sentence_tokenizer.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Helper package to find relevant paragraphs in a website using a query.""" 16 | 17 | import functools 18 | import threading 19 | from typing import List, Text 20 | 21 | import nltk 22 | 23 | 24 | @functools.lru_cache() 25 | def _load_sentence_tokenizer(): 26 | """Returns a sentence tokenization function.""" 27 | # Lock to avoid a race-condition in the creation of the download directory. 28 | with threading.Lock(): 29 | nltk.download("punkt") 30 | return nltk.data.load("nltk:tokenizers/punkt/english.pickle") 31 | 32 | 33 | def tokenize(document): 34 | """Split text into sentences.""" 35 | sentence_tokenizer = _load_sentence_tokenizer() 36 | result = [] 37 | for sentence in sentence_tokenizer.tokenize(document): 38 | sentence = sentence.strip() 39 | if sentence: 40 | result.append(sentence) 41 | return result 42 | -------------------------------------------------------------------------------- /tapas/utils/sentence_tokenizer_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from absl.testing import absltest 17 | from tapas.utils import sentence_tokenizer 18 | 19 | 20 | class SentenceTokenizerTest(absltest.TestCase): 21 | 22 | def test_sentence_tokenizer(self): 23 | sentences = sentence_tokenizer.tokenize( 24 | 'A sentence about dogs. Dogs are cute amimals. Cats are OK as well') 25 | self.assertEqual(sentences, [ 26 | 'A sentence about dogs.', 27 | 'Dogs are cute amimals.', 28 | 'Cats are OK as well', 29 | ]) 30 | 31 | 32 | if __name__ == '__main__': 33 | absltest.main() 34 | -------------------------------------------------------------------------------- /tapas/utils/span_prediction_utils_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from absl import logging 17 | from tapas.utils import span_prediction_utils 18 | import tensorflow.compat.v1 as tf 19 | 20 | tf.disable_v2_behavior() 21 | 22 | 23 | class SpanPredictionUtilsTest(tf.test.TestCase): 24 | 25 | def test_get_start_and_end_ids(self): 26 | label_ids = tf.constant([ 27 | [0, 0, 0, 0, 1, 0], 28 | [0, 0, 1, 1, 1, 0], 29 | [0, 1, 0, 0, 1, 1], 30 | ]) 31 | start_ids, end_ids = span_prediction_utils._get_start_and_end_ids(label_ids) 32 | with self.session() as sess: 33 | sess.run(tf.global_variables_initializer()) 34 | (start_ids_val, end_ids_val) = sess.run([start_ids, end_ids]) 35 | self.assertAllClose(start_ids_val, [ 36 | [0., 0., 0., 0., 1., 0.], 37 | [0., 0., 1., 0., 0., 0.], 38 | [0., 1., 0., 0., 1., 0.], 39 | ]) 40 | self.assertAllClose(end_ids_val, [ 41 | [0., 0., 0., 0., 1., 0.], 42 | [0., 0., 0., 0., 1., 0.], 43 | [0., 1., 0., 0., 0., 1.], 44 | ]) 45 | 46 | def test_get_span_logits(self): 47 | seq_length = 5 48 | batch_size = 2 49 | embedding_dim = 3 50 | embeddings = tf.random.normal(shape=(batch_size, seq_length, embedding_dim)) 51 | 52 | start_ids = tf.constant([ 53 | [0.0, 1.0, 0.0, 0.0, 0.0], 54 | [0.0, 0.0, 1.0, 0.0, 0.0], 55 | ]) 56 | end_ids = tf.constant([ 57 | [0.0, 0.0, 0.0, 1.0, 0.0], 58 | [0.0, 0.0, 0.0, 1.0, 0.0], 59 | ]) 60 | column_ids = tf.constant([ 61 | [0, 1, 1, 1, 2], 62 | [0, 0, 1, 1, 1], 63 | ]) 64 | row_ids = tf.constant([ 65 | [0, 2, 2, 2, 2], 66 | [0, 3, 3, 3, 3], 67 | ]) 68 | 69 | spans, span_logits, loss = span_prediction_utils._get_span_logits( 70 | embeddings, 71 | start_ids, 72 | end_ids, 73 | column_ids, 74 | row_ids, 75 | max_span_length=2, 76 | ) 77 | 78 | span_mask = tf.where( 79 | span_logits > -1000.0, 80 | tf.ones_like(span_logits), 81 | tf.zeros_like(span_logits), 82 | ) 83 | 84 | with self.session() as sess: 85 | sess.run(tf.global_variables_initializer()) 86 | ( 87 | spans_value, 88 | span_logits_value, 89 | loss_value, 90 | span_mask, 91 | ) = sess.run([ 92 | spans, 93 | span_logits, 94 | loss, 95 | span_mask, 96 | ]) 97 | 98 | logging.info("span_value: %s", spans_value) 99 | logging.info("span_logits_value: %s", span_logits_value) 100 | logging.info("loss_value: %s", loss_value) 101 | 102 | self.assertAllClose( 103 | spans, 104 | [ 105 | [[0, 0], [1, 1], [2, 2], [3, 3], [4, 4], [0, 1], [1, 2], [2, 3], 106 | [3, 4]], 107 | [[0, 0], [1, 1], [2, 2], [3, 3], [4, 4], [0, 1], [1, 2], [2, 3], 108 | [3, 4]], 109 | ], 110 | ) 111 | self.assertAllClose( 112 | span_mask, 113 | [ 114 | [0., 1., 1., 1., 1., 0., 1., 1., 0.], 115 | [0., 0., 1., 1., 1., 0., 0., 1., 1.], 116 | ], 117 | ) 118 | self.assertGreater(loss_value, 0.0) 119 | 120 | def test_get_boundary_logits(self): 121 | seq_length = 5 122 | batch_size = 2 123 | embedding_dim = 3 124 | embeddings = tf.random.normal(shape=(batch_size, seq_length, embedding_dim)) 125 | 126 | label_ids = tf.constant([ 127 | [0, 1, 1, 0, 0], 128 | [0, 0, 1, 1, 1], 129 | ], 130 | shape=(batch_size, seq_length)) 131 | column_ids = tf.constant([ 132 | [0, 1, 1, 1, 2], 133 | [0, 0, 1, 1, 1], 134 | ]) 135 | row_ids = tf.constant([ 136 | [0, 2, 2, 2, 2], 137 | [0, 3, 3, 3, 3], 138 | ]) 139 | 140 | spans, span_logits, loss = span_prediction_utils.get_boundary_logits( 141 | embeddings, label_ids, column_ids, row_ids, max_span_length=2) 142 | 143 | all_finite = tf.reduce_all(tf.math.is_finite(span_logits)) 144 | 145 | with self.session() as sess: 146 | sess.run(tf.global_variables_initializer()) 147 | ( 148 | spans_value, 149 | span_logits_value, 150 | loss_value, 151 | ) = sess.run([ 152 | spans, 153 | span_logits, 154 | loss, 155 | ]) 156 | self.assertTrue(sess.run(all_finite)) 157 | 158 | logging.info("spans_value: %s", spans_value) 159 | logging.info("span_logits_value: %s", span_logits_value) 160 | logging.info("loss_value: %s", loss_value) 161 | self.assertAllEqual( 162 | spans_value[0], 163 | spans_value[1], 164 | ) 165 | self.assertAllEqual(spans_value[0], [[0, 0], [1, 1], [2, 2], [3, 3], [4, 4], 166 | [0, 1], [1, 2], [2, 3], [3, 4]]) 167 | self.assertGreater(loss_value, 0.0) 168 | self.assertEqual(spans.shape, (2, 9, 2)) 169 | self.assertEqual(span_logits.shape, (2, 9)) 170 | self.assertEqual(loss.shape, ()) 171 | 172 | 173 | if __name__ == "__main__": 174 | tf.test.main() 175 | -------------------------------------------------------------------------------- /tapas/utils/tasks.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Lists all the fine-tuning tasks supported by Tapas.""" 16 | 17 | import enum 18 | 19 | 20 | class Task(enum.Enum): 21 | """Fine-tuning tasks supported by Tapas.""" 22 | SQA = 0 23 | WTQ = 1 24 | WIKISQL = 2 25 | WIKISQL_SUPERVISED = 3 26 | TABFACT = 4 27 | NQ_RETRIEVAL = 7 28 | HYBRIDQA_RC = 8 # Reading comprehension (RC) of HybridQA. 29 | HYBRIDQA_E2E = 9 # HybridQA RC using preds of HybridQA Cell Selection. 30 | HYBRIDQA = 6 31 | SEM_TAB_FACT = 10 # https://sites.google.com/corp/view/sem-tab-facts 32 | -------------------------------------------------------------------------------- /tapas/utils/testdata/interaction_00.pbtxt: -------------------------------------------------------------------------------- 1 | # proto-file: third_party/py/tapas/protos/interaction.proto 2 | # proto-message: language.tapas.Interaction 3 | id: "nt-12454-1" 4 | table: { 5 | columns: { 6 | text: "Rank" 7 | } 8 | columns: { 9 | text: "City" 10 | } 11 | columns: { 12 | text: "Passengers" 13 | } 14 | columns: { 15 | text: "Ranking" 16 | } 17 | columns: { 18 | text: "Airline" 19 | } 20 | rows: { 21 | cells: { 22 | text: "1" 23 | numeric_value: { 24 | float_value: 1.0 25 | } 26 | } 27 | cells: { 28 | text: "United States, Los Angeles" 29 | } 30 | cells: { 31 | text: "14,749" 32 | numeric_value: { 33 | float_value: 14749.0 34 | } 35 | } 36 | cells: { 37 | text: "" 38 | } 39 | cells: { 40 | text: "Alaska Airlines" 41 | } 42 | } 43 | rows: { 44 | cells: { 45 | text: "2" 46 | numeric_value: { 47 | float_value: 2.0 48 | } 49 | } 50 | cells: { 51 | text: "United States, Houston" 52 | } 53 | cells: { 54 | text: "5,465" 55 | numeric_value: { 56 | float_value: 5465.0 57 | } 58 | } 59 | cells: { 60 | text: "" 61 | } 62 | cells: { 63 | text: "United Express" 64 | } 65 | } 66 | rows: { 67 | cells: { 68 | text: "3" 69 | numeric_value: { 70 | float_value: 3.0 71 | } 72 | } 73 | cells: { 74 | text: "Canada, Calgary" 75 | } 76 | cells: { 77 | text: "3,761" 78 | numeric_value: { 79 | float_value: 3761.0 80 | } 81 | } 82 | cells: { 83 | text: "" 84 | } 85 | cells: { 86 | text: "Air Transat, WestJet" 87 | } 88 | } 89 | rows: { 90 | cells: { 91 | text: "4" 92 | numeric_value: { 93 | float_value: 4.0 94 | } 95 | } 96 | cells: { 97 | text: "Canada, Saskatoon" 98 | } 99 | cells: { 100 | text: "2,282" 101 | numeric_value: { 102 | float_value: 2282.0 103 | } 104 | } 105 | cells: { 106 | text: "4" 107 | } 108 | cells: { 109 | text: "" 110 | } 111 | } 112 | rows: { 113 | cells: { 114 | text: "5" 115 | numeric_value: { 116 | float_value: 5.0 117 | } 118 | } 119 | cells: { 120 | text: "Canada, Vancouver" 121 | } 122 | cells: { 123 | text: "2,103" 124 | numeric_value: { 125 | float_value: 2103.0 126 | } 127 | } 128 | cells: { 129 | text: "" 130 | } 131 | cells: { 132 | text: "Air Transat" 133 | } 134 | } 135 | rows: { 136 | cells: { 137 | text: "6" 138 | numeric_value: { 139 | float_value: 6.0 140 | } 141 | } 142 | cells: { 143 | text: "United States, Phoenix" 144 | } 145 | cells: { 146 | text: "1,829" 147 | numeric_value: { 148 | float_value: 1829.0 149 | } 150 | } 151 | cells: { 152 | text: "1" 153 | } 154 | cells: { 155 | text: "US Airways" 156 | } 157 | } 158 | rows: { 159 | cells: { 160 | text: "7" 161 | numeric_value: { 162 | float_value: 7.0 163 | } 164 | } 165 | cells: { 166 | text: "Canada, Toronto" 167 | } 168 | cells: { 169 | text: "1,202" 170 | numeric_value: { 171 | float_value: 1202.0 172 | } 173 | } 174 | cells: { 175 | text: "1" 176 | } 177 | cells: { 178 | text: "Air Transat, CanJet" 179 | } 180 | } 181 | rows: { 182 | cells: { 183 | text: "8" 184 | numeric_value: { 185 | float_value: 8.0 186 | } 187 | } 188 | cells: { 189 | text: "Canada, Edmonton" 190 | } 191 | cells: { 192 | text: "110" 193 | numeric_value: { 194 | float_value: 110.0 195 | } 196 | } 197 | cells: { 198 | text: "" 199 | } 200 | cells: { 201 | text: "" 202 | } 203 | } 204 | rows: { 205 | cells: { 206 | text: "9" 207 | numeric_value: { 208 | float_value: 9.0 209 | } 210 | } 211 | cells: { 212 | text: "United States, Oakland" 213 | } 214 | cells: { 215 | text: "107" 216 | numeric_value: { 217 | float_value: 107.0 218 | } 219 | } 220 | cells: { 221 | text: "" 222 | } 223 | cells: { 224 | text: "" 225 | } 226 | } 227 | table_id: "table_csv/203_515.csv" 228 | } 229 | questions: { 230 | id: "nt-12454-1_0" 231 | text: "which cities had less than 2,000 passengers?" 232 | original_text: "which cities had less than 2,000 passengers?" 233 | annotations: { 234 | spans: { 235 | begin_index: 27 236 | end_index: 32 237 | values: { 238 | float_value: 2000.0 239 | } 240 | } 241 | } 242 | answer: { 243 | answer_coordinates: { 244 | row_index: 5 245 | column_index: 1 246 | } 247 | answer_coordinates: { 248 | row_index: 6 249 | column_index: 1 250 | } 251 | answer_coordinates: { 252 | row_index: 7 253 | column_index: 1 254 | } 255 | answer_coordinates: { 256 | row_index: 8 257 | column_index: 1 258 | } 259 | } 260 | } 261 | -------------------------------------------------------------------------------- /tapas/utils/testdata/interaction_01.pbtxt: -------------------------------------------------------------------------------- 1 | # proto-file: third_party/py/tapas/protos/interaction.proto 2 | # proto-message: language.tapas.Interaction 3 | table: { 4 | columns: { 5 | text: "Title" 6 | } 7 | columns: { 8 | text: "Album" 9 | } 10 | columns: { 11 | text: "Chart" 12 | } 13 | rows: { 14 | cells: { 15 | text: "Mercury" 16 | } 17 | cells: { 18 | text: "Released 1983" 19 | } 20 | cells: { 21 | text: "5" 22 | } 23 | } 24 | rows: { 25 | cells: { 26 | text: "Survival" 27 | } 28 | cells: { 29 | text: "Released 1984" 30 | } 31 | cells: { 32 | text: "" 33 | } 34 | } 35 | table_id: "http://en.wikipedia.org/wiki/!Action_Pact!_1" 36 | } 37 | questions: { 38 | id: "CAPTION" 39 | original_text: "List of albums" 40 | } 41 | questions: { 42 | id: "TEXT" 43 | original_text: "Contains ..." 44 | } 45 | -------------------------------------------------------------------------------- /tapas/utils/testdata/interaction_02.pbtxt: -------------------------------------------------------------------------------- 1 | # proto-file: third_party/py/tapas/protos/interaction.proto 2 | # proto-message: language.tapas.Interaction 3 | id: "nt_1135_0" 4 | table: { 5 | columns: { 6 | text: "Team" 7 | } 8 | columns: { 9 | text: "County" 10 | } 11 | columns: { 12 | text: "Wins" 13 | } 14 | columns: { 15 | text: "Years won" 16 | } 17 | rows: { 18 | cells: { 19 | text: "Greystones" 20 | } 21 | cells: { 22 | text: "Wicklow" 23 | } 24 | cells: { 25 | text: "1" 26 | } 27 | cells: { 28 | text: "2011" 29 | } 30 | } 31 | rows: { 32 | cells: { 33 | text: "Ballymore Eustace" 34 | } 35 | cells: { 36 | text: "Kildare" 37 | } 38 | cells: { 39 | text: "1" 40 | } 41 | cells: { 42 | text: "2010" 43 | } 44 | } 45 | rows: { 46 | cells: { 47 | text: "Maynooth" 48 | } 49 | cells: { 50 | text: "Kildare" 51 | } 52 | cells: { 53 | text: "1" 54 | } 55 | cells: { 56 | text: "2009" 57 | } 58 | } 59 | rows: { 60 | cells: { 61 | text: "Ballyroan Abbey" 62 | } 63 | cells: { 64 | text: "Laois" 65 | } 66 | cells: { 67 | text: "1" 68 | } 69 | cells: { 70 | text: "2008" 71 | } 72 | } 73 | rows: { 74 | cells: { 75 | text: "Fingal Ravens" 76 | } 77 | cells: { 78 | text: "Dublin" 79 | } 80 | cells: { 81 | text: "1" 82 | } 83 | cells: { 84 | text: "2007" 85 | } 86 | } 87 | rows: { 88 | cells: { 89 | text: "Confey" 90 | } 91 | cells: { 92 | text: "Kildare" 93 | } 94 | cells: { 95 | text: "1" 96 | } 97 | cells: { 98 | text: "2006" 99 | } 100 | } 101 | rows: { 102 | cells: { 103 | text: "Crettyard" 104 | } 105 | cells: { 106 | text: "Laois" 107 | } 108 | cells: { 109 | text: "1" 110 | } 111 | cells: { 112 | text: "2005" 113 | } 114 | } 115 | rows: { 116 | cells: { 117 | text: "Wolfe Tones" 118 | } 119 | cells: { 120 | text: "Meath" 121 | } 122 | cells: { 123 | text: "1" 124 | } 125 | cells: { 126 | text: "2004" 127 | } 128 | } 129 | rows: { 130 | cells: { 131 | text: "Dundalk Gaels" 132 | } 133 | cells: { 134 | text: "Louth" 135 | } 136 | cells: { 137 | text: "1" 138 | } 139 | cells: { 140 | text: "2003" 141 | } 142 | } 143 | } 144 | questions: { 145 | original_text: "what county is the team that won in 2009 from?" 146 | answer: { 147 | answer_coordinates: { 148 | row_index: 2 149 | column_index: 1 150 | } 151 | } 152 | } 153 | questions: { 154 | original_text: "what is the teams name?" 155 | answer: { 156 | answer_coordinates: { 157 | row_index: 2 158 | column_index: 0 159 | } 160 | aggregation_function: SUM 161 | } 162 | } 163 | -------------------------------------------------------------------------------- /tapas/utils/testdata/interaction_03.pbtxt: -------------------------------------------------------------------------------- 1 | # proto-file: third_party/py/tapas/protos/interaction.proto 2 | # proto-message: language.tapas.Interaction 3 | id: "key" 4 | table: { 5 | columns: { 6 | text: "A" 7 | } 8 | columns: { 9 | text: "B" 10 | } 11 | rows: { 12 | cells: { 13 | text: "1" 14 | numeric_value { 15 | float_value: 1.0 16 | } 17 | } 18 | cells: { 19 | text: "1" 20 | numeric_value { 21 | float_value: 1.0 22 | } 23 | } 24 | } 25 | rows: { 26 | cells: { 27 | text: "1" 28 | numeric_value { 29 | float_value: 1.0 30 | } 31 | } 32 | cells: { 33 | text: "1" 34 | numeric_value { 35 | float_value: 1.0 36 | } 37 | } 38 | } 39 | } 40 | questions { 41 | original_text: "query" 42 | answer { 43 | float_value: 2.0 44 | } 45 | } 46 | -------------------------------------------------------------------------------- /tapas/utils/testdata/questions.tsv: -------------------------------------------------------------------------------- 1 | id annotator position question table_file answer_coordinates answer_text 2 | nt-14053 1 0 who were the captains? table_csv/203_386.csv ['(0, 3)', '(1, 3)'] [u'Heinrich Brodda', u'Oskar Staudinger'] 3 | nt-14053 1 1 which ones lost their u-boat on may 5? table_csv/203_386.csv ['(1, 3)', '(2, 3)'] [u'Oskar Staudinger', u'Herbert Neckel'] 4 | nt-14053 1 2 of those, which one is not oskar staudinger? table_csv/203_386.csv ['(2, 3)'] [u'Herbert Neckel'] 5 | nt-5431 0 0 what are all the countries? table_csv/204_703.csv ['(1, 1)', '(0, 1)'] [u'Canada (CAN)', u'Russia (RUS)'] 6 | -------------------------------------------------------------------------------- /tapas/utils/testdata/questions_aggregation.tsv: -------------------------------------------------------------------------------- 1 | id annotator position question table_file answer_coordinates answer_text aggregation 2 | nt-14053 1 0 who were the captains? table_csv/203_386.csv ['(0, 3)', '(1, 3)'] [u'Heinrich Brodda', u'Oskar Staudinger'] 3 | nt-14053 1 1 which ones lost their u-boat on may 5? table_csv/203_386.csv ['(1, 3)'] [u'Oskar Staudinger'] none 4 | nt-14053 1 2 of those, which one is not oskar staudinger? table_csv/203_386.csv ['(2, 3)'] [u'Herbert Neckel'] NONE 5 | nt-4436 0 0 which language has more males then females? table_csv/203_88.csv ['(2, 0)'] [u'Russian'] SUM 6 | nt-4436 0 1 which of those have less than 500 males? table_csv/203_88.csv ['(5, 0)'] [u'Romanian'] COUNT 7 | nt-4436 0 2 the ones have less than 20 females? table_csv/203_88.csv ['(5, 0)', '(7, 0)'] [u'Romanian', u'Estonian'] AVERAGE 8 | 9 | -------------------------------------------------------------------------------- /tapas/utils/testdata/questions_float_answer.tsv: -------------------------------------------------------------------------------- 1 | id annotator position question table_file answer_coordinates answer_text aggregation float_answer 2 | nt-14053 1 0 who were the captains? table_csv/203_386.csv ['(0, 3)', '(1, 3)'] [u'Heinrich Brodda', u'Oskar Staudinger'] 3 | nt-14053 1 1 which ones lost their u-boat on may 5? table_csv/203_386.csv ['(1, 3)'] [u'Oskar Staudinger'] none 4 | nt-14053 1 2 of those, which one is not oskar staudinger? table_csv/203_386.csv ['(2, 3)'] [u'Herbert Neckel'] NONE 5 | nt-4436 0 0 which language has more males then females? table_csv/203_88.csv ['(2, 0)'] [u'Russian'] SUM 150. 6 | nt-4436 0 1 which of those have less than 500 males? table_csv/203_88.csv ['(5, 0)'] [u'Romanian'] COUNT 7 7 | nt-4436 0 2 the ones have less than 20 females? table_csv/203_88.csv ['(5, 0)', '(7, 0)'] [u'Romanian', u'Estonian'] AVERAGE 7.5 8 | 9 | -------------------------------------------------------------------------------- /tapas/utils/text_index.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Helper package to find relevant paragraphs in a website using a query.""" 16 | 17 | import dataclasses 18 | import enum 19 | import heapq 20 | from typing import Callable, Generic, Sequence, Text, Tuple, TypeVar 21 | 22 | from sklearn.feature_extraction.text import TfidfVectorizer 23 | 24 | T = TypeVar('T') 25 | 26 | 27 | class Analyzer(enum.Enum): 28 | """Helper enum to determine how to compute n-grams to build the TextIndex. 29 | 30 | Option ‘char_wb’ creates character n-grams only from text inside word 31 | boundaries; n-grams at the edges of words are padded with space. 32 | """ 33 | WORD = 'word' 34 | CHAR = 'char' 35 | CHAR_WB = 'char_wb' 36 | 37 | 38 | @dataclasses.dataclass 39 | class SearchResult(Generic[T]): 40 | document: T 41 | score: float 42 | 43 | 44 | class TextIndex(Generic[T]): 45 | """A simple text index from a corpus of text using tf-idf similarity.""" 46 | 47 | def __init__(self, 48 | documents, 49 | text_getter, 50 | ngram_range = (1, 2), 51 | analyzer = Analyzer.WORD, 52 | min_df = 1, 53 | max_df = .9): 54 | """Init parameters for TextIndex. 55 | 56 | Args: 57 | documents: Corpus of documents to be indexed and retrieved. 58 | text_getter: Function to extract text from documents. 59 | ngram_range: tuple (min_n, max_n), default=(1, 2) The lower and upper 60 | boundary of the range of n-values for different n-grams to be extracted. 61 | All values of n such that min_n <= n <= max_n will be used. For example 62 | an ``ngram_range`` of ``(1, 1)`` means only unigrams, ``(1, 2)`` means 63 | unigrams and bigrams, and ``(2, 2)`` means only bigrams. 64 | analyzer: Analyzer, {‘word’, ‘char’, ‘char_wb’}. Whether the 65 | feature should be made of word or character n-grams. Option 66 | ‘char_wb’ creates character n-grams only from text inside word 67 | boundaries; n-grams at the edges of words are padded with space. 68 | min_df: float in range [0.0, 1.0] or int (default=1) When building the 69 | vocabulary ignore terms that have a document frequency strictly lower 70 | than the given threshold. This value is also called cut-off in the 71 | literature. If float, the parameter represents a proportion of 72 | documents, integer absolute counts. 73 | max_df: float in range [0.0, 1.0] or int (default=0.9) When building the 74 | vocabulary ignore terms that have a document frequency strictly higher 75 | than the given threshold (corpus-specific stop words). If float, the 76 | parameter represents a proportion of documents, integer absolute counts. 77 | """ 78 | self._vectorizer = TfidfVectorizer( 79 | ngram_range=ngram_range, 80 | min_df=min_df, 81 | max_df=max_df, 82 | analyzer=analyzer.value) 83 | 84 | self._documents = documents 85 | self._index = self._vectorizer.fit_transform( 86 | map(text_getter, self._documents)) 87 | 88 | def search(self, 89 | query, 90 | retrieval_threshold = 0.0, 91 | num_results = 5): 92 | """Retrieve matching text in the corpus. 93 | 94 | Args: 95 | query: Text used to search for candidates in the corpus.s 96 | retrieval_threshold: Filter results above this threshold. 97 | num_results: Number of results to return. 98 | 99 | Returns: 100 | Tuple of text and float score. Top `num_results` elements in the corpus. 101 | """ 102 | query_vector = self._vectorizer.transform([query]) 103 | scores = zip(self._documents, 104 | self._index.dot(query_vector.T).T.toarray()[0]) 105 | filtered_scores = (SearchResult(doc, score) 106 | for doc, score in scores 107 | if score > retrieval_threshold) 108 | return heapq.nlargest(num_results, filtered_scores, key=lambda p: p.score) 109 | -------------------------------------------------------------------------------- /tapas/utils/text_index_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tests for tapas.utils.google.text_index.""" 16 | 17 | from absl.testing import absltest 18 | from tapas.utils import text_index 19 | 20 | 21 | class TextIndexTest(absltest.TestCase): 22 | 23 | def test_num_results(self): 24 | index = text_index.TextIndex([ 25 | 'A sentence about dogs.', 26 | 'Dogs are cute amimals', 27 | 'Cats are OK as well', 28 | ], lambda x: x) 29 | results = index.search('Two dogs.', num_results=1, retrieval_threshold=0.0) 30 | self.assertLen(results, 1) 31 | self.assertEqual(results[0].document, 'A sentence about dogs.') 32 | 33 | def test_bad_results_filtered(self): 34 | index = text_index.TextIndex([ 35 | 'A sentence about dogs and cats.', 36 | 'Dogs are cute amimals. I like dogs.', 37 | 'Cats are OK as well', 38 | ], lambda x: x) 39 | results = index.search('Two dogs.', num_results=3, retrieval_threshold=0.0) 40 | self.assertLen(results, 2) 41 | first, second = results 42 | self.assertEqual(first.document, 'Dogs are cute amimals. I like dogs.') 43 | self.assertEqual(second.document, 'A sentence about dogs and cats.') 44 | self.assertGreater(first.score, second.score) 45 | 46 | 47 | if __name__ == '__main__': 48 | absltest.main() 49 | -------------------------------------------------------------------------------- /tapas/utils/wikisql_utils_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import csv 17 | import json 18 | import os 19 | import tempfile 20 | 21 | from absl import logging 22 | from absl.testing import absltest 23 | from tapas.utils import wikisql_utils 24 | import tensorflow.compat.v1 as tf 25 | 26 | 27 | def _create_inputs(input_dir, tables, examples): 28 | for name in wikisql_utils._DATASETS: 29 | filename = os.path.join(input_dir, f'{name}.tables.jsonl') 30 | with tf.io.gfile.GFile(filename, 'w') as table_file: 31 | for table in tables: 32 | table_file.write(json.dumps(table) + '\n') 33 | 34 | for name in wikisql_utils._DATASETS: 35 | filename = os.path.join(input_dir, f'{name}.jsonl') 36 | with tf.io.gfile.GFile(filename, 'w') as example_file: 37 | for example in examples: 38 | example_file.write(json.dumps(example) + '\n') 39 | 40 | 41 | class WikisqlUtilsTest(absltest.TestCase): 42 | 43 | def test_simple_test(self): 44 | with tempfile.TemporaryDirectory() as input_dir: 45 | with tempfile.TemporaryDirectory() as output_dir: 46 | _create_inputs( 47 | input_dir, 48 | tables=[{ 49 | 'id': '1-0000001-1', 50 | 'header': ['Text', 'Number'], 51 | 'types': ['text', 'real'], 52 | 'rows': [['A', 1], ['B', 2], ['C', 3]], 53 | }], 54 | examples=[ 55 | { 56 | 'question': 'What is text for 2?', 57 | 'table_id': '1-0000001-1', 58 | 'sql': { 59 | 'agg': 0, # No aggregation 60 | 'sel': 0, # Text column 61 | 'conds': [[1, 0, 2]] # Column 1 = 2 62 | }, 63 | }, 64 | { 65 | 'question': 'What is sum when number is greater than 1?', 66 | 'table_id': '1-0000001-1', 67 | 'sql': { 68 | 'agg': 4, # SUM 69 | 'sel': 1, # Number column 70 | 'conds': [[1, 1, 1]] # Column 1 > 1 71 | } 72 | } 73 | ]) 74 | 75 | wikisql_utils.convert(input_dir=input_dir, output_dir=output_dir) 76 | 77 | table_path = os.path.join( 78 | output_dir, 79 | wikisql_utils._TABLE_DIR_NAME, 80 | '1-0000001-1.csv', 81 | ) 82 | with tf.io.gfile.GFile(table_path) as table_file: 83 | actual = [dict(row) for row in csv.DictReader(table_file)] 84 | self.assertEqual([{ 85 | 'Text': 'A', 86 | 'Number': '1', 87 | }, { 88 | 'Text': 'B', 89 | 'Number': '2', 90 | }, { 91 | 'Text': 'C', 92 | 'Number': '3' 93 | }], actual) 94 | 95 | filename = os.path.join(output_dir, 'dev.tsv') 96 | with tf.io.gfile.GFile(filename) as dev_file: 97 | actual = list(csv.DictReader(dev_file, delimiter='\t')) 98 | logging.info(actual) 99 | self.assertEqual( 100 | { 101 | 'id': 'dev-0', 102 | 'annotator': '0', 103 | 'position': '0', 104 | 'question': 'What is text for 2?', 105 | 'table_file': 'table_csv/1-0000001-1.csv', 106 | 'answer_coordinates': "['(1, 0)']", 107 | 'aggregation': '', 108 | 'answer_text': "['B']", 109 | 'float_answer': '', 110 | }, dict(actual[0])) 111 | self.assertEqual( 112 | { 113 | 'id': 'dev-1', 114 | 'annotator': '0', 115 | 'position': '0', 116 | 'question': 'What is sum when number is greater than 1?', 117 | 'table_file': 'table_csv/1-0000001-1.csv', 118 | 'answer_coordinates': "['(1, 1)', '(2, 1)']", 119 | 'aggregation': 'SUM', 120 | 'answer_text': "['5.0']", 121 | 'float_answer': '5.0', 122 | }, dict(actual[1])) 123 | 124 | 125 | if __name__ == '__main__': 126 | absltest.main() 127 | -------------------------------------------------------------------------------- /tapas/utils/wtq_utils_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import csv 17 | import os 18 | import tempfile 19 | 20 | from absl import logging 21 | from absl.testing import absltest 22 | from tapas.utils import wtq_utils 23 | import tensorflow.compat.v1 as tf 24 | 25 | 26 | def _create_inputs(input_dir, tables, examples): 27 | for table in tables: 28 | filename = os.path.join(input_dir, table['table_id']) 29 | tf.io.gfile.makedirs(os.path.dirname(filename)) 30 | with tf.io.gfile.GFile(filename, 'w') as table_file: 31 | writer = csv.writer(table_file) 32 | writer.writerow(table['columns']) 33 | writer.writerows(table['rows']) 34 | 35 | for number in range(0, 6): 36 | for name in wtq_utils._get_train_test(number, wtq_utils.Version.V_10): 37 | filepath = os.path.join(input_dir, 'data', name) 38 | tf.io.gfile.makedirs(os.path.dirname(filepath)) 39 | with tf.io.gfile.GFile(filepath, 'w') as example_file: 40 | writer = csv.DictWriter( 41 | example_file, 42 | fieldnames=[ 43 | 'id', 44 | 'utterance', 45 | 'context', 46 | 'targetValue', 47 | ], 48 | delimiter='\t', 49 | ) 50 | writer.writeheader() 51 | for example in examples: 52 | writer.writerow(example) 53 | 54 | 55 | class WtqUtilsTest(absltest.TestCase): 56 | 57 | def test_simple_test(self): 58 | with tempfile.TemporaryDirectory() as input_dir: 59 | with tempfile.TemporaryDirectory() as output_dir: 60 | _create_inputs( 61 | input_dir, 62 | tables=[{ 63 | 'table_id': 'csv/203-csv/515.csv', 64 | 'columns': ['Text', 'Number'], 65 | 'rows': [['A', 1], ['B', 2], ['тапас', 3]], 66 | }], 67 | examples=[ 68 | { 69 | 'id': 'nt-2', 70 | 'utterance': 'What is text for 2?', 71 | 'context': 'csv/203-csv/515.csv', 72 | 'targetValue': 'B', 73 | }, 74 | ]) 75 | 76 | wtq_utils.convert(input_dir=input_dir, output_dir=output_dir) 77 | 78 | table_dir = os.path.join(output_dir, wtq_utils._TABLE_DIR_NAME) 79 | self.assertCountEqual( 80 | tf.io.gfile.listdir(output_dir), [ 81 | 'random-split-1-dev.tsv', 82 | 'random-split-1-train.tsv', 83 | 'random-split-2-dev.tsv', 84 | 'random-split-2-train.tsv', 85 | 'random-split-3-dev.tsv', 86 | 'random-split-3-train.tsv', 87 | 'random-split-4-dev.tsv', 88 | 'random-split-4-train.tsv', 89 | 'random-split-5-dev.tsv', 90 | 'random-split-5-train.tsv', 91 | 'table_csv', 92 | 'test.tsv', 93 | 'train.tsv', 94 | ]) 95 | self.assertEqual(tf.io.gfile.listdir(table_dir), ['203-515.csv']) 96 | 97 | table_path = os.path.join(table_dir, '203-515.csv') 98 | with tf.io.gfile.GFile(table_path) as table_file: 99 | actual = [dict(row) for row in csv.DictReader(table_file)] 100 | self.assertEqual([{ 101 | 'Text': 'a', 102 | 'Number': '1', 103 | }, { 104 | 'Text': 'b', 105 | 'Number': '2', 106 | }, { 107 | 'Text': 'тапас', 108 | 'Number': '3' 109 | }], actual) 110 | 111 | filename = os.path.join(output_dir, 'test.tsv') 112 | with tf.io.gfile.GFile(filename) as dev_file: 113 | actual = list(csv.DictReader(dev_file, delimiter='\t')) 114 | logging.info(actual) 115 | self.assertEqual( 116 | { 117 | 'id': 'nt-2', 118 | 'annotator': '0', 119 | 'position': '0', 120 | 'question': 'What is text for 2?', 121 | 'table_file': 'table_csv/203-515.csv', 122 | 'answer_coordinates': "['(-1, -1)']", 123 | 'aggregation': 'NONE', 124 | 'answer_text': "['B']", 125 | 'float_answer': '', 126 | }, dict(actual[0])) 127 | 128 | 129 | if __name__ == '__main__': 130 | absltest.main() 131 | -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | [tox] 2 | envlist = py37,py38 3 | [testenv] 4 | deps = -r{toxinidir}/requirements.txt 5 | commands = python -m unittest discover -p "*_test.py" 6 | --------------------------------------------------------------------------------