├── __init__.py ├── ExampleGeneration ├── __init__.py ├── tests │ ├── utils │ │ ├── __init__.py │ │ └── test_multiprocessing.py │ └── datajobs │ │ ├── data │ │ └── datajob_samples │ │ │ └── parse_wiki_dump_sample.gz │ │ ├── parse_wiki_dump_test.py │ │ ├── classify_table_column_types_test.py │ │ ├── format_questions_test.py │ │ └── gen_synthetic_questions_from_templates_debug.py ├── ExampleGeneration │ ├── __init__.py │ ├── datajobs │ │ ├── __init__.py │ │ └── format_questions.py │ ├── question_generators │ │ ├── tab_reas │ │ │ ├── __init__.py │ │ │ └── simple.py │ │ └── question_generator_factory.py │ ├── bash_scripts │ │ ├── parse_wiki_dump.sh │ │ └── download_reasoning_examples.sh │ ├── common │ │ ├── wikipedia_dump_utils.py │ │ ├── question_template_utils.py │ │ ├── context_utils.py │ │ └── questions_utils.py │ ├── run.py │ ├── datajob.py │ ├── configurations │ │ ├── config_reas.json │ │ └── config_tests.json │ ├── datajob_factory.py │ └── run_multiple_chunks.py ├── wikitextparser │ ├── _comment.py │ ├── __init__.py │ ├── _wikilink.py │ ├── _externallink.py │ ├── _section.py │ ├── _parser_function.py │ ├── _parameter.py │ ├── _argument.py │ └── _wikilist.py ├── externals │ └── wikiextractor │ │ └── extract.sh ├── LICENSE.txt ├── requirements.txt └── README.md ├── Training ├── ContinuousPreTraining │ ├── __init__.py │ ├── Data │ │ ├── __init__.py │ │ ├── dataset_readers │ │ │ ├── __init__.py │ │ │ ├── t5_mlm_dataset.py │ │ │ ├── synthetic_questions_multi_datasets.py │ │ │ ├── iirc_dataset.py │ │ │ └── unified_qa_dataset.py │ │ ├── datasets_wrapper.py │ │ └── data_utils.py │ ├── Training │ │ ├── __init__.py │ │ ├── trainers │ │ │ ├── __init__.py │ │ │ └── basic_trainer.py │ │ ├── trainer_callbacks │ │ │ ├── __init__.py │ │ │ ├── basic_qa_callback.py │ │ │ ├── basic_qa_callback_handler.py │ │ │ ├── multi_task_callback.py │ │ │ └── multi_task_heterogeneous_callback.py │ │ ├── callback_factory.py │ │ ├── optimizers.py │ │ └── trainer_factory.py │ ├── evaluators │ │ ├── __init__.py │ │ ├── drop_eval.py │ │ ├── drop_list_eval.py │ │ ├── span_evaluator.py │ │ ├── basic_qa_evaluator.py │ │ └── iirc_eval.py │ ├── predictors │ │ ├── __init__.py │ │ ├── generative_predictor.py │ │ ├── list_generative_predictor.py │ │ ├── boolean_predictor.py │ │ └── span_predictor.py │ ├── samplers │ │ ├── __init__.py │ │ ├── random_sampler.py │ │ ├── dataset_uniform_sampler.py │ │ ├── lambda_mlm_sampler.py │ │ ├── error_distribution_heterogeneous_sampler.py │ │ └── adaptive_error_heterogeneous_sampler.py │ ├── scripts │ │ ├── __init__.py │ │ ├── preprocess_drop.py │ │ ├── preprocess_mmqa_for_question_classification.py │ │ └── preprocess_mmqa.py │ ├── basic_factory.py │ ├── configurations │ │ ├── drop_config.json │ │ ├── iirc_config.json │ │ ├── iirc_retrieval_config.json │ │ ├── mmqa_config.json │ │ ├── mmqa_para_classifier_config.json │ │ ├── mmqa_question_classifier_config.json │ │ ├── mmqa_retrieval_config.json │ │ ├── PReasM_uniform_config.json │ │ ├── PReasM_errors_config.json │ │ └── PReasM_momentum_config.json │ └── Common │ │ ├── transfomer_utils.py │ │ ├── config.py │ │ └── evaluation_utils.py ├── bash_scripts │ ├── setup_datasets.sh │ ├── download_preasm.sh │ └── download_datasets.sh ├── requirements.txt └── README.md ├── AUTHORS.rst ├── README.md └── .gitignore /__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ExampleGeneration/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ExampleGeneration/tests/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /Training/ContinuousPreTraining/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ExampleGeneration/ExampleGeneration/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /Training/ContinuousPreTraining/Data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /Training/ContinuousPreTraining/Training/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /Training/ContinuousPreTraining/evaluators/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /Training/ContinuousPreTraining/predictors/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /Training/ContinuousPreTraining/samplers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /Training/ContinuousPreTraining/scripts/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ExampleGeneration/ExampleGeneration/datajobs/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /Training/ContinuousPreTraining/Data/dataset_readers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /Training/ContinuousPreTraining/Training/trainers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /Training/ContinuousPreTraining/Training/trainer_callbacks/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ExampleGeneration/ExampleGeneration/question_generators/tab_reas/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /AUTHORS.rst: -------------------------------------------------------------------------------- 1 | ============ 2 | Contributors 3 | ============ 4 | 5 | * Ori Yoran 6 | -------------------------------------------------------------------------------- /ExampleGeneration/tests/datajobs/data/datajob_samples/parse_wiki_dump_sample.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oriyor/turning_tables/HEAD/ExampleGeneration/tests/datajobs/data/datajob_samples/parse_wiki_dump_sample.gz -------------------------------------------------------------------------------- /Training/bash_scripts/setup_datasets.sh: -------------------------------------------------------------------------------- 1 | echo "Downloading datasets..." 2 | ./bash_scripts/download_datasets.sh 3 | 4 | echo "Preprocessing drop..." 5 | python ./ContinuousPreTraining/scripts/preprocess_drop.py 6 | 7 | echo "Preprocessing mmqa..." 8 | python ./ContinuousPreTraining/scripts/preprocess_mmqa.py -------------------------------------------------------------------------------- /ExampleGeneration/wikitextparser/_comment.py: -------------------------------------------------------------------------------- 1 | """Define the Comment class.""" 2 | 3 | 4 | from ._wikitext import SubWikiText 5 | 6 | 7 | class Comment(SubWikiText): 8 | 9 | """Create a new object.""" 10 | 11 | @property 12 | def contents(self) -> str: 13 | """Return contents of this comment.""" 14 | return self.string[4:-3] 15 | -------------------------------------------------------------------------------- /Training/ContinuousPreTraining/predictors/generative_predictor.py: -------------------------------------------------------------------------------- 1 | def GenerativePredictor(tokenizer, model, input_ids, attention_mask, labels): 2 | """ 3 | get a generative prediction 4 | """ 5 | generated_prediction = tokenizer.batch_decode(model.generate(input_ids), 6 | skip_special_tokens=True) 7 | return generated_prediction 8 | 9 | -------------------------------------------------------------------------------- /Training/ContinuousPreTraining/evaluators/drop_eval.py: -------------------------------------------------------------------------------- 1 | from ContinuousPreTraining.Common.evaluation_utils import get_drop_metrics 2 | from ContinuousPreTraining.evaluators.basic_qa_evaluator import BasicQAEvaluator 3 | 4 | 5 | class DropEval(BasicQAEvaluator): 6 | 7 | def evaluate_single_example_method(self, pred): 8 | em_score, f1_score = get_drop_metrics(pred['prediction'], pred['gold']) 9 | return em_score, f1_score 10 | -------------------------------------------------------------------------------- /Training/ContinuousPreTraining/predictors/list_generative_predictor.py: -------------------------------------------------------------------------------- 1 | def ListGenerativePredictor(tokenizer, model, input_ids, attention_mask, labels): 2 | """ 3 | get a list generative prediction with # as the separator 4 | """ 5 | generated_prediction = tokenizer.batch_decode(model.generate(input_ids), 6 | skip_special_tokens=True) 7 | return [pred.split('#') 8 | for pred in generated_prediction] 9 | 10 | -------------------------------------------------------------------------------- /Training/ContinuousPreTraining/samplers/random_sampler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class RandomSampler: 5 | """ 6 | indices for a random sample over the iterables 7 | """ 8 | def sample(self, task_iterables_list): 9 | task_choice_list = [] 10 | for i, task_iterable in enumerate(task_iterables_list.values()): 11 | task_choice_list += [i] * task_iterable['num_batches'] 12 | task_choice_list = np.array(task_choice_list) 13 | np.random.shuffle(task_choice_list) 14 | 15 | return task_choice_list 16 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Turning Tables 2 | 3 | 4 | This repository contains the code accompanying the paper: 5 | 6 | **"Turning Tables: Generating Examples from Semi-structured Tables for Endowing Language Models with Reasoning Skills"** [[Preprint]](https://arxiv.org/abs/2107.07261) 7 | 8 | ### Structure 9 | The repository contains: 10 | * Implementation and instructions for generating reasoning examples from semi-structured tables in the `ExampleGeneration` directory. 11 | * Implementation and instructions for training and fine-tuning PReasM in the `Training` directory. 12 | * The generated examples and models described in the paper are publicly avaialble. See `ExampleGeneration` to download the generated examples and `Training` to download the PReasM models. 13 | -------------------------------------------------------------------------------- /ExampleGeneration/tests/utils/test_multiprocessing.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=no-self-use,invalid-name 2 | import pytest, os 3 | from ExampleGeneration.common.multi_process_streaming import multi_process_lst, multi_process_data_stream 4 | from ExampleGeneration.common.file_utils import CACHE_DIRECTORY 5 | 6 | class TestMultiProcessing: 7 | 8 | @staticmethod 9 | def apply_on_chunk(chunk): 10 | return [item*3 for item in chunk] 11 | 12 | @staticmethod 13 | def apply_on_lines(lines): 14 | return [int(line)*3 for line in lines] 15 | 16 | def test_multi_process_lst(self): 17 | lst = [1]*100000 18 | res = multi_process_lst(lst, self.apply_on_chunk, chunk_size=1000, n_processes=5) 19 | assert res == self.apply_on_chunk(lst) -------------------------------------------------------------------------------- /Training/ContinuousPreTraining/evaluators/drop_list_eval.py: -------------------------------------------------------------------------------- 1 | from ContinuousPreTraining.Common.evaluation_utils import get_drop_metrics 2 | from ContinuousPreTraining.evaluators.basic_qa_evaluator import BasicQAEvaluator 3 | 4 | 5 | class DropListEval(BasicQAEvaluator): 6 | 7 | def evaluate_single_example_method(self, pred): 8 | """ 9 | get prediction with max from lists 10 | """ 11 | max_em_score = 0.0 12 | max_f1_score = 0.0 13 | 14 | for gold in pred['gold']: 15 | em_score, f1_score = get_drop_metrics(pred['prediction'], gold) 16 | max_em_score = max(max_em_score, em_score) 17 | max_f1_score = max(max_f1_score, f1_score) 18 | 19 | return max_em_score, max_f1_score 20 | -------------------------------------------------------------------------------- /Training/ContinuousPreTraining/samplers/dataset_uniform_sampler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class DatasetUniformSampler: 5 | """ 6 | indices for a uniform sample based on the iterable size between iterable 7 | """ 8 | 9 | def sample(self, task_iterables_list): 10 | """ 11 | sample indices 12 | """ 13 | max_num_batches_for_task = max([task_iterable['num_batches'] 14 | for task_iterable in task_iterables_list.values()]) 15 | task_choice_list = [] 16 | for i, task_iterable in enumerate(task_iterables_list.values()): 17 | task_choice_list += [i] * max_num_batches_for_task 18 | task_choice_list = np.array(task_choice_list) 19 | np.random.shuffle(task_choice_list) 20 | 21 | return task_choice_list 22 | 23 | -------------------------------------------------------------------------------- /Training/ContinuousPreTraining/Training/trainer_callbacks/basic_qa_callback.py: -------------------------------------------------------------------------------- 1 | from transformers import TrainerCallback, TrainingArguments, TrainerState, TrainerControl 2 | 3 | 4 | class BasicQaCallback(TrainerCallback): 5 | """ 6 | a basic qa callback class, that implements on_save_predictions call 7 | """ 8 | 9 | def on_predictions_save(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): 10 | """ 11 | Event called at the end of the initialization of the :class:`~transformers.Trainer`. 12 | """ 13 | pass 14 | 15 | def on_batch_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): 16 | """ 17 | Event called at the beginning of each training batch, as opposed to training step (to running multi-task gas) 18 | """ 19 | pass -------------------------------------------------------------------------------- /ExampleGeneration/externals/wikiextractor/extract.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | # NOTES 4 | # 5 | # - Must expand templates to avoid a large loss of content. 6 | # - Text will not (redundantly) contain the title string. 7 | # - Keep sections. Section title will be marked by "Section::::". 8 | # - Keep lists. List bullets will be marked by "BULLET::::". 9 | # - Keep tables. They're mostly garbage but can be removed later (remove "^!*"). 10 | # - Remove disambiguation pages. Right now there is no use for them. 11 | 12 | INPUT=$1 13 | PROCESSES=$2 14 | TEMPLATES=$3 15 | OUTPUT=$4 16 | 17 | python WikiExtractor.py $INPUT \ 18 | --json \ 19 | --processes $PROCESSES \ 20 | --templates $TEMPLATES \ 21 | --output $OUTPUT \ 22 | --bytes 1M \ 23 | --compress \ 24 | --links \ 25 | --sections \ 26 | --lists \ 27 | --keep_tables \ 28 | --min_text_length 0 \ 29 | --filter_disambig_pages 30 | -------------------------------------------------------------------------------- /Training/ContinuousPreTraining/evaluators/span_evaluator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class SpanEvaluator: 5 | """ 6 | a class to evaluate mlm task 7 | """ 8 | 9 | def evaluate(self, 10 | predictions, 11 | output_predictions_path, 12 | dataset_name): 13 | 14 | # create a vector for all the predictions 15 | preds_all = np.array([]) 16 | for p in predictions: 17 | preds_all = np.append(preds_all, p.correct_predictions.flatten(), axis=0) 18 | 19 | # calculate the different matrices 20 | span_precision = np.average([p.precision for p in predictions]) 21 | span_f1 = np.average([p.f1 for p in predictions]) 22 | token_em = np.average(preds_all) 23 | 24 | result_dict = {f'{dataset_name}_span_precision': span_precision, 25 | f'{dataset_name}_span_f1': span_f1, 26 | f'{dataset_name}_token_em': token_em 27 | } 28 | return result_dict 29 | -------------------------------------------------------------------------------- /Training/bash_scripts/download_preasm.sh: -------------------------------------------------------------------------------- 1 | echo "Downloading PReasM to CheckpointsRestored/PReasM-$Sampler-$Size ..." 2 | 3 | mkdir -p CheckpointsRestored 4 | cd CheckpointsRestored 5 | 6 | mkdir -p PReasM-$Sampler-$Size 7 | cd PReasM-$Sampler-$Size 8 | 9 | echo "Downloading config.json..." 10 | wget https://tabreas.s3.us-west-2.amazonaws.com/PReasM/PReasM-$Sampler-$Size/config.json 11 | 12 | echo "Downloading optimizer.pt..." 13 | wget https://tabreas.s3.us-west-2.amazonaws.com/PReasM/PReasM-$Sampler-$Size/optimizer.pt 14 | 15 | echo "Downloading pytorch_model.bin..." 16 | wget https://tabreas.s3.us-west-2.amazonaws.com/PReasM/PReasM-$Sampler-$Size/pytorch_model.bin 17 | 18 | echo "Downloading scheduler.pt..." 19 | wget https://tabreas.s3.us-west-2.amazonaws.com/PReasM/PReasM-$Sampler-$Size/scheduler.pt 20 | 21 | echo "Downloading trainer_state.json..." 22 | wget https://tabreas.s3.us-west-2.amazonaws.com/PReasM/PReasM-$Sampler-$Size/trainer_state.json 23 | 24 | echo "Downloading training_args.bin..." 25 | wget https://tabreas.s3.us-west-2.amazonaws.com/PReasM/PReasM-$Sampler-$Size/training_args.bin 26 | 27 | cd ../.. 28 | -------------------------------------------------------------------------------- /ExampleGeneration/LICENSE.txt: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in all 11 | copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | SOFTWARE. 20 | -------------------------------------------------------------------------------- /ExampleGeneration/ExampleGeneration/bash_scripts/parse_wiki_dump.sh: -------------------------------------------------------------------------------- 1 | cd WIKI_DUMP_DIR 2 | 3 | yourfilenames=`ls ./*.bz2` 4 | echo $yourfilenames 5 | cd .. 6 | 7 | counter=0 8 | for eachfile in $yourfilenames 9 | do 10 | echo "starting, first unzipping" 11 | echo $eachfile 12 | 13 | unziped_file_name=${eachfile%."bz2"} 14 | unziped_file_name_parsed=${unziped_file_name:2} 15 | 16 | cd wiki_dump_1 17 | bzip2 -dk $eachfile 18 | cd .. 19 | 20 | 21 | echo "finished unzipping" 22 | echo $unziped_file_name_parsed 23 | 24 | unziped_file="wiki_dump_1/${unziped_file_name_parsed}" 25 | output_file="OUTPUT_DIR${unziped_file_name_parsed}_parsed.gz" 26 | 27 | echo $unziped_file 28 | echo $output_file 29 | 30 | chmod 777 $unziped_file 31 | 32 | echo "running script for file" 33 | 34 | python ExampleGeneration/ExampleGeneration/run.py -c ParseWikiDump -in $unziped_file -out $output_file 35 | 36 | echo "removing file" 37 | 38 | rm $unziped_file 39 | 40 | cp "wiki_dump_1/${unziped_file_name_parsed}.bz2" "finished_script/${unziped_file_name_parsed}.bz2" 41 | 42 | echo "finished file" 43 | 44 | done -------------------------------------------------------------------------------- /Training/ContinuousPreTraining/predictors/boolean_predictor.py: -------------------------------------------------------------------------------- 1 | def BooleanPredictor(tokenizer, model, input_ids, attention_mask, labels): 2 | """ 3 | get a boolean prediction, whether the token with yes or no gets a higher probability 4 | """ 5 | # get generated outputs 6 | generated_outputs = model.generate(input_ids, 7 | return_dict_in_generate=True, 8 | output_scores=True, 9 | output_hidden_states=True) 10 | 11 | # calculate yes token and no token for tokenizer 12 | yes_token = tokenizer.convert_tokens_to_ids(tokenizer.tokenize('yes'))[0] 13 | no_token = tokenizer.convert_tokens_to_ids(tokenizer.tokenize('no'))[0] 14 | 15 | # for every question check whether the probability for yes token is higher than no 16 | boolean_predictions = ['yes' if prediction_score[yes_token] >= prediction_score[no_token] 17 | else 'no' 18 | for prediction_score in generated_outputs.scores[0] 19 | ] 20 | return boolean_predictions 21 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Temporary and binary files 2 | *~ 3 | *.py[cod] 4 | *.so 5 | *.cfg 6 | !.isort.cfg 7 | !setup.cfg 8 | *.orig 9 | *.log 10 | *.pot 11 | __pycache__/* 12 | .cache/* 13 | .*.swp 14 | */.ipynb_checkpoints/* 15 | .DS_Store 16 | ExampleGeneration/venv/ 17 | ExampleGeneration/models_cache/ 18 | ExampleGeneration/data/ 19 | ExampleGeneration/reasoning_examples 20 | 21 | # Data models and analysis 22 | Training/ContinuousPreTraining/Data/raw_wikipedia/ 23 | Training/ContinuousPreTraining/Data/MLM/ 24 | Training/ContinuousPreTraining/Data/MLM/mlm_data.jsonl 25 | Training/ContinuousPreTraining/Runs/ 26 | 27 | # Project files 28 | .ropeproject 29 | .project 30 | .pydevproject 31 | .settings 32 | .idea 33 | .idea/ 34 | .idea/workspace.xml 35 | ExampleGeneration/.idea 36 | tags 37 | 38 | # Package files 39 | *.egg 40 | *.eggs/ 41 | .installed.cfg 42 | *.egg-info 43 | 44 | # Unittest and coverage 45 | htmlcov/* 46 | .coverage 47 | .tox 48 | junit.xml 49 | coverage.xml 50 | .pytest_cache/ 51 | 52 | # Build and docs folder/files 53 | build/* 54 | dist/* 55 | sdist/* 56 | docs/api/* 57 | docs/_rst/* 58 | docs/_build/* 59 | cover/* 60 | MANIFEST 61 | 62 | # Per-project virtualenvs 63 | .venv*/ 64 | -------------------------------------------------------------------------------- /Training/ContinuousPreTraining/Data/datasets_wrapper.py: -------------------------------------------------------------------------------- 1 | class DatasetsWrapper: 2 | """ 3 | class to wrap a list of datasets, that share the same config 4 | this will include a list of datasets, the sampler between the datasets, the args for the dataloader, etc.. 5 | """ 6 | 7 | def __init__(self, 8 | datasets, 9 | dataloader_args, 10 | datasets_names = None, 11 | sampler='random', 12 | predictor=None, 13 | eval_method=None, 14 | save_error_distribution=None, 15 | is_train_task=True 16 | ): 17 | 18 | self.datasets = datasets 19 | self.datasets_names = datasets_names 20 | self.sampler = sampler 21 | self.dataloader_args = dataloader_args 22 | self.single_dataset = True if len(datasets) == 1 else False 23 | self.num_examples = sum([len(dataset) for dataset in self.datasets]) 24 | self.predictor = predictor 25 | self.eval_method = eval_method 26 | self.save_error_distribution = save_error_distribution 27 | self.is_train_task = is_train_task 28 | 29 | -------------------------------------------------------------------------------- /Training/ContinuousPreTraining/Training/callback_factory.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from ContinuousPreTraining.Common.file_utils import upper_to_lower_notation_name, find_module 4 | 5 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name 6 | 7 | 8 | class CallbackFactory: 9 | """ 10 | factory to get the trainer from the config 11 | """ 12 | 13 | def __init__(self): 14 | pass 15 | 16 | def find_callback(self, callback_name): 17 | 18 | callback_name_lower = upper_to_lower_notation_name(callback_name) 19 | module_name = find_module(os.path.dirname(os.path.abspath(__file__)), callback_name_lower) 20 | try: 21 | mod = __import__('ContinuousPreTraining.Training.' + module_name, 22 | fromlist=[callback_name]) 23 | except: 24 | logger.error(module_name + ' module not found!!') 25 | assert (ValueError('qgen_name not found!')) 26 | 27 | return getattr(mod, callback_name) 28 | 29 | def get_callback(self, callback_name): 30 | """ 31 | factory method to get a trainer 32 | """ 33 | callback = self.find_callback(callback_name) 34 | 35 | # init trainer and return 36 | return callback 37 | -------------------------------------------------------------------------------- /Training/ContinuousPreTraining/basic_factory.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | 4 | from ContinuousPreTraining.Common.file_utils import upper_to_lower_notation_name, find_module 5 | 6 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name 7 | 8 | 9 | class BasicFactory: 10 | # todo move all other factories here 11 | """ 12 | factory to get the trainer from the config 13 | """ 14 | 15 | def __init__(self): 16 | pass 17 | 18 | def find_object(self, object_name): 19 | 20 | object_name_lower = upper_to_lower_notation_name(object_name) 21 | module_name = find_module(os.path.dirname(os.path.abspath(__file__)), object_name_lower) 22 | try: 23 | mod = __import__(f'ContinuousPreTraining.' + module_name, 24 | fromlist=[object_name]) 25 | except: 26 | logger.error(module_name + ' module not found!!') 27 | assert (ValueError('object not found!')) 28 | 29 | return getattr(mod, object_name) 30 | 31 | def get_object(self, object_name): 32 | #todo rename 33 | """ 34 | factory method to get a predictor 35 | """ 36 | obj = self.find_object(object_name) 37 | 38 | # init predictor and return 39 | return obj -------------------------------------------------------------------------------- /Training/ContinuousPreTraining/samplers/lambda_mlm_sampler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class LambdaMlmSampler: 5 | """ 6 | task at index 0 is seen lambda percent of the time, the other task 1-lambda 7 | """ 8 | 9 | def __init__(self, 10 | lmbda 11 | ): 12 | """ 13 | init the dataloaders for each task 14 | get the sampling indices between the tasks 15 | """ 16 | self.lmbda = float(lmbda) 17 | 18 | def sample(self, task_iterables_list): 19 | """ 20 | sample indices 21 | """ 22 | # we must have exactly two tasks 23 | assert len(task_iterables_list) == 2 24 | 25 | # the name of the first task must be WikiTrain 26 | assert [k for k in task_iterables_list][0] == 'WikiTrain' 27 | 28 | # get the number of batches for the second task, add lambda batches for mlm 29 | num_batches_for_second_task = list(task_iterables_list.values())[1]['num_batches'] 30 | alpha = (1-self.lmbda)/self.lmbda 31 | task_choice_list = [0]*int(num_batches_for_second_task*alpha) + [1]*num_batches_for_second_task 32 | task_choice_list = np.array(task_choice_list) 33 | np.random.shuffle(task_choice_list) 34 | 35 | return task_choice_list 36 | 37 | -------------------------------------------------------------------------------- /Training/ContinuousPreTraining/Training/optimizers.py: -------------------------------------------------------------------------------- 1 | from transformers import Adafactor, AdamW, get_linear_schedule_with_warmup 2 | 3 | 4 | def get_optimizer(optimizer_config, model, args_lr): 5 | """ 6 | helper method to get an optimizer 7 | """ 8 | lr = optimizer_config['lr'] if args_lr is None else args_lr 9 | 10 | # traverse the possible optimizers 11 | if optimizer_config['type'] == 'AdaFactor': 12 | return Adafactor( # https://discuss.huggingface.co/t/t5-finetuning-tips/684/3 13 | model.parameters(), 14 | lr=lr, 15 | eps=(1e-30, 1e-3), 16 | clip_threshold=1.0, 17 | decay_rate=-0.8, 18 | beta1=None, 19 | weight_decay=0.0, 20 | relative_step=False, 21 | scale_parameter=False, 22 | warmup_init=False 23 | ) 24 | 25 | if optimizer_config['type'] == 'AdamW': 26 | return AdamW(model.parameters(), lr=lr) 27 | 28 | 29 | def get_scheduler(optimizer, scheduler_config): 30 | """ 31 | helper method to get a scheduler 32 | """ 33 | return get_linear_schedule_with_warmup( 34 | optimizer, 35 | num_warmup_steps=scheduler_config['num_warmup_steps'], 36 | num_training_steps=scheduler_config['num_training_steps'] 37 | ) 38 | -------------------------------------------------------------------------------- /Training/requirements.txt: -------------------------------------------------------------------------------- 1 | blis==0.4.1 2 | catalogue==1.0.0 3 | certifi==2020.6.20 4 | chardet==3.0.4 5 | click==7.1.2 6 | configparser==5.0.1 7 | cycler==0.10.0 8 | cymem==2.0.3 9 | dataclasses 10 | datasets==1.1.2 11 | dill==0.3.3 12 | docker-pycreds==0.4.0 13 | docopt==0.6.2 14 | filelock==3.0.12 15 | future==0.18.2 16 | gitdb==4.0.5 17 | GitPython==3.1.11 18 | idna==2.10 19 | importlib-metadata==2.0.0 20 | joblib==0.17.0 21 | jsonlines==1.2.0 22 | kiwisolver==1.2.0 23 | matplotlib==3.3.2 24 | multiprocess==0.70.11.1 25 | murmurhash==1.0.2 26 | numpy==1.19.2 27 | packaging==20.4 28 | pandas==1.1.3 29 | pathtools==0.1.2 30 | Pillow==8.0.1 31 | pipreqs==0.4.10 32 | plac==1.1.3 33 | preshed==3.0.2 34 | promise==2.3 35 | protobuf==3.13.0 36 | psutil==5.7.3 37 | pyarrow==2.0.0 38 | pyparsing==2.4.7 39 | python-dateutil==2.8.1 40 | pytz==2020.1 41 | PyYAML==5.3.1 42 | regex==2020.10.15 43 | requests==2.24.0 44 | sacremoses==0.0.43 45 | scipy==1.5.3 46 | seaborn==0.11.0 47 | sentencepiece==0.1.91 48 | sentry-sdk==0.19.1 49 | shortuuid==1.0.1 50 | six==1.15.0 51 | smmap==3.0.4 52 | spacy==2.3.2 53 | srsly==1.0.2 54 | subprocess32==3.5.4 55 | thinc==7.4.1 56 | tokenizers==0.10.1 57 | transformers==4.3.3 58 | ujson==4.0.1 59 | urllib3==1.25.10 60 | wasabi==0.8.0 61 | watchdog==0.10.3 62 | xxhash==2.0.0 63 | yarg==0.1.9 64 | zipp==3.3.1 65 | json-cfg==0.4.2 66 | # Optionals 67 | # for s3 support 68 | boto3 69 | # for wandb support 70 | wandb==0.10.8 71 | word2number==1.1 -------------------------------------------------------------------------------- /ExampleGeneration/wikitextparser/__init__.py: -------------------------------------------------------------------------------- 1 | """Initialize the wikitextparser.""" 2 | 3 | # Scheme: [N!]N(.N)*[{a|b|rc}N][.postN][.devN] 4 | __version__ = '0.24.4.dev0' 5 | 6 | from ._parameter import Parameter 7 | from ._argument import Argument 8 | from ._externallink import ExternalLink 9 | from ._wikilink import WikiLink 10 | from ._section import Section 11 | from ._comment import Comment 12 | from . import _wikitext 13 | from ._table import Table 14 | from ._template import Template 15 | from ._parser_function import ParserFunction 16 | from ._tag import Tag 17 | from ._tag import START_TAG_PATTERN as _START_TAG_PATTERN 18 | from ._tag import END_TAG_PATTERN as _END_TAG_PATTERN 19 | from ._tag import START_TAG_FINDITER as _START_TAG_FINDITER 20 | from ._wikilist import WikiList 21 | from ._wikilist import LIST_PATTERN_FORMAT as _LIST_PATTERN_FORMAT 22 | 23 | 24 | _wikitext.ExternalLink = ExternalLink 25 | _wikitext.WikiLink = WikiLink 26 | _wikitext.Template = Template 27 | _wikitext.Comment = Comment 28 | _wikitext.ParserFunction = ParserFunction 29 | _wikitext.Parameter = Parameter 30 | _wikitext.Table = Table 31 | _wikitext.Section = Section 32 | _wikitext.WikiList = WikiList 33 | _wikitext.LIST_PATTERN_FORMAT = _LIST_PATTERN_FORMAT 34 | _wikitext.Tag = _wikitext.ExtensionTag = Tag 35 | _wikitext.START_TAG_PATTERN = _START_TAG_PATTERN 36 | _wikitext.END_TAG_PATTERN = _END_TAG_PATTERN 37 | _wikitext.START_TAG_FINDITER = _START_TAG_FINDITER 38 | 39 | WikiText = _wikitext.WikiText 40 | parse = WikiText 41 | -------------------------------------------------------------------------------- /ExampleGeneration/ExampleGeneration/common/wikipedia_dump_utils.py: -------------------------------------------------------------------------------- 1 | import bz2 2 | from xml.etree import ElementTree 3 | 4 | 5 | def iterate_articles(file_path): 6 | with bz2.open(file_path, 'rb') as reader: 7 | content = None 8 | for line in reader: 9 | line = line.decode('utf-8').strip() 10 | if line == '': 11 | content = [line] 12 | elif line == '': 13 | content.append(line) 14 | content = '\n'.join(content) 15 | tree = ElementTree.fromstring(content) 16 | content = None 17 | ns_elem = tree.find('ns') 18 | if ns_elem is None: 19 | continue 20 | if ns_elem.text.strip() != '0': 21 | continue 22 | title_elem = tree.find('title') 23 | if title_elem is None: 24 | continue 25 | title = title_elem.text 26 | id_elem = tree.find('id') 27 | if id_elem is None: 28 | continue 29 | page_id = id_elem.text 30 | text_elem = tree.find('revision/text') 31 | if text_elem is None: 32 | continue 33 | text = text_elem.text 34 | if text is None: 35 | continue 36 | yield title, page_id, text 37 | else: 38 | if type(content) is list: 39 | content.append(line) 40 | -------------------------------------------------------------------------------- /ExampleGeneration/ExampleGeneration/common/question_template_utils.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | def substitute_named_templates(t, named_templates): 4 | for k, v in t.items(): 5 | if isinstance(v, dict): 6 | if 'copy_from' in v: 7 | if v['copy_from'] in named_templates: 8 | for k, tv in named_templates[v['copy_from']].items(): 9 | if k not in v: 10 | v[k] = tv 11 | else: 12 | assert (ValueError) 13 | else: 14 | substitute_named_templates(v, named_templates) 15 | 16 | elif isinstance(v, list): 17 | for item in v: 18 | if isinstance(item, dict): 19 | if 'copy_from' in item: 20 | if item['copy_from'] in named_templates: 21 | for k, v in named_templates[item['copy_from']].items(): 22 | if k not in item: 23 | item[k] = v 24 | else: 25 | assert (ValueError) 26 | else: 27 | substitute_named_templates(item, named_templates) 28 | 29 | 30 | 31 | def process_question_templates(templates): 32 | if templates[0]['name'] == "NamedTemplates": 33 | named_templates = templates[0] 34 | templates = templates[1:] 35 | 36 | for template in templates: 37 | substitute_named_templates(template, named_templates) 38 | 39 | return templates -------------------------------------------------------------------------------- /Training/bash_scripts/download_datasets.sh: -------------------------------------------------------------------------------- 1 | # download drop 2 | cd ./ContinuousPreTraining/Data 3 | mkdir drop 4 | cd ./drop 5 | echo "downloading DROP ..." 6 | wget https://s3-us-west-2.amazonaws.com/allennlp/datasets/drop/drop_dataset.zip 7 | unzip drop_dataset.zip 8 | mv ./drop_dataset/*.json . 9 | mv drop_dataset.zip drop_dataset 10 | rm -rf drop_dataset 11 | cd ../../.. 12 | 13 | # download IIRC 14 | cd ./ContinuousPreTraining/Data 15 | echo "downloading IIRC ..." 16 | wget http://jamesf-incomplete-qa.s3.amazonaws.com/iirc.tar.gz 17 | tar -xzf iirc.tar.gz 18 | rm iirc.tar.gz 19 | 20 | echo "downloading IIRC dev in drop format for eval script (see https://github.com/jferguson144/IIRC-baseline)..." 21 | cd ./iirc 22 | wget https://tabreas.s3.us-west-2.amazonaws.com/iirc/iirc_dev_drop_format.json 23 | cd ../../.. 24 | 25 | # download MMQA 26 | cd ./ContinuousPreTraining/Data 27 | echo "downloading MMQA ..." 28 | mkdir mmqa 29 | cd ./mmqa 30 | wget https://github.com/allenai/multimodalqa/blob/master/dataset/MMQA_train.jsonl.gz?raw=true 31 | wget https://github.com/allenai/multimodalqa/blob/master/dataset/MMQA_dev.jsonl.gz?raw=true 32 | wget https://github.com/allenai/multimodalqa/blob/master/dataset/MMQA_test.jsonl.gz?raw=true 33 | wget https://github.com/allenai/multimodalqa/blob/master/dataset/MMQA_texts.jsonl.gz?raw=true 34 | wget https://github.com/allenai/multimodalqa/blob/master/dataset/MMQA_tables.jsonl.gz?raw=true 35 | 36 | echo "downloading preprocessed contexts after pipeline retrieval..." 37 | wget https://tabreas.s3.us-west-2.amazonaws.com/mmqa/parsed_mmqa_dev_retrieval.json 38 | 39 | cd ../../.. -------------------------------------------------------------------------------- /Training/ContinuousPreTraining/evaluators/basic_qa_evaluator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import logging 4 | 5 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name 6 | logging.basicConfig(level=logging.INFO) 7 | 8 | 9 | class BasicQAEvaluator: 10 | """ 11 | A class with an evaluation loop for QA evaluation 12 | """ 13 | def evaluate_single_example_method(self): 14 | """ 15 | method to evaluate a single example 16 | """ 17 | raise NotImplementedError 18 | 19 | def evaluate(self, 20 | predictions, 21 | output_predictions_path, 22 | dataset_name): 23 | """ 24 | evaluate a QA dataset 25 | """ 26 | # todo all this should be in the evaluator, it should get the predictions and return a dict 27 | # calculate em and f1 for every prediction 28 | for pred in predictions: 29 | em_score, f1_score = self.evaluate_single_example_method(pred) 30 | pred['em'] = em_score 31 | pred['f1'] = f1_score 32 | 33 | logger.info(f'saving predictions to {output_predictions_path}') 34 | 35 | # save csv with eval 36 | predictions_df = pd.DataFrame(predictions) 37 | predictions_df.to_csv(output_predictions_path) 38 | 39 | em = np.average([p['em'] for p in predictions]) 40 | f1 = np.average([p['f1'] for p in predictions]) 41 | 42 | result_dict = {f'{dataset_name}_em': em, 43 | f'{dataset_name}_f1': f1} 44 | 45 | return result_dict 46 | 47 | -------------------------------------------------------------------------------- /ExampleGeneration/tests/datajobs/parse_wiki_dump_test.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=no-self-use,invalid-name 2 | import pytest 3 | import argparse 4 | from ExampleGeneration.datajobs.parse_wiki_dump import ParseWikiDumpDataJob 5 | import os 6 | 7 | class TestParseWikiDump: 8 | def test_run_datajob(self): 9 | parse = argparse.ArgumentParser("") 10 | parse.add_argument("-c", "--datajob_name", type=str, help="The name of the datajob class and config to use") 11 | parse.add_argument("-o", "--operation", type=str, help="The task stage to run") 12 | parse.add_argument("-out", "--output_file", type=str, help="") 13 | parse.add_argument("-config", "--config_file_name", type=str, help="DataJobs config file name", default="config_tests.json") 14 | parse.add_argument("-wd", "--working_directory", type=str, help="dir of input file, can be s3 path") 15 | 16 | # In the test no output file will be produced, change -out to create an output 17 | args = parse.parse_args(["-c", "Example","-o", "build_datajob", 18 | "-out","parse_wiki_dump_sample.jsonl", 19 | "-wd", "data"]) 20 | 21 | datajob = ParseWikiDumpDataJob('ParseWikiDump',args) 22 | 23 | # reducing data size to a sample: 24 | datajob._config['max_number_of_examples'] = 100 25 | 26 | # Seems this is the default in the config anyway .. 27 | datajob.output_path = os.path.join("data", "datajob_samples", 28 | "parse_wiki_dump_sample.gz") 29 | 30 | datajob.run_datajob(args) 31 | -------------------------------------------------------------------------------- /Training/ContinuousPreTraining/Training/trainer_callbacks/basic_qa_callback_handler.py: -------------------------------------------------------------------------------- 1 | from transformers import TrainingArguments 2 | from transformers.trainer_callback import CallbackHandler, TrainerState, TrainerControl 3 | import logging 4 | 5 | logger = logging.getLogger(__name__) 6 | 7 | class BasicQaCallbackHandler(CallbackHandler): 8 | """ 9 | callback handler for qa events 10 | """ 11 | 12 | def on_save(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, 13 | prefix_checkpoint, use_wandb): 14 | return self.call_event("on_save", args, state, control, 15 | prefix_checkpoint=prefix_checkpoint, 16 | use_wandb=use_wandb) 17 | 18 | def on_batch_begin(self, args: TrainingArguments, 19 | state: TrainerState, 20 | control: TrainerControl, 21 | batch_inputs): 22 | """ 23 | in on_step_begin we want to pass the batch inputs to the event handler so we can update the counter for each task 24 | """ 25 | self.call_event("on_batch_begin", args, state, control, 26 | batch_inputs=batch_inputs) 27 | 28 | def on_predictions_save(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, 29 | output_predictions_path, 30 | dataset_name): 31 | return self.call_event("on_predictions_save", args, state, control, 32 | output_predictions_path=output_predictions_path, 33 | dataset_name=dataset_name) 34 | -------------------------------------------------------------------------------- /ExampleGeneration/wikitextparser/_wikilink.py: -------------------------------------------------------------------------------- 1 | """The WikiLink class.""" 2 | 3 | 4 | from typing import Optional 5 | 6 | from ._wikitext import SubWikiText 7 | 8 | 9 | class WikiLink(SubWikiText): 10 | """Define a class to represent WikiLinks.""" 11 | 12 | @property 13 | def target(self) -> str: 14 | """Return target of this WikiLink.""" 15 | head, pipe, tail = self._atomic_partition(124) 16 | if pipe: 17 | return head[2:] 18 | else: 19 | return head[2:-2] 20 | 21 | @target.setter 22 | def target(self, newtarget: str) -> None: 23 | """Set a new target.""" 24 | head, pipe, tail = self._atomic_partition(124) 25 | if not pipe: 26 | head = head[:-2] 27 | self[2:len(head)] = newtarget 28 | 29 | @property 30 | def text(self) -> Optional[str]: 31 | """Return the text of this WikiLink. Do not include linktrail.""" 32 | head, pipe, tail = self._atomic_partition(124) 33 | if pipe: 34 | return tail[:-2] 35 | return None 36 | 37 | @text.setter 38 | def text(self, newtext: Optional[str]) -> None: 39 | """Set self.text to newtext. Remove it if newtext is None. 40 | 41 | Do not change the linktrail. 42 | """ 43 | head, pipe, tail = self._atomic_partition(124) 44 | if pipe: 45 | if newtext is None: 46 | del self[len(head + pipe) - 1:len(head + pipe + tail) - 2] 47 | else: 48 | self[len(head + pipe):len(head + pipe + tail) - 2] = newtext 49 | elif newtext is not None: 50 | self.insert(-2, '|' + newtext) 51 | -------------------------------------------------------------------------------- /ExampleGeneration/tests/datajobs/classify_table_column_types_test.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=no-self-use,invalid-name 2 | import pytest 3 | import argparse 4 | from ExampleGeneration.datajobs.reas_classify_column_types import ReasClassifyColumnTypesDataJob 5 | 6 | 7 | class TestClassifyTableColumnTypesDataJob: 8 | def test_run_datajob(self): 9 | parse = argparse.ArgumentParser("") 10 | parse.add_argument("-c", "--datajob_name", type=str, help="The name of the datajob class and config to use") 11 | parse.add_argument("-o", "--operation", type=str, help="The task stage to run") 12 | parse.add_argument("-out", "--output_file", type=str, help="") 13 | parse.add_argument("-config", "--config_file_name", type=str, help="DataJobs config file name", default="config_reas.json") 14 | parse.add_argument("-wd", "--working_directory", type=str, help="dir of input file, can be s3 path") 15 | parse.add_argument("-af", "--annotated_questions_file", type=str, help="dir of input file, can be s3 path", default=None) 16 | # In the test no output file will be produced, change -out to create an output 17 | args = parse.parse_args(["-c", "AddColumnTypeMetadata","-o", "build_datajob", 18 | "-wd", "data"]) 19 | 20 | datajob = ReasClassifyColumnTypesDataJob('ReasClassifyColumnTypes',args) 21 | 22 | # reducing data size to a sample: 23 | datajob._config['max_number_of_examples'] = 100 24 | datajob._config['max_number_of_add_column_type_metadatas'] = 100 25 | datajob.output_path = "data/datajob_samples/classify_table_column_types_sample.jsonl" 26 | 27 | datajob.run_datajob(args) 28 | -------------------------------------------------------------------------------- /ExampleGeneration/requirements.txt: -------------------------------------------------------------------------------- 1 | attrs==21.2.0 2 | beautifulsoup4==4.9.3 3 | blis==0.7.4 4 | boto3==1.18.16 5 | botocore==1.21.16 6 | bs4==0.0.1 7 | cachetools==4.2.2 8 | catalogue==2.0.4 9 | certifi==2021.5.30 10 | charset-normalizer==2.0.4 11 | click==7.1.2 12 | contextvars==2.4 13 | cymem==2.0.5 14 | dataclasses==0.8 15 | en-core-web-sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.1.0/en_core_web_sm-3.1.0-py3-none-any.whl 16 | fgrequests==0.1.1 17 | google-api-core==1.31.1 18 | google-auth==1.34.0 19 | google-cloud-vision==2.4.2 20 | googleapis-common-protos==1.53.0 21 | grpcio==1.39.0 22 | idna==3.2 23 | immutables==0.16 24 | importlib-metadata==4.6.3 25 | iniconfig==1.1.1 26 | Jinja2==3.0.1 27 | jmespath==0.10.0 28 | joblib==1.0.1 29 | jsonschema==3.2.0 30 | MarkupSafe==2.0.1 31 | murmurhash==1.0.5 32 | numpy==1.19.5 33 | overrides==6.1.0 34 | packaging==21.0 35 | pandas==1.1.5 36 | pathy==0.6.0 37 | pluggy==0.13.1 38 | preshed==3.0.5 39 | proto-plus==1.19.0 40 | protobuf==3.17.3 41 | py==1.10.0 42 | pyasn1==0.4.8 43 | pyasn1-modules==0.2.8 44 | pydantic==1.8.2 45 | pyparsing==2.4.7 46 | pyrsistent==0.18.0 47 | pytest==6.2.4 48 | pytest-pycharm==0.7.0 49 | python-dateutil==2.8.2 50 | pytz==2021.1 51 | regex==2021.8.3 52 | requests==2.26.0 53 | rsa==4.7.2 54 | s3transfer==0.5.0 55 | scikit-learn==0.24.2 56 | scipy==1.5.4 57 | six==1.16.0 58 | sklearn==0.0 59 | smart-open==5.1.0 60 | soupsieve==2.2.1 61 | spacy==3.1.1 62 | spacy-legacy==3.0.8 63 | srsly==2.4.1 64 | thinc==8.0.8 65 | threadpoolctl==2.2.0 66 | toml==0.10.2 67 | tqdm==4.62.0 68 | typer==0.3.2 69 | typing-extensions==3.10.0.0 70 | typing-utils==0.1.0 71 | urllib3==1.26.6 72 | wasabi==0.8.2 73 | wcwidth==0.2.5 74 | wikipedia==1.4.0 75 | zipp==3.5.0 76 | -------------------------------------------------------------------------------- /Training/ContinuousPreTraining/Data/data_utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import gzip 3 | from datasets import tqdm 4 | 5 | 6 | class OrderedBlock: 7 | """ 8 | Ordered block object, which contains a block of sequences of similar size, and the maximum sequence length 9 | """ 10 | def __init__(self, block, max_seq): 11 | self.block = block 12 | self.max_seq = max_seq 13 | 14 | 15 | def build_data_blocks(data_path, max_block_size): 16 | """ 17 | Method to build blocks of data with similar size 18 | data_path: path to data path, each example must contain a phrase and a context 19 | max_block_size: max number of examples in a single block 20 | """ 21 | input_examples = [] 22 | with gzip.open(data_path, "r") as f: 23 | for i, l in enumerate(tqdm(f)): 24 | input_examples.append(json.loads(l)) 25 | 26 | if 'phrase' in input_examples[0]: 27 | input_examples.sort(key=lambda x: len(x['phrase']) + len(x['context'])) 28 | else: 29 | input_examples.sort(key=lambda x: len(x['context'])) 30 | ordered_blocks = [] 31 | 32 | while len(input_examples) > 0: 33 | to_take = min(max_block_size, len(input_examples)) 34 | # select = random.randint(0, len(input_examples) - to_take) 35 | select = 0 36 | block = input_examples[select:select + to_take] 37 | if 'phrase' in block[0]: 38 | max_seq = max([len(x['phrase']) + len(x['context']) for x in 39 | block]) 40 | else: 41 | max_seq = max([len(x['context']) for x in block]) 42 | ordered_blocks.append(OrderedBlock(block=block, 43 | max_seq=max_seq)) 44 | del input_examples[select:select + to_take] 45 | return ordered_blocks 46 | -------------------------------------------------------------------------------- /Training/ContinuousPreTraining/Training/trainer_factory.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from transformers import default_data_collator 4 | from ContinuousPreTraining.Common.file_utils import upper_to_lower_notation_name, find_module 5 | 6 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name 7 | 8 | 9 | class TrainerFactory: 10 | """ 11 | factory to get the trainer from the config 12 | """ 13 | def __init__(self): 14 | pass 15 | 16 | 17 | def find_trainer(self, trainer_name): 18 | 19 | trainer_name_lower = upper_to_lower_notation_name(trainer_name) 20 | module_name = find_module(os.path.dirname(os.path.abspath(__file__)), trainer_name_lower) 21 | try: 22 | mod = __import__('ContinuousPreTraining.Training.' + module_name, 23 | fromlist=[trainer_name]) 24 | except: 25 | logger.error(module_name + ' module not found!!') 26 | assert (ValueError('qgen_name not found!')) 27 | 28 | return getattr(mod, trainer_name) 29 | 30 | def get_trainer(self, trainer_config, trainer_args, tokenizer): 31 | """ 32 | factory method to get a trainer 33 | """ 34 | trainer_name = trainer_config['type'] 35 | trainer = self.find_trainer(trainer_name) 36 | 37 | # check if we need to init a data collator 38 | if 'data_collator' in trainer_config: 39 | if trainer_config['data_collator'] == 'smart': 40 | trainer_args['data_collator'] = SmartDataCollator(tokenizer) 41 | else: 42 | trainer_args['data_collator'] = default_data_collator 43 | 44 | # add the tokenizer, as we may need it for eval 45 | trainer_args['tokenizer'] = tokenizer 46 | 47 | # init trainer and return 48 | return trainer(**trainer_args) 49 | 50 | -------------------------------------------------------------------------------- /Training/ContinuousPreTraining/evaluators/iirc_eval.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from ContinuousPreTraining.evaluators.drop_eval import DropEval 3 | import numpy as np 4 | import pandas as pd 5 | 6 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name 7 | 8 | 9 | class IircEval(DropEval): 10 | 11 | def evaluate(self, 12 | predictions, 13 | output_predictions_path, 14 | dataset_name): 15 | """ 16 | evaluate a QA dataset 17 | """ 18 | # todo all this should be in the evaluator, it should get the predictions and return a dict 19 | # calculate em and f1 for every prediction 20 | for pred in predictions: 21 | em_score, f1_score = self.evaluate_single_example_method(pred) 22 | pred['em'] = em_score 23 | pred['f1'] = f1_score 24 | 25 | logger.info(f'saving predictions to {output_predictions_path}') 26 | predictions_df = pd.DataFrame(predictions) 27 | predictions_df.to_csv(output_predictions_path) 28 | 29 | em = np.average([p['em'] for p in predictions]) 30 | f1 = np.average([p['f1'] for p in predictions]) 31 | 32 | result_dict = {f'{dataset_name}_em': em, 33 | f'{dataset_name}_f1': f1} 34 | 35 | # add answer types to dict 36 | answer_types = {p['answer_type'] for p in predictions} 37 | for answer_type in answer_types: 38 | type_em = np.average([p['em'] for p in predictions 39 | if p['answer_type'] == answer_type]) 40 | type_f1 = np.average([p['f1'] for p in predictions 41 | if p['answer_type'] == answer_type]) 42 | result_dict[f'{dataset_name}_{answer_type}_em'] = type_em 43 | result_dict[f'{dataset_name}_{answer_type}_f1'] = type_f1 44 | 45 | return result_dict 46 | -------------------------------------------------------------------------------- /ExampleGeneration/ExampleGeneration/run.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from ExampleGeneration.datajob_factory import DataJobFactory 4 | 5 | import logging 6 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 7 | level=logging.INFO) 8 | logger = logging.getLogger(__name__) 9 | 10 | 11 | def main(): 12 | parse = argparse.ArgumentParser("") 13 | parse.add_argument("-c", "--datajob_name", type=str, help="The name of the datajob class and config to use") 14 | parse.add_argument("-o", "--operation", type=str, help="The task stage to run", default='run_datajob') 15 | parse.add_argument("-in", "--input_file", type=str, help="") 16 | parse.add_argument("-out", "--output_file", type=str, help="") 17 | parse.add_argument("-config", "--config_file_name", type=str, help="DataJobs config file name", default="config_tests.json") 18 | parse.add_argument("--copy_from", type=str, help="For create new challenge, the chllenge to copy from", default=-1) 19 | parse.add_argument("--datajob_module", type=str, help="For create new challenge, the target challenge path", default='') 20 | parse.add_argument("-wd", "--working_directory", type=str, help="dir of input file, can be s3 path", default='') 21 | parse.add_argument("-af", "--annotated_questions_file", type=str, help="dir of input file, can be s3 path", default=None) 22 | 23 | # In the test no output file will be produced, change -out to create an output 24 | args = parse.parse_args() 25 | 26 | 27 | if args.operation == 'create_new_datajob': 28 | DataJobFactory().create_new_datajob(args.datajob_name, args.datajob_name, args) 29 | else: 30 | datajob = DataJobFactory().get_datajob(args.datajob_name, args.datajob_name, args) 31 | if args.operation == 'run_datajob': 32 | datajob.run_datajob(args) 33 | else: 34 | logger.error('Operation not supported') 35 | 36 | if __name__ == '__main__': 37 | main() 38 | -------------------------------------------------------------------------------- /ExampleGeneration/ExampleGeneration/common/context_utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | def get_context_name(context): 4 | """ 5 | :param context: 6 | :return: the name of the context in lower case 7 | """ 8 | return context.context[0].title.replace(' ', '_') 9 | 10 | 11 | def get_table(context): 12 | """ 13 | :param context: 14 | :return: the context's table 15 | """ 16 | for d in context.context: 17 | if len(d.table) > 0: 18 | return d 19 | return None 20 | 21 | 22 | def get_images_from_context(context): 23 | """ 24 | :param context: 25 | :return: the context's images 26 | """ 27 | images = [] 28 | for doc in context.context: 29 | if doc.metadata['type'] == 'image': 30 | images.append(doc) 31 | """ 32 | if 'metadata' in doc['documents'] and 'type' in doc['documents']['metadata']: 33 | # image_docs = [doc for doc in context['context']['documents'] if doc['metadata']['type'] == 'image'] 34 | images.append(doc['documents']) 35 | """ 36 | if images: 37 | return images 38 | else: 39 | return None 40 | 41 | 42 | def get_images_mapping_from_context(context): 43 | """ 44 | :param context: 45 | :return: the coords of the images 46 | """ 47 | images_map = {} 48 | for doc in context.context: 49 | if 'images_map' in doc.metadata: 50 | images_map = doc.metadata['images_map'] 51 | break 52 | """ 53 | if 'metadata' in doc['documents'] and 'type' in doc['documents']['metadata']: 54 | # image_docs = [doc for doc in context['context']['documents'] if doc['metadata']['type'] == 'image'] 55 | images.append(doc['documents']) 56 | """ 57 | if images_map: 58 | return images_map 59 | else: 60 | return None 61 | 62 | 63 | REP_PATTERN = re.compile(r'\[[0-9]*\]|\n') 64 | def normalize_paragraph_tag(p_tag): 65 | return re.sub(REP_PATTERN, "", p_tag.text) 66 | 67 | def get_url_from_title(title): 68 | return f"https://en.wikipedia.org/wiki/{title.replace(' ', '_')}" 69 | -------------------------------------------------------------------------------- /Training/ContinuousPreTraining/configurations/drop_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "experiment": { 3 | "experiment_name": "DROP" 4 | }, 5 | "model": { 6 | "PReasM": true, 7 | "sampler": "Err", 8 | "size": "Base" 9 | }, 10 | "tokenizer": "t5-base", 11 | "datasets_sampler": {"type": "DatasetUniformSampler"}, 12 | "train_datasets": { 13 | "DROP_train": { 14 | "reader": { 15 | "type": "UnifiedQaDataset", 16 | "pass_tokenizer": true, 17 | "path": "/ContinuousPreTraining/Data/drop/parsed_drop_train_with_lists.json", 18 | "max_input_token_len": 512, 19 | "max_output_token_len": 32, 20 | "generation_model": true 21 | }, 22 | "dataloader": { 23 | "single_task_sampler": "LengthGroupedSampler" 24 | } 25 | } 26 | }, 27 | "validation_datasets":{ 28 | "DROP_eval": { 29 | "reader": { 30 | "type": "UnifiedQaDataset", 31 | "pass_tokenizer": true, 32 | "path": "/ContinuousPreTraining/Data/drop/parsed_drop_dev_with_lists.json", 33 | "max_input_token_len": 512, 34 | "max_output_token_len": 32, 35 | "generation_model": true 36 | }, 37 | "dataloader": { 38 | }, 39 | "predictor": "ListGenerativePredictor", 40 | "eval_method": "DropListEval" 41 | } 42 | }, 43 | "optimizer": { 44 | "type": "AdaFactor", 45 | "lr": 1e-4 46 | }, 47 | "scheduler": { 48 | "type": "linear_scheduler_with_warmup", 49 | "num_warmup_steps": 500, 50 | "num_training_steps": 2e32 51 | }, 52 | "training_arguments": { 53 | "num_train_epochs": 20, 54 | "per_device_train_batch_size": 20, 55 | "per_device_eval_batch_size": 24, 56 | "gradient_accumulation_steps": 1, 57 | "log_steps": 100, 58 | "eval_steps": 500, 59 | "save_steps": 100000, 60 | "evaluation_strategy": "epoch", 61 | "weight_decay": 0.01, 62 | "save_total_limit": 5, 63 | "seed": 42, 64 | "prediction_loss_only": true, 65 | "no_cuda": false 66 | }, 67 | "trainer": { 68 | "type": "UpdatedMtTrainer", 69 | "load_train_dataloader_after_eval": false, 70 | "callbacks": [] 71 | } 72 | } -------------------------------------------------------------------------------- /ExampleGeneration/tests/datajobs/format_questions_test.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import argparse 3 | import os, json 4 | 5 | from ExampleGeneration.datajobs.format_questions import FormatQuestionsDataJob 6 | from ExampleGeneration.common.analysis_utils import dump_synthetic_questions_analysis 7 | 8 | 9 | class TestFormatQuestionsDEBUG: 10 | 11 | def test_format_questoins(self): 12 | config_file = "config_reas.json" 13 | config_entry = "FormatSyntheticQuestions" 14 | working_directory = "data/tab_reas" 15 | 16 | parse = argparse.ArgumentParser("") 17 | parse.add_argument("-c", "--datajob_name", type=str, help="The name of the datajob class and config to use") 18 | parse.add_argument("-o", "--operation", type=str, help="The task stage to run") 19 | parse.add_argument("-out", "--output_file", type=str, help="") 20 | parse.add_argument("-config", "--config_file_name", type=str, help="", default="config_reas.json") 21 | parse.add_argument("-wd", "--working_directory", type=str, help="dir of input file, can be s3 path") 22 | parse.add_argument("-af", "--annotated_questions_file", type=str, help="dir of input file, can be s3 path", 23 | default=None) 24 | 25 | # In the test no output file will be produced, change -out to create an output 26 | args = parse.parse_args( 27 | ["-c", "Example", "-o", "build_datajob", "-config", config_file, "-wd", working_directory]) 28 | 29 | datajob = FormatQuestionsDataJob(config_entry, args) 30 | 31 | # reducing data size to a sample: 32 | datajob._config['n_processes'] = 1 33 | datajob._config['max_chunk_size'] = 1000 34 | datajob._config['max_number_of_examples'] = 1000 35 | datajob.input_path = "data/datajob_samples/synthetic_questions.jsonl" 36 | datajob.output_path = "data/datajob_samples/formatted_synthetic_questions.jsonl" 37 | datajob.run_datajob(args) 38 | 39 | dump_synthetic_questions_analysis('data/datajob_samples/formatted_synthetic_questions.jsonl', \ 40 | 'data/datajob_samples/formatted_synthetic_questions.csv') 41 | -------------------------------------------------------------------------------- /ExampleGeneration/wikitextparser/_externallink.py: -------------------------------------------------------------------------------- 1 | """Define the ExternalLink class.""" 2 | 3 | from typing import Optional 4 | 5 | from regex import compile as regex_compile 6 | 7 | from ._wikitext import SubWikiText, BRACKET_EXTERNAL_LINK_URL 8 | 9 | 10 | URL_MATCH = regex_compile(BRACKET_EXTERNAL_LINK_URL).match 11 | 12 | 13 | class ExternalLink(SubWikiText): 14 | 15 | """Create a new ExternalLink object.""" 16 | 17 | @property 18 | def url(self) -> str: 19 | """Return the url.""" 20 | if self[0] == '[': 21 | return self[1:URL_MATCH(self._ext_link_shadow, 1).end()] 22 | return self.string 23 | 24 | @url.setter 25 | def url(self, newurl: str) -> None: 26 | """Set a new url.""" 27 | if self[0] == '[': 28 | self[1:len('[' + self.url)] = newurl 29 | else: 30 | self[0:len(self.url)] = newurl 31 | 32 | @property 33 | def text(self) -> Optional[str]: 34 | """Return the text part (the part after the first space). 35 | 36 | Return None if this is a bare link or has no associated text. 37 | """ 38 | string = self.string 39 | if string[0] == '[': 40 | end_match = URL_MATCH(self._ext_link_shadow, 1) 41 | url_end = end_match.end() 42 | end_char = string[url_end] 43 | if end_char == ']': 44 | return None 45 | if end_char == ' ': 46 | return string[url_end + 1:-1] 47 | return string[url_end:-1] 48 | 49 | @text.setter 50 | def text(self, newtext: str) -> None: 51 | """Set a new text. 52 | 53 | Automatically put the ExternalLink in brackets if it's not already. 54 | """ 55 | string = self.string 56 | if string[0] == '[': 57 | text = self.text 58 | if text: 59 | self[-len(text) - 1:-1] = newtext 60 | return 61 | self.insert(-1, ' ' + newtext) 62 | return 63 | self.insert(len(string), ' ' + newtext + ']') 64 | self.insert(0, '[') 65 | 66 | @property 67 | def in_brackets(self) -> bool: 68 | """Return true if the ExternalLink is in brackets. False otherwise.""" 69 | return self[0] == '[' 70 | -------------------------------------------------------------------------------- /ExampleGeneration/wikitextparser/_section.py: -------------------------------------------------------------------------------- 1 | """Define the Section class.""" 2 | 3 | 4 | from re import compile as re_compile, MULTILINE 5 | 6 | from ._wikitext import WS, SubWikiText 7 | 8 | HEADER_MATCH = re_compile( 9 | rb'(={1,6})[^\n]+?\1[ \t]*$', 10 | MULTILINE, 11 | ).match 12 | 13 | 14 | class Section(SubWikiText): 15 | 16 | """Section class is used to represent page sections.""" 17 | 18 | @property 19 | def level(self) -> int: 20 | """Return level of this section. 21 | 22 | Level is in range(1,7) or 0 for the lead section. 23 | """ 24 | m = HEADER_MATCH(self._shadow) 25 | if m: 26 | return len(m.group(1)) 27 | return 0 28 | 29 | @level.setter 30 | def level(self, value: int) -> None: 31 | """Change level of this section.""" 32 | old_level = self.level 33 | title = self.title 34 | new_equals = '=' * value 35 | self[0:old_level + len(title) + old_level] =\ 36 | new_equals + title + new_equals 37 | 38 | @property 39 | def title(self) -> str: 40 | """Return title of this section. Return '' for lead sections.""" 41 | level = self.level 42 | if level == 0: 43 | return '' 44 | return self._atomic_partition(10)[0].rstrip(WS)[level:-level] 45 | 46 | @title.setter 47 | def title(self, value: str) -> None: 48 | """Set the new title for this section and update self.lststr.""" 49 | level = self.level 50 | if level == 0: 51 | raise RuntimeError( 52 | "Can't set title for a lead section. " 53 | "Try adding it to contents." 54 | ) 55 | title = self.title 56 | self[level:level + len(title)] = value 57 | 58 | @property 59 | def contents(self) -> str: 60 | """Return contents of this section.""" 61 | if self.level == 0: 62 | return self.string 63 | return self._atomic_partition(10)[2] 64 | 65 | @contents.setter 66 | def contents(self, value: str) -> None: 67 | """Set value as the contents of this section.""" 68 | level = self.level 69 | if level == 0: 70 | self[:] = value 71 | return 72 | contents = self.contents 73 | start = level + len(self.title) + level + 1 74 | self[start:start + len(contents)] = value 75 | -------------------------------------------------------------------------------- /Training/ContinuousPreTraining/configurations/iirc_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "experiment": { 3 | "experiment_name": "IIRC" 4 | }, 5 | "model": { 6 | "PReasM": true, 7 | "sampler": "Err", 8 | "size": "Base" 9 | }, 10 | "tokenizer": "t5-base", 11 | "datasets_sampler": {"type": "DatasetUniformSampler"}, 12 | "train_datasets": { 13 | "IIRCTrain": { 14 | "reader": { 15 | "type": "IircDataset", 16 | "pass_tokenizer": true, 17 | "path": "/ContinuousPreTraining/Data/iirc/train.json", 18 | "max_seq_len": 256, 19 | "summ_len": 32, 20 | "generation_model": true 21 | }, 22 | "dataloader": { 23 | "LR": 1e-4, 24 | "single_task_sampler": "Random", 25 | "no_collator_in_eval": true 26 | }, 27 | "predictor": "ListGenerativePredictor", 28 | "eval_method": "DropEval" 29 | } 30 | }, 31 | "validation_datasets":{ 32 | "IIRCEval": { 33 | "reader": { 34 | "type": "IircDataset", 35 | "pass_tokenizer": true, 36 | "path": "/ContinuousPreTraining/Data/iirc/dev.json", 37 | "max_seq_len": 256, 38 | "summ_len": 32, 39 | "generation_model": true 40 | }, 41 | "dataloader": { 42 | "batch_size": 8, 43 | "no_collator_in_eval": true 44 | }, 45 | "predictor": "ListGenerativePredictor", 46 | "eval_method": "IircEval", 47 | "save_error_distribution": false 48 | } 49 | }, 50 | "optimizer": { 51 | "type": "AdaFactor", 52 | "lr": 1e-4 53 | }, 54 | "scheduler": { 55 | "type": "linear_scheduler_with_warmup", 56 | "num_warmup_steps": 500, 57 | "num_training_steps": 2e32 58 | }, 59 | "training_arguments": { 60 | "num_train_epochs": 100 , 61 | "per_device_train_batch_size": 20, 62 | "per_device_eval_batch_size": 50, 63 | "gradient_accumulation_steps": 1, 64 | "log_steps": 100, 65 | "evaluation_strategy": "epoch", 66 | "save_steps": 5000, 67 | "eval_steps": 100, 68 | "weight_decay": 0.01, 69 | "save_total_limit": 5, 70 | "seed": 43, 71 | "prediction_loss_only": true, 72 | "no_cuda": false 73 | }, 74 | "trainer": { 75 | "type": "UpdatedMtTrainer", 76 | "override_huggingface_train_method": false, 77 | "load_train_dataloader_after_eval": false, 78 | "callbacks": [] 79 | } 80 | } -------------------------------------------------------------------------------- /ExampleGeneration/ExampleGeneration/datajob.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | from ExampleGeneration.common.file_utils import upload_jsonl_to_s3, save_jsonl_to_local, is_path_creatable 5 | import pandas as pd 6 | 7 | pd.set_option('display.max_rows', 500) 8 | pd.set_option('display.max_columns', 500) 9 | pd.set_option('display.width', 1000) 10 | pd.set_option('display.max_colwidth', 200) 11 | 12 | 13 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 14 | level=logging.INFO) 15 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name 16 | 17 | 18 | class DataJob(): 19 | 20 | def __init__(self, args, load_context_as_multiqa=False): 21 | config_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "configurations",\ 22 | args.config_file_name) 23 | with open(config_path, 'r') as f: 24 | self._config = json.load(f)[self.datajob_name] 25 | 26 | if "input_file" in args and args.input_file is not None: 27 | self.input_path = self.get_path(args.working_directory, args.input_file) 28 | elif "input_file" in self._config: 29 | self.input_path = self.get_path(args.working_directory, self._config["input_file"]) 30 | 31 | if "output_file" in args and args.output_file is not None: 32 | self.output_path = self.get_path(args.working_directory, args.output_file) 33 | elif "output_file" in self._config: 34 | self.output_path = self.get_path(args.working_directory, self._config["output_file"]) 35 | 36 | @staticmethod 37 | def get_path(dir, file_name): 38 | if "/" in file_name: # if full path simply use it 39 | path = file_name 40 | else: 41 | assert len(dir) > 0, "No directory has been specified" 42 | path = os.path.join(dir, file_name) 43 | return path 44 | 45 | def save_output(self): 46 | if self._config["output_file"].startswith('s3://'): 47 | save_func = upload_jsonl_to_s3 48 | elif is_path_creatable(self._config["output_file"]) and len(self._config["output_file"]) > 0: 49 | save_func = save_jsonl_to_local 50 | else: 51 | # Do nothing 52 | return 53 | 54 | save_func(self._config["output_file"], self.datajob_output['contexts'], self.datajob_output.get('header', None)) 55 | 56 | def run_datajob(self,args): 57 | pass 58 | -------------------------------------------------------------------------------- /Training/ContinuousPreTraining/configurations/iirc_retrieval_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "experiment": { 3 | "experiment_name": "IIRC_Retrieval" 4 | }, 5 | "model": { 6 | "PReasM": false, 7 | "sampler": "Err", 8 | "size": "Base" 9 | }, 10 | "tokenizer": "t5-base", 11 | "datasets_sampler": {"type": "DatasetUniformSampler"}, 12 | "train_datasets": { 13 | "IIRCTrain": { 14 | "reader": { 15 | "type": "IircRetrievalDataset", 16 | "pass_tokenizer": true, 17 | "path": "/ContinuousPreTraining/Data/iirc/train.json", 18 | "max_seq_len": 512, 19 | "summ_len": 32, 20 | "generation_model": true 21 | }, 22 | "dataloader": { 23 | "LR": 1e-4, 24 | "single_task_sampler": "Random", 25 | "no_collator_in_eval": true 26 | }, 27 | "predictor": "ListGenerativePredictor", 28 | "eval_method": "DropEval" 29 | } 30 | }, 31 | "validation_datasets":{ 32 | "IIRCEval": { 33 | "reader": { 34 | "type": "IircRetrievalDataset", 35 | "pass_tokenizer": true, 36 | "path": "/ContinuousPreTraining/Data/iirc/train.json", 37 | "max_seq_len": 512, 38 | "summ_len": 32, 39 | "retrieval_file": "PATH TO DEV SET RETRIEVAL RESULTS", 40 | "generation_model": true 41 | }, 42 | "dataloader": { 43 | "batch_size": 8, 44 | "no_collator_in_eval": true 45 | }, 46 | "predictor": "ListGenerativePredictor", 47 | "eval_method": "IircEval", 48 | "save_error_distribution": false 49 | } 50 | }, 51 | "optimizer": { 52 | "type": "AdaFactor", 53 | "lr": 1e-4 54 | }, 55 | "scheduler": { 56 | "type": "linear_scheduler_with_warmup", 57 | "num_warmup_steps": 500, 58 | "num_training_steps": 2e32 59 | }, 60 | "training_arguments": { 61 | "num_train_epochs": 60 , 62 | "per_device_train_batch_size": 20, 63 | "per_device_eval_batch_size": 30, 64 | "gradient_accumulation_steps": 1, 65 | "log_steps": 100, 66 | "evaluation_strategy": "epoch", 67 | "save_steps": 5000, 68 | "eval_steps": 100, 69 | "weight_decay": 0.01, 70 | "save_total_limit": 5, 71 | "seed": 43, 72 | "prediction_loss_only": true, 73 | "no_cuda": false 74 | }, 75 | "trainer": { 76 | "type": "UpdatedMtTrainer", 77 | "override_huggingface_train_method": false, 78 | "load_train_dataloader_after_eval": false, 79 | "callbacks": [] 80 | } 81 | } -------------------------------------------------------------------------------- /Training/ContinuousPreTraining/Training/trainer_callbacks/multi_task_callback.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from ContinuousPreTraining.Training.trainer_callbacks.basic_qa_callback import BasicQaCallback 3 | 4 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name 5 | 6 | 7 | class MultiTaskCallback(BasicQaCallback): 8 | """ 9 | A :class:`~transformers.TrainerCallback` for handling multi task events 10 | """ 11 | 12 | def __init__(self): 13 | self._initialized = False 14 | 15 | def on_train_begin(self, args, state, control, model=None, **kwargs): 16 | # init multi task specific fields 17 | state.task_counter = {} 18 | state.tasks_indices = {} 19 | state.tasks_errors = {} 20 | 21 | control.restart_train_dataloader = False 22 | control.report_task_counter = False 23 | 24 | 25 | def on_batch_begin(self, args, state, control, model=None, **kwargs): 26 | """ 27 | for every step, update the counter for the current step 28 | """ 29 | if 'batch_inputs' in kwargs: 30 | batch_inputs = kwargs['batch_inputs'] 31 | if 'task_name' in batch_inputs: 32 | # get batch task name and delete 33 | train_task_name = batch_inputs['task_name'] 34 | del batch_inputs['task_name'] 35 | 36 | # update the index dict with the tasks name 37 | state.tasks_indices[state.global_step] = train_task_name 38 | 39 | # update the counter for the task 40 | if train_task_name not in state.task_counter: 41 | state.task_counter[train_task_name] = {} 42 | state.task_counter[train_task_name]['Batches_total'] = 0 43 | state.task_counter[train_task_name]['Batches_since_resampling'] = 0 44 | state.task_counter[train_task_name]['Examples'] = 0 45 | 46 | # update the batches and examples counter 47 | state.task_counter[train_task_name]['Batches_total'] += 1 48 | state.task_counter[train_task_name]['Batches_since_resampling'] += 1 49 | state.task_counter[train_task_name]['Examples'] += len(batch_inputs['input_ids']) 50 | 51 | 52 | def on_step_end(self, args, state, control, model=None, **kwargs): 53 | """ 54 | after every step, check if we need to report the task counter 55 | """ 56 | 57 | if control.should_evaluate: 58 | control.report_task_counter = True 59 | 60 | -------------------------------------------------------------------------------- /Training/ContinuousPreTraining/configurations/mmqa_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "experiment": { 3 | "experiment_name": "MMQA" 4 | }, 5 | "model": { 6 | "PReasM": false, 7 | "sampler": "Err", 8 | "size": "Base" 9 | }, 10 | "tokenizer": "t5-base", 11 | "datasets_sampler": {"type": "DatasetUniformSampler"}, 12 | "train_datasets": { 13 | "MMQA_train": { 14 | "reader": { 15 | "type": "UnifiedQaDataset", 16 | "pass_tokenizer": true, 17 | "path": "/ContinuousPreTraining/Data/mmqa/parsed_mmqa_train_oracle.json", 18 | "max_input_token_len": 1536, 19 | "max_output_token_len": 32, 20 | "generation_model": true 21 | }, 22 | "dataloader": { 23 | "single_task_sampler": "LengthGroupedSampler" 24 | } 25 | } 26 | }, 27 | "validation_datasets":{ 28 | "MMQA_eval": { 29 | "reader": { 30 | "type": "UnifiedQaDataset", 31 | "pass_tokenizer": true, 32 | "path": "/ContinuousPreTraining/Data/mmqa/parsed_mmqa_dev_retrieval.json", 33 | "max_input_token_len": 1536, 34 | "max_output_token_len": 32, 35 | "generation_model": true 36 | }, 37 | "dataloader": { 38 | }, 39 | "predictor": "ListGenerativePredictor", 40 | "eval_method": "DropListEval" 41 | }, 42 | "MMQA_eval_oracle": { 43 | "reader": { 44 | "type": "UnifiedQaDataset", 45 | "pass_tokenizer": true, 46 | "path": "/ContinuousPreTraining/Data/mmqa/parsed_mmqa_dev_oracle.json", 47 | "max_input_token_len": 1536, 48 | "max_output_token_len": 32, 49 | "generation_model": true 50 | }, 51 | "dataloader": { 52 | }, 53 | "predictor": "ListGenerativePredictor", 54 | "eval_method": "DropListEval" 55 | } 56 | }, 57 | "optimizer": { 58 | "type": "AdaFactor", 59 | "lr": 1e-4 60 | }, 61 | "scheduler": { 62 | "type": "linear_scheduler_with_warmup", 63 | "num_warmup_steps": 500, 64 | "num_training_steps": 2e32 65 | }, 66 | "training_arguments": { 67 | "num_train_epochs": 20, 68 | "per_device_train_batch_size": 3, 69 | "per_device_eval_batch_size": 4, 70 | "gradient_accumulation_steps": 6, 71 | "log_steps": 100, 72 | "eval_steps": 500, 73 | "save_steps": 100000, 74 | "evaluation_strategy": "epoch", 75 | "weight_decay": 0.01, 76 | "save_total_limit": 5, 77 | "seed": 40, 78 | "prediction_loss_only": true, 79 | "no_cuda": false 80 | }, 81 | "trainer": { 82 | "type": "UpdatedMtTrainer", 83 | "load_train_dataloader_after_eval": false, 84 | "callbacks": [] 85 | } 86 | } -------------------------------------------------------------------------------- /ExampleGeneration/ExampleGeneration/configurations/config_reas.json: -------------------------------------------------------------------------------- 1 | { 2 | "ParseWikiDump": { "enable": false, 3 | "type": "ParseWikiDump", 4 | "description": "Creates a MultiQA object from every table in wikipedia with the specified dimension filter", 5 | "dump_file_path": "s3://tabreas/wiki_dump_example/enwiki-20200101-pages-articles10.xml-p2336425p3046511.bz2", 6 | "output_file": "enwiki-20200101-pages-articles10.xml-p2336425p3046511_parsed.gz", 7 | "max_number_of_examples": 10, 8 | "n_processes": 5, 9 | "max_chunk_size": 10, 10 | "min_table_rows": 10, 11 | "min_table_cols": 2, 12 | "max_table_rows": 25 13 | }, 14 | "ReasClassifyColumnTypes": { 15 | "enable": false, 16 | "type": "ReasClassifyColumnTypes", 17 | "description": "Adds a type to the table columns that are non-index, numbers/dates ... ", 18 | "input_file": "https://tabreas.s3-us-west-2.amazonaws.com/parsed_data_chunks/chunk_0000.gz", 19 | "output_file": "ClassifyTableColumnsFiltered.jsonl", 20 | "max_number_of_examples": 100000, 21 | "n_processes": 8, 22 | "max_chunk_size": 1000 23 | }, 24 | "GenQuestionsFromTemplates_TabReas": {"enable": false, 25 | "type": "GenSyntheticQuestionsFromTemplates", 26 | "description": "Creates synthetic questions", 27 | "input_file": "ClassifyTableColumnsFiltered.jsonl", 28 | "output_file": "PseudoLangQuestions_All.jsonl", 29 | "reset_qas": true, 30 | "question_templates_file": "tabreas_question_templates.json", 31 | "templates_to_use": [ 32 | "SimpleConjunction", 33 | "SimpleComparison", 34 | "TemporalDistance", 35 | "TemporalComparison", 36 | "NumericSuperlatives", 37 | "TemporalSuperlatives", 38 | "Arithmetic", 39 | "Counting", 40 | "NumericBooleanComparison", 41 | "TemporalBooleanComparison", 42 | "MultihopComposition", 43 | "SimpleConjunction", 44 | "OnlyQuantifier", 45 | "Quantifiers", 46 | "TemporalDistance" 47 | ], 48 | "max_number_of_examples": 10000000, 49 | "n_processes": 20, 50 | "max_chunk_size": 1000 51 | }, 52 | "FormatSyntheticQuestions": 53 | { "enable": false, 54 | "type": "FormatQuestions", 55 | "description": "Format the questions to triplets of question, context, and answer", 56 | "input_file": "PseudoLangQuestions_All.jsonl", 57 | "output_file": "FormattedQuestions.jsonl", 58 | "max_number_of_examples": 100000, 59 | "n_processes": 1, 60 | "max_chunk_size": 20000 61 | } 62 | } -------------------------------------------------------------------------------- /ExampleGeneration/ExampleGeneration/configurations/config_tests.json: -------------------------------------------------------------------------------- 1 | { 2 | "ParseWikiDump": { "enable": false, 3 | "type": "ParseWikiDump", 4 | "description": "Creates a MultiQA object from every table in wikipedia with the specified dimension filter", 5 | "dump_file_path": "s3://tabreas/wiki_dump_example/enwiki-20200101-pages-articles10.xml-p2336425p3046511.bz2", 6 | "output_file": "enwiki-20200101-pages-articles10.xml-p2336425p3046511_parsed.gz", 7 | "max_number_of_examples": 10, 8 | "n_processes": 5, 9 | "max_chunk_size": 10, 10 | "min_table_rows": 10, 11 | "min_table_cols": 2, 12 | "max_table_rows": 25 13 | }, 14 | "ReasClassifyColumnTypes": { 15 | "enable": false, 16 | "type": "ReasClassifyColumnTypes", 17 | "description": "Adds a type to the table columns that are non-index, numbers/dates ... ", 18 | "input_file": "https://tabreas.s3-us-west-2.amazonaws.com/parsed_data_chunks/chunk_0000.gz", 19 | "output_file": "ClassifyTableColumnsFiltered.jsonl", 20 | "max_number_of_examples": 100000, 21 | "n_processes": 8, 22 | "max_chunk_size": 1000 23 | }, 24 | "GenQuestionsFromTemplates_TabReas": {"enable": false, 25 | "type": "GenSyntheticQuestionsFromTemplates", 26 | "description": "Creates synthetic questions", 27 | "input_file": "ClassifyTableColumnsFiltered.jsonl", 28 | "output_file": "PseudoLangQuestions_All.jsonl", 29 | "reset_qas": true, 30 | "question_templates_file": "tabreas_question_templates.json", 31 | "templates_to_use": [ 32 | "SimpleConjunction", 33 | "SimpleComparison", 34 | "TemporalDistance", 35 | "TemporalComparison", 36 | "NumericSuperlatives", 37 | "TemporalSuperlatives", 38 | "Arithmetic", 39 | "Counting", 40 | "NumericBooleanComparison", 41 | "TemporalBooleanComparison", 42 | "MultihopComposition", 43 | "SimpleConjunction", 44 | "OnlyQuantifier", 45 | "Quantifiers", 46 | "TemporalDistance" 47 | ], 48 | "max_number_of_examples": 10000000, 49 | "n_processes": 20, 50 | "max_chunk_size": 1000 51 | }, 52 | "FormatSyntheticQuestions": 53 | { "enable": false, 54 | "type": "FormatQuestions", 55 | "description": "Format the questions to triplets of question, context, and answer", 56 | "input_file": "PseudoLangQuestions_All.jsonl", 57 | "output_file": "FormattedQuestions.jsonl", 58 | "max_number_of_examples": 100000, 59 | "n_processes": 1, 60 | "max_chunk_size": 20000 61 | } 62 | } -------------------------------------------------------------------------------- /Training/ContinuousPreTraining/configurations/mmqa_para_classifier_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "experiment": { 3 | "experiment_name": "MMQA_para_classifier" 4 | }, 5 | "model": { 6 | "PReasM": false, 7 | "sampler": "Err", 8 | "size": "Base" 9 | }, 10 | "tokenizer": "t5-base", 11 | "datasets_sampler": {"type": "DatasetUniformSampler"}, 12 | "train_datasets": { 13 | "MMQA_train": { 14 | "reader": { 15 | "type": "UnifiedQaDataset", 16 | "pass_tokenizer": true, 17 | "path": "/ContinuousPreTraining/Data/mmqa/parsed_mmqa_paragraph_classifier_train.json", 18 | "max_input_token_len": 1536, 19 | "max_output_token_len": 4, 20 | "generation_model": true 21 | }, 22 | "dataloader": { 23 | "single_task_sampler": "LengthGroupedSampler" 24 | } 25 | } 26 | }, 27 | "validation_datasets":{ 28 | "MMQA_eval": { 29 | "reader": { 30 | "type": "UnifiedQaDataset", 31 | "pass_tokenizer": true, 32 | "path": "/ContinuousPreTraining/Data/mmqa/parsed_mmqa_paragraph_classifier_dev.json", 33 | "max_input_token_len": 1536, 34 | "max_output_token_len": 4, 35 | "generation_model": true 36 | }, 37 | "dataloader": { 38 | }, 39 | "predictor": "ListGenerativePredictor", 40 | "eval_method": "DropListEval" 41 | }, 42 | "MMQA_test": { 43 | "reader": { 44 | "type": "UnifiedQaDataset", 45 | "pass_tokenizer": true, 46 | "path": "/ContinuousPreTraining/Data/mmqa/parsed_mmqa_paragraph_classifier_test.json", 47 | "max_input_token_len": 1536, 48 | "max_output_token_len": 4, 49 | "generation_model": true 50 | }, 51 | "dataloader": { 52 | }, 53 | "predictor": "ListGenerativePredictor", 54 | "eval_method": "DropListEval" 55 | } 56 | }, 57 | "optimizer": { 58 | "type": "AdaFactor", 59 | "lr": 1e-4 60 | }, 61 | "scheduler": { 62 | "type": "linear_scheduler_with_warmup", 63 | "num_warmup_steps": 500, 64 | "num_training_steps": 2e32 65 | }, 66 | "training_arguments": { 67 | "num_train_epochs": 10, 68 | "per_device_train_batch_size": 3, 69 | "per_device_eval_batch_size": 4, 70 | "gradient_accumulation_steps": 6, 71 | "log_steps": 100, 72 | "eval_steps": 500, 73 | "save_steps": 100000, 74 | "evaluation_strategy": "epoch", 75 | "weight_decay": 0.01, 76 | "save_total_limit": 5, 77 | "seed": 40, 78 | "prediction_loss_only": true, 79 | "no_cuda": false 80 | }, 81 | "trainer": { 82 | "type": "UpdatedMtTrainer", 83 | "load_train_dataloader_after_eval": false, 84 | "callbacks": [] 85 | } 86 | } -------------------------------------------------------------------------------- /Training/ContinuousPreTraining/configurations/mmqa_question_classifier_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "experiment": { 3 | "experiment_name": "MMQA_question_classifier" 4 | }, 5 | "model": { 6 | "PReasM": false, 7 | "sampler": "Err", 8 | "size": "Base" 9 | }, 10 | "tokenizer": "t5-base", 11 | "datasets_sampler": {"type": "DatasetUniformSampler"}, 12 | "train_datasets": { 13 | "MMQA_train": { 14 | "reader": { 15 | "type": "UnifiedQaDataset", 16 | "pass_tokenizer": true, 17 | "path": "/ContinuousPreTraining/Data/mmqa/parsed_mmqa_question_classifier_train.json", 18 | "max_input_token_len": 1536, 19 | "max_output_token_len": 4, 20 | "generation_model": true 21 | }, 22 | "dataloader": { 23 | "single_task_sampler": "LengthGroupedSampler" 24 | } 25 | } 26 | }, 27 | "validation_datasets":{ 28 | "MMQA_eval": { 29 | "reader": { 30 | "type": "UnifiedQaDataset", 31 | "pass_tokenizer": true, 32 | "path": "/ContinuousPreTraining/Data/mmqa/parsed_mmqa_question_classifier_dev.json", 33 | "max_input_token_len": 1536, 34 | "max_output_token_len": 4, 35 | "generation_model": true 36 | }, 37 | "dataloader": { 38 | }, 39 | "predictor": "ListGenerativePredictor", 40 | "eval_method": "DropListEval" 41 | }, 42 | "MMQA_test": { 43 | "reader": { 44 | "type": "UnifiedQaDataset", 45 | "pass_tokenizer": true, 46 | "path": "/ContinuousPreTraining/Data/mmqa/parsed_mmqa_question_classifier_test.json", 47 | "max_input_token_len": 1536, 48 | "max_output_token_len": 4, 49 | "generation_model": true 50 | }, 51 | "dataloader": { 52 | }, 53 | "predictor": "ListGenerativePredictor", 54 | "eval_method": "DropListEval" 55 | } 56 | }, 57 | "optimizer": { 58 | "type": "AdaFactor", 59 | "lr": 1e-4 60 | }, 61 | "scheduler": { 62 | "type": "linear_scheduler_with_warmup", 63 | "num_warmup_steps": 500, 64 | "num_training_steps": 2e32 65 | }, 66 | "training_arguments": { 67 | "num_train_epochs": 10, 68 | "per_device_train_batch_size": 3, 69 | "per_device_eval_batch_size": 4, 70 | "gradient_accumulation_steps": 6, 71 | "log_steps": 100, 72 | "eval_steps": 500, 73 | "save_steps": 100000, 74 | "evaluation_strategy": "epoch", 75 | "weight_decay": 0.01, 76 | "save_total_limit": 5, 77 | "seed": 40, 78 | "prediction_loss_only": true, 79 | "no_cuda": false 80 | }, 81 | "trainer": { 82 | "type": "UpdatedMtTrainer", 83 | "load_train_dataloader_after_eval": false, 84 | "callbacks": [] 85 | } 86 | } -------------------------------------------------------------------------------- /Training/ContinuousPreTraining/Data/dataset_readers/t5_mlm_dataset.py: -------------------------------------------------------------------------------- 1 | # Creating a custom dataset for reading the dataframe and loading it into the dataloader to pass it to the neural network at a later stage for finetuning the model and to prepare it for predictions 2 | import json 3 | import random 4 | 5 | import torch 6 | from datasets import tqdm 7 | 8 | 9 | class T5MlmDataset(torch.utils.data.Dataset): 10 | 11 | def __init__(self, data_path, 12 | tokenizer, 13 | max_input_token_len, 14 | max_output_token_len, 15 | num_examples_to_load=1000, 16 | num_wiki_examples=7529903): 17 | 18 | # init 19 | self.num_examples_to_load = num_examples_to_load 20 | self.tokenizer = tokenizer 21 | self.max_input_token_len = max_input_token_len 22 | self.max_output_token_len = max_output_token_len 23 | epoch_indices = random.sample(range(0, num_wiki_examples), self.num_examples_to_load) 24 | epoch_indices = set(epoch_indices) 25 | 26 | self.wiki_prefix = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize('Wiki: ')) 27 | self.wiki_examples = [] 28 | 29 | with open(data_path, "r") as f: 30 | for i, l in tqdm(enumerate(f)): 31 | if i in epoch_indices: 32 | self.wiki_examples.append(json.loads(l)) 33 | 34 | if i == num_wiki_examples: 35 | break 36 | 37 | def __len__(self): 38 | return len(self.wiki_examples) 39 | 40 | def __getitem__(self, index): 41 | 42 | # if wiki source, add pad indices for short sequences and return 43 | num_input_pads = self.max_input_token_len - len(self.wiki_prefix) - len(self.wiki_examples[index]['inputs']) 44 | num_label_pads = self.max_output_token_len - len(self.wiki_examples[index]['labels']) 45 | input_ids_tensor = torch.IntTensor(self.wiki_prefix 46 | + self.wiki_examples[index]['inputs'] 47 | + [self.tokenizer.pad_token_id] * num_input_pads).to(dtype=torch.long) 48 | labels_tensor = torch.IntTensor(self.wiki_examples[index]['labels'] 49 | + [-100] * num_label_pads).to(dtype=torch.long) 50 | 51 | # create attention masks 52 | attention_mask = torch.ones(input_ids_tensor.shape) 53 | attention_mask[input_ids_tensor == self.tokenizer.pad_token_id] = 0 54 | 55 | return { 56 | 'input_ids': input_ids_tensor, 57 | 'attention_mask': attention_mask.to(dtype=torch.long), 58 | 'labels': labels_tensor, 59 | } 60 | -------------------------------------------------------------------------------- /ExampleGeneration/tests/datajobs/gen_synthetic_questions_from_templates_debug.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import argparse 3 | import os, json 4 | 5 | from ExampleGeneration.common.analysis_utils import dump_manual_analysis_facts 6 | from ExampleGeneration.datajobs.gen_synthetic_questions_from_templates import \ 7 | GenSyntheticQuestionsFromTemplatesDataJob 8 | 9 | 10 | class TestSyntheticGenQuestionsFromTemplatesDEBUG: 11 | 12 | def test_synthetic_reasoning_questions(self): 13 | config_file = "config_reas.json" 14 | config_entry = "GenQuestionsFromTemplates_TabReas" 15 | working_directory = "data/tab_reas" 16 | 17 | parse = argparse.ArgumentParser("") 18 | parse.add_argument("-c", "--datajob_name", type=str, help="The name of the datajob class and config to use") 19 | parse.add_argument("-o", "--operation", type=str, help="The task stage to run") 20 | parse.add_argument("-out", "--output_file", type=str, help="") 21 | parse.add_argument("-config", "--config_file_name", type=str, help="", default="config_reas.json") 22 | parse.add_argument("-wd", "--working_directory", type=str, help="dir of input file, can be s3 path") 23 | parse.add_argument("-af", "--annotated_questions_file", type=str, help="dir of input file, can be s3 path", 24 | default=None) 25 | 26 | # In the test no output file will be produced, change -out to create an output 27 | args = parse.parse_args( 28 | ["-c", "Example", "-o", "build_datajob", "-config", config_file, "-wd", working_directory]) 29 | 30 | # loading the question template externally, to control which questions to produce: 31 | # with open(os.path.join('ExampleGeneration', 'question_templates', question_templates_file)) as f: 32 | # q_templates = process_question_templates(json.load(f)) 33 | # curr_question_gen_template = [t for t in q_templates if t['name'] == question_generator_name] 34 | # curr_question_gen_template[0]['enable'] = True 35 | 36 | datajob = GenSyntheticQuestionsFromTemplatesDataJob(config_entry, args) 37 | 38 | # reducing data size to a sample: 39 | datajob._config['n_processes'] = 1 40 | datajob._config['max_chunk_size'] = 100 41 | datajob._config['max_number_of_examples'] = 100 42 | datajob.input_path = "data/datajob_samples/classify_table_column_types_sample.jsonl" 43 | datajob.output_path = "data/datajob_samples/synthetic_questions.jsonl" 44 | 45 | datajob.run_datajob(args) 46 | 47 | dump_manual_analysis_facts('data/datajob_samples/synthetic_questions.jsonl', \ 48 | 'data/datajob_samples/synthetic_questions.csv') 49 | -------------------------------------------------------------------------------- /Training/ContinuousPreTraining/samplers/error_distribution_heterogeneous_sampler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import logging 3 | 4 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name 5 | logging.basicConfig(level=logging.INFO) 6 | 7 | 8 | class ErrorDistributionHeterogeneousSampler: 9 | """ 10 | indices for a uniform sample based on the iterable size between iterable 11 | """ 12 | 13 | def __init__(self, 14 | distribution_name, 15 | trainer_state, 16 | temperature=1 17 | ): 18 | """ 19 | init the dataloaders for each task 20 | get the sampling indices between the tasks 21 | """ 22 | self.distribution_name = distribution_name 23 | self.trainer_state = trainer_state 24 | self.temperature = float(temperature) 25 | 26 | def update_sampler_trainer_state(self, trainer_state): 27 | """ 28 | update trainer state after calculating errors 29 | """ 30 | self.trainer_state = trainer_state 31 | 32 | def sample(self, task_iterables_list): 33 | 34 | num_tasks = len(task_iterables_list) 35 | 36 | # try and retrieve the error distribution from the trainer state 37 | task_errors_name_for_state = f'{self.distribution_name}Errors' 38 | if hasattr(self.trainer_state, 39 | 'tasks_errors') and task_errors_name_for_state in self.trainer_state.tasks_errors: 40 | 41 | # get the last validation errors for error distribution sampling 42 | validation_errors = self.trainer_state.tasks_errors[task_errors_name_for_state] 43 | last_validation_errors_key = max(validation_errors.keys()) 44 | last_validation_errors = validation_errors[last_validation_errors_key] 45 | 46 | # iterate subtask and find error for each task 47 | errors = [] 48 | for sub_task in task_iterables_list.keys(): 49 | 50 | # look for each subtask in the error distribution 51 | if sub_task not in last_validation_errors: 52 | raise Exception(f"Subtask {sub_task} not found in error distribution") 53 | else: 54 | errors.append(last_validation_errors[sub_task]) 55 | 56 | # normalize errors and use temperature 57 | errors = np.power(errors, self.temperature) 58 | error_distribution = np.array(errors) / sum(errors) 59 | 60 | # else, init the uniform distribution 61 | else: 62 | error_distribution = [1 / num_tasks for i in range(num_tasks)] 63 | 64 | print('Error Distributions Heterogeneous:') 65 | print({k: error_distribution[i] 66 | for i, k in enumerate(task_iterables_list.keys())}) 67 | 68 | return error_distribution 69 | -------------------------------------------------------------------------------- /ExampleGeneration/wikitextparser/_parser_function.py: -------------------------------------------------------------------------------- 1 | """Define the ParserFunction class.""" 2 | from bisect import insort 3 | from typing import List 4 | 5 | import regex 6 | 7 | from ._wikitext import SubWikiText 8 | from ._argument import Argument 9 | from ._wikilist import WikiList 10 | 11 | 12 | BAR_SPLITS_FULLMATCH = regex.compile( 13 | rb'{{' 14 | rb'[^:|}]*+' # name 15 | rb'(?:[^|}]*+)?+(?\|[^|}]*+)*+' 16 | rb'}}' 17 | ).fullmatch 18 | 19 | 20 | class SubWikiTextWithArgs(SubWikiText): 21 | 22 | """Define common attributes for `Template` and `ParserFunction`.""" 23 | 24 | _args_matcher = NotImplemented 25 | _first_arg_sep = 0 26 | 27 | @property 28 | def arguments(self) -> List[Argument]: 29 | """Parse template content. Create self.name and self.arguments.""" 30 | shadow = self._shadow 31 | split_spans = self._args_matcher(shadow).spans('arg') 32 | if not split_spans: 33 | return [] 34 | arguments = [] 35 | arguments_append = arguments.append 36 | type_to_spans = self._type_to_spans 37 | ss, se = span = self._span 38 | type_ = id(span) 39 | lststr = self._lststr 40 | string = lststr[0] 41 | arg_spans = type_to_spans.setdefault(type_, []) 42 | span_tuple_to_span_get = {(s[0], s[1]): s for s in arg_spans}.get 43 | for arg_self_start, arg_self_end in split_spans: 44 | s, e = arg_span = [ss + arg_self_start, ss + arg_self_end] 45 | old_span = span_tuple_to_span_get((s, e)) 46 | if old_span is None: 47 | insort(arg_spans, arg_span) 48 | else: 49 | arg_span = old_span 50 | arg = Argument(lststr, type_to_spans, arg_span, type_) 51 | arg._shadow_cache = ( 52 | string[s:e], shadow[arg_self_start:arg_self_end]) 53 | arguments_append(arg) 54 | return arguments 55 | 56 | def lists(self, pattern: str = None) -> List[WikiList]: 57 | """Return the lists in all arguments. 58 | 59 | For performance reasons it is usually preferred to get a specific 60 | Argument and use the `lists` method of that argument instead. 61 | """ 62 | return [ 63 | lst for arg in self.arguments for lst in arg.lists(pattern) if lst] 64 | 65 | @property 66 | def name(self) -> str: 67 | """Return template's name (includes whitespace).""" 68 | h = self._atomic_partition(self._first_arg_sep)[0] 69 | if len(h) == len(self.string): 70 | return h[2:-2] 71 | return h[2:] 72 | 73 | @name.setter 74 | def name(self, newname: str) -> None: 75 | """Set the new name.""" 76 | self[2:2 + len(self.name)] = newname 77 | 78 | 79 | class ParserFunction(SubWikiTextWithArgs): 80 | 81 | """Create a new ParserFunction object.""" 82 | 83 | _args_matcher = BAR_SPLITS_FULLMATCH 84 | _first_arg_sep = 58 85 | -------------------------------------------------------------------------------- /ExampleGeneration/README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Setup 4 | 5 | ## Setting up a virtual environment 6 | 7 | 1. First, clone the repository: 8 | 9 | ``` 10 | git clone https://github.com/oriyor/turning_tables.git 11 | ``` 12 | 13 | 2. Change your directory to where you cloned the files: 14 | 15 | ``` 16 | cd ExampleGeneration 17 | export PYTHONPATH=${PYTHONPATH}:`pwd` 18 | ``` 19 | 20 | 3. Create a virtual environment with Python 3.6 or above: 21 | 22 | ``` 23 | virtualenv venv --python=python3.7 (or python3.7 -m venv venv or conda create -n turning python=3.7) 24 | ``` 25 | 26 | 4. Activate the virtual environment: 27 | ``` 28 | source venv/bin/activate (or source venv/bin/activate.csh or conda activate turning) 29 | ``` 30 | 5. Install the required dependencies: 31 | 32 | ``` 33 | pip install -r requirements.txt 34 | ``` 35 | 36 | ## Example Generation 37 | 38 | Our pre-precossed data is split to 35 chunks containing 20K tables each. Please note that running on all chunks can take time and speed-up is possible with multi-processing (~10 hours with 20 processes). The examples are generated using a pipeline of datajobs which are logical units that perform actions on pre-processed Wikipedia tables. The pipeline includes 3 datajobs: 39 | * ReasClassifyColumnTypes: classifies the table's columns to semantic types 40 | * GenQuestionsFromTemplates_TabReas: generates examples from the pre-processed tables 41 | * FormatSyntheticQuestions: post-processes the examples to pseudo-language question, context, answer triplets 42 | 43 | Please see `ExampleGeneration/configurations/config_reas.json` for the full configuration of each datajob including the number of processes and the path to the input data. To generate examples: 44 | 45 | 1. Choose start and end chunks between 0 and 34: 46 | ``` 47 | export Start_Chunk=0 48 | export End_Chunk=0 49 | ``` 50 | 51 | 3. Generate examples: 52 | 53 | python ExampleGeneration/run_multiple_chunks.py -config config_reas.json -dj ReasClassifyColumnTypes,GenQuestionsFromTemplates_TabReas,FormatSyntheticQuestions -wd data/data_chunks/ -sc ${Start_Chunk} -ec ${End_Chunk} 54 | 55 | ## Downloading generated reasoning examples 56 | 57 | Our generated examples are publicly available. To download the examples: 58 | 59 | 1. Clone the repository 60 | 61 | 2. Download the examples: 62 | ``` 63 | ./ExampleGeneration/bash_scripts/download_reasoning_examples.sh 64 | ``` 65 | 66 | ## Other 67 | 68 | A caching infra is used, so please make sure to have enough disk space and control the cache directory using `TURNINGTABLES_CACHE_ROOT` env variable. 69 | 70 | ## Parsing tables from a Wikipedia dump 71 | 72 | Our infra supports parsing a Wikipedia dump to tables based on [WikiExtractor](https://github.com/attardi/wikiextractor) and [WikiTextParser](https://github.com/5j9/wikitextparser). To parse a full Wikipedia dump, see `ExampleGeneration/ExampleGeneration/bash_scripts/parse_wiki_dump.sh`. 73 | -------------------------------------------------------------------------------- /Training/ContinuousPreTraining/configurations/mmqa_retrieval_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "experiment": { 3 | "experiment_name": "MMQA" 4 | }, 5 | "model": { 6 | "PReasM": false, 7 | "sampler": "Err", 8 | "size": "Base" 9 | }, 10 | "tokenizer": "t5-base", 11 | "datasets_sampler": {"type": "DatasetUniformSampler"}, 12 | "train_datasets": { 13 | "MMQA_train": { 14 | "reader": { 15 | "type": "UnifiedQaDataset", 16 | "pass_tokenizer": true, 17 | "path": "/ContinuousPreTraining/Data/mmqa/parsed_mmqa_train_oracle.json", 18 | "max_input_token_len": 1536, 19 | "max_output_token_len": 32, 20 | "generation_model": true 21 | }, 22 | "dataloader": { 23 | "single_task_sampler": "LengthGroupedSampler" 24 | } 25 | } 26 | }, 27 | "validation_datasets":{ 28 | "MMQA_eval": { 29 | "reader": { 30 | "type": "UnifiedQaDataset", 31 | "pass_tokenizer": true, 32 | "path": "/ContinuousPreTraining/Data/mmqa/parsed_mmqa_dev_retrieval.json", 33 | "max_input_token_len": 1536, 34 | "max_output_token_len": 32, 35 | "generation_model": true 36 | }, 37 | "dataloader": { 38 | }, 39 | "predictor": "ListGenerativePredictor", 40 | "eval_method": "DropListEval" 41 | }, 42 | "MMQA_eval_oracle": { 43 | "reader": { 44 | "type": "UnifiedQaDataset", 45 | "pass_tokenizer": true, 46 | "path": "/ContinuousPreTraining/Data/mmqa/parsed_mmqa_dev_oracle.json", 47 | "max_input_token_len": 1536, 48 | "max_output_token_len": 32, 49 | "generation_model": true 50 | }, 51 | "dataloader": { 52 | }, 53 | "predictor": "ListGenerativePredictor", 54 | "eval_method": "DropListEval" 55 | }, 56 | "MMQA_test": { 57 | "reader": { 58 | "type": "UnifiedQaDataset", 59 | "pass_tokenizer": true, 60 | "path": "/ContinuousPreTraining/Data/mmqa/parsed_mmqa_test_retrieval.json", 61 | "max_input_token_len": 1536, 62 | "max_output_token_len": 32, 63 | "generation_model": true 64 | }, 65 | "dataloader": { 66 | }, 67 | "predictor": "ListGenerativePredictor", 68 | "eval_method": "DropListEval" 69 | } 70 | }, 71 | "optimizer": { 72 | "type": "AdaFactor", 73 | "lr": 1e-4 74 | }, 75 | "scheduler": { 76 | "type": "linear_scheduler_with_warmup", 77 | "num_warmup_steps": 500, 78 | "num_training_steps": 2e32 79 | }, 80 | "training_arguments": { 81 | "num_train_epochs": 20, 82 | "per_device_train_batch_size": 3, 83 | "per_device_eval_batch_size": 4, 84 | "gradient_accumulation_steps": 6, 85 | "log_steps": 100, 86 | "eval_steps": 500, 87 | "save_steps": 100000, 88 | "evaluation_strategy": "epoch", 89 | "weight_decay": 0.01, 90 | "save_total_limit": 5, 91 | "seed": 40, 92 | "prediction_loss_only": true, 93 | "no_cuda": false 94 | }, 95 | "trainer": { 96 | "type": "UpdatedMtTrainer", 97 | "load_train_dataloader_after_eval": false, 98 | "callbacks": [] 99 | } 100 | } -------------------------------------------------------------------------------- /Training/ContinuousPreTraining/Training/trainers/basic_trainer.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Dict 2 | 3 | import os 4 | from datasets import Dataset 5 | from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR 6 | from transformers import Trainer 7 | from ContinuousPreTraining.Training.callback_factory import CallbackFactory 8 | from ContinuousPreTraining.Training.trainer_callbacks.basic_qa_callback_handler import BasicQaCallbackHandler 9 | 10 | 11 | class BasicTrainer(Trainer): 12 | """ 13 | trainer that can upload checkpoints to WandB 14 | """ 15 | 16 | def __init__(self, **kwargs): 17 | """ 18 | we add a new field that states whether we use in wandb 19 | """ 20 | self.use_wandb = kwargs['use_wandb'] 21 | del kwargs['use_wandb'] 22 | super().__init__(**kwargs) 23 | 24 | # add saving callbacks 25 | saving_callbacks = [] 26 | for saving_callback_name in []: 27 | saving_callbacks.append(CallbackFactory().get_callback(saving_callback_name)) 28 | 29 | self.qa_callback_handler = BasicQaCallbackHandler(list(set(kwargs['callbacks'] + saving_callbacks)), 30 | self.model, 31 | self.tokenizer, 32 | self.optimizer, 33 | self.lr_scheduler) 34 | 35 | 36 | def _save_checkpoint(self, model, trial, metrics=None): 37 | # call original save checkpoint 38 | super()._save_checkpoint(model, trial, metrics=metrics) 39 | self.qa_callback_handler.on_save(self.args, self.state, self.control, 40 | prefix_checkpoint=PREFIX_CHECKPOINT_DIR, 41 | use_wandb=self.use_wandb) 42 | 43 | def evaluate(self, eval_dataset: Optional[Dataset] = None) -> Dict[str, float]: 44 | """ 45 | Run evaluation and returns metrics. 46 | 47 | The calling script will be responsible for providing a method to compute metrics, as they are 48 | task-dependent (pass it to the init :obj:`compute_metrics` argument). 49 | 50 | You can also subclass and override this method to inject custom behavior. 51 | 52 | Args: 53 | eval_dataset (:obj:`Dataset`, `optional`): 54 | Pass a dataset if you wish to override :obj:`self.eval_dataset`. If it is an :obj:`datasets.Dataset`, 55 | columns not accepted by the ``model.forward()`` method are automatically removed. 56 | 57 | Returns: 58 | A dictionary containing the evaluation loss and the potential metrics computed from the predictions. 59 | """ 60 | super().evaluate(eval_dataset) 61 | 62 | # we also pass the tokenizer and eval dataset 63 | self.qa_callback_handler.on_evaluate(self.args, self.state, self.control, 64 | tokenizer=self.tokenizer, 65 | eval_dataset=self.eval_dataset, 66 | train_dataset=self.train_dataset, 67 | use_wandb=self.use_wandb) 68 | -------------------------------------------------------------------------------- /ExampleGeneration/ExampleGeneration/datajobs/format_questions.py: -------------------------------------------------------------------------------- 1 | import logging, json 2 | import random 3 | 4 | from ExampleGeneration.common.multiqa_format_wrapper import MultiQaModel 5 | from ExampleGeneration.common.multi_process_streaming import multi_process_data_stream 6 | from ExampleGeneration.datajob import DataJob 7 | 8 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name 9 | 10 | 11 | class DictToObject(object): 12 | 13 | def __init__(self, dictionary): 14 | self.__dict__.update(dictionary) 15 | 16 | def to_json(self): 17 | return dict(self.__dict__) 18 | 19 | class FormatQuestionsDataJob(DataJob): 20 | def __init__(self, datajob_name, args): 21 | self.datajob_name = datajob_name 22 | logger.info("loading...") 23 | super().__init__(args) 24 | self._args = args 25 | self.output_obj = 's3' in self.output_path 26 | 27 | def format_question(self, c, q): 28 | # a question is a triplet of question, answer and context 29 | # we will also add some metadata: context id, page url, question id 30 | page_title = c.context[0].title.strip() 31 | table_title = c.context[0].table.table_name.strip() 32 | page_url = c.context[0].url 33 | context_prefix = f'In {table_title} of {page_title}: ' 34 | context_content = q.facts + q.distractors 35 | random.shuffle(context_content) 36 | facts = '. '.join(context_content) 37 | context = context_prefix + facts 38 | return DictToObject({ 39 | 'qid': q.qid, 40 | 'question': q.question, 41 | 'phrase': q.question, 42 | 'context': context, 43 | 'answer': ', '.join([str(a) for a in q.answers]), 44 | 'question_metadata': q.metadata, 45 | 'type': q.metadata['type'], 46 | 'template': q.metadata['template'], 47 | 'context_id': c.id, 48 | 'url': page_url, 49 | 'page_title': page_title, 50 | 'table_title': table_title 51 | }) 52 | 53 | def process_chunk(self, contexts): 54 | 55 | random.seed(42) 56 | questions = [] 57 | 58 | contexts = [MultiQaModel.from_json(json.loads(c)) for c in contexts 59 | if c] 60 | 61 | # removing extra data from contexts : 62 | for context in contexts: 63 | 64 | # adding annotation fields to context: 65 | for q in context.qas: 66 | formatted_question = self.format_question(context, q) 67 | if self.output_obj: 68 | questions.append(json.dumps(formatted_question.to_json())) 69 | else: 70 | questions.append(formatted_question) 71 | return questions 72 | 73 | def run_datajob(self, args): 74 | multi_process_data_stream(self.input_path, self.output_path, 75 | apply_on_lines_chunk=self.process_chunk, n_processes=self._config["n_processes"], 76 | max_chunk_size=self._config["max_chunk_size"], 77 | max_lines_to_process=self._config.get("max_number_of_examples", None), 78 | copy_header_to_output=False, 79 | args=args) 80 | -------------------------------------------------------------------------------- /Training/ContinuousPreTraining/Common/transfomer_utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import transformers 3 | import os 4 | from ContinuousPreTraining.Common.file_utils import s3_get 5 | from transformers import RobertaTokenizer, BertTokenizer, AutoTokenizer, BartTokenizer, T5Tokenizer, T5ForConditionalGeneration 6 | 7 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name 8 | transformers.logging.set_verbosity_info() 9 | 10 | def get_tokenizer(tokenizer_name, lower_case=True): 11 | """ 12 | :param tokenizer_name: named identifier for the tokenizer 13 | :param lower_case: whether we want lower case tokenizer 14 | :return: tokenizer from hf 15 | """ 16 | if 't5' in tokenizer_name: 17 | return T5Tokenizer.from_pretrained(tokenizer_name) 18 | if tokenizer_name == 'roberta-base': 19 | return RobertaTokenizer.from_pretrained(tokenizer_name, do_lower_case=lower_case) 20 | if tokenizer_name == 'bert-base-uncased': 21 | return BertTokenizer.from_pretrained(tokenizer_name, do_lower_case=lower_case) 22 | if tokenizer_name == 'bert-large-uncased-whole-word-masking': 23 | return AutoTokenizer.from_pretrained(tokenizer_name) 24 | if tokenizer_name == 'bart-base': 25 | return BartTokenizer.from_pretrained('facebook/bart-base') 26 | return None 27 | 28 | 29 | def get_model(model_config, local_directory=None): 30 | """ 31 | :param model_name: named identifier for the model 32 | :param local_directory: whether to start the model from a local directory 33 | :return: model from hf 34 | """ 35 | if model_config['size'] == 'Base': 36 | model_name = 't5-base' 37 | else: 38 | if model_config['size'] == 'Large': 39 | model_name = 't5-large' 40 | else: 41 | assert False 42 | 43 | # check if we need to restore the model 44 | if model_config['PReasM']: 45 | sampler = model_config['sampler'] 46 | size = model_config['size'] 47 | 48 | # create a local dir for the model 49 | local_dir = f'CheckpointsRestored/PReasM-{sampler}-{size}/' 50 | if not os.path.exists(local_dir): 51 | os.makedirs(local_dir, exist_ok=True) 52 | 53 | # get model checkpoint files from s3 54 | s3_directory_url = f's3://tabreas/PReasM/PReasM-{sampler}-{size}/' 55 | for file_to_restore in ["config.json", "pytorch_model.bin", "optimizer.pt", 56 | "scheduler.pt", "trainer_state.json", "training_args.bin"]: 57 | 58 | # download a file from s3 59 | local_path = local_dir + file_to_restore 60 | s3_path = s3_directory_url + file_to_restore 61 | 62 | logger.info(f'Downloading {file_to_restore} to {local_path}') 63 | 64 | with open(local_path, "wb") as f: 65 | s3_get(s3_path, f) 66 | 67 | logger.info(f'Downloaded checkpoint to {local_dir}') 68 | 69 | # get the model 70 | if 't5' in model_name: 71 | if model_config['PReasM']: 72 | return T5ForConditionalGeneration.from_pretrained(local_dir, return_dict=True) 73 | else: 74 | logger.info(f'Getting model from huggingface') 75 | return T5ForConditionalGeneration.from_pretrained(model_name, return_dict=True) 76 | 77 | else: 78 | assert False 79 | -------------------------------------------------------------------------------- /ExampleGeneration/wikitextparser/_parameter.py: -------------------------------------------------------------------------------- 1 | """Define the Parameter class.""" 2 | 3 | 4 | from typing import Optional 5 | 6 | from ._wikitext import SubWikiText, WS 7 | 8 | 9 | class Parameter(SubWikiText): 10 | 11 | """Create a new {{{parameters}}} object.""" 12 | 13 | @property 14 | def name(self) -> str: 15 | """Return current parameter's name.""" 16 | name, pipe, default = self._atomic_partition(124) 17 | if pipe: 18 | return name[3:] 19 | return name[3:-3] 20 | 21 | @name.setter 22 | def name(self, newname: str) -> None: 23 | """Set the new name.""" 24 | self[3:3 + len(self.name)] = newname 25 | 26 | @property 27 | def pipe(self) -> str: 28 | """Return `|` if there is a pipe (default value) in the Parameter. 29 | 30 | Return '' otherwise. 31 | """ 32 | return self._atomic_partition(124)[1] 33 | 34 | @property 35 | def default(self) -> Optional[str]: 36 | """Return the default value. Return None if there is no default.""" 37 | name, pipe, default = self._atomic_partition(124) 38 | if pipe: 39 | return default[:-3] 40 | return None 41 | 42 | @default.setter 43 | def default(self, newdefault: Optional[str]) -> None: 44 | """Set a new default value. Use None to remove default.""" 45 | name, pipe, default = self._atomic_partition(124) 46 | if not pipe: 47 | # olddefault is None 48 | if newdefault is None: 49 | return 50 | self.insert(-3, '|' + newdefault) 51 | return 52 | if newdefault is None: 53 | # Only newdefault is None 54 | del self[len(name):-3] 55 | return 56 | # olddefault is not None and newdefault is not None 57 | self[len(name):-3] = '|' + newdefault 58 | 59 | def append_default(self, new_default_name: str) -> None: 60 | """Append a new default parameter in the appropriate place. 61 | 62 | Add the new default to the innter-most parameter. 63 | If the parameter already exists among defaults, don't change anything. 64 | 65 | Example: 66 | >>> p = Parameter('{{{p1|{{{p2|}}}}}}') 67 | >>> p.append_default('p3') 68 | >>> p 69 | Parameter("'{{{p1|{{{p2|{{{p3|}}}}}}}}}'") 70 | """ 71 | stripped_default_name = new_default_name.strip(WS) 72 | if stripped_default_name == self.name.strip(WS): 73 | return 74 | dig = True 75 | innermost_param = self 76 | while dig: 77 | dig = False 78 | default = innermost_param.default 79 | for p in innermost_param.parameters: 80 | if p.string == default: 81 | if stripped_default_name == p.name.strip(WS): 82 | return 83 | innermost_param = p 84 | dig = True 85 | innermost_default = innermost_param.default 86 | if innermost_default is None: 87 | innermost_param.insert(-3, '|{{{' + new_default_name + '}}}') 88 | else: 89 | name = innermost_param.name 90 | innermost_param[ 91 | len('{{{' + name + '|'): 92 | len('{{{' + name + '|' + innermost_default) 93 | ] = '{{{' + new_default_name + '|' + innermost_default + '}}}' 94 | -------------------------------------------------------------------------------- /Training/ContinuousPreTraining/samplers/adaptive_error_heterogeneous_sampler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import logging 3 | 4 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name 5 | logging.basicConfig(level=logging.INFO) 6 | 7 | 8 | class AdaptiveErrorHeterogeneousSampler: 9 | """ 10 | indices for a uniform sample based on the iterable size between iterable 11 | """ 12 | 13 | def __init__(self, 14 | trainer_state, 15 | is_adaptive, 16 | normalize_with_prev, 17 | distribution_name 18 | ): 19 | """ 20 | init the dataloaders for each task 21 | get the sampling indices between the tasks 22 | """ 23 | self.is_adaptive = is_adaptive 24 | self.trainer_state = trainer_state 25 | self.last_adaptive_weights = None 26 | self.normalize_with_prev = normalize_with_prev 27 | self.distribution_name = distribution_name 28 | 29 | def update_sampler_trainer_state(self, trainer_state): 30 | """ 31 | update trainer state after calculating errors 32 | """ 33 | self.trainer_state = trainer_state 34 | 35 | def sample(self, task_iterables_list): 36 | 37 | num_tasks = len(task_iterables_list) 38 | 39 | #if self.last_adaptive_weights is None: 40 | # self.last_adaptive_weights = [1 / num_tasks for i in range(num_tasks)] 41 | 42 | task_errors_name_for_state = f'{self.distribution_name}Errors' 43 | 44 | # try and retrieve the error distribution from the trainer state 45 | if hasattr(self.trainer_state, 'tasks_errors') and \ 46 | len(list(self.trainer_state.tasks_errors.values())[0]) > 1 and self.is_adaptive: 47 | 48 | # reverse the dict 49 | errors_dict = {} 50 | for step_num, tasks_name_values in self.trainer_state.tasks_errors[task_errors_name_for_state].items(): 51 | for task_name, task_error_value in tasks_name_values.items(): 52 | if task_name not in errors_dict: 53 | errors_dict[task_name] = {} 54 | errors_dict[task_name][step_num] = task_error_value 55 | 56 | # iterate subtask and find error for each task 57 | delta_errors = [] 58 | for i, sub_task in enumerate(task_iterables_list.keys()): 59 | 60 | # look for each subtask in the error distribution 61 | if sub_task not in errors_dict: 62 | raise Exception(f"Subtask {sub_task} not found in error distribution") 63 | else: 64 | error_list = list(errors_dict[sub_task].values()) 65 | 66 | #if self.normalize_with_prev: 67 | # task_weight = max(0.01, error_list[-1] - error_list[-2]) / self.last_adaptive_weights[i] 68 | #else: 69 | window = error_list[-min(len(error_list), 4):] 70 | window_gains = ((window[-1] + window[-2]) - (window[1] + window[0])) / len(window) 71 | task_weight = max(0.002, abs(window_gains)) 72 | delta_errors.append(task_weight) 73 | 74 | # normalize errors 75 | adaptive_weights = np.array(delta_errors) / sum(delta_errors) 76 | 77 | # else, init the uniform distribution 78 | else: 79 | adaptive_weights = [1 / num_tasks for i in range(num_tasks)] 80 | 81 | print('Error Distributions:') 82 | print({k: adaptive_weights[i] 83 | for i, k in enumerate(task_iterables_list.keys())}) 84 | return adaptive_weights 85 | -------------------------------------------------------------------------------- /Training/ContinuousPreTraining/scripts/preprocess_drop.py: -------------------------------------------------------------------------------- 1 | import json 2 | import gzip 3 | 4 | 5 | def main(): 6 | # read files 7 | with open('ContinuousPreTraining/Data/drop/drop_dataset_dev.json') as json_file: 8 | dev = json.load(json_file) 9 | 10 | with open('ContinuousPreTraining/Data/drop/drop_dataset_train.json') as json_file: 11 | train = json.load(json_file) 12 | 13 | def parse_drop_answer(answer, train_mode=True): 14 | """ 15 | parse drop answer 16 | """ 17 | if len(answer['spans']) == 1: 18 | # print('span') 19 | return ' '.join(answer['spans']) 20 | 21 | if len(answer['spans']) > 1: 22 | if train_mode: 23 | # print('spans') 24 | 25 | return '#'.join(answer['spans']) 26 | else: 27 | return answer['spans'] 28 | 29 | elif len(answer['number']) > 0: 30 | # print('number') 31 | return answer['number'] 32 | 33 | else: 34 | # print('date') 35 | day = answer['date']['day'] 36 | if day: 37 | day += ' ' 38 | 39 | month = answer['date']['month'] 40 | if month: 41 | month += ' ' 42 | 43 | year = answer['date']['year'] 44 | 45 | return f'{day}{month}{year}' 46 | 47 | # go over all ds splits 48 | parsed_train_questions = [] 49 | parsed_dev_questions = [] 50 | 51 | for context in train.values(): 52 | 53 | context_passage = context['passage'] 54 | 55 | for q in context['qa_pairs']: 56 | context_with_question = q['question'] + '\n' + context_passage 57 | parsed_train_questions.append({'context': context_with_question, 58 | 'answer': parse_drop_answer(q['answer']), 59 | 'all_answers': parse_drop_answer(q['answer']), 60 | 'id': q['query_id']}) 61 | 62 | for context in dev.values(): 63 | 64 | context_passage = context['passage'] 65 | 66 | for q in context['qa_pairs']: 67 | context_with_question = q['question'] + '\n' + context_passage 68 | answer = parse_drop_answer(q['answer']) 69 | validated_answers = [parse_drop_answer(annotated_answer, train_mode=False) 70 | for annotated_answer in q['validated_answers']] + [ 71 | parse_drop_answer(q['answer'], train_mode=False)] 72 | parsed_dev_questions.append({'context': context_with_question, 73 | 'answer': answer, 74 | 'all_answers': validated_answers, 75 | 'id': q['query_id']}) 76 | 77 | # write train questions 78 | train_output_file = 'ContinuousPreTraining/Data/drop/parsed_drop_train_with_lists.json' 79 | output_fp = gzip.open(train_output_file, 'wb') 80 | for question in parsed_train_questions: 81 | output_fp.write((json.dumps(question, ensure_ascii=False) + '\n').encode('utf-8')) 82 | 83 | # write dev questions 84 | dev_output_file = 'ContinuousPreTraining/Data/drop/parsed_drop_dev_with_lists.json' 85 | output_fp = gzip.open(dev_output_file, 'wb') 86 | for question in parsed_dev_questions: 87 | output_fp.write((json.dumps(question, ensure_ascii=False) + '\n').encode('utf-8')) 88 | 89 | print('Finished pre-processing drop') 90 | 91 | if __name__ == '__main__': 92 | """ 93 | Script to preprocess the drop dataset to UnifiedQA format 94 | """ 95 | main() 96 | -------------------------------------------------------------------------------- /ExampleGeneration/ExampleGeneration/bash_scripts/download_reasoning_examples.sh: -------------------------------------------------------------------------------- 1 | echo "Downloading reasoning examples to Data/reasoning_examples ..." 2 | echo "Downloading train reasoning examples to Data/reasoning_examples/train ..." 3 | 4 | mkdir -p reasoning_examples 5 | cd reasoning_examples 6 | 7 | mkdir -p train 8 | cd train 9 | 10 | wget https://tabreas.s3.us-west-2.amazonaws.com/generated_reasoning_examples_train/arithmetic_addition.gz 11 | wget https://tabreas.s3.us-west-2.amazonaws.com/generated_reasoning_examples_train/arithmetic_superlatives.gz 12 | wget https://tabreas.s3.us-west-2.amazonaws.com/generated_reasoning_examples_train/composition_2_hop.gz 13 | wget https://tabreas.s3.us-west-2.amazonaws.com/generated_reasoning_examples_train/composition.gz 14 | wget https://tabreas.s3.us-west-2.amazonaws.com/generated_reasoning_examples_train/conjunction.gz 15 | wget https://tabreas.s3.us-west-2.amazonaws.com/generated_reasoning_examples_train/counting.gz 16 | wget https://tabreas.s3.us-west-2.amazonaws.com/generated_reasoning_examples_train/every_quantifier.gz 17 | wget https://tabreas.s3.us-west-2.amazonaws.com/generated_reasoning_examples_train/most_quantifier.gz 18 | wget https://tabreas.s3.us-west-2.amazonaws.com/generated_reasoning_examples_train/numeric comparison.gz 19 | wget https://tabreas.s3.us-west-2.amazonaws.com/generated_reasoning_examples_train/numeric_comparison_boolean.gz 20 | wget https://tabreas.s3.us-west-2.amazonaws.com/generated_reasoning_examples_train/numeric_superlatives.gz 21 | wget https://tabreas.s3.us-west-2.amazonaws.com/generated_reasoning_examples_train/only_quantifier.gz 22 | wget https://tabreas.s3.us-west-2.amazonaws.com/generated_reasoning_examples_train/temporal_comparison_boolean.gz 23 | wget https://tabreas.s3.us-west-2.amazonaws.com/generated_reasoning_examples_train/temporal_comparison.gz 24 | wget https://tabreas.s3.us-west-2.amazonaws.com/generated_reasoning_examples_train/temporal_difference.gz 25 | wget https://tabreas.s3.us-west-2.amazonaws.com/generated_reasoning_examples_train/temporal_superlatives.gz 26 | 27 | cd .. 28 | 29 | echo "Downloading train reasoning examples to Data/reasoning_examples/dev ..." 30 | mkdir -p dev 31 | cd dev 32 | 33 | wget https://tabreas.s3.us-west-2.amazonaws.com/generated_reasoning_examples_dev/arithmetic_addition.gz 34 | wget https://tabreas.s3.us-west-2.amazonaws.com/generated_reasoning_examples_dev/arithmetic_superlatives.gz 35 | wget https://tabreas.s3.us-west-2.amazonaws.com/generated_reasoning_examples_dev/composition_2_hop.gz 36 | wget https://tabreas.s3.us-west-2.amazonaws.com/generated_reasoning_examples_dev/composition.gz 37 | wget https://tabreas.s3.us-west-2.amazonaws.com/generated_reasoning_examples_dev/conjunction.gz 38 | wget https://tabreas.s3.us-west-2.amazonaws.com/generated_reasoning_examples_dev/counting.gz 39 | wget https://tabreas.s3.us-west-2.amazonaws.com/generated_reasoning_examples_dev/every_quantifier.gz 40 | wget https://tabreas.s3.us-west-2.amazonaws.com/generated_reasoning_examples_dev/most_quantifier.gz 41 | wget https://tabreas.s3.us-west-2.amazonaws.com/generated_reasoning_examples_dev/numeric comparison.gz 42 | wget https://tabreas.s3.us-west-2.amazonaws.com/generated_reasoning_examples_dev/numeric_comparison_boolean.gz 43 | wget https://tabreas.s3.us-west-2.amazonaws.com/generated_reasoning_examples_dev/numeric_superlatives.gz 44 | wget https://tabreas.s3.us-west-2.amazonaws.com/generated_reasoning_examples_dev/only_quantifier.gz 45 | wget https://tabreas.s3.us-west-2.amazonaws.com/generated_reasoning_examples_dev/temporal_comparison_boolean.gz 46 | wget https://tabreas.s3.us-west-2.amazonaws.com/generated_reasoning_examples_dev/temporal_comparison.gz 47 | wget https://tabreas.s3.us-west-2.amazonaws.com/generated_reasoning_examples_dev/temporal_difference.gz 48 | wget https://tabreas.s3.us-west-2.amazonaws.com/generated_reasoning_examples_dev/temporal_superlatives.gz 49 | cd ../.. -------------------------------------------------------------------------------- /Training/ContinuousPreTraining/Training/trainer_callbacks/multi_task_heterogeneous_callback.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from ContinuousPreTraining.Training.trainer_callbacks.basic_qa_callback import BasicQaCallback 3 | 4 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name 5 | 6 | 7 | class MultiTaskHeterogeneousCallback(BasicQaCallback): 8 | """ 9 | A :class:`~transformers.TrainerCallback` for handling multi task events 10 | """ 11 | 12 | def __init__(self): 13 | self._initialized = False 14 | 15 | def on_train_begin(self, args, state, control, model=None, **kwargs): 16 | # init multi task specific fields 17 | state.task_counter = {} 18 | state.tasks_indices = {} 19 | state.tasks_errors = {} 20 | 21 | control.restart_train_dataloader = False 22 | control.report_task_counter = False 23 | 24 | 25 | def on_batch_begin(self, args, state, control, model=None, **kwargs): 26 | """ 27 | for every step, update the counter for the current step 28 | """ 29 | if 'batch_inputs' in kwargs: 30 | batch_inputs = kwargs['batch_inputs'] 31 | 32 | if 'batches' not in state.task_counter: 33 | state.task_counter['batches'] = {'without_subtask': 0, 34 | 'with_subtask': 0} 35 | 36 | # update num of batches withput subtask names 37 | if type(batch_inputs['task_name']) == str: 38 | # update the counter for the task 39 | state.task_counter['batches']['without_subtask'] += 1 40 | train_task_name = batch_inputs['task_name'] 41 | del batch_inputs['task_name'] 42 | 43 | if train_task_name not in state.task_counter: 44 | state.task_counter[train_task_name] = {} 45 | state.task_counter[train_task_name]['Batches_total'] = 0 46 | state.task_counter[train_task_name]['Batches_since_resampling'] = 0 47 | state.task_counter[train_task_name]['Examples'] = 0 48 | 49 | # update the batches and examples counter 50 | state.task_counter[train_task_name]['Batches_total'] += 1 51 | state.task_counter[train_task_name]['Batches_since_resampling'] += 1 52 | state.task_counter[train_task_name]['Examples'] += len(batch_inputs['input_ids']) 53 | 54 | # else update the number of examples for each sub_task 55 | else: 56 | 57 | state.task_counter['batches']['with_subtask'] += 1 58 | # get batch task name and delete 59 | train_task_name = batch_inputs['task_name'] 60 | del batch_inputs['task_name'] 61 | 62 | # update the index dict with the tasks name 63 | state.tasks_indices[state.global_step] = train_task_name 64 | 65 | # update the counter for the task 66 | for task_name, task_example_count in dict(train_task_name).items(): 67 | if task_name not in state.task_counter: 68 | state.task_counter[task_name] = {} 69 | state.task_counter[task_name]['Examples_total'] = 0 70 | state.task_counter[task_name]['Examples_since_resampling'] = 0 71 | # update the batches and examples counter 72 | state.task_counter[task_name]['Examples_total'] += task_example_count 73 | state.task_counter[task_name]['Examples_since_resampling'] += task_example_count 74 | 75 | def on_step_end(self, args, state, control, model=None, **kwargs): 76 | """ 77 | after every step, check if we need to report the task counter 78 | """ 79 | 80 | if control.should_evaluate: 81 | control.report_task_counter = True 82 | 83 | -------------------------------------------------------------------------------- /ExampleGeneration/ExampleGeneration/common/questions_utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | from ExampleGeneration.common.multiqa_format_wrapper import Question, Answer 4 | 5 | def sample_questions_per_template(questions, sample_size): 6 | """ 7 | :param questions: 8 | :param sample_size: 9 | :return: for each question template sample take sample_size 10 | note we will restart the random seed for each template, so to enforce the order to be 11 | reproducible. 12 | """ 13 | 14 | template_list = [q.metadata['template_variation'] for q in questions] 15 | sampled_questions = [] 16 | for template_name in set(template_list): 17 | curr_template_questions = [questions[i] for i,t in enumerate(template_list) if t == template_name] 18 | if len(curr_template_questions) < sample_size: 19 | sampled_questions += curr_template_questions 20 | else: 21 | # we will restart the random seed for each template, so to enforce the order to be 22 | # reproducible. The shuffle insures that if we decide to sample more from each context the first 23 | # K we already sampled remain the same 24 | random.seed(7) 25 | random.shuffle(curr_template_questions) 26 | sampled_questions += curr_template_questions[0:sample_size] 27 | 28 | return sampled_questions 29 | 30 | def sample_questions(questions, sample_size): 31 | """ 32 | :param questions: 33 | :param sample_size: 34 | :return: a random sample_size of items from questions 35 | """ 36 | random.seed(3) 37 | if len(questions) < sample_size: 38 | return questions 39 | return random.sample(questions, sample_size) 40 | 41 | def get_composite_question(first_comp_question, second_comp_question): 42 | """ 43 | :param first_comp_question: 44 | :param second_comp_question: 45 | :return: the composition question returned by injection of the first question to the second 46 | """ 47 | first_comp_answer = first_comp_question.answers.answers[0] 48 | question_text = second_comp_question.question.replace(first_comp_answer, 49 | f'({first_comp_question.question} : {first_comp_answer})') 50 | return Question(qid=f'Comp-{first_comp_question.qid}-{second_comp_question.qid}', 51 | question=question_text, 52 | answers=second_comp_question.answers, 53 | metadata={ 54 | 'type': f'composition-{first_comp_question.metadata["type"]}-{second_comp_question.metadata["type"]}', 55 | 'source': 'generated-composition', "schema": "simple-injection-composition", 56 | 'link_answer': first_comp_answer}, 57 | supporting_context=first_comp_question.supporting_context + second_comp_question.supporting_context) 58 | 59 | 60 | def get_conjunction_question(q1, q2, intersecting_answers): 61 | """ 62 | :param q1: 63 | :param q2: 64 | :param intersecting_answers: 65 | :return: the conjunction question returned by combining the two questions 66 | """ 67 | if random.choice([True, False]): 68 | question_text = f'({q1.question}) and ({q2.question})' 69 | else: 70 | question_text = f'({q2.question}) and ({q1.question})' 71 | qid = f'Conj-{q1.qid}-{q2.qid}' 72 | return Question(qid=qid, 73 | question=question_text, 74 | answers=Answer(list(intersecting_answers)), 75 | metadata={ 76 | 'type': f'conj-{q1.metadata["type"]}-{q2.metadata["type"]}', 77 | 'source': 'generated-conjunction', "schema": "conj1", 78 | 'q1_answers': q1.answers.answers, 79 | 'q2_answers': q2.answers.answers}, 80 | supporting_context=q1.supporting_context + q2.supporting_context) 81 | -------------------------------------------------------------------------------- /ExampleGeneration/ExampleGeneration/datajob_factory.py: -------------------------------------------------------------------------------- 1 | import os, json 2 | import sys 3 | from setuptools import find_packages 4 | from pkgutil import iter_modules 5 | 6 | class DataJobFactory: 7 | def __init__(self): 8 | pass 9 | 10 | def upper_to_lower_notation_name(self, datajob_name): 11 | return ''.join(['_' + c.lower() if c.isupper() else c for c in datajob_name ])[1:] 12 | 13 | def find_datajob(self, path, callange_to_find): 14 | modules = list() 15 | for pkg in [''] + find_packages(path): 16 | pkgpath = path + '/' + pkg.replace('.', '/') 17 | if sys.version_info.major == 2 or (sys.version_info.major == 3 and sys.version_info.minor < 6): 18 | for _, name, ispkg in iter_modules([pkgpath]): 19 | if not ispkg: 20 | modules.append(pkg + '.' + name) 21 | else: 22 | for info in iter_modules([pkgpath]): 23 | if not info.ispkg: 24 | modules.append(pkg + '.' + info.name) 25 | 26 | found_datajob = [module for module in modules if module.find(callange_to_find) > -1] 27 | if len(found_datajob) > 0: 28 | found_datajob = found_datajob[0] 29 | if found_datajob.startswith('.'): 30 | found_datajob = found_datajob[1:] 31 | else: 32 | found_datajob = None 33 | 34 | return found_datajob 35 | 36 | def get_datajob(self, datajob_name, datajob_type, args): 37 | datajob_type_lower = self.upper_to_lower_notation_name(datajob_type) 38 | module_name = self.find_datajob(os.path.dirname(os.path.abspath(__file__)) + '/datajobs', datajob_type_lower) 39 | try: 40 | mod = __import__('datajobs.' + module_name, fromlist=[datajob_type]) 41 | except: 42 | assert (ValueError('datajob_name not found!')) 43 | 44 | return getattr(mod, datajob_type + 'DataJob')(datajob_name, args) 45 | 46 | def create_new_datajob(self, datajob_name, copy_from, args): 47 | os.chdir(os.path.dirname(os.path.abspath(__file__)) + '/datajobs') 48 | copy_from_lower = self.upper_to_lower_notation_name(copy_from) 49 | datajob_name_lower = self.upper_to_lower_notation_name(datajob_name) 50 | copy_from_module = self.find_datajob(os.getcwd(), copy_from_lower) 51 | if copy_from_module is None: 52 | assert (ValueError('copy_from datajob not found!')) 53 | copy_from_path = copy_from_module.replace('.',os.sep) + '.py' 54 | 55 | with open(copy_from_path,'r') as f: 56 | copied_datajob_txt = f.read() 57 | copied_datajob_txt = copied_datajob_txt.replace(copy_from, datajob_name) 58 | 59 | new_datajob_path = datajob_name_lower + '.py' 60 | with open(new_datajob_path, 'w') as f: 61 | f.write(copied_datajob_txt) 62 | 63 | # duplicating the test 64 | os.chdir('../../tests/datajobs') 65 | with open(copy_from_path.replace('.py','_test.py'),'r') as f: 66 | copied_datajob_txt = f.read() 67 | copied_datajob_txt = copied_datajob_txt.replace('datajobs.' + copy_from_module, \ 68 | 'datajobs.' + datajob_name_lower) 69 | copied_datajob_txt = copied_datajob_txt.replace(copy_from, datajob_name) 70 | copied_datajob_txt = copied_datajob_txt.replace(copy_from_lower, datajob_name_lower) 71 | 72 | new_datajob_path = datajob_name_lower + '_test.py' 73 | with open(new_datajob_path, 'w') as f: 74 | f.write(copied_datajob_txt) 75 | 76 | # adding to config file: 77 | with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), "configurations", 78 | args.config_file_name) , 'r') as f: 79 | config = json.load(f) 80 | config[datajob_name] = config[copy_from] 81 | with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), "configurations", 82 | args.config_file_name) , 'w') as f: 83 | json.dump(config, f ,indent=4) 84 | -------------------------------------------------------------------------------- /Training/ContinuousPreTraining/scripts/preprocess_mmqa_for_question_classification.py: -------------------------------------------------------------------------------- 1 | import json 2 | import gzip 3 | from datasets import tqdm 4 | 5 | 6 | def main(): 7 | 8 | def parse_image_question_classifier(q): 9 | """ 10 | parse a question to tell if it's an image question 11 | """ 12 | question = q['question'] 13 | id = q['qid'] 14 | 15 | # if this is a test question, we don't no the answer 16 | if ds_split['test']: 17 | ds_array.append({'context': question, 18 | 'answer': "", 19 | 'all_answers': [""], 20 | 'id': id}) 21 | else: 22 | # check if the question has an image modality 23 | question_modalities = q['metadata']['modalities'] 24 | if 'image' in question_modalities: 25 | answer = 'yes' 26 | else: 27 | answer = 'no' 28 | 29 | ds_array.append({'context': question, 30 | 'answer': answer, 31 | 'all_answers': [answer], 32 | 'id': id}) 33 | 34 | train = [] 35 | dev = [] 36 | test = [] 37 | tables = [] 38 | texts = [] 39 | 40 | f = gzip.open("ContinuousPreTraining/Data/mmqa/MMQA_train.jsonl.gz?raw=true", 'r') 41 | for l in f: 42 | train.append(json.loads(l)) 43 | 44 | f = gzip.open("ContinuousPreTraining/Data/mmqa/MMQA_dev.jsonl.gz?raw=true", 'r') 45 | for l in f: 46 | dev.append(json.loads(l)) 47 | 48 | f = gzip.open("ContinuousPreTraining/Data/mmqa/MMQA_test.jsonl.gz?raw=true", 'r') 49 | for l in f: 50 | test.append(json.loads(l)) 51 | 52 | f = gzip.open("ContinuousPreTraining/Data/mmqa/MMQA_tables.jsonl.gz?raw=true", 'r') 53 | for l in f: 54 | tables.append(json.loads(l)) 55 | 56 | f = gzip.open("ContinuousPreTraining/Data/mmqa/MMQA_texts.jsonl.gz?raw=true", 'r') 57 | for l in f: 58 | texts.append(json.loads(l)) 59 | 60 | # go over all ds splits 61 | parsed_train_questions = [] 62 | parsed_dev_questions = [] 63 | parsed_test_questions = [] 64 | 65 | train = {'lines': train, 'array': parsed_train_questions, 'balanced_sampling': True, 'test': False, 66 | 'add_table': False} 67 | dev = {'lines': dev, 'array': parsed_dev_questions, 'balanced_sampling': False, 'test': False, 'add_table': False} 68 | test = {'lines': test, 'array': parsed_test_questions, 'balanced_sampling': False, 'test': True, 'add_table': False} 69 | 70 | print(f'Preprocessing MMQA for image question classification') 71 | for ds_split in [train, dev, test]: 72 | 73 | ds_lines = ds_split['lines'] 74 | ds_array = ds_split['array'] 75 | 76 | for q in tqdm(ds_lines): 77 | 78 | parse_image_question_classifier(q) 79 | 80 | # write train questions 81 | train_output_file = 'ContinuousPreTraining/Data/mmqa/parsed_mmqa_question_classifier_train.json' 82 | output_fp = gzip.open(train_output_file, 'wb') 83 | for question in parsed_train_questions: 84 | output_fp.write((json.dumps(question, ensure_ascii=False) + '\n').encode('utf-8')) 85 | 86 | # write dev questions 87 | dev_output_file = 'ContinuousPreTraining/Data/mmqa/parsed_mmqa_question_classifier_dev.json' 88 | output_fp = gzip.open(dev_output_file, 'wb') 89 | for question in parsed_dev_questions: 90 | output_fp.write((json.dumps(question, ensure_ascii=False) + '\n').encode('utf-8')) 91 | 92 | # write test questions 93 | dev_output_file = 'ContinuousPreTraining/Data/mmqa/parsed_mmqa_question_classifier_test.json' 94 | output_fp = gzip.open(dev_output_file, 'wb') 95 | for question in parsed_test_questions: 96 | output_fp.write((json.dumps(question, ensure_ascii=False) + '\n').encode('utf-8')) 97 | 98 | print('Finished pre-processing MMQA for question classification') 99 | 100 | 101 | if __name__ == '__main__': 102 | """ 103 | Script to preprocess MMQA for image question classification 104 | """ 105 | main() 106 | -------------------------------------------------------------------------------- /Training/ContinuousPreTraining/Data/dataset_readers/synthetic_questions_multi_datasets.py: -------------------------------------------------------------------------------- 1 | # Creating a custom dataset for reading the dataframe and loading it into the dataloader to pass it to the neural network at a later stage for finetuning the model and to prepare it for predictions 2 | import gzip 3 | import json 4 | import torch 5 | import pandas as pd 6 | import random 7 | from datasets import tqdm 8 | 9 | 10 | class SyntheticQuestionsMultiDatasets(torch.utils.data.Dataset): 11 | """ 12 | Dataset reader for Synthetic Quetions IIRC 13 | """ 14 | 15 | def __init__(self, 16 | data_path, 17 | tokenizer, 18 | max_input_token_len, 19 | generation_model=False, 20 | num_examples_to_load=None, 21 | max_output_token_len=None): 22 | 23 | # load lot data from gzip 24 | self.generation_model = generation_model 25 | 26 | random.seed(42) 27 | examples = [] 28 | with gzip.open(data_path, "r") as f: 29 | for i, l in enumerate(tqdm(f)): 30 | examples.append(json.loads(l)) 31 | # 32 | # if i == 100: 33 | # break 34 | 35 | # sample num_examples_to_load examples 36 | if num_examples_to_load is not None: 37 | num_examples_to_load = min(len(examples), num_examples_to_load) 38 | examples = random.sample(examples, num_examples_to_load) 39 | 40 | 41 | self.data = pd.DataFrame([[example['qid'], 42 | example['question'], 43 | example['context'], 44 | example['answer'], 45 | example['template']] 46 | for example in examples], columns=['qids', 'phrases', 'contexts', 'gold', 'type']) 47 | 48 | self.lengths = [len(self.data.contexts[k]) + len(self.data.phrases[k]) 49 | for k in range(len(self.data))] 50 | 51 | self.qids = self.data.qids 52 | self.phrases = self.data.phrases 53 | self.contexts = self.data.contexts 54 | self.gold = self.data.gold 55 | self.types = self.data.type 56 | 57 | self.tokenizer = tokenizer 58 | self.max_input_token_len = max_input_token_len 59 | self.max_output_token_len = max_output_token_len 60 | 61 | def __len__(self): 62 | return len(self.phrases) 63 | 64 | def __getitem__(self, index): 65 | 66 | phrase = str(self.phrases[index]) 67 | context = str(self.contexts[index]) 68 | type = str(self.types[index]) 69 | qid = str(self.qids[index]) 70 | 71 | source_text = phrase + '\n' + context 72 | source_text = 'QA: ' + source_text 73 | 74 | gold_text = self.gold[index] 75 | 76 | labels = self.tokenizer.batch_encode_plus([gold_text], max_length=self.max_output_token_len, 77 | truncation=True, padding='max_length', 78 | return_tensors='pt').input_ids.squeeze() \ 79 | .to(dtype=torch.long) 80 | labels[labels == 0] = -100 81 | tokenized_inputs = self.tokenizer.encode_plus(text=source_text, 82 | add_special_tokens=True, 83 | max_length=self.max_input_token_len, 84 | pad_to_max_length=False, 85 | return_token_type_ids=False, 86 | return_attention_mask=True, 87 | return_overflowing_tokens=False, 88 | return_special_tokens_mask=False, 89 | ) 90 | return { 91 | 'input_ids': tokenized_inputs.input_ids, 92 | 'attention_mask': tokenized_inputs.attention_mask, 93 | 'labels': labels, 94 | 'answer_type': type, 95 | 'id': qid, 96 | 'question': phrase, 97 | 'context': context, 98 | 'answer': gold_text 99 | } 100 | -------------------------------------------------------------------------------- /Training/ContinuousPreTraining/Data/dataset_readers/iirc_dataset.py: -------------------------------------------------------------------------------- 1 | # Creating a custom dataset for reading the dataframe and loading it into the dataloader to pass it to the neural network at a later stage for finetuning the model and to prepare it for predictions 2 | import json 3 | import torch 4 | import pandas as pd 5 | 6 | 7 | class IircDataset(torch.utils.data.Dataset): 8 | """ 9 | Dataset reader for IIRC 10 | """ 11 | 12 | def __init__(self, data_path, tokenizer, max_seq_len, 13 | prefix=None, generation_model=False, summ_len=None): 14 | 15 | # load lot data from gzip 16 | self.generation_model = generation_model 17 | self.prefix = 'QA: ' if prefix is None else prefix 18 | 19 | with open(data_path, "r") as f: 20 | data = json.load(f) 21 | 22 | examples = [] 23 | question_index = 0 24 | for context in data: 25 | for q in context['questions']: 26 | example = {'qid': q['qid'] if 'qid' in q else str(question_index), 27 | 'phrase': q['question'], 28 | 'context': self.preprocess_iirc_context(q['context']), 29 | 'answer': self.preprocess_iirc_answer(q['answer']), 30 | 'answer_type': q['answer']['type']} 31 | examples.append(example) 32 | question_index += 1 33 | 34 | self.data = pd.DataFrame([[example['qid'], 35 | example['phrase'], 36 | example['context'], 37 | example['answer'], 38 | example['answer_type']] 39 | for example in examples], columns=['qids', 'phrases', 'contexts', 'gold', 'type']) 40 | 41 | self.qids = self.data.qids 42 | self.phrases = self.data.phrases 43 | self.contexts = self.data.contexts 44 | self.gold = self.data.gold 45 | self.types = self.data.type 46 | 47 | self.tokenizer = tokenizer 48 | self.source_len = max_seq_len 49 | self.summ_len = summ_len 50 | 51 | def preprocess_iirc_context(self, context): 52 | """ 53 | preprocess a context from iirc dataset to a text, which can be used by a generative model 54 | """ 55 | return '\n'.join([c['passage'] + ': ' + c['text'] for c in context]) 56 | 57 | def preprocess_iirc_answer(self, answer): 58 | """ 59 | preprocess an answer from iirc dataset to a text, which can be used by a generative model 60 | """ 61 | if answer['type'] == 'none': 62 | return 'none' 63 | if answer['type'] == 'span': 64 | return '#'.join([a['text'] for a in answer['answer_spans']]) 65 | if answer['type'] in ['binary', 'value']: 66 | return answer['answer_value'] 67 | return None 68 | 69 | def __len__(self): 70 | return len(self.phrases) 71 | 72 | def __getitem__(self, index): 73 | 74 | phrase = str(self.phrases[index]) 75 | context = str(self.contexts[index]) 76 | type = str(self.types[index]) 77 | qid = str(self.qids[index]) 78 | 79 | source_text = phrase + '\n' + context 80 | source_text = self.prefix + source_text 81 | 82 | gold_text = self.gold[index] 83 | 84 | labels = self.tokenizer.batch_encode_plus([gold_text], max_length=self.summ_len, 85 | truncation=True, padding='max_length', 86 | return_tensors='pt').input_ids.squeeze() \ 87 | .to(dtype=torch.long) 88 | labels[labels == 0] = -100 89 | tokenized_inputs = self.tokenizer.batch_encode_plus([source_text], max_length=self.source_len, 90 | truncation=True, padding='max_length', 91 | return_tensors='pt') 92 | 93 | return { 94 | 'input_ids': tokenized_inputs.input_ids.squeeze().to(dtype=torch.long), 95 | 'attention_mask': tokenized_inputs.attention_mask, 96 | 'labels': labels, 97 | 'answer_type': type, 98 | 'id': qid, 99 | 'question': phrase, 100 | 'context': context, 101 | 'answer': gold_text, 102 | 'all_answers': json.dumps(gold_text.split('#')) 103 | } 104 | -------------------------------------------------------------------------------- /Training/ContinuousPreTraining/predictors/span_predictor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from typing import Set 4 | 5 | 6 | class SpanPrediction: 7 | """ 8 | span prediction object 9 | """ 10 | def __init__(self, 11 | correct_predictions, 12 | tokens_predictions_dict, 13 | tokens_labels_dict, 14 | precision, 15 | f1): 16 | """ 17 | init relevant fields for a span prediction 18 | """ 19 | self.correct_predictions = correct_predictions 20 | self.tokens_predictions_dict = tokens_predictions_dict 21 | self.tokens_labels_dict = tokens_labels_dict 22 | self.precision = precision 23 | self.f1 = f1 24 | 25 | 26 | def _compute_f1(predicted_bag: Set[int], gold_bag: Set[int]) -> float: 27 | intersection = len(gold_bag.intersection(predicted_bag)) 28 | if not predicted_bag: 29 | precision = 1.0 30 | else: 31 | precision = intersection / float(len(predicted_bag)) 32 | if not gold_bag: 33 | recall = 1.0 34 | else: 35 | recall = intersection / float(len(gold_bag)) 36 | f1 = (2 * precision * recall) / (precision + recall) if not (precision == 0.0 and recall == 0.0) else 0.0 37 | return f1 38 | 39 | 40 | def get_span_prediction_dict(predictions): 41 | """ 42 | method to get dictionary between spans and their predictions 43 | """ 44 | span_prediction_dict = {} 45 | i = 0 46 | num_predicted_tokens = len(predictions) 47 | last_extra_id = None 48 | 49 | # iterate over all tokens 50 | while i < num_predicted_tokens: 51 | pred_token = predictions[i] 52 | 53 | # check if this is an extra id token 54 | if 32000 <= pred_token <= 33000: 55 | span_prediction_dict[pred_token] = [] 56 | last_extra_id = pred_token 57 | 58 | else: 59 | if last_extra_id is not None: 60 | span_prediction_dict[last_extra_id].append(pred_token) 61 | 62 | # raise cnt for i 63 | i += 1 64 | return span_prediction_dict 65 | 66 | 67 | def get_precision(gold_spans, predicted_spans): 68 | """ 69 | calculate span precision 70 | """ 71 | # make sure we don't calculate by zero 72 | if len(gold_spans) == 0: 73 | return 0 74 | 75 | num_spans = len(gold_spans) 76 | correct_predictions = 0 77 | for span_key in gold_spans.keys(): 78 | if span_key in predicted_spans: 79 | if predicted_spans[span_key] == gold_spans[span_key]: 80 | correct_predictions += 1 81 | 82 | return correct_predictions / num_spans 83 | 84 | 85 | def get_average_f1(gold_spans, predicted_spans): 86 | """ 87 | get average f1 between a gold and predicted span 88 | """ 89 | # make sure we don't calculate by zero 90 | if len(gold_spans) == 0: 91 | return 0 92 | 93 | num_spans = len(gold_spans) 94 | tot_f1 = 0 95 | for span_key in gold_spans.keys(): 96 | if span_key in predicted_spans: 97 | tot_f1 += _compute_f1(set(gold_spans[span_key]), set(predicted_spans[span_key])) 98 | 99 | return tot_f1 / num_spans 100 | 101 | 102 | def SpanPredictor(tokenizer, model, input_ids, attention_mask, labels): 103 | """ 104 | span predictor for mlm task 105 | """ 106 | 107 | # init fields 108 | span_predictions = [] 109 | outputs = model(input_ids=input_ids, labels=labels) 110 | logits = outputs.logits 111 | 112 | # mlm preds 113 | preds = torch.argmax(logits, dim=-1).detach().cpu().numpy() 114 | labels_cpu = labels.detach().cpu().numpy() 115 | correct_preds = (preds == labels_cpu) 116 | 117 | # calculate the precision and f1 for every sample 118 | for k, pred in enumerate(preds): 119 | preds_spans_dict = get_span_prediction_dict(preds[k]) 120 | labels_spans_dict = get_span_prediction_dict(labels_cpu[k]) 121 | precision = get_precision(preds_spans_dict, labels_spans_dict) 122 | f1 = get_average_f1(preds_spans_dict, labels_spans_dict) 123 | 124 | # append prediction 125 | span_predictions.append(SpanPrediction(correct_predictions=correct_preds, 126 | tokens_predictions_dict=preds_spans_dict, 127 | tokens_labels_dict=labels_spans_dict, 128 | precision=precision, 129 | f1=f1)) 130 | 131 | return span_predictions 132 | -------------------------------------------------------------------------------- /ExampleGeneration/ExampleGeneration/question_generators/question_generator_factory.py: -------------------------------------------------------------------------------- 1 | import os, json 2 | import sys 3 | from setuptools import find_packages 4 | from pkgutil import iter_modules 5 | import logging 6 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name 7 | 8 | 9 | class QGenFactory: 10 | def __init__(self): 11 | pass 12 | 13 | def upper_to_lower_notation_name(self, qgen_name): 14 | return ''.join(['_' + c.lower() if c.isupper() else c for c in qgen_name ])[1:] 15 | 16 | def find_qgen(self, path, callange_to_find): 17 | modules = list() 18 | for pkg in [''] + find_packages(path): 19 | pkgpath = path + '/' + pkg.replace('.', '/') 20 | if sys.version_info.major == 2 or (sys.version_info.major == 3 and sys.version_info.minor < 6): 21 | for _, name, ispkg in iter_modules([pkgpath]): 22 | if not ispkg: 23 | modules.append(pkg + '.' + name) 24 | else: 25 | for info in iter_modules([pkgpath]): 26 | if not info.ispkg: 27 | modules.append(pkg + '.' + info.name) 28 | 29 | found_qgen = [module for module in modules if module.find(callange_to_find) > -1] 30 | if len(found_qgen) > 0: 31 | found_qgen = found_qgen[0] 32 | if found_qgen.startswith('.'): 33 | found_qgen = found_qgen[1:] 34 | else: 35 | found_qgen = None 36 | 37 | return found_qgen 38 | 39 | def get_qgen(self, qgen_name, template, args): 40 | qgen_name_lower = self.upper_to_lower_notation_name(qgen_name) 41 | module_name = self.find_qgen(os.path.dirname(os.path.abspath(__file__)), qgen_name_lower) 42 | try: 43 | mod = __import__('ExampleGeneration.question_generators.' + module_name, fromlist=[qgen_name]) 44 | except: 45 | logger.error(module_name + ' module not found!!') 46 | assert (ValueError('qgen_name not found!')) 47 | 48 | return getattr(mod, qgen_name)(template, args) 49 | 50 | def create_new_qgen(self, qgen_name, qgen_module, copy_from, args): 51 | os.chdir(os.path.dirname(os.path.abspath(__file__)) + '/qgens') 52 | copy_from_lower = self.upper_to_lower_notation_name(copy_from) 53 | qgen_name_lower = self.upper_to_lower_notation_name(qgen_name) 54 | copy_from_module = self.find_qgen(os.getcwd(), copy_from_lower) 55 | if copy_from_module is None: 56 | assert (ValueError('copy_from qgen not found!')) 57 | copy_from_path = copy_from_module.replace('.',os.sep) + '.py' 58 | 59 | if not os.path.isdir(qgen_module): 60 | os.mkdir(qgen_module) 61 | open(os.path.join(qgen_module,'__init__.py'), 'a').close() 62 | 63 | with open(copy_from_path,'r') as f: 64 | copied_qgen_txt = f.read() 65 | copied_qgen_txt = copied_qgen_txt.replace(copy_from, qgen_name) 66 | 67 | if len(qgen_module) > 0: 68 | new_qgen_path = os.path.join(qgen_module, qgen_name_lower) + '.py' 69 | else: 70 | new_qgen_path = qgen_name_lower + '.py' 71 | with open(new_qgen_path, 'w') as f: 72 | f.write(copied_qgen_txt) 73 | 74 | # duplicating the test 75 | os.chdir('../../tests/qgens') 76 | if not os.path.isdir(qgen_module): 77 | os.mkdir(qgen_module) 78 | with open(copy_from_path.replace('.py','_test.py'),'r') as f: 79 | copied_qgen_txt = f.read() 80 | copied_qgen_txt = copied_qgen_txt.replace('qgens.' + copy_from_module, \ 81 | 'qgens.' + qgen_module + '.' + qgen_name_lower) 82 | copied_qgen_txt = copied_qgen_txt.replace(copy_from, qgen_name) 83 | copied_qgen_txt = copied_qgen_txt.replace(copy_from_lower, qgen_name_lower) 84 | 85 | if len(qgen_module) > 0: 86 | new_qgen_path = os.path.join(qgen_module, qgen_name_lower) + '_test.py' 87 | else: 88 | new_qgen_path = qgen_name_lower + '_test.py' 89 | with open(new_qgen_path, 'w') as f: 90 | f.write(copied_qgen_txt) 91 | 92 | # adding to config file: 93 | with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), "configurations", 94 | args.config_file_name) , 'r') as f: 95 | config = json.load(f) 96 | config[qgen_name] = config[copy_from] 97 | with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), "configurations", 98 | args.config_file_name) , 'w') as f: 99 | json.dump(config, f ,indent=4) 100 | -------------------------------------------------------------------------------- /Training/ContinuousPreTraining/Common/config.py: -------------------------------------------------------------------------------- 1 | global logger 2 | import datetime 3 | import json 4 | import logging 5 | 6 | 7 | tracer = logging.getLogger('config') 8 | tracer.setLevel(logging.CRITICAL) # or desired level 9 | tracer.addHandler(logging.FileHandler('indexer.log')) 10 | 11 | 12 | class Config: 13 | """ A python singleton """ 14 | 15 | # Usage example 16 | # Config().write_log('INFO', 'train_metric', context_dict=elastic_train_metrics) 17 | 18 | class __impl: 19 | """ Implementation of the singleton interface """ 20 | 21 | def __init__(self): 22 | self._start_time = datetime.datetime.utcnow().strftime("%d-%m-%y_%H:%M") 23 | self._config = {'config_path': '', 24 | 'use_wandb': False, 25 | 'use_elastic': False} 26 | 27 | def load(self, config_path): 28 | """ 29 | Loads a full config jsonnet 30 | """ 31 | with open(config_path) as f: 32 | self._config.update(json.load(f)) 33 | 34 | def override_dict(self, new_config): 35 | """ 36 | Adds a value to the config if missing, or overrides if already in config 37 | key: string path in config separated by '.' e.g. training_arguments.num_train_epochs 38 | """ 39 | for key,val in new_config.items(): 40 | tokens = key.split('.') 41 | config_internal = self._config 42 | for token in tokens[:-1]: 43 | if token in config_internal: 44 | config_internal = config_internal[token] 45 | else: 46 | raise (ValueError(f"could not find {key} in config!")) 47 | 48 | if tokens[-1] in config_internal: 49 | if val is not None: 50 | if val == 'False': 51 | config_internal[tokens[-1]] = False 52 | elif val == 'True': 53 | config_internal[tokens[-1]] = True 54 | else: 55 | config_internal[tokens[-1]] = val 56 | else: 57 | raise (ValueError(f"could not find {key} in config!")) 58 | 59 | def add_value(self, key, value): 60 | """ 61 | Adds a value to the config if missing, or overrides if already in config 62 | key: string path in config separated by '.' e.g. training_arguments.num_train_epochs 63 | """ 64 | pass 65 | 66 | def iter_on_config(self, d, key, full_key = ''): 67 | entries = [] 68 | vals = [] 69 | val = None 70 | for k, v in d.items(): 71 | if isinstance(v, dict): 72 | if key in v: 73 | vals.append(v[key]) 74 | keys.append() 75 | return keys, vals 76 | else: 77 | keys, vals = self.iter_on_config(v, full_key) 78 | for i_k, i_v in zip(key,val): 79 | vals.append(val) 80 | 81 | return full_key + '.' + key, val 82 | 83 | def get(self, key): 84 | """ 85 | Adds a value to the config if missing, or overrides if already in config 86 | """ 87 | tokens = key.split('.') 88 | result = self._config 89 | for token in tokens[:-1]: 90 | if token in result: 91 | result = result[token] 92 | if tokens[-1] in result: 93 | return result[tokens[-1]] 94 | else: 95 | #return self.iter_on_config(self._config, key) 96 | tracer.info(f"could not find {key} in config, return none!") 97 | return None 98 | 99 | return result 100 | 101 | # storage for the instance reference 102 | __instance = None 103 | 104 | def __init__(self): 105 | """ Create singleton instance """ 106 | # Check whether we already have an instance 107 | if Config.__instance is None: 108 | # Create and remember instance 109 | Config.__instance = Config.__impl() 110 | 111 | # Store instance reference as the only member in the handle 112 | self.__dict__['_Singleton__instance'] = Config.__instance 113 | 114 | def __getattr__(self, attr): 115 | """ Delegate access to implementation """ 116 | return getattr(self.__instance, attr) 117 | 118 | def __setattr__(self, attr, value): 119 | """ Delegate access to implementation """ 120 | return setattr(self.__instance, attr, value) 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | -------------------------------------------------------------------------------- /Training/ContinuousPreTraining/Data/dataset_readers/unified_qa_dataset.py: -------------------------------------------------------------------------------- 1 | # Creating a custom dataset for reading the dataframe and loading it into the dataloader to pass it to the neural network at a later stage for finetuning the model and to prepare it for predictions 2 | import json 3 | import torch 4 | import gzip 5 | import pandas as pd 6 | import hashlib 7 | from datasets import tqdm 8 | from ContinuousPreTraining.Common.config import Config 9 | from torch.utils.data.dataset import Dataset 10 | import random 11 | 12 | 13 | class UnifiedQaDataset(torch.utils.data.Dataset): 14 | 15 | def __init__(self, 16 | data_path, 17 | tokenizer, 18 | max_input_token_len, 19 | max_output_token_len=32, 20 | prefix=None, 21 | sample_size=None, 22 | generation_model=False 23 | ): 24 | 25 | # load lot data from gzip 26 | random.seed(42) 27 | self.generation_model = generation_model 28 | examples = [] 29 | 30 | self.prefix = 'QA: ' if prefix is None else prefix 31 | 32 | with gzip.open(data_path, "r") as f: 33 | for i, l in enumerate(tqdm(f)): 34 | examples.append(json.loads(l)) 35 | 36 | # sample num_examples_to_load examples 37 | if sample_size is not None: 38 | sample_size = min(len(examples), sample_size) 39 | examples = random.sample(examples, sample_size) 40 | 41 | for example in examples: 42 | if 'id' not in example: 43 | m = hashlib.md5() 44 | m.update(example['context'].encode()) 45 | m.update(example['answer'].encode()) 46 | example['id'] = m.hexdigest() 47 | 48 | self.data = pd.DataFrame([[example['context'], 49 | example['answer'], 50 | example['all_answers'], 51 | example['id']] 52 | for example in examples], 53 | columns=['contexts', 'gold', 'all_answers', 'id']) 54 | 55 | self.tokenizer = tokenizer 56 | self.max_input_token_len = max_input_token_len 57 | self.max_output_token_len = max_output_token_len 58 | 59 | # lengths for fast grouping 60 | self.lengths = [len(context) 61 | for context in self.data.contexts] 62 | 63 | def __len__(self): 64 | return len(self.data.id) 65 | 66 | def __getitem__(self, index): 67 | 68 | source_text = str(self.data.contexts[index]) 69 | 70 | if self.generation_model: 71 | #if Config().get('train_dataset_reader.datasets.unifiedqa.add_prefix'): 72 | source_text = self.prefix + source_text 73 | gold_text = self.data.gold[index] 74 | labels = self.tokenizer.batch_encode_plus([gold_text], max_length=self.max_output_token_len, 75 | truncation=True, padding='max_length', 76 | return_tensors='pt').input_ids.squeeze() \ 77 | .to(dtype=torch.long) 78 | labels[labels == 0] = -100 79 | tokenized_inputs = self.tokenizer.encode_plus(text=source_text, 80 | add_special_tokens=True, 81 | max_length=self.max_input_token_len, 82 | pad_to_max_length=False, 83 | return_token_type_ids=False, 84 | return_attention_mask=True, 85 | return_overflowing_tokens=False, 86 | return_special_tokens_mask=False, 87 | ) 88 | return { 89 | 'input_ids': tokenized_inputs.input_ids, 90 | 'attention_mask': tokenized_inputs.attention_mask, 91 | 'labels': labels, 92 | 'question': source_text, 93 | 'answer': gold_text, 94 | 'id': self.data.id[index], 95 | 'all_answers': json.dumps(self.data.all_answers[index]) 96 | } 97 | 98 | else: 99 | tokenized = self.tokenizer.batch_encode_plus([source_text], max_length=self.max_input_token_len, 100 | truncation=True, padding='max_length', 101 | return_tensors='pt') 102 | return { 103 | 'input_ids': tokenized.input_ids.squeeze().to(dtype=torch.long), 104 | 'attention_mask': tokenized.attention_mask, 105 | 'labels': torch.tensor(self.data.gold[index]) 106 | } 107 | -------------------------------------------------------------------------------- /Training/ContinuousPreTraining/scripts/preprocess_mmqa.py: -------------------------------------------------------------------------------- 1 | import json 2 | import gzip 3 | from datasets import tqdm 4 | from transformers import T5Tokenizer 5 | import matplotlib.pyplot as plt 6 | 7 | 8 | def main(): 9 | def parse_question(q, linearized_table): 10 | """ 11 | parse the question to conain the oracle text documentss and table 12 | """ 13 | question = q['question'] 14 | id = q['qid'] 15 | context_text_docs = q['metadata']['text_doc_ids'] 16 | question_text_docs = [texts_dict[doc] for doc in context_text_docs] 17 | 18 | supporting_contexts_ids = set([supporting_context['doc_id'] 19 | for supporting_context in q['supporting_context'] 20 | if supporting_context['doc_part'] == 'text']) 21 | 22 | gold_text_contexts = [d for d in question_text_docs 23 | if d['id'] in supporting_contexts_ids] 24 | text_context = '\n'.join([d['text'] for d in gold_text_contexts]) 25 | 26 | all_answers = [str(a['answer']) for a in q['answers']] 27 | answer = '#'.join(all_answers) 28 | 29 | ds_array.append({'context': f'{question} \n {text_context} \n {linearized_table}', 30 | 'answer': answer, 31 | 'all_answers': [all_answers], 32 | 'id': id}) 33 | 34 | 35 | train = [] 36 | dev = [] 37 | test = [] 38 | tables = [] 39 | texts = [] 40 | 41 | f = gzip.open("ContinuousPreTraining/Data/mmqa/MMQA_train.jsonl.gz?raw=true", 'r') 42 | for l in f: 43 | train.append(json.loads(l)) 44 | 45 | f = gzip.open("ContinuousPreTraining/Data/mmqa/MMQA_dev.jsonl.gz?raw=true", 'r') 46 | for l in f: 47 | dev.append(json.loads(l)) 48 | 49 | f = gzip.open("ContinuousPreTraining/Data/mmqa/MMQA_test.jsonl.gz?raw=true", 'r') 50 | for l in f: 51 | test.append(json.loads(l)) 52 | 53 | f = gzip.open("ContinuousPreTraining/Data/mmqa/MMQA_tables.jsonl.gz?raw=true", 'r') 54 | for l in f: 55 | tables.append(json.loads(l)) 56 | 57 | f = gzip.open("ContinuousPreTraining/Data/mmqa/MMQA_texts.jsonl.gz?raw=true", 'r') 58 | for l in f: 59 | texts.append(json.loads(l)) 60 | 61 | # go over all ds splits 62 | parsed_train_questions = [] 63 | parsed_dev_questions = [] 64 | parsed_test_questions = [] 65 | tables_dict = {t['id']: t for t in tables} 66 | texts_dict = {t['id']: t for t in texts} 67 | 68 | train = {'lines': train, 'array': parsed_train_questions, 'balanced_sampling': True, 'test': False, 69 | 'add_table': False} 70 | dev = {'lines': dev, 'array': parsed_dev_questions, 'balanced_sampling': False, 'test': False, 'add_table': False} 71 | 72 | print(f'Preprocessing MMQA') 73 | for ds_split in [train, dev]: 74 | 75 | ds_lines = ds_split['lines'] 76 | ds_array = ds_split['array'] 77 | 78 | for q in tqdm(ds_lines): 79 | 80 | # filter image questions 81 | if 'image' not in q['metadata']['modalities']: 82 | 83 | # liniearize table 84 | linearized_table = '' 85 | table_id = q['metadata']['table_id'] 86 | table = tables_dict[table_id]['table'] 87 | for i, row in enumerate(table['table_rows']): 88 | 89 | row_text = f'R{i}: ' 90 | for j, cell in enumerate(row): 91 | column_name = table['header'][j]['column_name'] 92 | column_value = cell['text'] 93 | row_text += f'{column_name} is {column_value};' 94 | linearized_table += row_text 95 | 96 | parse_question(q, linearized_table) 97 | # parse_question_paragraph_classifier(q) 98 | 99 | tokenizer = T5Tokenizer.from_pretrained('t5-base') 100 | tokenized_inputs = [tokenizer.tokenize(input['context']) for input in dev['array']] 101 | input_lenghts = [len(input) for input in tokenized_inputs] 102 | plt.hist(input_lenghts, bins=128, cumulative=True, density=True) 103 | plt.plot() 104 | plt.title('Cumulative Linearized Tables Lengths') 105 | plt.xlabel('# tokens') 106 | plt.xlabel('% of example') 107 | plt.grid() 108 | plt.show() 109 | # write train questions 110 | train_output_file = 'ContinuousPreTraining/Data/mmqa/parsed_mmqa_train_oracle.json' 111 | output_fp = gzip.open(train_output_file, 'wb') 112 | for question in parsed_train_questions: 113 | output_fp.write((json.dumps(question, ensure_ascii=False) + '\n').encode('utf-8')) 114 | 115 | # write dev questions 116 | dev_output_file = 'ContinuousPreTraining/Data/mmqa/parsed_mmqa_dev_oracle.json' 117 | output_fp = gzip.open(dev_output_file, 'wb') 118 | for question in parsed_dev_questions: 119 | output_fp.write((json.dumps(question, ensure_ascii=False) + '\n').encode('utf-8')) 120 | 121 | print('Finished pre-processing MMQA') 122 | 123 | 124 | if __name__ == '__main__': 125 | """ 126 | Script to preprocess the MMQA dataset to UnifiedQA format 127 | """ 128 | main() 129 | -------------------------------------------------------------------------------- /ExampleGeneration/wikitextparser/_argument.py: -------------------------------------------------------------------------------- 1 | """Define the Argument class.""" 2 | 3 | from regex import compile as regex_compile, MULTILINE, DOTALL 4 | 5 | from ._wikitext import SubWikiText, SECTION_HEADING 6 | from ._spans import parse_to_spans 7 | 8 | ARG_SHADOW_FULLMATCH = regex_compile( 9 | rb'[|:](?(?:[^=]*+(?:' + SECTION_HEADING + 10 | rb'\n)?+)*+)(?:\Z|(?=)(?.*+))', 11 | MULTILINE | DOTALL).fullmatch 12 | 13 | 14 | class Argument(SubWikiText): 15 | 16 | """Create a new Argument Object. 17 | 18 | Note that in MediaWiki documentation `arguments` are (also) called 19 | parameters. In this module the convention is: 20 | {{{parameter}}}, {{template|argument}}. 21 | See https://www.mediawiki.org/wiki/Help:Templates for more information. 22 | """ 23 | 24 | @property 25 | def _shadow_match(self): 26 | cached_shadow_match, cache_string = getattr( 27 | self, '_shadow_match_cache', (None, None)) 28 | self_string = str(self) 29 | if cache_string == self_string: 30 | return cached_shadow_match 31 | shadow_match = ARG_SHADOW_FULLMATCH(self._shadow) 32 | self._shadow_match_cache = shadow_match, self_string 33 | return shadow_match 34 | 35 | @property 36 | def name(self) -> str: 37 | """Return argument's name. 38 | 39 | For positional arguments return the position as a string. 40 | """ 41 | lststr0 = self._lststr[0] 42 | ss = self._span[0] 43 | shadow_match = self._shadow_match 44 | if shadow_match['eq']: 45 | s, e = shadow_match.span('pre_eq') 46 | return lststr0[ss + s:ss + e] 47 | # positional argument 48 | position = 1 49 | # Todo: if we had the index of self._span, we could only look-up 50 | # the head of the self._type_to_spans. 51 | for s, e in self._type_to_spans[self._type]: 52 | if ss <= s: 53 | break 54 | arg_str = lststr0[s:e] 55 | if '=' in arg_str: 56 | # The argument may is still be positional if the equal sign is 57 | # inside an atomic sub-spans. 58 | byte_array = bytearray(arg_str, 'ascii', 'replace') 59 | parse_to_spans(byte_array) # Remove sub-spans from byte_array 60 | if b'=' in byte_array: 61 | # This is a keyword argument. 62 | continue 63 | # This is a preceding positional argument. 64 | position += 1 65 | return str(position) 66 | 67 | @name.setter 68 | def name(self, newname: str) -> None: 69 | """Set the name for this argument. 70 | 71 | If this is a positional argument, convert it to keyword argument. 72 | """ 73 | oldname = self.name 74 | if self._shadow_match['eq']: 75 | self[1:1 + len(oldname)] = newname 76 | else: 77 | self[0:1] = '|' + newname + '=' 78 | 79 | @property 80 | def positional(self) -> bool: 81 | """Return True if there is an equal sign in the argument else False.""" 82 | return False if self._shadow_match['eq'] else True 83 | 84 | @positional.setter 85 | def positional(self, to_positional: bool) -> None: 86 | """Change to keyword or positional accordingly. 87 | 88 | Raise ValueError on trying to convert positional to keyword argument. 89 | """ 90 | shadow_match = self._shadow_match 91 | if shadow_match['eq']: 92 | # Keyword argument 93 | if to_positional: 94 | del self[1:shadow_match.end('eq')] 95 | else: 96 | return 97 | if to_positional: 98 | # Positional argument. to_positional is True. 99 | return 100 | # Positional argument. to_positional is False. 101 | raise ValueError( 102 | 'Converting positional argument to keyword argument is not ' 103 | 'possible without knowing the new name. ' 104 | 'You can use `self.name = somename` instead.') 105 | 106 | @property 107 | def value(self) -> str: 108 | """Return value of a keyword argument.""" 109 | shadow_match = self._shadow_match 110 | if shadow_match['eq']: 111 | return self[shadow_match.start('post_eq'):] 112 | return self[1:] 113 | 114 | @value.setter 115 | def value(self, newvalue: str) -> None: 116 | """Assign the newvalue to self.""" 117 | shadow_match = self._shadow_match 118 | if shadow_match['eq']: 119 | self[shadow_match.start('post_eq'):] = newvalue 120 | else: 121 | self[1:] = newvalue 122 | 123 | @property 124 | def _lists_shadow_ss(self): 125 | shadow_match = self._shadow_match 126 | if shadow_match['eq']: 127 | post_eq = shadow_match['post_eq'] 128 | ls_post_eq = post_eq.lstrip() 129 | return ( 130 | ls_post_eq, 131 | self._span[0] + shadow_match.start('post_eq') 132 | + len(post_eq) - len(ls_post_eq)) 133 | return shadow_match[0][1:], self._span[0] + 1 134 | -------------------------------------------------------------------------------- /Training/ContinuousPreTraining/configurations/PReasM_uniform_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "experiment": { 3 | "experiment_name": "PReasM_Uni" 4 | }, 5 | "model": { 6 | "PReasM": false, 7 | "size": "Base" 8 | }, 9 | "tokenizer": "t5-base", 10 | "datasets_sampler": {"type": "LambdaMlmSampler", "lmbda": 0.5}, 11 | "train_datasets": { 12 | "WikiTrain": { 13 | "reader": { 14 | "type": "T5MlmDataset", 15 | "path": "https://tabreas.s3-us-west-2.amazonaws.com/mlm_data/t5_wiki_train_256.jsonl", 16 | "pass_tokenizer": true, 17 | "num_examples_to_load": 1000000, 18 | "max_input_token_len": 260, 19 | "max_output_token_len": 57 20 | }, 21 | "dataloader": { 22 | "LR": 3e-5, 23 | "single_task_sampler": "Random", 24 | "no_collator_in_eval": true 25 | }, 26 | "predictor": "SpanPredictor", 27 | "eval_method": "SpanEvaluator" 28 | }, 29 | "SyntheticQuestionTrain": { 30 | "reader": { 31 | "type": "SyntheticQuestionsMultiDatasets", 32 | "pass_tokenizer": true, 33 | "path": "https://tabreas.s3-us-west-2.amazonaws.com/generated_reasoning_examples_train/", 34 | "skills": [ 35 | "counting", 36 | "numeric_superlatives", 37 | "numeric comparison", 38 | "composition_2_hop", 39 | "composition", 40 | "numeric_comparison_boolean", 41 | "temporal_comparison", 42 | "temporal_difference", 43 | "temporal_comparison_boolean", 44 | "conjunction", 45 | "arithmetic_superlatives", 46 | "arithmetic_addition", 47 | "most_quantifier", 48 | "only_quantifier", 49 | "every_quantifier", 50 | "temporal_superlatives" 51 | ], "max_input_token_len": 384, 52 | "max_output_token_len": 32, 53 | "generation_model": true 54 | }, 55 | "dataloader": { 56 | "LR": 1e-4, 57 | "single_task_sampler": "LengthGroupedSampler" 58 | }, 59 | "predictor": "GenerativePredictor", 60 | "eval_method": "DropEval", 61 | "dataset_sampler": { 62 | "type": "ErrorDistributionHeterogeneousSampler", 63 | "pass_trainer_state": true, 64 | "distribution_name": "SyntheticQuestionValidation" 65 | } 66 | } 67 | }, 68 | "validation_datasets":{ 69 | "SyntheticQuestionValidation": { 70 | "reader": { 71 | "type": "SyntheticQuestionsMultiDatasets", 72 | "pass_tokenizer": true, 73 | "path": "https://tabreas.s3-us-west-2.amazonaws.com/generated_reasoning_examples_dev/", 74 | "skills": [ 75 | "counting", 76 | "numeric_superlatives", 77 | "numeric comparison", 78 | "composition_2_hop", 79 | "composition", 80 | "numeric_comparison_boolean", 81 | "temporal_comparison", 82 | "temporal_difference", 83 | "temporal_comparison_boolean", 84 | "conjunction", 85 | "arithmetic_superlatives", 86 | "arithmetic_addition", 87 | "most_quantifier", 88 | "only_quantifier", 89 | "every_quantifier", 90 | "temporal_superlatives" 91 | ], "max_input_token_len": 384, 92 | "max_output_token_len": 32, 93 | "num_examples_to_load": 1000, 94 | "generation_model": true 95 | }, 96 | "dataloader": { 97 | }, 98 | "predictor": "GenerativePredictor", 99 | "eval_method": "DropEval", 100 | "save_error_distribution": true 101 | }, 102 | "WikiEval": { 103 | "reader": { 104 | "type": "T5MlmDataset", 105 | "path": "https://tabreas.s3-us-west-2.amazonaws.com/mlm_data/t5_wiki_eval_256.jsonl", 106 | "pass_tokenizer": true, 107 | "num_examples_to_load": 1000, 108 | "num_wiki_examples": 75299, 109 | "max_input_token_len": 260, 110 | "max_output_token_len": 57 111 | }, 112 | "dataloader": { 113 | "no_collator_in_eval": true 114 | }, 115 | "predictor": "SpanPredictor", 116 | "eval_method": "SpanEvaluator", 117 | "save_error_distribution": false 118 | } 119 | }, 120 | "optimizer": { 121 | "type": "AdaFactor", 122 | "lr": 1e-4 123 | }, 124 | "scheduler": { 125 | "type": "linear_scheduler_with_warmup", 126 | "num_warmup_steps": 500, 127 | "num_training_steps": 2e32 128 | }, 129 | "training_arguments": { 130 | "num_train_epochs": 1000, 131 | "per_device_train_batch_size": 40, 132 | "per_device_eval_batch_size": 60, 133 | "gradient_accumulation_steps": 1, 134 | "log_steps": 100, 135 | "save_steps": 5000, 136 | "eval_steps": 500, 137 | "weight_decay": 0.01, 138 | "save_total_limit": 5, 139 | "seed": 43, 140 | "prediction_loss_only": true, 141 | "no_cuda": false 142 | }, 143 | "trainer": { 144 | "type": "UpdatedMtTrainer", 145 | "override_huggingface_train_method": true, 146 | "load_train_dataloader_after_eval": false, 147 | "callbacks": ["MultiTaskHeterogeneousCallback"] 148 | } 149 | } -------------------------------------------------------------------------------- /Training/README.md: -------------------------------------------------------------------------------- 1 | # Turning Tables 2 | 3 | ## Setup 4 | 5 | ### Setting up a virtual environment 6 | 7 | 1. First, clone the repository: 8 | 9 | ``` 10 | git clone https://github.com/oriyor/turning_tables.git 11 | ``` 12 | 13 | 2. Change your directory to where you cloned the files: 14 | 15 | ``` 16 | cd Training 17 | export PYTHONPATH=${PYTHONPATH}:`pwd` 18 | ``` 19 | 20 | 3. Create a virtual environment with Python 3.6 or above: 21 | 22 | ``` 23 | virtualenv venv --python=python3.7 (or python3.7 -m venv venv or conda create -n turningtables python=3.7) 24 | ``` 25 | 26 | 4. Activate the virtual environment. You will need to activate the venv environment in each terminal in which you want to use ContinuousPreTraining. 27 | 28 | ``` 29 | source venv/bin/activate (or source venv/bin/activate.csh or conda activate turningtables) 30 | ``` 31 | 5. Install the required dependencies: 32 | 33 | ``` 34 | pip install -r requirements.txt 35 | ``` 36 | 37 | 6. Download the relevant torch version for your machine (see [torch versions](https://pytorch.org/)) 38 | 39 | 40 | ### Downloading PReasM 41 | 42 | To download a pre-trained PReasM model please follow the instructions below: 43 | 44 | 1. Clone the repository 45 | 46 | 2. Choose a PReasM model: 47 | 48 | ``` 49 | export Sampler=Err/Moment/Uni 50 | export Size=Base/Large 51 | ``` 52 | 53 | 3. Download PReasM: 54 | ``` 55 | ./bash_scripts/download_preasm.sh 56 | ``` 57 | 58 | 4. Using PReasM is similar to using T5ForConditionalGeneration in the transformers library: 59 | ``` 60 | from transformers import T5ForConditionalGeneration 61 | model = T5ForConditionalGeneration.from_pretrained(f'CheckpointsRestored/PReasM-{sampler}-{size}/', return_dict=True) 62 | 63 | ``` 64 | 65 | ### Fine-tune PReasM 66 | 67 | To fine-tune PReasM please follow the istructions below: 68 | 69 | 1. Clone the repository 70 | 71 | 2. Setup the datasets: 72 | ``` 73 | ./bash_scripts/setup_datasets.sh 74 | ``` 75 | 3. Choose a PReasM model (setting PReasM to False will fine-tune the T5 baseline): 76 | 77 | ``` 78 | export PReasM=True/False 79 | export Sampler=Err/Moment/Uni 80 | export Size=Base/Large 81 | export Dataset=drop/iirc/mmqa 82 | ``` 83 | 84 | 4. Fine-tune PReasM. Hyperparameters such as the learning rate and batch size can be updated using the experiment's config file: 85 | 86 | ``` 87 | python ContinuousPreTraining/scripts/train.py -c ContinuousPreTraining/configurations/${Dataset}_config.json --model.PReasM ${PReasM} --model.sampler ${Sampler} --model.size ${Size} 88 | ``` 89 | 90 | 5. Verify your results with the official evaluation scripts: 91 | ``` 92 | python ContinuousPreTraining/scripts/verify_${Dataset}_eval.py --prediction_path /{experiment_name}/{prediction_json_file} --gold_path ContinuousPreTraining/Data/iirc/iirc_dev_drop_format.json 93 | ``` 94 | 95 | ### Training PReasM from scratch 96 | 97 | To train PReasM from scratch please follow the instructions below. Note that training PReasM will download data for the original T5 pre-training task and the generated reasoning examples (overall ~13GB). Hyperparameters such as the learning rate and batch size can be updated using the experiment's config file: 98 | 99 | 1. Clone the repository 100 | 101 | 2. Set the sampling strategy and model size: 102 | 103 | ``` 104 | export SAMPLER=uniform/momentum/errors 105 | export SIZE=Base/Large 106 | ``` 107 | 108 | 3. Run the following command: 109 | 110 | ``` 111 | python ContinuousPreTraining/scripts/train.py -c ContinuousPreTraining/configurations/PReasM_${SAMPLER}_config.json --model.size Base -t t5-base -tbs 64 -ebs 128 -gas 1 --training_arguments.eval_steps 5000 --training_arguments.save_steps 5000 --optimizer.lr 1e-4 --experiment.experiment_name PReasM_Base_${SAMPLER} 112 | ``` 113 | 114 | ### Training the MMQA pipeline 115 | 116 | To train the MMQA pipeline retrieval models described in the paper, please follow the instructions below: 117 | 118 | 1. Clone the repository 119 | preprocess_mmqa_for_paragraph_classification.py 120 | 121 | 2. Train the question classifier: 122 | ``` 123 | python ContinuousPreTraining/scripts/preprocess_mmqa_for_question_classification.py 124 | python ContinuousPreTraining/scripts/train.py -c ContinuousPreTraining/configurations/mmqa_question_classifier_config.json 125 | ``` 126 | 127 | 3. Train the paragraph classifier: 128 | ``` 129 | python ContinuousPreTraining/scripts/preprocess_mmqa_for_paragraph_classification.py 130 | python ContinuousPreTraining/scripts/train.py -c ContinuousPreTraining/configurations/mmqa_para_classifier_config.json 131 | ``` 132 | 133 | 4. Unifiy between the classifier's predictions to create the retrieval contexts: 134 | ``` 135 | python ContinuousPreTraining/scripts/unify_mmqa_context_with_retriever.py.py --dev_questions_classifier_predictions_path {question_classificaiton_dev_predictions.csv} --test_questions_classifier_predictions_path {question_classificaiton_test_predictions.csv} --dev_paragraphs_classifier_predictions_path {paragraph_classificaiton_dev_predictions.csv} --test_paragraphs_classifier_predictions_path {paragraph_classificaiton_test_predictions.csv} 136 | ``` 137 | 138 | 5. Train MMQA with the retrieval contexts: 139 | ``` 140 | python ContinuousPreTraining/configurations/mmqa_retrieval_config.json --model.PReasM ${PReasM} --model.sampler ${Sampler} --model.size ${Size} 141 | ``` 142 | -------------------------------------------------------------------------------- /Training/ContinuousPreTraining/configurations/PReasM_errors_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "experiment": { 3 | "experiment_name": "PReasM_Errors" 4 | }, 5 | "model": { 6 | "PReasM": false, 7 | "sampler": "Err", 8 | "size": "Base" 9 | }, 10 | "tokenizer": "t5-base", 11 | "datasets_sampler": {"type": "LambdaMlmSampler", "lmbda": 0.5}, 12 | "train_datasets": { 13 | "WikiTrain": { 14 | "reader": { 15 | "type": "T5MlmDataset", 16 | "path": "https://tabreas.s3-us-west-2.amazonaws.com/mlm_data/t5_wiki_train_256.jsonl", 17 | "pass_tokenizer": true, 18 | "num_examples_to_load": 1000000, 19 | "max_input_token_len": 260, 20 | "max_output_token_len": 57 21 | }, 22 | "dataloader": { 23 | "LR": 3e-5, 24 | "single_task_sampler": "Random", 25 | "no_collator_in_eval": true 26 | }, 27 | "predictor": "SpanPredictor", 28 | "eval_method": "SpanEvaluator" 29 | }, 30 | "SyntheticQuestionTrain": { 31 | "reader": { 32 | "type": "SyntheticQuestionsMultiDatasets", 33 | "pass_tokenizer": true, 34 | "path": "https://tabreas.s3-us-west-2.amazonaws.com/generated_reasoning_examples_train/", 35 | "skills": [ 36 | "counting", 37 | "numeric_superlatives", 38 | "numeric comparison", 39 | "composition_2_hop", 40 | "composition", 41 | "numeric_comparison_boolean", 42 | "temporal_comparison", 43 | "temporal_difference", 44 | "temporal_comparison_boolean", 45 | "conjunction", 46 | "arithmetic_superlatives", 47 | "arithmetic_addition", 48 | "most_quantifier", 49 | "only_quantifier", 50 | "every_quantifier", 51 | "temporal_superlatives" 52 | ], 53 | "max_input_token_len": 384, 54 | "max_output_token_len": 32, 55 | "generation_model": true 56 | }, 57 | "dataloader": { 58 | "LR": 1e-4, 59 | "single_task_sampler": "LengthGroupedSampler" 60 | }, 61 | "predictor": "GenerativePredictor", 62 | "eval_method": "DropEval", 63 | "dataset_sampler": { 64 | "type": "ErrorDistributionHeterogeneousSampler", 65 | "pass_trainer_state": true, 66 | "distribution_name": "SyntheticQuestionValidation" 67 | } 68 | } 69 | }, 70 | "validation_datasets":{ 71 | "SyntheticQuestionValidation": { 72 | "reader": { 73 | "type": "SyntheticQuestionsMultiDatasets", 74 | "pass_tokenizer": true, 75 | "path": "https://tabreas.s3-us-west-2.amazonaws.com/generated_reasoning_examples_dev/", 76 | "skills": [ 77 | "counting", 78 | "numeric_superlatives", 79 | "numeric comparison", 80 | "composition_2_hop", 81 | "composition", 82 | "numeric_comparison_boolean", 83 | "temporal_comparison", 84 | "temporal_difference", 85 | "temporal_comparison_boolean", 86 | "conjunction", 87 | "arithmetic_superlatives", 88 | "arithmetic_addition", 89 | "most_quantifier", 90 | "only_quantifier", 91 | "every_quantifier", 92 | "temporal_superlatives" 93 | ], 94 | "max_input_token_len": 384, 95 | "max_output_token_len": 32, 96 | "num_examples_to_load": 1000, 97 | "generation_model": true 98 | }, 99 | "dataloader": { 100 | }, 101 | "predictor": "GenerativePredictor", 102 | "eval_method": "DropEval", 103 | "save_error_distribution": true 104 | }, 105 | "WikiEval": { 106 | "reader": { 107 | "type": "T5MlmDataset", 108 | "path": "https://tabreas.s3-us-west-2.amazonaws.com/mlm_data/t5_wiki_eval_256.jsonl", 109 | "pass_tokenizer": true, 110 | "num_examples_to_load": 1000, 111 | "num_wiki_examples": 75299, 112 | "max_input_token_len": 260, 113 | "max_output_token_len": 57 114 | }, 115 | "dataloader": { 116 | "no_collator_in_eval": true 117 | }, 118 | "predictor": "SpanPredictor", 119 | "eval_method": "SpanEvaluator", 120 | "save_error_distribution": false 121 | } 122 | }, 123 | "optimizer": { 124 | "type": "AdaFactor", 125 | "lr": 1e-4 126 | }, 127 | "scheduler": { 128 | "type": "linear_scheduler_with_warmup", 129 | "num_warmup_steps": 500, 130 | "num_training_steps": 2e32 131 | }, 132 | "training_arguments": { 133 | "num_train_epochs": 1000, 134 | "per_device_train_batch_size": 40, 135 | "per_device_eval_batch_size": 60, 136 | "gradient_accumulation_steps": 1, 137 | "log_steps": 100, 138 | "save_steps": 5000, 139 | "eval_steps": 500, 140 | "weight_decay": 0.01, 141 | "save_total_limit": 5, 142 | "seed": 43, 143 | "prediction_loss_only": true, 144 | "no_cuda": false 145 | }, 146 | "trainer": { 147 | "type": "UpdatedMtTrainer", 148 | "override_huggingface_train_method": true, 149 | "load_train_dataloader_after_eval": true, 150 | "callbacks": ["MultiTaskHeterogeneousCallback"] 151 | } 152 | } -------------------------------------------------------------------------------- /ExampleGeneration/wikitextparser/_wikilist.py: -------------------------------------------------------------------------------- 1 | """Define the class for List objects.""" 2 | 3 | from typing import List, Union, Tuple, Dict, MutableSequence, Match 4 | 5 | from regex import escape, fullmatch, MULTILINE 6 | 7 | from ._wikitext import SubWikiText 8 | 9 | 10 | SUBLIST_PATTERN = rb'(?>^(?{pattern})[:;#*].*+(?>\n|\Z))*+' 11 | LIST_PATTERN_FORMAT = ( 12 | rb'(?' 13 | rb'^' 14 | rb'(?{pattern})' 15 | rb'(?(?<=;\s*+)' 16 | # mark inline definition as an item 17 | rb'(?[^:\n]*+)(?:(?.*+))?+' 18 | rb'(?>\n|\Z)' + SUBLIST_PATTERN + 19 | rb'|' 20 | # non-definition 21 | rb'(?.*+)' 22 | rb'(?>\n|\Z)' + SUBLIST_PATTERN + 23 | rb')' 24 | rb')++' 25 | ) 26 | 27 | 28 | class WikiList(SubWikiText): 29 | 30 | """Class to represent ordered, unordered, and definition lists.""" 31 | 32 | def __init__( 33 | self, 34 | string: Union[str, MutableSequence[str]], 35 | pattern: str, 36 | _match: Match = None, 37 | _type_to_spans: Dict[str, List[List[int]]] = None, 38 | _span: List[int] = None, 39 | _type: str = None, 40 | ) -> None: 41 | super().__init__(string, _type_to_spans, _span, _type) 42 | self.pattern = pattern 43 | if _match: 44 | self._match_cache = _match, self.string 45 | else: 46 | self._match_cache = fullmatch( 47 | LIST_PATTERN_FORMAT.replace(b'{pattern}', pattern.encode()), 48 | self._shadow, 49 | MULTILINE, 50 | ), self.string 51 | 52 | @property 53 | def _match(self): 54 | """Return the match object for the current list.""" 55 | cache_match, cache_string = self._match_cache 56 | string = self.string 57 | if cache_string == string: 58 | return cache_match 59 | cache_match = fullmatch( 60 | LIST_PATTERN_FORMAT.replace(b'{pattern}', self.pattern.encode()), 61 | self._shadow, 62 | MULTILINE, 63 | ) 64 | self._match_cache = cache_match, string 65 | return cache_match 66 | 67 | @property 68 | def items(self) -> List[str]: 69 | """Return items as a list of strings. 70 | 71 | Don't include sub-items and the start pattern. 72 | """ 73 | items = [] # type: List[str] 74 | append = items.append 75 | string = self.string 76 | match = self._match 77 | ms = match.start() 78 | for s, e in match.spans('item'): 79 | append(string[s - ms:e - ms]) 80 | return items 81 | 82 | @property 83 | def fullitems(self) -> List[str]: 84 | """Return list of item strings. Includes their start and sub-items.""" 85 | fullitems = [] # type: List[str] 86 | append = fullitems.append 87 | string = self.string 88 | match = self._match 89 | ms = match.start() 90 | for s, e in match.spans('fullitem'): 91 | append(string[s - ms:e - ms]) 92 | return fullitems 93 | 94 | @property 95 | def level(self) -> int: 96 | """Return level of nesting for the current list. 97 | 98 | Level is a one-based index, for example the level for `* a` will be 1. 99 | """ 100 | return len(self._match['pattern']) 101 | 102 | def sublists( 103 | self, i: int = None, pattern: str = None 104 | ) -> List['WikiList']: 105 | """Return the Lists inside the item with the given index. 106 | 107 | :param i: The index if the item which its sub-lists are desired. 108 | The performance is likely to be better if `i` is None. 109 | 110 | :param pattern: The starting symbol for the desired sub-lists. 111 | The `pattern` of the current list will be automatically added 112 | as prefix. 113 | Although this parameter is optional, but specifying it can improve 114 | the performance. 115 | """ 116 | patterns = (r'\#', r'\*', '[:;]') if pattern is None \ 117 | else (pattern,) # type: Tuple[str, ...] 118 | self_pattern = self.pattern 119 | lists = self.lists 120 | sublists = [] # type: List['WikiList'] 121 | sublists_append = sublists.append 122 | if i is None: 123 | # Any sublist is acceptable 124 | for pattern in patterns: 125 | for lst in lists(self_pattern + pattern): 126 | sublists_append(lst) 127 | return sublists 128 | # Only return sub-lists that are within the given item 129 | match = self._match 130 | fullitem_spans = match.spans('fullitem') 131 | ss = self._span[0] 132 | ms = match.start() 133 | s, e = fullitem_spans[i] 134 | e -= ms - ss 135 | s -= ms - ss 136 | for pattern in patterns: 137 | for lst in lists(self_pattern + pattern): 138 | # noinspection PyProtectedMember 139 | ls, le = lst._span 140 | if s < ls and le <= e: 141 | sublists_append(lst) 142 | return sublists 143 | 144 | def convert(self, newstart: str) -> None: 145 | """Convert to another list type by replacing starting pattern.""" 146 | match = self._match 147 | ms = match.start() 148 | for s, e in reversed(match.spans('pattern')): 149 | self[s - ms:e - ms] = newstart 150 | self.pattern = escape(newstart) 151 | -------------------------------------------------------------------------------- /ExampleGeneration/ExampleGeneration/run_multiple_chunks.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import gzip 3 | import json 4 | import os 5 | import logging 6 | import platform 7 | from ExampleGeneration.datajob_factory import DataJobFactory 8 | 9 | from ExampleGeneration.common.analysis_utils import dump_synthetic_questions_analysis 10 | from ExampleGeneration.common.file_utils import cached_path, upload_local_file_to_s3 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | 15 | def main(): 16 | parse = argparse.ArgumentParser("") 17 | parse.add_argument("-c", "--datajob_name", type=str, help="The name of the datajob class and config to use") 18 | parse.add_argument("-config", "--config_file_name", type=str, help="DataJobs config file name", default="config_tests.json") 19 | parse.add_argument("-wd", "--working_directory", type=str, help="dir of input file, can be s3 path", default='') 20 | parse.add_argument("-sc", "--start_chunk", type=int, help="dir of input file, can be s3 path", default=0) 21 | parse.add_argument("-ec", "--end_chunk", type=int, help="dir of input file, can be s3 path", default=None) 22 | parse.add_argument("-af", "--annotated_questions_file", type=str, help="dir of input file, can be s3 path", default=None) 23 | parse.add_argument("-mf", "--filename_to_merge", type=str, help="dir of input file, can be s3 path", default=None) 24 | parse.add_argument("-dj", "--datajobs_to_run", type=str, help="A list of datajobs (it will check if each is enabled)", default=None) 25 | parse.add_argument("--build_train_dev_sets", action='store_true', help="upload dev train splits", default=False) 26 | 27 | # In the test no output file will be produced, change -out to create an output 28 | args = parse.parse_args() 29 | 30 | 31 | if args.datajobs_to_run is not None: 32 | args.datajobs_to_run = args.datajobs_to_run.split(',') 33 | base_working_directory = args.working_directory 34 | 35 | args.base_working_directory = base_working_directory 36 | if args.annotated_questions_file is not None: 37 | args.annotated_questions_file = base_working_directory + args.annotated_questions_file 38 | 39 | for chunk_num in range(args.start_chunk,args.end_chunk + 1): 40 | config_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "configurations", \ 41 | args.config_file_name) 42 | with open(config_path, 'r') as f: 43 | _config = json.load(f) 44 | 45 | # TODO better this be configurable: 46 | logger.info(f"----------- Working on chunk {'{:04d}'.format(chunk_num)} ------------- ") 47 | 48 | args.working_directory = base_working_directory + 'chunk_' + '{:04d}'.format(chunk_num) + '/' 49 | 50 | # loop over all the datajobs 51 | for datajob_name, datajob_config in _config.items(): 52 | if (args.datajobs_to_run is not None and datajob_name in args.datajobs_to_run) or \ 53 | (args.datajobs_to_run is None and datajob_config['enable']): 54 | if datajob_name == 'FilterWikiTables': 55 | args.input_file = _config['FilterWikiTables']["input_file"].replace('chunk_0000','chunk_' + '{:04d}'.format(chunk_num)) 56 | if datajob_name == 'ReasClassifyColumnTypes': 57 | args.input_file = _config['ReasClassifyColumnTypes']["input_file"].replace('chunk_0000','chunk_' + '{:04d}'.format(chunk_num)) 58 | elif 'input_file' in args: 59 | del args.input_file 60 | 61 | logger.info("-------------- Running: " + datajob_name + " --------------------") 62 | datajob = DataJobFactory().get_datajob(datajob_name, datajob_config['type'], args) 63 | datajob.run_datajob(args) 64 | 65 | # Append all dataset to the full dataset. 66 | if args.filename_to_merge is not None: 67 | for filename_to_merge in args.filename_to_merge.split(','): 68 | 69 | logger.info(f'\n-------Merging file {filename_to_merge}-----------\n') 70 | 71 | questions = [] 72 | for chunk_num in range(args.start_chunk, args.end_chunk + 1): 73 | chunk_dataset_path = cached_path(base_working_directory + 'chunk_' + '{:04d}'.format(chunk_num) + '/' + filename_to_merge) 74 | with gzip.open(chunk_dataset_path, 'r') as f: 75 | #header = f.readline() 76 | for line in f: 77 | question = json.loads(line) 78 | questions.append(question) 79 | 80 | with gzip.open(filename_to_merge, 'w') as f: 81 | for line in questions: 82 | f.write((json.dumps(line) + '\n').encode('utf-8')) 83 | 84 | upload_local_file_to_s3(filename_to_merge, base_working_directory.replace('s3://', '') + filename_to_merge) 85 | 86 | # dump csv if local 87 | local_platform = platform.node() == 'Oris-MacBook-Pro.local' 88 | 89 | logger.info(f'---- Calculating stats, local platfrom {local_platform} ----') 90 | 91 | dump_synthetic_questions_analysis(filename_to_merge, \ 92 | 'data/tab_reas/samples/template_q_sample_full_wip.csv', 93 | dump_csv=local_platform) 94 | 95 | # remove local if not local platform 96 | if not local_platform: 97 | os.remove(filename_to_merge) 98 | 99 | print() 100 | 101 | if __name__ == '__main__': 102 | main() -------------------------------------------------------------------------------- /Training/ContinuousPreTraining/configurations/PReasM_momentum_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "experiment": { 3 | "experiment_name": "PReasM_Moment" 4 | }, 5 | "model": { 6 | "PReasM": false, 7 | "size": "Base" 8 | }, 9 | "tokenizer": "t5-base", 10 | "datasets_sampler": {"type": "LambdaMlmSampler", "lmbda": 0.5}, 11 | "train_datasets": { 12 | "WikiTrain": { 13 | "reader": { 14 | "type": "T5MlmDataset", 15 | "path": "https://tabreas.s3-us-west-2.amazonaws.com/mlm_data/t5_wiki_train_256.jsonl", 16 | "pass_tokenizer": true, 17 | "num_examples_to_load": 1000000, 18 | "max_input_token_len": 260, 19 | "max_output_token_len": 57 20 | }, 21 | "dataloader": { 22 | "LR": 3e-5, 23 | "single_task_sampler": "Random", 24 | "no_collator_in_eval": true 25 | }, 26 | "predictor": "SpanPredictor", 27 | "eval_method": "SpanEvaluator" 28 | }, 29 | "SyntheticQuestionTrain": { 30 | "reader": { 31 | "type": "SyntheticQuestionsMultiDatasets", 32 | "pass_tokenizer": true, 33 | "path": "https://tabreas.s3-us-west-2.amazonaws.com/synthetic_questions_train_wip/", 34 | "skills": [ 35 | "counting", 36 | "numeric_superlatives", 37 | "numeric comparison", 38 | "composition_2_hop", 39 | "composition", 40 | "numeric_comparison_boolean", 41 | "temporal_comparison", 42 | "temporal_difference", 43 | "temporal_comparison_boolean", 44 | "conjunction", 45 | "arithmetic_superlatives", 46 | "arithmetic_addition", 47 | "most_quantifier", 48 | "only_quantifier", 49 | "every_quantifier", 50 | "temporal_superlatives" 51 | ], 52 | "max_input_token_len": 384, 53 | "max_output_token_len": 32, 54 | "generation_model": true 55 | }, 56 | "dataloader": { 57 | "LR": 1e-4, 58 | "single_task_sampler": "LengthGroupedSampler" 59 | }, 60 | "predictor": "GenerativePredictor", 61 | "eval_method": "DropEval", 62 | "dataset_sampler": { 63 | "type": "AdaptiveErrorHeterogeneousSampler", 64 | "pass_trainer_state": true, 65 | "is_adaptive": true, 66 | "normalize_with_prev": true, 67 | "distribution_name": "SyntheticQuestionValidation" 68 | } 69 | } 70 | }, 71 | "validation_datasets":{ 72 | "SyntheticQuestionValidation": { 73 | "reader": { 74 | "type": "SyntheticQuestionsMultiDatasets", 75 | "pass_tokenizer": true, 76 | "path": "https://tabreas.s3-us-west-2.amazonaws.com/generated_reasoning_examples_train/", 77 | "skills": [ 78 | "counting", 79 | "numeric_superlatives", 80 | "numeric comparison", 81 | "composition_2_hop", 82 | "composition", 83 | "numeric_comparison_boolean", 84 | "temporal_comparison", 85 | "temporal_difference", 86 | "temporal_comparison_boolean", 87 | "conjunction", 88 | "arithmetic_superlatives", 89 | "arithmetic_addition", 90 | "most_quantifier", 91 | "only_quantifier", 92 | "every_quantifier", 93 | "temporal_superlatives" 94 | ], 95 | "max_input_token_len": 384, 96 | "max_output_token_len": 32, 97 | "num_examples_to_load": 1000, 98 | "generation_model": true 99 | }, 100 | "dataloader": { 101 | }, 102 | "predictor": "GenerativePredictor", 103 | "eval_method": "DropEval", 104 | "save_error_distribution": true 105 | }, 106 | "WikiEval": { 107 | "reader": { 108 | "type": "T5MlmDataset", 109 | "path": "https://tabreas.s3-us-west-2.amazonaws.com/mlm_data/t5_wiki_eval_256.jsonl", 110 | "pass_tokenizer": true, 111 | "num_examples_to_load": 1000, 112 | "num_wiki_examples": 75299, 113 | "max_input_token_len": 260, 114 | "max_output_token_len": 57 115 | }, 116 | "dataloader": { 117 | "no_collator_in_eval": true 118 | }, 119 | "predictor": "SpanPredictor", 120 | "eval_method": "SpanEvaluator", 121 | "save_error_distribution": false 122 | } 123 | }, 124 | "optimizer": { 125 | "type": "AdaFactor", 126 | "lr": 1e-4 127 | }, 128 | "scheduler": { 129 | "type": "linear_scheduler_with_warmup", 130 | "num_warmup_steps": 500, 131 | "num_training_steps": 2e32 132 | }, 133 | "training_arguments": { 134 | "num_train_epochs": 1000, 135 | "per_device_train_batch_size": 40, 136 | "per_device_eval_batch_size": 60, 137 | "gradient_accumulation_steps": 1, 138 | "log_steps": 100, 139 | "save_steps": 5000, 140 | "eval_steps": 500, 141 | "weight_decay": 0.01, 142 | "save_total_limit": 5, 143 | "seed": 43, 144 | "prediction_loss_only": true, 145 | "no_cuda": false 146 | }, 147 | "trainer": { 148 | "type": "UpdatedMtTrainer", 149 | "override_huggingface_train_method": true, 150 | "load_train_dataloader_after_eval": true, 151 | "callbacks": ["MultiTaskHeterogeneousCallback"] 152 | } 153 | } -------------------------------------------------------------------------------- /ExampleGeneration/ExampleGeneration/question_generators/tab_reas/simple.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import random 3 | import json, re 4 | import logging 5 | from copy import copy 6 | 7 | from ExampleGeneration.common.table_wrapper import WikiTable 8 | 9 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name 10 | sample = [] 11 | question_index = {'i': 0} 12 | from ExampleGeneration.question_generators.question_generator import QuestionGenerator 13 | from ExampleGeneration.common.multiqa_format_wrapper import Question, SyntheticQuestion 14 | 15 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name 16 | 17 | 18 | class Simple(QuestionGenerator): 19 | def __init__(self, qgen_config, args): 20 | self.qgen_name = 'Simple' 21 | self.reasoning_types = ['Simple'] 22 | super().__init__(args) 23 | self._qgen_config = qgen_config 24 | 25 | def filter_simple_distractor(self, f, f1): 26 | """ 27 | check if f can be a distractor for a compoisiton question from f1 to f2 28 | """ 29 | 30 | # # filter if this distractor is the explicit answer to the question 31 | # if f.target_column_ind == f2.target_column_ind and f.src_column_ind == f1.src_column_ind: 32 | # if f.source_val_indices[0] == f1.source_val_indices[0]: 33 | # return True 34 | 35 | # filter if the distractor columns are irrelevent 36 | relevant_columns = {f1.src_column_ind, f1.target_column_ind} 37 | if not {f.src_column_ind, f.target_column_ind}.intersection(relevant_columns): 38 | return True 39 | 40 | # filter if the distractor is about the relevant rows 41 | if set(f.source_val_indices).intersection({f1.source_val_indices[0]}): 42 | return True 43 | 44 | # else return false 45 | return False 46 | 47 | def generate(self, context): 48 | """ 49 | :param from_question: 50 | :param to_question: 51 | :return: the composition question returned by injection of the first question to the second 52 | """ 53 | # Generate facts 54 | table = WikiTable(context) 55 | all_facts = self.generate_facts(table) 56 | facts = [f for f in all_facts 57 | if not f.filtered] 58 | random.seed(42) 59 | 60 | simple_questions = [] 61 | 62 | for template_config in self._qgen_config['templates']: 63 | template = template_config['question_template'] 64 | none_ratio = self._qgen_config['templates'][0]['none_ratio'] 65 | 66 | # generate composition questions by looping the facts 67 | for f1 in facts: 68 | # look for facts with one source 69 | if len(f1.source_val_indices) == 1: 70 | source_column = f1.src_column_ind 71 | target_column = f1.target_column_ind 72 | 73 | phrase = template 74 | phrase = phrase.replace("[page_title]", f1.page_title) 75 | phrase = phrase.replace("[table_title]", f1.table_title.strip()) 76 | phrase = phrase.replace("[source_column]", f1.src_column_header.strip()) 77 | phrase = phrase.replace("[target_column]", f1.target_column_header.strip()) 78 | phrase = phrase.replace("[from_cell_text]", str(f1.src_column_value).strip()) 79 | 80 | # sample distractors 81 | possible_distractors = self.sample_distractors(facts, f1, f1, 'simple') 82 | num_distractors = min(len(possible_distractors), 8) 83 | distractors = random.sample(possible_distractors, num_distractors) 84 | 85 | m = hashlib.md5() 86 | m.update(context.id.encode()) 87 | m.update(str(source_column).encode()) 88 | m.update(str(target_column).encode()) 89 | m.update(f1.src_column_value.encode()) 90 | qid = 'Simple-' + m.hexdigest() 91 | 92 | # answer 93 | answer = f1.target_column_values 94 | # randomly downsample facts 95 | question_facts = [f1.format_fact()] 96 | 97 | num_facts = len(question_facts) 98 | question_distractors = [d.format_fact() for d in 99 | distractors[num_facts:]] 100 | 101 | simple_questions.append(SyntheticQuestion(qid=qid, 102 | question=phrase, 103 | answers=answer, 104 | facts=question_facts, 105 | distractors=question_distractors, 106 | metadata={'type': 'simple', 107 | 'reasoning': self.reasoning_types, 108 | 'answer_type': 'entity', 109 | 'reversed_facts': [], 110 | 'template': f'somple', 111 | } 112 | )) 113 | 114 | if len(simple_questions) > 1: 115 | simple_questions = random.sample(simple_questions, 1) 116 | return simple_questions 117 | -------------------------------------------------------------------------------- /Training/ContinuousPreTraining/Common/evaluation_utils.py: -------------------------------------------------------------------------------- 1 | 2 | from typing import Any, Dict, List, Set, Tuple, Union, Optional 3 | import string 4 | import re 5 | import numpy as np 6 | from scipy.optimize import linear_sum_assignment 7 | # drop eval script methods 8 | ### copied code from drop eval: start 9 | ### https://github.com/jferguson144/IIRC-baseline/blob/fa397b2bbee54c71861abbb7a379d6999552bcac/numnet_plus/drop_eval.py#L74 10 | 11 | def _remove_articles(text: str) -> str: 12 | regex = re.compile(r'\b(a|an|the)\b', re.UNICODE) 13 | return re.sub(regex, ' ', text) 14 | 15 | 16 | def _white_space_fix(text: str) -> str: 17 | return ' '.join(text.split()) 18 | 19 | 20 | EXCLUDE = set(string.punctuation) 21 | 22 | 23 | def _remove_punc(text: str) -> str: 24 | if not _is_number(text): 25 | return ''.join(ch for ch in text if ch not in EXCLUDE) 26 | else: 27 | return text 28 | 29 | 30 | def _lower(text: str) -> str: 31 | return text.lower() 32 | 33 | 34 | def _tokenize(text: str) -> List[str]: 35 | return re.split(" |-", text) 36 | 37 | 38 | def _normalize_answer(text: str) -> str: 39 | """Lower text and remove punctuation, articles and extra whitespace.""" 40 | 41 | parts = [_white_space_fix(_remove_articles(_normalize_number(_remove_punc(_lower(token))))) 42 | for token in _tokenize(text)] 43 | parts = [part for part in parts if part.strip()] 44 | normalized = ' '.join(parts).strip() 45 | return normalized 46 | 47 | 48 | def _is_number(text: str) -> bool: 49 | try: 50 | float(text) 51 | return True 52 | except ValueError: 53 | return False 54 | 55 | 56 | def _normalize_number(text: str) -> str: 57 | if _is_number(text): 58 | return str(float(text)) 59 | else: 60 | return text 61 | 62 | 63 | def _normalize_answer(text: str) -> str: 64 | """Lower text and remove punctuation, articles and extra whitespace.""" 65 | 66 | parts = [_white_space_fix(_remove_articles(_normalize_number(_remove_punc(_lower(token))))) 67 | for token in _tokenize(text)] 68 | parts = [part for part in parts if part.strip()] 69 | normalized = ' '.join(parts).strip() 70 | return normalized 71 | 72 | 73 | def _answer_to_bags(answer: Union[str, List[str], Tuple[str, ...]]) -> Tuple[List[str], List[Set[str]]]: 74 | if isinstance(answer, (list, tuple)): 75 | raw_spans = answer 76 | else: 77 | raw_spans = [answer] 78 | normalized_spans: List[str] = [] 79 | token_bags = [] 80 | for raw_span in raw_spans: 81 | normalized_span = _normalize_answer(raw_span) 82 | normalized_spans.append(normalized_span) 83 | token_bags.append(set(normalized_span.split())) 84 | return normalized_spans, token_bags 85 | 86 | 87 | def _compute_f1(predicted_bag: Set[str], gold_bag: Set[str]) -> float: 88 | intersection = len(gold_bag.intersection(predicted_bag)) 89 | if not predicted_bag: 90 | precision = 1.0 91 | else: 92 | precision = intersection / float(len(predicted_bag)) 93 | if not gold_bag: 94 | recall = 1.0 95 | else: 96 | recall = intersection / float(len(gold_bag)) 97 | f1 = (2 * precision * recall) / (precision + recall) if not (precision == 0.0 and recall == 0.0) else 0.0 98 | return f1 99 | 100 | 101 | def _match_numbers_if_present(gold_bag: Set[str], predicted_bag: Set[str]) -> bool: 102 | gold_numbers = set() 103 | predicted_numbers = set() 104 | for word in gold_bag: 105 | if _is_number(word): 106 | gold_numbers.add(word) 107 | for word in predicted_bag: 108 | if _is_number(word): 109 | predicted_numbers.add(word) 110 | if (not gold_numbers) or gold_numbers.intersection(predicted_numbers): 111 | return True 112 | 113 | 114 | def _align_bags(predicted: List[Set[str]], gold: List[Set[str]]) -> List[float]: 115 | """ 116 | Takes gold and predicted answer sets and first finds the optimal 1-1 alignment 117 | between them and gets maximum metric values over all the answers. 118 | """ 119 | scores = np.zeros([len(gold), len(predicted)]) 120 | for gold_index, gold_item in enumerate(gold): 121 | for pred_index, pred_item in enumerate(predicted): 122 | if _match_numbers_if_present(gold_item, pred_item): 123 | scores[gold_index, pred_index] = _compute_f1(pred_item, gold_item) 124 | row_ind, col_ind = linear_sum_assignment(-scores) 125 | 126 | max_scores = np.zeros([max(len(gold), len(predicted))]) 127 | for row, column in zip(row_ind, col_ind): 128 | max_scores[row] = max(max_scores[row], scores[row, column]) 129 | return max_scores 130 | 131 | 132 | def get_drop_metrics(predicted: Union[str, List[str], Tuple[str, ...]], 133 | gold: Union[str, List[str], Tuple[str, ...]]) -> Tuple[float, float]: 134 | """ 135 | Takes a predicted answer and a gold answer (that are both either a string or a list of 136 | strings), and returns exact match and the DROP F1 metric for the prediction. If you are 137 | writing a script for evaluating objects in memory (say, the output of predictions during 138 | validation, or while training), this is the function you want to call, after using 139 | :func:`answer_json_to_strings` when reading the gold answer from the released data file. 140 | """ 141 | predicted_bags = _answer_to_bags(predicted) 142 | gold_bags = _answer_to_bags(gold) 143 | 144 | if set(predicted_bags[0]) == set(gold_bags[0]) and len(predicted_bags[0]) == len(gold_bags[0]): 145 | exact_match = 1.0 146 | else: 147 | exact_match = 0.0 148 | 149 | f1_per_bag = _align_bags(predicted_bags[1], gold_bags[1]) 150 | f1 = np.mean(f1_per_bag) 151 | f1 = round(f1, 2) 152 | return exact_match, f1 153 | ### copied code from drop eval: start --------------------------------------------------------------------------------