├── .github └── workflows │ ├── ci.yml │ └── publish_pypi.yml ├── .gitignore ├── .pylintrc ├── LICENSE ├── README.md ├── examples ├── pos_tagging │ ├── .gitignore │ ├── .trapper_plugins │ ├── README.md │ ├── experiments │ │ └── roberta │ │ │ └── experiment.jsonnet │ ├── requirements.txt │ ├── scripts │ │ ├── __init__.py │ │ ├── cache_hf_datasets_fixtures.py │ │ └── run_tests.py │ ├── src │ │ ├── __init__.py │ │ ├── data │ │ │ ├── __init__.py │ │ │ ├── data_adapter.py │ │ │ ├── data_processor.py │ │ │ ├── label_mapper.py │ │ │ └── tokenizer_wrapper.py │ │ └── pipeline.py │ ├── test_fixtures │ │ └── hf_datasets │ │ │ └── conll2003_test_fixture │ │ │ ├── README.md │ │ │ ├── conll2003_test_fixture.py │ │ │ ├── test.txt │ │ │ ├── train.txt │ │ │ └── valid.txt │ ├── tests │ │ ├── __init__.py │ │ ├── conftest.py │ │ ├── test_data_adapter.py │ │ ├── test_data_processor.py │ │ ├── test_pipeline.py │ │ └── test_trainer.py │ └── train.py └── question_answering │ ├── README.md │ ├── experiment.jsonnet │ └── question_answering.ipynb ├── pyproject.toml ├── requirements.txt ├── resources └── trapper_diagram.png ├── scripts ├── __init__.py ├── cache_hf_datasets_fixtures.py ├── run_code_style.py └── run_tests.py ├── setup.cfg ├── setup.py ├── test_fixtures ├── commands │ └── experiment_trivial.jsonnet ├── data │ └── question_answering │ │ └── squad_qa │ │ ├── dev.json │ │ └── train.json ├── hf_datasets │ └── squad_qa_test_fixture │ │ ├── dev.json │ │ ├── squad_qa_test_fixture.py │ │ └── train.json ├── metrics │ ├── label_ids.pkl │ └── predictions.pkl ├── pipelines │ └── pipeline_integration_experiment.jsonnet ├── plugins │ ├── .trapper_plugins │ ├── __init__.py │ ├── custom_dummy_module.py │ └── custom_dummy_package │ │ ├── __init__.py │ │ ├── dummy_loader.py │ │ └── dummy_loader2.py └── training │ └── question_answering │ └── squad │ ├── dev.json │ └── train.json ├── tests ├── __init__.py ├── commands │ └── test_run.py ├── common │ ├── __init__.py │ └── test_plugins.py ├── conftest.py ├── data │ ├── __init__.py │ ├── conftest.py │ ├── data_collators │ │ ├── __init__.py │ │ └── squad │ │ │ ├── __init__.py │ │ │ └── test_question_answering_collator_and_adapter.py │ ├── data_processors │ │ ├── __init__.py │ │ ├── squad │ │ │ ├── __init__.py │ │ │ └── test_question_answering_processor.py │ │ └── test_data_processor.py │ ├── test_dataset_reader.py │ ├── test_label_mapper.py │ └── test_tokenizer_wrapper.py ├── metrics │ ├── __init__.py │ └── test_metric_handler.py ├── pipelines │ ├── __init__.py │ ├── test_functional.py │ ├── test_integration.py │ └── test_squad_pipeline.py └── training │ ├── __init__.py │ ├── conftest.py │ ├── test_basic_trainer_for_question_answering.py │ └── test_seq2seq_trainer.py └── trapper ├── __init__.py ├── __main__.py ├── commands.py ├── common ├── __init__.py ├── constants.py ├── io.py ├── lazy.py ├── notebook_utils │ ├── __init__.py │ ├── file_transfer.py │ └── prepare_data.py ├── params.py ├── plugins.py ├── registrable.py ├── testing_utils │ ├── __init__.py │ ├── hf_datasets_caching.py │ ├── pytest_fixtures │ │ ├── __init__.py │ │ ├── data.py │ │ └── training.py │ └── shell_utils.py └── utils.py ├── data ├── __init__.py ├── data_adapters │ ├── __init__.py │ ├── data_adapter.py │ └── question_answering_adapter.py ├── data_collator.py ├── data_processors │ ├── __init__.py │ ├── data_processor.py │ └── squad │ │ ├── __init__.py │ │ ├── question_answering_processor.py │ │ └── squad_processor.py ├── dataset_loader.py ├── dataset_reader.py ├── label_mapper.py └── tokenizers │ ├── __init__.py │ ├── squad.py │ └── tokenizer_wrapper.py ├── metrics ├── __init__.py ├── input_handlers │ ├── __init__.py │ ├── input_handler.py │ ├── language_generation_input_handler.py │ ├── question_answering_input_handler.py │ └── token_classification_input_handler.py ├── jury.py ├── metric.py └── output_handlers │ ├── __init__.py │ ├── output_handler.py │ └── token_classification_output_handler.py ├── models ├── __init__.py ├── auto_wrappers.py └── model_wrapper.py ├── pipelines ├── __init__.py ├── functional.py ├── pipeline.py └── question_answering_pipeline.py ├── training ├── __init__.py ├── callbacks.py ├── optimizers.py ├── train.py ├── trainer.py └── training_args.py └── version.py /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: Tests 2 | 3 | on: 4 | push: 5 | branches: [main] 6 | pull_request: 7 | branches: [main] 8 | 9 | jobs: 10 | build: 11 | runs-on: ubuntu-latest 12 | strategy: 13 | matrix: 14 | operating-system: [ubuntu-latest, windows-latest, macos-latest] 15 | python-version: [3.7, 3.8, 3.9] 16 | fail-fast: false 17 | 18 | steps: 19 | - name: Checkout 20 | uses: actions/checkout@v2 21 | 22 | - name: Set up Python 23 | uses: actions/setup-python@v2 24 | with: 25 | python-version: ${{ matrix.python-version }} 26 | 27 | - name: Restore Ubuntu cache 28 | uses: actions/cache@v1 29 | if: matrix.operating-system == 'ubuntu-latest' 30 | with: 31 | path: ~/.cache/pip 32 | key: ${{ matrix.os }}-${{ matrix.python-version }}-${{ hashFiles('**/setup.py')}} 33 | restore-keys: ${{ matrix.os }}-${{ matrix.python-version }}- 34 | 35 | - name: Restore MacOS cache 36 | uses: actions/cache@v1 37 | if: matrix.operating-system == 'macos-latest' 38 | with: 39 | path: ~/Library/Caches/pip 40 | key: ${{ matrix.os }}-${{ matrix.python-version }}-${{ hashFiles('**/setup.py')}} 41 | restore-keys: ${{ matrix.os }}-${{ matrix.python-version }}- 42 | 43 | - name: Restore Windows cache 44 | uses: actions/cache@v1 45 | if: matrix.operating-system == 'windows-latest' 46 | with: 47 | path: ~\AppData\Local\pip\Cache 48 | key: ${{ matrix.os }}-${{ matrix.python-version }}-${{ hashFiles('**/setup.py')}} 49 | restore-keys: ${{ matrix.os }}-${{ matrix.python-version }}- 50 | 51 | - name: Update pip 52 | run: python -m pip install --upgrade pip 53 | 54 | - name: Install dependencies 55 | run: > 56 | pip install -e .[dev] 57 | 58 | - name: Lint with flake8 and black 59 | run: | 60 | python -m scripts.run_code_style check 61 | 62 | - name: Cache `HuggingFace datasets` test fixture. 63 | run: | 64 | python -m scripts.cache_hf_datasets_fixtures 65 | 66 | - name: Run tests. 67 | run: | 68 | python -m scripts.run_tests 69 | -------------------------------------------------------------------------------- /.github/workflows/publish_pypi.yml: -------------------------------------------------------------------------------- 1 | name: Publish Python Package 2 | 3 | on: 4 | release: 5 | types: [created] 6 | 7 | jobs: 8 | deploy: 9 | runs-on: ubuntu-20.04 10 | 11 | steps: 12 | - uses: actions/checkout@v2 13 | - name: Set up Python 14 | uses: actions/setup-python@v2 15 | with: 16 | python-version: '3.x' 17 | - name: Install dependencies 18 | run: | 19 | python -m pip install --upgrade pip 20 | pip install setuptools wheel twine 21 | - name: Build and publish 22 | env: 23 | TWINE_USERNAME: __token__ 24 | TWINE_PASSWORD: ${{ secrets.PYPI_TOKEN }} 25 | run: | 26 | python setup.py sdist bdist_wheel 27 | twine upload dist/* -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Distribution / packaging 2 | *.egg-info/ 3 | 4 | # Environments 5 | .env 6 | .venv 7 | env/ 8 | venv/ 9 | ENV/ 10 | env.bak/ 11 | venv.bak/ 12 | 13 | # dev tools 14 | .idea 15 | .vscode 16 | 17 | # IPython 18 | profile_default/ 19 | ipython_config.py 20 | .ipynb_checkpoints 21 | 22 | # Byte-compiled / optimized / DLL files 23 | __pycache__/ 24 | *.py[cod] 25 | *$py.class 26 | 27 | # Unit test / coverage reports 28 | htmlcov/ 29 | .tox/ 30 | .nox/ 31 | .coverage 32 | .coverage.* 33 | .cache 34 | nosetests.xml 35 | coverage.xml 36 | *.cover 37 | *.py,cover 38 | .hypothesis/ 39 | .pytest_cache/ 40 | 41 | # datasets library's dataset artifacts 42 | *.lock 43 | *dataset_infos.json 44 | 45 | # mypy 46 | .mypy_cache 47 | 48 | # Misc 49 | debug_scripts/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Open Business Software Solutions 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /examples/pos_tagging/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore the following folders 2 | results/ 3 | outputs/ 4 | 5 | # Byte-compiled / optimized / DLL files 6 | __pycache__/ 7 | *.py[cod] 8 | *$py.class 9 | 10 | # datasets library's dataset artifacts 11 | *.lock 12 | *dataset_infos.json 13 | -------------------------------------------------------------------------------- /examples/pos_tagging/.trapper_plugins: -------------------------------------------------------------------------------- 1 | src 2 | -------------------------------------------------------------------------------- /examples/pos_tagging/experiments/roberta/experiment.jsonnet: -------------------------------------------------------------------------------- 1 | local output_dir = "experiments/roberta/outputs"; 2 | local result_dir = "experiments/roberta/results"; 3 | local conll2003_test_fixture="test_fixtures/hf_datasets/conll2003_test_fixture"; 4 | local save_steps = 292; 5 | { 6 | "pretrained_model_name_or_path": "roberta-base", 7 | "train_split_name": "train", 8 | "dev_split_name": "validation", 9 | "tokenizer_wrapper": { 10 | "type": "pos_tagging_example", 11 | "model_max_sequence_length": 512, 12 | "add_prefix_space": true 13 | }, 14 | "dataset_loader": { 15 | "dataset_reader": { 16 | "path": "conll2003", # actual dataset 17 | // "path": conll2003_test_fixture, # for testing the project 18 | }, 19 | "data_processor": {"type": "conll2003_pos_tagging_example"}, 20 | "data_adapter": {"type": "conll2003_pos_tagging_example"}, 21 | }, 22 | "data_collator": {"type": "default"}, 23 | "model_wrapper": {"type": "token_classification", "num_labels": 47}, 24 | "compute_metrics": {"metric_params": "seqeval"}, 25 | "metric_handler": {"type": "pos-tagging"}, 26 | "label_mapper": {"type": "conll2003_pos_tagging_example"}, 27 | "args": { 28 | "type": "default", 29 | "output_dir": output_dir + "/checkpoints", 30 | "result_dir": result_dir, 31 | "logging_first_step": true, 32 | "num_train_epochs": 3, 33 | "per_device_train_batch_size": 16, 34 | "per_device_eval_batch_size": 32, 35 | "gradient_accumulation_steps": 1, 36 | "logging_dir": output_dir + "/logs", 37 | "no_cuda": true, 38 | "logging_steps": save_steps, 39 | "eval_steps": save_steps, 40 | "evaluation_strategy": "steps", 41 | "save_steps": save_steps, 42 | "label_names": ["labels"], 43 | "lr_scheduler_type": "linear", 44 | "warmup_steps": 157, 45 | "do_train": true, 46 | "do_eval": true, 47 | "save_total_limit": 1, 48 | "metric_for_best_model": "eval_f1", 49 | "greater_is_better": true, 50 | }, 51 | "optimizer": { 52 | "type": "huggingface_adamw", 53 | "weight_decay": 0.1, 54 | "parameter_groups": [ 55 | [ 56 | ["bias", "LayerNorm\\\\.weight", "layer_norm\\\\.weight"], 57 | {"weight_decay": 0}, 58 | ] 59 | ], 60 | "lr": 3e-5, 61 | "eps": 1e-8, 62 | }, 63 | } 64 | -------------------------------------------------------------------------------- /examples/pos_tagging/requirements.txt: -------------------------------------------------------------------------------- 1 | trapper==0.0.4 2 | -------------------------------------------------------------------------------- /examples/pos_tagging/scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/obss/trapper/bbc82450097bd48466d2c47cfbb3bb194319410a/examples/pos_tagging/scripts/__init__.py -------------------------------------------------------------------------------- /examples/pos_tagging/scripts/cache_hf_datasets_fixtures.py: -------------------------------------------------------------------------------- 1 | """ 2 | Caches the tests dataset to HuggingFace's `datasets` library's cache so that the 3 | interpreter can find it when we try to load it through the `datasets` library. 4 | """ 5 | from src import POS_TAGGING_FIXTURES_ROOT 6 | 7 | from trapper.common.testing_utils.hf_datasets_caching import ( 8 | renew_hf_datasets_fixtures_cache, 9 | ) 10 | 11 | if __name__ == "__main__": 12 | renew_hf_datasets_fixtures_cache(POS_TAGGING_FIXTURES_ROOT / "hf_datasets") 13 | -------------------------------------------------------------------------------- /examples/pos_tagging/scripts/run_tests.py: -------------------------------------------------------------------------------- 1 | from trapper.common.testing_utils.shell_utils import shell, validate_and_exit 2 | 3 | if __name__ == "__main__": 4 | sts_tests = shell( 5 | "pytest --cov src --cov-report term-missing --cov-report xml -vvv tests" 6 | ) 7 | validate_and_exit(tests=sts_tests) 8 | -------------------------------------------------------------------------------- /examples/pos_tagging/src/__init__.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | from src import data 4 | from src.pipeline import ExamplePosTaggingPipeline 5 | 6 | POS_TAGGING_PROJECT_ROOT = Path(__file__).parent.parent.resolve() 7 | POS_TAGGING_TESTS_ROOT = POS_TAGGING_PROJECT_ROOT / "tests" 8 | POS_TAGGING_FIXTURES_ROOT = POS_TAGGING_PROJECT_ROOT / "test_fixtures" 9 | -------------------------------------------------------------------------------- /examples/pos_tagging/src/data/__init__.py: -------------------------------------------------------------------------------- 1 | from src.data.data_adapter import ExampleDataAdapterForPosTagging 2 | from src.data.data_processor import ExampleConll2003PosTaggingDataProcessor 3 | from src.data.label_mapper import ExampleLabelMapperForPosTagging 4 | from src.data.tokenizer_wrapper import ExamplePosTaggingTokenizerWrapper 5 | -------------------------------------------------------------------------------- /examples/pos_tagging/src/data/data_adapter.py: -------------------------------------------------------------------------------- 1 | from trapper.common.constants import IGNORED_LABEL_ID 2 | from trapper.data.data_adapters.data_adapter import DataAdapter 3 | from trapper.data.data_processors import IndexedInstance 4 | 5 | 6 | @DataAdapter.register("conll2003_pos_tagging_example") 7 | class ExampleDataAdapterForPosTagging(DataAdapter): 8 | """ 9 | This class takes the processed instance dict from the data processor and 10 | creates a new dict that has the "input_ids" and "labels" keys required by the 11 | models. It also takes care of the special BOS and EOS tokens while constructing 12 | these fields. 13 | """ 14 | 15 | def __call__(self, raw_instance: IndexedInstance) -> IndexedInstance: 16 | """ 17 | Create a sequence with the following field: 18 | input_ids: ...tokens... 19 | """ 20 | # We return an instance having the keys fields specified in 21 | # `trapper.models.auto_wrappers._TASK_TO_INPUT_FIELDS["token_classification"]` 22 | input_ids = [self._bos_token_id] + raw_instance["tokens"] 23 | input_ids.append(self._eos_token_id) 24 | labels = [IGNORED_LABEL_ID] + raw_instance["pos_tags"] 25 | labels.append(IGNORED_LABEL_ID) 26 | return {"input_ids": input_ids, "labels": labels} 27 | -------------------------------------------------------------------------------- /examples/pos_tagging/src/data/data_processor.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Any, Dict, List, Optional 3 | 4 | from trapper.data.data_processors import DataProcessor 5 | from trapper.data.data_processors.data_processor import IndexedInstance 6 | 7 | logger = logging.getLogger(__file__) 8 | 9 | 10 | @DataProcessor.register("conll2003_pos_tagging_example") 11 | class ExampleConll2003PosTaggingDataProcessor(DataProcessor): 12 | """ 13 | This class extracts the "tokens", "pos_tags" and "id" fields from from an input 14 | data instance. It tokenizes the `tokens` field since it actually consists of 15 | words which may need further tokenization. Then, it generates the corresponding 16 | token ids and store them. Finally, the`pos_tags` are stored directly without 17 | any processing since this field consists of integer labels ids instead of 18 | categorical labels. 19 | """ 20 | 21 | NUM_EXTRA_SPECIAL_TOKENS_IN_SEQUENCE = 2 # tokens 22 | 23 | def process(self, instance_dict: Dict[str, Any]) -> Optional[IndexedInstance]: 24 | return self.text_to_instance( 25 | id_=instance_dict["id"], 26 | tokens=instance_dict["tokens"], 27 | pos_tags=instance_dict["pos_tags"], 28 | ) 29 | 30 | def text_to_instance( 31 | self, 32 | tokens: List[str], 33 | id_: str = 0, 34 | pos_tags: Optional[List[int]] = None, 35 | ) -> IndexedInstance: 36 | expanded_tokens = [] 37 | expanded_token_counts = [] 38 | for token in tokens: 39 | expanded_token = self.tokenizer.tokenize(token) 40 | expanded_tokens.extend(expanded_token) 41 | expanded_token_counts.append(len(expanded_token)) 42 | 43 | instance = {"id": id_} 44 | 45 | if pos_tags is not None: 46 | expanded_pos_tags = [] 47 | for expanded_len, pos_tag in zip(expanded_token_counts, pos_tags): 48 | expanded_pos_tag = [pos_tag] * expanded_len 49 | expanded_pos_tags.extend(expanded_pos_tag) 50 | self._chop_excess_tokens(expanded_pos_tags, len(expanded_pos_tags)) 51 | instance["pos_tags"] = expanded_pos_tags 52 | 53 | self._chop_excess_tokens(expanded_tokens, len(expanded_tokens)) 54 | instance["tokens"] = self.tokenizer.convert_tokens_to_ids(expanded_tokens) 55 | 56 | return instance 57 | -------------------------------------------------------------------------------- /examples/pos_tagging/src/data/label_mapper.py: -------------------------------------------------------------------------------- 1 | from trapper.data.label_mapper import LabelMapper 2 | 3 | 4 | @LabelMapper.register("conll2003_pos_tagging_example", constructor="from_labels") 5 | class ExampleLabelMapperForPosTagging(LabelMapper): 6 | # Obtained by executing `dataset["train"].features["pos_tags"].feature.names` 7 | _LABELS = ( 8 | '"', 9 | "''", 10 | "#", 11 | "$", 12 | "(", 13 | ")", 14 | ",", 15 | ".", 16 | ":", 17 | "``", 18 | "CC", 19 | "CD", 20 | "DT", 21 | "EX", 22 | "FW", 23 | "IN", 24 | "JJ", 25 | "JJR", 26 | "JJS", 27 | "LS", 28 | "MD", 29 | "NN", 30 | "NNP", 31 | "NNPS", 32 | "NNS", 33 | "NN|SYM", 34 | "PDT", 35 | "POS", 36 | "PRP", 37 | "PRP$", 38 | "RB", 39 | "RBR", 40 | "RBS", 41 | "RP", 42 | "SYM", 43 | "TO", 44 | "UH", 45 | "VB", 46 | "VBD", 47 | "VBG", 48 | "VBN", 49 | "VBP", 50 | "VBZ", 51 | "WDT", 52 | "WP", 53 | "WP$", 54 | "WRB", 55 | ) 56 | -------------------------------------------------------------------------------- /examples/pos_tagging/src/data/tokenizer_wrapper.py: -------------------------------------------------------------------------------- 1 | from trapper.data import TokenizerWrapper 2 | 3 | 4 | @TokenizerWrapper.register("pos_tagging_example", constructor="from_pretrained") 5 | class ExamplePosTaggingTokenizerWrapper(TokenizerWrapper): 6 | """A `tokenizer wrapper` that is used for demonstrating how to implement 7 | your own by extending the base class.""" 8 | 9 | # Although we could have used the `TokenizerWrapper` directly, this class is 10 | # implemented for demonstration purposes. You can override 11 | # `_TASK_SPECIFIC_SPECIAL_TOKENS` here if your task requires custom extra 12 | # special tokens. 13 | -------------------------------------------------------------------------------- /examples/pos_tagging/src/pipeline.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Open Business Software Solutions, the HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | This implementation is adapted from the token classification pipeline from the 16 | HuggingFace's transformers library. Original code is available at: 17 | ``_. 18 | """ 19 | from typing import Optional 20 | 21 | # needed for registering the data-related classes 22 | # noinspection PyUnresolvedReferences 23 | # pylint: disable=unused-import 24 | import src.data 25 | from transformers import ( 26 | ModelCard, 27 | PreTrainedModel, 28 | PreTrainedTokenizer, 29 | TokenClassificationPipeline, 30 | ) 31 | from transformers.feature_extraction_utils import PreTrainedFeatureExtractor 32 | from transformers.pipelines import ( 33 | SUPPORTED_TASKS, 34 | ArgumentHandler, 35 | TokenClassificationArgumentHandler, 36 | ) 37 | 38 | from trapper.data import DataAdapter, DataCollator, DataProcessor 39 | from trapper.pipelines import PipelineMixin 40 | 41 | 42 | @PipelineMixin.register("example-pos-tagging", constructor="from_partial_objects") 43 | class ExamplePosTaggingPipeline(PipelineMixin, TokenClassificationPipeline): 44 | """ 45 | CONLL2003 POS tagging pipeline that extracts POS tags from a given sentence 46 | or a list of sentences. 47 | """ 48 | 49 | def __init__( 50 | self, 51 | model: PreTrainedModel, 52 | tokenizer: PreTrainedTokenizer, 53 | data_processor: DataProcessor, 54 | data_adapter: DataAdapter, 55 | data_collator: DataCollator, 56 | feature_extractor: Optional[PreTrainedFeatureExtractor] = None, 57 | modelcard: Optional[ModelCard] = None, 58 | framework: Optional[str] = None, 59 | task: str = "token-classification", 60 | args_parser: Optional[ArgumentHandler] = None, 61 | device: int = -1, 62 | binary_output: bool = False, 63 | ): 64 | super(ExamplePosTaggingPipeline, self).__init__( 65 | model=model, 66 | tokenizer=tokenizer, 67 | data_processor=data_processor, 68 | data_adapter=data_adapter, 69 | data_collator=data_collator, 70 | feature_extractor=feature_extractor, 71 | modelcard=modelcard, 72 | framework=framework, 73 | task=task, 74 | args_parser=args_parser, 75 | device=device, 76 | binary_output=binary_output, 77 | ) 78 | self._args_parser = TokenClassificationArgumentHandler() 79 | 80 | 81 | SUPPORTED_TASKS["pos_tagging_example"] = { 82 | "impl": ExamplePosTaggingPipeline, 83 | "pt": PreTrainedModel, 84 | } 85 | -------------------------------------------------------------------------------- /examples/pos_tagging/test_fixtures/hf_datasets/conll2003_test_fixture/test.txt: -------------------------------------------------------------------------------- 1 | Japan NNP B-NP B-LOC 2 | began VBD B-VP O 3 | the DT B-NP O 4 | defence NN I-NP O 5 | of IN B-PP O 6 | their PRP$ B-NP O 7 | Asian JJ I-NP B-MISC 8 | Cup NNP I-NP I-MISC 9 | title NN I-NP O 10 | with IN B-PP O 11 | a DT B-NP O 12 | lucky JJ I-NP O 13 | 2-1 CD I-NP O 14 | win VBP B-VP O 15 | against IN B-PP O 16 | Syria NNP B-NP B-LOC 17 | in IN B-PP O 18 | a DT B-NP O 19 | Group NNP I-NP O 20 | C NNP I-NP O 21 | championship NN I-NP O 22 | match NN I-NP O 23 | on IN B-PP O 24 | Friday NNP B-NP O 25 | . . O O 26 | -------------------------------------------------------------------------------- /examples/pos_tagging/test_fixtures/hf_datasets/conll2003_test_fixture/train.txt: -------------------------------------------------------------------------------- 1 | The DT B-NP O 2 | European NNP I-NP B-ORG 3 | Commission NNP I-NP I-ORG 4 | said VBD B-VP O 5 | on IN B-PP O 6 | Thursday NNP B-NP O 7 | it PRP B-NP O 8 | disagreed VBD B-VP O 9 | with IN B-PP O 10 | German JJ B-NP B-MISC 11 | advice NN I-NP O 12 | to TO B-PP O 13 | consumers NNS B-NP O 14 | to TO B-VP O 15 | shun VB I-VP O 16 | British JJ B-NP B-MISC 17 | lamb NN I-NP O 18 | until IN B-SBAR O 19 | scientists NNS B-NP O 20 | determine VBP B-VP O 21 | whether IN B-SBAR O 22 | mad JJ B-NP O 23 | cow NN I-NP O 24 | disease NN I-NP O 25 | can MD B-VP O 26 | be VB I-VP O 27 | transmitted VBN I-VP O 28 | to TO B-PP O 29 | sheep NN B-NP O 30 | . . O O 31 | -------------------------------------------------------------------------------- /examples/pos_tagging/test_fixtures/hf_datasets/conll2003_test_fixture/valid.txt: -------------------------------------------------------------------------------- 1 | West NNP B-NP B-MISC 2 | Indian NNP I-NP I-MISC 3 | all-rounder NN I-NP O 4 | Phil NNP I-NP B-PER 5 | Simmons NNP I-NP I-PER 6 | took VBD B-VP O 7 | four CD B-NP O 8 | for IN B-PP O 9 | 38 CD B-NP O 10 | on IN B-PP O 11 | Friday NNP B-NP O 12 | as IN B-PP O 13 | Leicestershire NNP B-NP B-ORG 14 | beat VBD B-VP O 15 | Somerset NNP B-NP B-ORG 16 | by IN B-PP O 17 | an DT B-NP O 18 | innings NN I-NP O 19 | and CC O O 20 | 39 CD B-NP O 21 | runs NNS I-NP O 22 | in IN B-PP O 23 | two CD B-NP O 24 | days NNS I-NP O 25 | to TO B-VP O 26 | take VB I-VP O 27 | over IN B-PP O 28 | at IN B-PP O 29 | the DT B-NP O 30 | head NN I-NP O 31 | of IN B-PP O 32 | the DT B-NP O 33 | county NN I-NP O 34 | championship NN I-NP O 35 | . . O O 36 | -------------------------------------------------------------------------------- /examples/pos_tagging/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/obss/trapper/bbc82450097bd48466d2c47cfbb3bb194319410a/examples/pos_tagging/tests/__init__.py -------------------------------------------------------------------------------- /examples/pos_tagging/tests/conftest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from src import POS_TAGGING_FIXTURES_ROOT 3 | 4 | from trapper.common import Params 5 | 6 | # noinspection PyUnresolvedReferences 7 | # pylint: disable=unused-import 8 | from trapper.common.testing_utils.pytest_fixtures import ( 9 | create_data_collator_args, 10 | create_data_processor_args, 11 | get_raw_dataset, 12 | make_data_collator, 13 | make_sequential_sampler, 14 | temp_output_dir, 15 | temp_result_dir, 16 | ) 17 | 18 | _HF_DATASETS_FIXTURES_ROOT = POS_TAGGING_FIXTURES_ROOT / "hf_datasets" 19 | 20 | 21 | @pytest.fixture(scope="package") 22 | def get_hf_datasets_fixture_path(): 23 | def _get_hf_datasets_fixture_path(dataset: str) -> str: 24 | return str(_HF_DATASETS_FIXTURES_ROOT / dataset) 25 | 26 | return _get_hf_datasets_fixture_path 27 | 28 | 29 | @pytest.fixture(scope="module") 30 | def experiment_params( 31 | temp_output_dir, temp_result_dir, get_hf_datasets_fixture_path 32 | ): 33 | params_dict = { 34 | "pretrained_model_name_or_path": "distilbert-base-uncased", 35 | "train_split_name": "train", 36 | "dev_split_name": "validation", 37 | "tokenizer_wrapper": { 38 | "type": "pos_tagging_example", 39 | "add_prefix_space": True, 40 | }, 41 | "dataset_loader": { 42 | "dataset_reader": { 43 | "path": get_hf_datasets_fixture_path("conll2003_test_fixture"), 44 | }, 45 | "data_processor": { 46 | "type": "conll2003_pos_tagging_example", 47 | "model_max_sequence_length": 512, 48 | }, 49 | "data_adapter": {"type": "conll2003_pos_tagging_example"}, 50 | }, 51 | "data_collator": {}, 52 | "model_wrapper": {"type": "token_classification", "num_labels": 47}, 53 | "compute_metrics": {"metric_params": "seqeval"}, 54 | "metric_input_handler": {"type": "token-classification"}, 55 | "metric_output_handler": {"type": "default"}, 56 | "label_mapper": {"type": "conll2003_pos_tagging_example"}, 57 | "args": { 58 | "type": "default", 59 | "output_dir": temp_output_dir + "/checkpoints", 60 | "result_dir": temp_result_dir, 61 | "num_train_epochs": 3, 62 | "per_device_train_batch_size": 1, 63 | "per_device_eval_batch_size": 1, 64 | "logging_dir": temp_output_dir + "/logs", 65 | "no_cuda": True, 66 | "logging_steps": 1, 67 | "evaluation_strategy": "steps", 68 | "save_steps": 2, 69 | "label_names": ["labels"], 70 | "lr_scheduler_type": "linear", 71 | "warmup_steps": 2, 72 | "do_train": True, 73 | "do_eval": True, 74 | "save_total_limit": 1, 75 | "metric_for_best_model": "eval_loss", 76 | "greater_is_better": False, 77 | "seed": 100, 78 | }, 79 | "optimizer": { 80 | "type": "huggingface_adamw", 81 | "weight_decay": 0.01, 82 | "parameter_groups": [ 83 | [ 84 | ["bias", "LayerNorm\\\\.weight", "layer_norm\\\\.weight"], 85 | {"weight_decay": 0}, 86 | ] 87 | ], 88 | "lr": 5e-5, 89 | "eps": 1e-8, 90 | }, 91 | } 92 | return Params(params_dict) 93 | -------------------------------------------------------------------------------- /examples/pos_tagging/tests/test_data_adapter.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from src.data.data_adapter import ExampleDataAdapterForPosTagging 3 | from src.data.data_processor import ExampleConll2003PosTaggingDataProcessor 4 | from src.data.tokenizer_wrapper import ExamplePosTaggingTokenizerWrapper 5 | 6 | from trapper.common.constants import IGNORED_LABEL_ID 7 | from trapper.data import InputBatch 8 | 9 | 10 | @pytest.fixture(scope="module") 11 | def data_collator_args(create_data_collator_args): 12 | return create_data_collator_args( 13 | task_type="token_classification", 14 | train_batch_size=1, 15 | validation_batch_size=1, 16 | is_distributed=False, 17 | model_max_sequence_length=512, 18 | tokenizer_factory=ExamplePosTaggingTokenizerWrapper, 19 | tokenizer_model_name="roberta-base", 20 | add_prefix_space=True, 21 | ) 22 | 23 | 24 | @pytest.fixture(scope="module") 25 | def raw_conll03_postagging_dataset(get_raw_dataset, get_hf_datasets_fixture_path): 26 | return get_raw_dataset( 27 | path=get_hf_datasets_fixture_path("conll2003_test_fixture"), split="train" 28 | ) 29 | 30 | 31 | @pytest.fixture(scope="module") 32 | def adapted_conll03_postagging_dataset( 33 | raw_conll03_postagging_dataset, data_collator_args 34 | ): 35 | data_adapter = ExampleDataAdapterForPosTagging( 36 | data_collator_args.tokenizer_wrapper 37 | ) 38 | data_processor = ExampleConll2003PosTaggingDataProcessor( 39 | data_collator_args.tokenizer_wrapper 40 | ) 41 | processed_dataset = raw_conll03_postagging_dataset.map(data_processor) 42 | return processed_dataset.map(data_adapter) 43 | 44 | 45 | @pytest.fixture(scope="module") 46 | def data_collator(make_data_collator, data_collator_args): 47 | return make_data_collator(data_collator_args) 48 | 49 | 50 | def test_batch_content_on_squad_dev_dataset( 51 | raw_conll03_postagging_dataset, 52 | adapted_conll03_postagging_dataset, 53 | data_collator_args, 54 | data_collator, 55 | ): 56 | expected_sentence = "The European Commission said on Thursday it disagreed with German advice to consumers to shun British lamb until scientists determine whether mad cow disease can be transmitted to sheep." 57 | if data_collator_args.is_tokenizer_uncased: 58 | expected_sentence = expected_sentence.lower() 59 | 60 | collated_batch = data_collator.build_model_inputs( 61 | adapted_conll03_postagging_dataset 62 | ) 63 | input_ids = collated_batch["input_ids"][0] 64 | labels = collated_batch["labels"][0] 65 | 66 | tokenizer = data_collator_args.tokenizer_wrapper.tokenizer 67 | decoded_sentence = tokenizer.decode( 68 | input_ids, skip_special_tokens=True 69 | ).lstrip() 70 | assert expected_sentence == decoded_sentence 71 | assert len(input_ids) == len(labels) 72 | 73 | encoding = tokenizer(expected_sentence, add_special_tokens=False) 74 | raw_pos_tags = [ 75 | 12, 76 | 22, 77 | 22, 78 | 38, 79 | 15, 80 | 22, 81 | 28, 82 | 38, 83 | 15, 84 | 16, 85 | 21, 86 | 35, 87 | 24, 88 | 35, 89 | 37, 90 | 16, 91 | 21, 92 | 15, 93 | 24, 94 | 41, 95 | 15, 96 | 16, 97 | 21, 98 | 21, 99 | 20, 100 | 37, 101 | 40, 102 | 35, 103 | 21, 104 | 7, 105 | ] 106 | expected_labels = [raw_pos_tags[ind] for ind in encoding.word_ids()] 107 | expected_labels.insert(0, IGNORED_LABEL_ID) # BOS 108 | expected_labels.append(IGNORED_LABEL_ID) # EOS 109 | assert expected_labels == labels 110 | validate_attention_mask(collated_batch) 111 | 112 | 113 | def validate_attention_mask(instance_batch: InputBatch): 114 | for input_ids, attention_mask in zip( 115 | instance_batch["input_ids"], instance_batch["attention_mask"] 116 | ): 117 | assert len(attention_mask) == len(input_ids) 118 | assert all(val == 1 for val in attention_mask) 119 | -------------------------------------------------------------------------------- /examples/pos_tagging/tests/test_data_processor.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from src.data.data_processor import ExampleConll2003PosTaggingDataProcessor 3 | from src.data.tokenizer_wrapper import ExamplePosTaggingTokenizerWrapper 4 | 5 | 6 | @pytest.fixture(scope="module") 7 | def args(create_data_processor_args): 8 | return create_data_processor_args( 9 | tokenizer_factory=ExamplePosTaggingTokenizerWrapper, 10 | tokenizer_model_name="roberta-base", 11 | model_max_sequence_length=512, 12 | add_prefix_space=True, 13 | ) 14 | 15 | 16 | def test_data_processor(get_raw_dataset, args, get_hf_datasets_fixture_path): 17 | expected_sentence = "The European Commission said on Thursday it disagreed with German advice to consumers to shun British lamb until scientists determine whether mad cow disease can be transmitted to sheep." 18 | if args.is_tokenizer_uncased: 19 | expected_sentence = expected_sentence.lower() 20 | data_processor = ExampleConll2003PosTaggingDataProcessor(args.tokenizer_wrapper) 21 | raw_dataset = get_raw_dataset( 22 | path=get_hf_datasets_fixture_path("conll2003_test_fixture"), split="train" 23 | ) 24 | processed_instance = raw_dataset.map(data_processor)[0] 25 | tokenizer = args.tokenizer_wrapper.tokenizer 26 | decoded_sentence = tokenizer.decode(processed_instance["tokens"]).lstrip() 27 | assert expected_sentence == decoded_sentence 28 | assert len(processed_instance["tokens"]) == len(processed_instance["pos_tags"]) 29 | 30 | encoding = tokenizer(expected_sentence, add_special_tokens=False) 31 | raw_pos_tags = [ 32 | 12, 33 | 22, 34 | 22, 35 | 38, 36 | 15, 37 | 22, 38 | 28, 39 | 38, 40 | 15, 41 | 16, 42 | 21, 43 | 35, 44 | 24, 45 | 35, 46 | 37, 47 | 16, 48 | 21, 49 | 15, 50 | 24, 51 | 41, 52 | 15, 53 | 16, 54 | 21, 55 | 21, 56 | 20, 57 | 37, 58 | 40, 59 | 35, 60 | 21, 61 | 7, 62 | ] 63 | expected_pos_tags = [raw_pos_tags[ind] for ind in encoding.word_ids()] 64 | assert expected_pos_tags == processed_instance["pos_tags"] 65 | -------------------------------------------------------------------------------- /examples/pos_tagging/tests/test_pipeline.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | import pytest 4 | from deepdiff import DeepDiff 5 | from transformers import set_seed 6 | 7 | from trapper.pipelines import create_pipeline_from_params 8 | 9 | 10 | @pytest.fixture(scope="module") 11 | def distilbert_conll_pipeline(experiment_params): 12 | set_seed(100) 13 | return create_pipeline_from_params( 14 | experiment_params, 15 | pipeline_type="example-pos-tagging", 16 | ) 17 | 18 | 19 | @pytest.fixture(scope="module") 20 | def distilbert_pipeline_sample_input(): 21 | return ["I love Istanbul."] 22 | 23 | 24 | @pytest.fixture(scope="module") 25 | def distilbert_pipeline_expected_output(): 26 | return [ 27 | [ 28 | { 29 | "entity": "LABEL_12", 30 | "score": 0.035119053, 31 | "index": 1, 32 | "word": "i", 33 | "start": 0, 34 | "end": 1, 35 | }, 36 | { 37 | "entity": "LABEL_12", 38 | "score": 0.036859084, 39 | "index": 2, 40 | "word": "love", 41 | "start": 2, 42 | "end": 6, 43 | }, 44 | { 45 | "entity": "LABEL_46", 46 | "score": 0.03283123, 47 | "index": 3, 48 | "word": "istanbul", 49 | "start": 7, 50 | "end": 15, 51 | }, 52 | { 53 | "entity": "LABEL_27", 54 | "score": 0.040444903, 55 | "index": 4, 56 | "word": ".", 57 | "start": 15, 58 | "end": 16, 59 | }, 60 | ] 61 | ] 62 | 63 | 64 | def test_distilbert_conll_pipeline_execution( 65 | distilbert_conll_pipeline, 66 | distilbert_pipeline_sample_input, 67 | distilbert_pipeline_expected_output, 68 | ): 69 | actual_output = distilbert_conll_pipeline(distilbert_pipeline_sample_input) 70 | diff = DeepDiff( 71 | distilbert_pipeline_expected_output, 72 | actual_output, 73 | significant_digits=3, 74 | ignore_numeric_type_changes=True, 75 | ) 76 | assert ( 77 | diff == {} 78 | ), f"Actual and Desired Dicts are not Almost Equal:\n {json.dumps(diff, indent=2)}" 79 | -------------------------------------------------------------------------------- /examples/pos_tagging/tests/test_trainer.py: -------------------------------------------------------------------------------- 1 | import datasets 2 | import pytest 3 | 4 | # needed for registering the data-related classes 5 | # noinspection PyUnresolvedReferences 6 | # pylint: disable=unused-import 7 | import src 8 | from transformers import DistilBertForTokenClassification, DistilBertTokenizerFast 9 | 10 | from trapper.data.data_collator import DataCollator 11 | from trapper.training import TransformerTrainer, TransformerTrainingArguments 12 | from trapper.training.optimizers import HuggingfaceAdamWOptimizer 13 | from trapper.training.train import run_experiment_using_trainer 14 | 15 | 16 | @pytest.fixture(scope="module") 17 | def trainer(experiment_params) -> TransformerTrainer: 18 | return TransformerTrainer.from_params(experiment_params) 19 | 20 | 21 | def test_trainer_fields(trainer): 22 | assert isinstance(trainer, TransformerTrainer) 23 | assert isinstance(trainer.model, DistilBertForTokenClassification) 24 | assert isinstance(trainer.args, TransformerTrainingArguments) 25 | assert isinstance(trainer.data_collator, DataCollator) 26 | assert isinstance(trainer.train_dataset, datasets.Dataset) 27 | assert isinstance(trainer.eval_dataset, datasets.Dataset) 28 | assert isinstance(trainer.tokenizer, DistilBertTokenizerFast) 29 | assert isinstance(trainer.optimizer, HuggingfaceAdamWOptimizer) 30 | 31 | 32 | def test_trainer_can_train(trainer): 33 | run_experiment_using_trainer(trainer) 34 | -------------------------------------------------------------------------------- /examples/pos_tagging/train.py: -------------------------------------------------------------------------------- 1 | """ 2 | This scripts shows how to train a POS tagger using trapper. However, 3 | intentionally, the dependency injection mechanism is not used, instead the 4 | objects are instantiated directly from concrete classes suitable for the job. 5 | Although we recommended config file based training using the trapper CLI instead, 6 | you can use this file to create a custom training script for your needs. 7 | """ 8 | -------------------------------------------------------------------------------- /examples/question_answering/README.md: -------------------------------------------------------------------------------- 1 | ## Question Answering Demo 2 | 3 | Open in Colab 4 | 5 | 6 | This notebook serves as an example for demonstrating training and inference using `trapper`. Question-answering task is supported by `trapper` already, and thus in this notebook we only give a basic [configuration file](./experiment.jsonnet) and let the trapper take care of the rest. For implementation of a desired task using trapper, see [Pos tagging example](../pos_tagging). 7 | -------------------------------------------------------------------------------- /examples/question_answering/experiment.jsonnet: -------------------------------------------------------------------------------- 1 | local checkpoint_dir = std.extVar("CHECKPOINT_PATH"); 2 | local result_dir = std.extVar("OUTPUT_PATH"); 3 | { 4 | "train_split_name": "train", 5 | "dev_split_name": "validation", 6 | "pretrained_model_name_or_path": "roberta-base", 7 | "tokenizer_wrapper": { 8 | "type": "question-answering" 9 | }, 10 | "dataset_loader": { 11 | "type": "default", 12 | "dataset_reader": { 13 | "type": "default", 14 | "path": "squad_qa_test_fixture" 15 | }, 16 | "data_processor": { 17 | "type": "squad-question-answering" 18 | }, 19 | "data_adapter": { 20 | "type": "question-answering" 21 | } 22 | }, 23 | "data_collator":{ 24 | "type": "default" 25 | }, 26 | "model_wrapper": { 27 | "type": "question_answering" 28 | }, 29 | "metric_input_handler": { 30 | "type": "question-answering" 31 | }, 32 | "compute_metrics": { 33 | "metric_params": [ 34 | "squad" 35 | ] 36 | }, 37 | "args": { 38 | "type": "default", 39 | "output_dir": checkpoint_dir, 40 | "result_dir": result_dir, 41 | "num_train_epochs": 10, 42 | "per_device_train_batch_size": 2, 43 | "gradient_accumulation_steps": 12, 44 | "per_device_eval_batch_size": 2, 45 | "logging_dir": checkpoint_dir + "/logs", 46 | "no_cuda": false, 47 | "logging_steps": 500, 48 | "evaluation_strategy": "steps", 49 | "save_steps": 500, 50 | "label_names": ["start_positions", "end_positions"], 51 | "lr_scheduler_type": "linear", 52 | "warmup_steps": 500, 53 | "do_train": true, 54 | "do_eval": true, 55 | "save_total_limit": 1 56 | }, 57 | "optimizer": { 58 | "type": "huggingface_adamw", 59 | "weight_decay": 0.01, 60 | "parameter_groups": [ 61 | [["bias", "LayerNorm\\\\.weight", "layer_norm\\\\.weight"], 62 | {"weight_decay": 0}]], 63 | "lr": 5e-5, 64 | "eps": 1e-6 65 | } 66 | } -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | line-length = 84 3 | exclude = ''' 4 | ( 5 | ^/( 6 | | .git 7 | | .github 8 | | .idea 9 | | .pytest_cache 10 | | .venv 11 | | resources 12 | | trapper.egg-info 13 | | venv 14 | | __pycache__ 15 | )/ 16 | ) 17 | ''' 18 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | allennlp>=2.10.0,<2.11 2 | datasets>=2.3.0,<2.5 3 | deepdiff>=5.2.0 4 | jury>=2.2.3,<2.3 5 | numpy>=1.21.2 6 | seqeval==1.2.2 7 | tensorboardX==2.1 8 | transformers>=4.18,<4.21 9 | -------------------------------------------------------------------------------- /resources/trapper_diagram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/obss/trapper/bbc82450097bd48466d2c47cfbb3bb194319410a/resources/trapper_diagram.png -------------------------------------------------------------------------------- /scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/obss/trapper/bbc82450097bd48466d2c47cfbb3bb194319410a/scripts/__init__.py -------------------------------------------------------------------------------- /scripts/cache_hf_datasets_fixtures.py: -------------------------------------------------------------------------------- 1 | """ 2 | Caches the tests dataset to HuggingFace's `datasets` library's cache so that the 3 | interpreter can find it when we try to load it through the `datasets` library. 4 | """ 5 | from trapper import FIXTURES_ROOT 6 | from trapper.common.testing_utils.hf_datasets_caching import ( 7 | renew_hf_datasets_fixtures_cache, 8 | ) 9 | 10 | if __name__ == "__main__": 11 | renew_hf_datasets_fixtures_cache(FIXTURES_ROOT / "hf_datasets") 12 | -------------------------------------------------------------------------------- /scripts/run_code_style.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | from trapper.common.testing_utils.shell_utils import shell, validate_and_exit 4 | 5 | if __name__ == "__main__": 6 | arg = sys.argv[1] 7 | 8 | if arg == "check": 9 | sts_flake = shell("flake8 trapper tests examples --config setup.cfg") 10 | sts_isort = shell("isort . --check --settings setup.cfg") 11 | sts_black = shell("black . --check --config pyproject.toml") 12 | validate_and_exit(flake8=sts_flake, isort=sts_isort, black=sts_black) 13 | elif arg == "format": 14 | sts_isort = shell("isort . --settings setup.cfg") 15 | sts_black = shell("black . --config pyproject.toml") 16 | validate_and_exit(isort=sts_isort, black=sts_black) 17 | -------------------------------------------------------------------------------- /scripts/run_tests.py: -------------------------------------------------------------------------------- 1 | from trapper.common.testing_utils.shell_utils import shell, validate_and_exit 2 | 3 | if __name__ == "__main__": 4 | sts_tests = shell( 5 | "pytest --cov trapper --cov-report term-missing --cov-report xml -vvv tests" 6 | ) 7 | sts_tests_examples = shell( 8 | "cd examples/pos_tagging && python -m scripts.run_tests" 9 | ) 10 | validate_and_exit(tests=sts_tests, tests_examples=sts_tests_examples) 11 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 84 3 | select = E9,F63,F7,F82 4 | per-file-ignores = __init__.py: F401 5 | max-complexity = 10 6 | 7 | [isort] 8 | line_length=84 9 | profile=black -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | VERSION = {} # type: ignore 4 | with open("trapper/version.py", "r") as version_file: 5 | exec(version_file.read(), VERSION) 6 | 7 | 8 | def get_requirements(): 9 | with open("requirements.txt") as f: 10 | return f.read().splitlines() 11 | 12 | 13 | extras_require = { 14 | "dev": [ 15 | "black==22.3.0", 16 | "flake8==3.9.2", 17 | "isort==5.9.2", 18 | "pytest>=6.2.4", 19 | "importlib-metadata>=1.1.0,<4.3;python_version<'3.8'", 20 | "pytest-cov>=2.12.1", 21 | "pylint>=2.11", 22 | "mypy>=0.9", 23 | ], 24 | } 25 | 26 | setup( 27 | name="trapper", 28 | version=VERSION["VERSION"], 29 | author="OBSS", 30 | url="https://github.com/obss/trapper", 31 | description="State-of-the-art NLP through transformer models in a modular design and consistent APIs.", 32 | long_description=open("README.md").read(), 33 | long_description_content_type="text/markdown", 34 | packages=find_packages( 35 | exclude=[ 36 | "*.tests", 37 | "*.tests.*", 38 | "tests.*", 39 | "tests", 40 | "test_fixtures", 41 | "test_fixtures.*", 42 | "scripts", 43 | "scripts.*", 44 | ] 45 | ), 46 | entry_points={"console_scripts": ["trapper=trapper.__main__:run"]}, 47 | python_requires=">=3.7.1", 48 | install_requires=get_requirements(), 49 | extras_require=extras_require, 50 | include_package_data=True, 51 | classifiers=[ 52 | "Operating System :: OS Independent", 53 | "Intended Audience :: Developers", 54 | "Intended Audience :: Science/Research", 55 | "Programming Language :: Python :: 3", 56 | "Programming Language :: Python :: 3.7", 57 | "Programming Language :: Python :: 3.8", 58 | "Programming Language :: Python :: 3.9", 59 | "Topic :: Software Development :: Libraries", 60 | "Topic :: Software Development :: Libraries :: Python Modules", 61 | "Topic :: Scientific/Engineering", 62 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 63 | ], 64 | keywords="python, nlp, natural-language-processing, deep-learning, transformer, pytorch, transformers, allennlp, pytorch-transformers", 65 | ) 66 | -------------------------------------------------------------------------------- /test_fixtures/commands/experiment_trivial.jsonnet: -------------------------------------------------------------------------------- 1 | { 2 | "pretrained_model_name_or_path": "albert-base-v2", 3 | "train_split_name": "train", 4 | "dev_split_name": "validation", 5 | "tokenizer_wrapper": { 6 | "type": "question-answering" 7 | }, 8 | "dataset_loader": { 9 | "type": "default", 10 | "dataset_reader": { 11 | "type": "default", 12 | "path": "test_fixtures/hf_datasets/squad_qa_test_fixture" 13 | }, 14 | "data_processor": { 15 | "type": "squad-question-answering" 16 | }, 17 | "data_adapter": { 18 | "type": "question-answering" 19 | } 20 | }, 21 | "data_collator":{ 22 | "type": "default" 23 | }, 24 | "model_wrapper": { 25 | "type": "question_answering" 26 | }, 27 | "metric_input_handler": { 28 | "type": "question-answering" 29 | }, 30 | "metric_output_handler": {"type": "default"}, 31 | "compute_metrics": { 32 | "metric_params": [ 33 | "squad" 34 | ] 35 | }, 36 | "args": { 37 | "type": "default", 38 | "num_train_epochs": 10, 39 | "per_device_train_batch_size": 2, 40 | "gradient_accumulation_steps": 12, 41 | "per_device_eval_batch_size": 2, 42 | "no_cuda": true, 43 | "logging_steps": 500, 44 | "evaluation_strategy": "steps", 45 | "save_steps": 500, 46 | "label_names": ["start_positions", "end_positions"], 47 | "lr_scheduler_type": "linear", 48 | "warmup_steps": 500, 49 | "do_train": false, 50 | "do_eval": false, 51 | "save_total_limit": 1 52 | }, 53 | "optimizer": { 54 | "type": "huggingface_adamw", 55 | "weight_decay": 0.01, 56 | "parameter_groups": [ 57 | [["bias", "LayerNorm\\\\.weight", "layer_norm\\\\.weight"], 58 | {"weight_decay": 0}]], 59 | "lr": 5e-5, 60 | "eps": 1e-6 61 | } 62 | } -------------------------------------------------------------------------------- /test_fixtures/data/question_answering/squad_qa/dev.json: -------------------------------------------------------------------------------- 1 | { 2 | "data": [ 3 | { 4 | "title": "University_of_Notre_Dame", 5 | "paragraphs": [ 6 | { 7 | "context": "Super Bowl 50 was an American football game to determine the champion of the National Football League (NFL) for the 2015 season. The American Football Conference (AFC) champion Denver Broncos defeated the National Football Conference (NFC) champion Carolina Panthers 24\u201310 to earn their third Super Bowl title. The game was played on February 7, 2016, at Levi's Stadium in the San Francisco Bay Area at Santa Clara, California. As this was the 50th Super Bowl, the league emphasized the \"golden anniversary\" with various gold-themed initiatives, as well as temporarily suspending the tradition of naming each Super Bowl game with Roman numerals (under which the game would have been known as \"Super Bowl L\"), so that the logo could prominently feature the Arabic numerals 50.", 8 | "qas": [ 9 | { 10 | "answers": [ 11 | { 12 | "text": "Denver Broncos", 13 | "answer_start": 177 14 | } 15 | ], 16 | "question": "Which NFL team represented the AFC at Super Bowl 50?", 17 | "id": "56be4db0acb8001400a502ec" 18 | }, 19 | { 20 | "answers": [ 21 | { 22 | "text": "Carolina Panthers", 23 | "answer_start": 249 24 | } 25 | ], 26 | "question": "Which NFL team represented the NFC at Super Bowl 50?", 27 | "id": "56be4db0acb8001400a502ed" 28 | }, 29 | { 30 | "answers": [ 31 | { 32 | "text": "Santa Clara, California", 33 | "answer_start": 403 34 | } 35 | ], 36 | "question": "Where did Super Bowl 50 take place?", 37 | "id": "56be4db0acb8001400a502ee" 38 | } 39 | ] 40 | }, 41 | { 42 | "context": "Architecturally, the school has a Catholic character. Atop the Main Building's gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend \"Venite Ad Me Omnes\". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive (and in a direct line that connects through 3 statues and the Gold Dome), is a simple, modern stone statue of Mary.", 43 | "qas": [ 44 | { 45 | "answers": [ 46 | { 47 | "answer_start": 515, 48 | "text": "Saint Bernadette Soubirous" 49 | } 50 | ], 51 | "question": "To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France?", 52 | "id": "5733be284776f41900661182" 53 | }, 54 | { 55 | "answers": [ 56 | { 57 | "answer_start": 92, 58 | "text": "a golden statue of the Virgin Mary" 59 | }, 60 | { 61 | "answer_start": 92, 62 | "text": "a golden statue of the Virgin Mary" 63 | } 64 | ], 65 | "question": "What sits on top of the Main Building at Notre Dame?", 66 | "id": "5733be284776f4190066117e" 67 | } 68 | ] 69 | }, 70 | { 71 | "context": "In 1882, Albert Zahm (John Zahm's brother) built an early wind tunnel used to compare lift to drag of aeronautical models. Around 1899, Professor Jerome Green became the first American to send a wireless message. In 1931, Father Julius Nieuwland performed early work on basic reactions that was used to create neoprene. Study of nuclear physics at the university began with the building of a nuclear accelerator in 1936, and continues now partly through a partnership in the Joint Institute for Nuclear Astrophysics.", 72 | "qas": [ 73 | { 74 | "answers": [ 75 | { 76 | "answer_start": 3, 77 | "text": "1882" 78 | } 79 | ], 80 | "question": "In what year did Albert Zahm begin comparing aeronatical models at Notre Dame?", 81 | "id": "5733b1da4776f41900661068" 82 | } 83 | ] 84 | } 85 | ] 86 | } 87 | ] 88 | } -------------------------------------------------------------------------------- /test_fixtures/data/question_answering/squad_qa/train.json: -------------------------------------------------------------------------------- 1 | { 2 | "data": [ 3 | { 4 | "title": "University_of_Notre_Dame", 5 | "paragraphs": [ 6 | { 7 | "context": "Architecturally, the school has a Catholic character. Atop the Main Building's gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend \"Venite Ad Me Omnes\". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive (and in a direct line that connects through 3 statues and the Gold Dome), is a simple, modern stone statue of Mary.", 8 | "qas": [ 9 | { 10 | "answers": [ 11 | { 12 | "answer_start": 515, 13 | "text": "Saint Bernadette Soubirous" 14 | } 15 | ], 16 | "question": "To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France?", 17 | "id": "5733be284776f41900661182" 18 | }, 19 | { 20 | "answers": [ 21 | { 22 | "answer_start": 92, 23 | "text": "a golden statue of the Virgin Mary" 24 | }, 25 | { 26 | "answer_start": 92, 27 | "text": "a golden statue of the Virgin Mary" 28 | }, 29 | { 30 | "answer_start": 0, 31 | "text": "Architecturally, the school has a Catholic character. Atop the Main Building's gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend \"Venite Ad Me Omnes\". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive (and in a direct line that connects through 3 statues and the Gold Dome), is a simple, modern stone statue of Mary." 32 | } 33 | ], 34 | "question": "What sits on top of the Main Building at Notre Dame?", 35 | "id": "5733be284776f4190066117e" 36 | } 37 | ] 38 | }, 39 | { 40 | "context": "In 1882, Albert Zahm (John Zahm's brother) built an early wind tunnel used to compare lift to drag of aeronautical models. Around 1899, Professor Jerome Green became the first American to send a wireless message. In 1931, Father Julius Nieuwland performed early work on basic reactions that was used to create neoprene. Study of nuclear physics at the university began with the building of a nuclear accelerator in 1936, and continues now partly through a partnership in the Joint Institute for Nuclear Astrophysics.", 41 | "qas": [ 42 | { 43 | "answers": [ 44 | { 45 | "answer_start": 3, 46 | "text": "1882" 47 | } 48 | ], 49 | "question": "In what year did Albert Zahm begin comparing aeronatical models at Notre Dame?", 50 | "id": "5733b1da4776f41900661068" 51 | }, 52 | { 53 | "answers": [ 54 | { 55 | "answer_start": 222, 56 | "text": "Father Julius Nieuwl" 57 | } 58 | ], 59 | "question": "Which individual worked on projects at Notre Dame that eventually created neoprene?", 60 | "id": "5733b1da4776f4190066106b" 61 | }, 62 | { 63 | "answers": [ 64 | { 65 | "answer_start": 49, 66 | "text": "an early wind tunnel" 67 | } 68 | ], 69 | "question": "What did the brother of John Zahm construct at Notre Dame?", 70 | "id": "5733b1da4776f41900661067" 71 | } 72 | ] 73 | } 74 | ] 75 | } 76 | ] 77 | } -------------------------------------------------------------------------------- /test_fixtures/hf_datasets/squad_qa_test_fixture/dev.json: -------------------------------------------------------------------------------- 1 | { 2 | "data": [ 3 | { 4 | "title": "University_of_Notre_Dame", 5 | "paragraphs": [ 6 | { 7 | "context": "Super Bowl 50 was an American football game to determine the champion of the National Football League (NFL) for the 2015 season. The American Football Conference (AFC) champion Denver Broncos defeated the National Football Conference (NFC) champion Carolina Panthers 24\u201310 to earn their third Super Bowl title. The game was played on February 7, 2016, at Levi's Stadium in the San Francisco Bay Area at Santa Clara, California. As this was the 50th Super Bowl, the league emphasized the \"golden anniversary\" with various gold-themed initiatives, as well as temporarily suspending the tradition of naming each Super Bowl game with Roman numerals (under which the game would have been known as \"Super Bowl L\"), so that the logo could prominently feature the Arabic numerals 50.", 8 | "qas": [ 9 | { 10 | "answers": [ 11 | { 12 | "text": "Denver Broncos", 13 | "answer_start": 177 14 | } 15 | ], 16 | "question": "Which NFL team represented the AFC at Super Bowl 50?", 17 | "id": "56be4db0acb8001400a502ec" 18 | }, 19 | { 20 | "answers": [ 21 | { 22 | "text": "Carolina Panthers", 23 | "answer_start": 249 24 | } 25 | ], 26 | "question": "Which NFL team represented the NFC at Super Bowl 50?", 27 | "id": "56be4db0acb8001400a502ed" 28 | }, 29 | { 30 | "answers": [ 31 | { 32 | "text": "Santa Clara, California", 33 | "answer_start": 403 34 | } 35 | ], 36 | "question": "Where did Super Bowl 50 take place?", 37 | "id": "56be4db0acb8001400a502ee" 38 | } 39 | ] 40 | }, 41 | { 42 | "context": "Architecturally, the school has a Catholic character. Atop the Main Building's gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend \"Venite Ad Me Omnes\". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive (and in a direct line that connects through 3 statues and the Gold Dome), is a simple, modern stone statue of Mary.", 43 | "qas": [ 44 | { 45 | "answers": [ 46 | { 47 | "answer_start": 515, 48 | "text": "Saint Bernadette Soubirous" 49 | } 50 | ], 51 | "question": "To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France?", 52 | "id": "5733be284776f41900661182" 53 | }, 54 | { 55 | "answers": [ 56 | { 57 | "answer_start": 92, 58 | "text": "a golden statue of the Virgin Mary" 59 | }, 60 | { 61 | "answer_start": 92, 62 | "text": "a golden statue of the Virgin Mary" 63 | } 64 | ], 65 | "question": "What sits on top of the Main Building at Notre Dame?", 66 | "id": "5733be284776f4190066117e" 67 | } 68 | ] 69 | }, 70 | { 71 | "context": "In 1882, Albert Zahm (John Zahm's brother) built an early wind tunnel used to compare lift to drag of aeronautical models. Around 1899, Professor Jerome Green became the first American to send a wireless message. In 1931, Father Julius Nieuwland performed early work on basic reactions that was used to create neoprene. Study of nuclear physics at the university began with the building of a nuclear accelerator in 1936, and continues now partly through a partnership in the Joint Institute for Nuclear Astrophysics.", 72 | "qas": [ 73 | { 74 | "answers": [ 75 | { 76 | "answer_start": 3, 77 | "text": "1882" 78 | } 79 | ], 80 | "question": "In what year did Albert Zahm begin comparing aeronatical models at Notre Dame?", 81 | "id": "5733b1da4776f41900661068" 82 | } 83 | ] 84 | } 85 | ] 86 | } 87 | ] 88 | } -------------------------------------------------------------------------------- /test_fixtures/hf_datasets/squad_qa_test_fixture/squad_qa_test_fixture.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Open Business Software Solutions, the HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Test fixture for version 1.1. of SQUAD: The Stanford Question Answering Dataset. 16 | This implementation is adapted from the question answering pipeline from the 17 | HuggingFace's transformers library. Original code is available at: 18 | ``_. 19 | 20 | """ 21 | 22 | import json 23 | 24 | import datasets 25 | 26 | logger = datasets.logging.get_logger(__name__) 27 | 28 | _DESCRIPTION = """\ 29 | Small dataset taken from SQuAD v1.1 and used for testing purposes. 30 | """ 31 | _URLS = { 32 | "train": "train.json", 33 | "dev": "dev.json", 34 | } 35 | 36 | 37 | class SquadTestFixtureConfig(datasets.BuilderConfig): 38 | """ 39 | BuilderConfig for SQUAD test data. 40 | Args: 41 | **kwargs (): 42 | """ 43 | 44 | def __init__(self, **kwargs): 45 | super().__init__(**kwargs) 46 | 47 | 48 | class SquadTestFixture(datasets.GeneratorBasedBuilder): 49 | """A test dataset taken from SQUAD Version 1.1. for trapper's QA modules""" 50 | 51 | BUILDER_CONFIGS = [ 52 | SquadTestFixtureConfig( 53 | name="qa_test_fixture", 54 | version=datasets.Version("1.0.0", ""), 55 | description="QA test fixtures", 56 | ), 57 | ] 58 | 59 | def _info(self): 60 | return datasets.DatasetInfo( 61 | description=_DESCRIPTION, 62 | features=datasets.Features( 63 | { 64 | "id": datasets.Value("string"), 65 | "title": datasets.Value("string"), 66 | "context": datasets.Value("string"), 67 | "paragraph_ind": datasets.Value("int32"), 68 | "question": datasets.Value("string"), 69 | "answers": datasets.features.Sequence( 70 | { 71 | "text": datasets.Value("string"), 72 | "answer_start": datasets.Value("int32"), 73 | } 74 | ), 75 | } 76 | ), 77 | supervised_keys=None, 78 | homepage="https://rajpurkar.github.io/SQuAD-explorer/", 79 | ) 80 | 81 | def _split_generators(self, dl_manager): 82 | downloaded_files = dl_manager.download_and_extract(_URLS) 83 | 84 | return [ 85 | datasets.SplitGenerator( 86 | name=datasets.Split.TRAIN, 87 | gen_kwargs={"filepath": downloaded_files["train"]}, 88 | ), 89 | datasets.SplitGenerator( 90 | name=datasets.Split.VALIDATION, 91 | gen_kwargs={"filepath": downloaded_files["dev"]}, 92 | ), 93 | ] 94 | 95 | def _generate_examples(self, filepath): 96 | """This function returns the examples in the raw (text) form.""" 97 | logger.info("generating examples from = %s", filepath) 98 | key = 0 99 | with open(filepath, encoding="utf-8") as f: 100 | squad = json.load(f) 101 | for article in squad["data"]: 102 | title = article.get("title", "") 103 | for paragraph_ind, paragraph in enumerate(article["paragraphs"]): 104 | context = paragraph["context"] 105 | for qa in paragraph["qas"]: 106 | answer_starts = [ 107 | answer["answer_start"] for answer in qa["answers"] 108 | ] 109 | answers = [answer["text"] for answer in qa["answers"]] 110 | yield key, { 111 | "title": title, 112 | "context": context, 113 | "question": qa["question"], 114 | "paragraph_ind": paragraph_ind, 115 | "id": qa["id"], 116 | "answers": { 117 | "answer_start": answer_starts, 118 | "text": answers, 119 | }, 120 | } 121 | key += 1 122 | -------------------------------------------------------------------------------- /test_fixtures/hf_datasets/squad_qa_test_fixture/train.json: -------------------------------------------------------------------------------- 1 | { 2 | "data": [ 3 | { 4 | "title": "University_of_Notre_Dame", 5 | "paragraphs": [ 6 | { 7 | "context": "Architecturally, the school has a Catholic character. Atop the Main Building's gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend \"Venite Ad Me Omnes\". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive (and in a direct line that connects through 3 statues and the Gold Dome), is a simple, modern stone statue of Mary.", 8 | "qas": [ 9 | { 10 | "answers": [ 11 | { 12 | "answer_start": 515, 13 | "text": "Saint Bernadette Soubirous" 14 | } 15 | ], 16 | "question": "To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France?", 17 | "id": "5733be284776f41900661182" 18 | }, 19 | { 20 | "answers": [ 21 | { 22 | "answer_start": 92, 23 | "text": "a golden statue of the Virgin Mary" 24 | }, 25 | { 26 | "answer_start": 92, 27 | "text": "a golden statue of the Virgin Mary" 28 | }, 29 | { 30 | "answer_start": 0, 31 | "text": "Architecturally, the school has a Catholic character. Atop the Main Building's gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend \"Venite Ad Me Omnes\". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive (and in a direct line that connects through 3 statues and the Gold Dome), is a simple, modern stone statue of Mary." 32 | } 33 | ], 34 | "question": "What sits on top of the Main Building at Notre Dame?", 35 | "id": "5733be284776f4190066117e" 36 | } 37 | ] 38 | }, 39 | { 40 | "context": "In 1882, Albert Zahm (John Zahm's brother) built an early wind tunnel used to compare lift to drag of aeronautical models. Around 1899, Professor Jerome Green became the first American to send a wireless message. In 1931, Father Julius Nieuwland performed early work on basic reactions that was used to create neoprene. Study of nuclear physics at the university began with the building of a nuclear accelerator in 1936, and continues now partly through a partnership in the Joint Institute for Nuclear Astrophysics.", 41 | "qas": [ 42 | { 43 | "answers": [ 44 | { 45 | "answer_start": 3, 46 | "text": "1882" 47 | } 48 | ], 49 | "question": "In what year did Albert Zahm begin comparing aeronatical models at Notre Dame?", 50 | "id": "5733b1da4776f41900661068" 51 | }, 52 | { 53 | "answers": [ 54 | { 55 | "answer_start": 222, 56 | "text": "Father Julius Nieuwl" 57 | } 58 | ], 59 | "question": "Which individual worked on projects at Notre Dame that eventually created neoprene?", 60 | "id": "5733b1da4776f4190066106b" 61 | }, 62 | { 63 | "answers": [ 64 | { 65 | "answer_start": 49, 66 | "text": "an early wind tunnel" 67 | } 68 | ], 69 | "question": "What did the brother of John Zahm construct at Notre Dame?", 70 | "id": "5733b1da4776f41900661067" 71 | } 72 | ] 73 | } 74 | ] 75 | } 76 | ] 77 | } -------------------------------------------------------------------------------- /test_fixtures/metrics/label_ids.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/obss/trapper/bbc82450097bd48466d2c47cfbb3bb194319410a/test_fixtures/metrics/label_ids.pkl -------------------------------------------------------------------------------- /test_fixtures/metrics/predictions.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/obss/trapper/bbc82450097bd48466d2c47cfbb3bb194319410a/test_fixtures/metrics/predictions.pkl -------------------------------------------------------------------------------- /test_fixtures/pipelines/pipeline_integration_experiment.jsonnet: -------------------------------------------------------------------------------- 1 | local checkpoint_dir = std.extVar("CHECKPOINT_PATH"); 2 | local result_dir = std.extVar("OUTPUT_PATH"); 3 | { 4 | "train_split_name": "train", 5 | "dev_split_name": "validation", 6 | "pretrained_model_name_or_path": "albert-base-v2", 7 | "tokenizer_wrapper": { 8 | "type": "question-answering" 9 | }, 10 | "dataset_loader": { 11 | "type": "default", 12 | "dataset_reader": { 13 | "type": "default", 14 | "path": "test_fixtures/hf_datasets/squad_qa_test_fixture" 15 | }, 16 | "data_processor": { 17 | "type": "squad-question-answering" 18 | }, 19 | "data_adapter": { 20 | "type": "question-answering" 21 | } 22 | }, 23 | "data_collator":{ 24 | "type": "default" 25 | }, 26 | "model_wrapper": { 27 | "type": "question_answering" 28 | }, 29 | "metric_input_handler": { 30 | "type": "question-answering" 31 | }, 32 | "metric_output_handler": {"type": "default"}, 33 | "compute_metrics": { 34 | "metric_params": [ 35 | "squad" 36 | ] 37 | }, 38 | "args": { 39 | "type": "default", 40 | "output_dir": checkpoint_dir, 41 | "result_dir": result_dir, 42 | "num_train_epochs": 10, 43 | "per_device_train_batch_size": 2, 44 | "gradient_accumulation_steps": 12, 45 | "per_device_eval_batch_size": 2, 46 | "logging_dir": checkpoint_dir + "/logs", 47 | "no_cuda": true, 48 | "logging_steps": 500, 49 | "evaluation_strategy": "steps", 50 | "save_steps": 500, 51 | "label_names": ["start_positions", "end_positions"], 52 | "lr_scheduler_type": "linear", 53 | "warmup_steps": 500, 54 | "do_train": true, 55 | "do_eval": true, 56 | "save_total_limit": 1 57 | }, 58 | "optimizer": { 59 | "type": "huggingface_adamw", 60 | "weight_decay": 0.01, 61 | "parameter_groups": [ 62 | [["bias", "LayerNorm\\\\.weight", "layer_norm\\\\.weight"], 63 | {"weight_decay": 0}]], 64 | "lr": 5e-5, 65 | "eps": 1e-6 66 | } 67 | } -------------------------------------------------------------------------------- /test_fixtures/plugins/.trapper_plugins: -------------------------------------------------------------------------------- 1 | custom_dummy_package 2 | custom_dummy_module 3 | -------------------------------------------------------------------------------- /test_fixtures/plugins/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/obss/trapper/bbc82450097bd48466d2c47cfbb3bb194319410a/test_fixtures/plugins/__init__.py -------------------------------------------------------------------------------- /test_fixtures/plugins/custom_dummy_module.py: -------------------------------------------------------------------------------- 1 | from trapper.data import DatasetLoader 2 | 3 | 4 | @DatasetLoader.register("dummy_dataset_loader_inside_module") 5 | class DummyDatasetLoader3(DatasetLoader): 6 | pass 7 | -------------------------------------------------------------------------------- /test_fixtures/plugins/custom_dummy_package/__init__.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=import-error, no-name-in-module 2 | # type: ignore[import] 3 | from custom_dummy_package.dummy_loader import DummyDatasetLoader 4 | from custom_dummy_package.dummy_loader2 import DummyDatasetLoader2 5 | -------------------------------------------------------------------------------- /test_fixtures/plugins/custom_dummy_package/dummy_loader.py: -------------------------------------------------------------------------------- 1 | from trapper.data import DatasetLoader 2 | 3 | 4 | @DatasetLoader.register("dummy_dataset_loader_inside_package") 5 | class DummyDatasetLoader(DatasetLoader): 6 | pass 7 | -------------------------------------------------------------------------------- /test_fixtures/plugins/custom_dummy_package/dummy_loader2.py: -------------------------------------------------------------------------------- 1 | from trapper.data import DatasetLoader 2 | 3 | 4 | @DatasetLoader.register("dummy_dataset_loader_inside_package2") 5 | class DummyDatasetLoader2(DatasetLoader): 6 | pass 7 | -------------------------------------------------------------------------------- /test_fixtures/training/question_answering/squad/dev.json: -------------------------------------------------------------------------------- 1 | { 2 | "data": [ 3 | { 4 | "title": "University_of_Notre_Dame", 5 | "paragraphs": [ 6 | { 7 | "context": "Architecturally, the school has a Catholic character. Atop the Main Building's gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend \"Venite Ad Me Omnes\". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive (and in a direct line that connects through 3 statues and the Gold Dome), is a simple, modern stone statue of Mary.", 8 | "qas": [ 9 | { 10 | "answers": [ 11 | { 12 | "answer_start": 515, 13 | "text": "Saint Bernadette Soubirous" 14 | } 15 | ], 16 | "question": "To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France?", 17 | "id": "5733be284776f41900661182", 18 | "clue": { 19 | "text": "the Virgin Mary", 20 | "answer_start": 111 21 | } 22 | }, 23 | { 24 | "answers": [ 25 | { 26 | "answer_start": 92, 27 | "text": "a golden statue of the Virgin Mary" 28 | }, 29 | { 30 | "answer_start": 92, 31 | "text": "a golden statue of the Virgin Mary" 32 | } 33 | ], 34 | "question": "What sits on top of the Main Building at Notre Dame?", 35 | "id": "5733be284776f4190066117e", 36 | "clue": { 37 | "text": "the Main Building", 38 | "answer_start": 152 39 | } 40 | } 41 | ] 42 | }, 43 | { 44 | "context": "In 1882, Albert Zahm (John Zahm's brother) built an early wind tunnel used to compare lift to drag of aeronautical models. Around 1899, Professor Jerome Green became the first American to send a wireless message. In 1931, Father Julius Nieuwland performed early work on basic reactions that was used to create neoprene. Study of nuclear physics at the university began with the building of a nuclear accelerator in 1936, and continues now partly through a partnership in the Joint Institute for Nuclear Astrophysics.", 45 | "qas": [ 46 | { 47 | "answers": [ 48 | { 49 | "answer_start": 3, 50 | "text": "1882" 51 | } 52 | ], 53 | "question": "In what year did Albert Zahm begin comparing aeronatical models at Notre Dame?", 54 | "id": "5733b1da4776f41900661068", 55 | "clue": { 56 | "text": "Albert Zahm", 57 | "answer_start": 9 58 | } 59 | } 60 | ] 61 | } 62 | ] 63 | } 64 | ] 65 | } -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/obss/trapper/bbc82450097bd48466d2c47cfbb3bb194319410a/tests/__init__.py -------------------------------------------------------------------------------- /tests/commands/test_run.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from unittest.mock import patch 3 | 4 | import pytest 5 | 6 | from trapper import FIXTURES_ROOT 7 | from trapper.commands import main 8 | 9 | COMMAND_FIXTURES = FIXTURES_ROOT / "commands" 10 | 11 | 12 | def test_run_command_missing(capsys): 13 | with pytest.raises(SystemExit, match="2"): 14 | run_args = ["trapper", "run"] 15 | with patch.object(sys, "argv", run_args): 16 | main("trapper") 17 | 18 | captured = capsys.readouterr() 19 | assert ( 20 | "trapper run: error: the following arguments are required" in captured.err 21 | ) 22 | 23 | 24 | def test_run_command_trivial(tmp_path): 25 | run_args = [ 26 | "trapper", 27 | "run", 28 | str(COMMAND_FIXTURES / "experiment_trivial.jsonnet"), 29 | "-o", 30 | ] 31 | overrides = { 32 | "args.output_dir": str(tmp_path / "output"), 33 | "args.result_dir": str(tmp_path / "output"), 34 | } 35 | run_args.append(str(overrides)) 36 | with patch.object(sys, "argv", run_args): 37 | main("trapper") 38 | -------------------------------------------------------------------------------- /tests/common/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/obss/trapper/bbc82450097bd48466d2c47cfbb3bb194319410a/tests/common/__init__.py -------------------------------------------------------------------------------- /tests/common/test_plugins.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from allennlp.common.util import pushd 3 | 4 | from trapper.common.plugins import discover_plugins, import_plugins 5 | from trapper.data import DatasetLoader 6 | 7 | 8 | @pytest.fixture(scope="module") 9 | def plugins_root(fixtures_root): 10 | return fixtures_root / "plugins" 11 | 12 | 13 | def test_no_plugins(plugins_root): 14 | available_plugins = set(discover_plugins()) 15 | assert available_plugins == set() 16 | 17 | 18 | def test_file_plugin(plugins_root): 19 | test_no_plugins(plugins_root) 20 | 21 | with pushd(plugins_root): 22 | available_plugins = set(discover_plugins()) 23 | assert available_plugins == {"custom_dummy_package", "custom_dummy_module"} 24 | 25 | import_plugins() 26 | dataset_loaders_available = DatasetLoader.list_available() 27 | for name in ( 28 | "dummy_dataset_loader_inside_package", 29 | "dummy_dataset_loader_inside_package2", 30 | "dummy_dataset_loader_inside_module", 31 | ): 32 | assert name in dataset_loaders_available 33 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import pytest 4 | 5 | from trapper import FIXTURES_ROOT, TESTS_ROOT 6 | 7 | _HF_DATASETS_FIXTURES_ROOT = FIXTURES_ROOT / "hf_datasets" 8 | 9 | 10 | @pytest.fixture(scope="package") 11 | def tests_root(): 12 | return TESTS_ROOT 13 | 14 | 15 | @pytest.fixture(scope="package") 16 | def fixtures_root() -> Path: 17 | return FIXTURES_ROOT 18 | 19 | 20 | @pytest.fixture(scope="package") 21 | def get_hf_datasets_fixture_path(): 22 | def _get_hf_datasets_fixture_path(dataset: str) -> str: 23 | return str(_HF_DATASETS_FIXTURES_ROOT / dataset) 24 | 25 | return _get_hf_datasets_fixture_path 26 | -------------------------------------------------------------------------------- /tests/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/obss/trapper/bbc82450097bd48466d2c47cfbb3bb194319410a/tests/data/__init__.py -------------------------------------------------------------------------------- /tests/data/conftest.py: -------------------------------------------------------------------------------- 1 | # noinspection PyUnresolvedReferences 2 | # pylint: disable=unused-import 3 | from trapper.common.testing_utils.pytest_fixtures import ( 4 | create_data_collator_args, 5 | create_data_processor_args, 6 | get_raw_dataset, 7 | make_data_collator, 8 | make_sequential_sampler, 9 | ) 10 | -------------------------------------------------------------------------------- /tests/data/data_collators/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/obss/trapper/bbc82450097bd48466d2c47cfbb3bb194319410a/tests/data/data_collators/__init__.py -------------------------------------------------------------------------------- /tests/data/data_collators/squad/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/obss/trapper/bbc82450097bd48466d2c47cfbb3bb194319410a/tests/data/data_collators/squad/__init__.py -------------------------------------------------------------------------------- /tests/data/data_collators/squad/test_question_answering_collator_and_adapter.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from torch.utils.data import DataLoader 3 | from transformers import PreTrainedTokenizerBase 4 | 5 | from trapper.data.data_adapters.question_answering_adapter import ( 6 | DataAdapterForQuestionAnswering, 7 | ) 8 | from trapper.data.data_collator import InputBatch 9 | from trapper.data.data_processors.squad import SquadQuestionAnsweringDataProcessor 10 | from trapper.data.tokenizers import QuestionAnsweringTokenizerWrapper 11 | 12 | 13 | @pytest.fixture(scope="module") 14 | def data_collator_args(create_data_collator_args): 15 | return create_data_collator_args( 16 | tokenizer_factory=QuestionAnsweringTokenizerWrapper, 17 | train_batch_size=2, 18 | validation_batch_size=1, 19 | tokenizer_model_name="roberta-base", 20 | task_type="question_answering", 21 | is_distributed=False, 22 | ) 23 | 24 | 25 | @pytest.fixture(scope="module") 26 | def processed_dataset( 27 | get_raw_dataset, data_collator_args, get_hf_datasets_fixture_path 28 | ): 29 | data_processor = SquadQuestionAnsweringDataProcessor( 30 | data_collator_args.tokenizer_wrapper 31 | ) 32 | raw_dataset = get_raw_dataset( 33 | path=get_hf_datasets_fixture_path("squad_qa_test_fixture") 34 | ) 35 | return raw_dataset.map(data_processor) 36 | 37 | 38 | @pytest.fixture(scope="module") 39 | def adapted_dataset(processed_dataset, data_collator_args): 40 | data_adapter = DataAdapterForQuestionAnswering( 41 | data_collator_args.tokenizer_wrapper 42 | ) 43 | return processed_dataset.map(data_adapter) 44 | 45 | 46 | @pytest.fixture(scope="module") 47 | def qa_data_collator(make_data_collator, data_collator_args): 48 | return make_data_collator(data_collator_args) 49 | 50 | 51 | @pytest.mark.parametrize( 52 | ["split", "expected_batch_size", "expected_dataset_size"], 53 | [ 54 | ("train", 2, 3), 55 | ("validation", 1, 6), 56 | ], 57 | ) 58 | def test_data_sizes( 59 | qa_data_collator, 60 | make_sequential_sampler, 61 | adapted_dataset, 62 | split, 63 | data_collator_args, 64 | expected_batch_size, 65 | expected_dataset_size, 66 | ): 67 | dataset_split = adapted_dataset[split] 68 | sampler = make_sequential_sampler( 69 | is_distributed=data_collator_args.is_distributed, dataset=dataset_split 70 | ) 71 | loader = DataLoader( 72 | dataset_split, 73 | batch_size=getattr(data_collator_args, f"{split}_batch_size"), 74 | sampler=sampler, 75 | collate_fn=qa_data_collator, 76 | ) 77 | assert loader.batch_size == expected_batch_size 78 | assert len(loader) == expected_dataset_size 79 | 80 | 81 | @pytest.fixture(scope="module") 82 | def collated_batch(qa_data_collator, adapted_dataset): 83 | return qa_data_collator.build_model_inputs(adapted_dataset["validation"]) 84 | 85 | 86 | @pytest.mark.parametrize( 87 | ["index", "expected_question"], 88 | [ 89 | (0, "Which NFL team represented the AFC at Super Bowl 50?"), 90 | (1, "Which NFL team represented the NFC at Super Bowl 50?"), 91 | (2, "Where did Super Bowl 50 take place?"), 92 | ], 93 | ) 94 | def test_batch_content( 95 | data_collator_args, processed_dataset, collated_batch, index, expected_question 96 | ): 97 | if data_collator_args.is_tokenizer_uncased: 98 | expected_question = expected_question.lower() 99 | validate_target_question_positions_using_decoded_tokens( 100 | expected_question, 101 | index, 102 | data_collator_args.tokenizer_wrapper.tokenizer, 103 | collated_batch, 104 | ) 105 | 106 | instance = processed_dataset["validation"][index] 107 | token_type_ids = collated_batch["token_type_ids"][index] 108 | validate_token_type_ids(token_type_ids, instance) 109 | validate_attention_mask(collated_batch) 110 | 111 | 112 | def validate_target_question_positions_using_decoded_tokens( 113 | expected_question, 114 | index, 115 | tokenizer: PreTrainedTokenizerBase, 116 | input_batch: InputBatch, 117 | ): 118 | input_ids = input_batch["input_ids"][index] 119 | question_start = -sum(input_batch["token_type_ids"][index]) 120 | question_end = -1 # EOS 121 | assert ( 122 | tokenizer.decode(input_ids[question_start:question_end]).lstrip() 123 | == expected_question 124 | ) 125 | 126 | 127 | def validate_token_type_ids(token_type_ids, instance): 128 | question_len = len(instance["question"]) 129 | context_end = len(instance["context"]) + 2 # BOS, EOS 130 | question_end = context_end + question_len + 1 # EOS 131 | 132 | # remaining context tokens 133 | assert all(token_type_id == 0 for token_type_id in token_type_ids[:context_end]) 134 | 135 | # answer tokens at the end 136 | assert all(token_type_id == 1 for token_type_id in token_type_ids[context_end:]) 137 | 138 | assert len(token_type_ids) == question_end 139 | 140 | 141 | def validate_attention_mask(instance_batch: InputBatch): 142 | for input_ids, attention_mask in zip( 143 | instance_batch["input_ids"], instance_batch["attention_mask"] 144 | ): 145 | assert len(attention_mask) == len(input_ids) 146 | assert all(val == 1 for val in attention_mask) 147 | -------------------------------------------------------------------------------- /tests/data/data_processors/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/obss/trapper/bbc82450097bd48466d2c47cfbb3bb194319410a/tests/data/data_processors/__init__.py -------------------------------------------------------------------------------- /tests/data/data_processors/squad/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/obss/trapper/bbc82450097bd48466d2c47cfbb3bb194319410a/tests/data/data_processors/squad/__init__.py -------------------------------------------------------------------------------- /tests/data/data_processors/squad/test_question_answering_processor.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | import pytest 4 | from transformers import PreTrainedTokenizerBase 5 | 6 | from trapper.data.data_processors.squad import SquadQuestionAnsweringDataProcessor 7 | from trapper.data.tokenizers import QuestionAnsweringTokenizerWrapper 8 | 9 | 10 | @pytest.fixture(scope="module") 11 | def args(create_data_processor_args): 12 | return create_data_processor_args( 13 | tokenizer_factory=QuestionAnsweringTokenizerWrapper, 14 | tokenizer_model_name="roberta-base", 15 | ) 16 | 17 | 18 | @pytest.fixture(scope="module") 19 | def processed_dev_dataset(get_raw_dataset, args, get_hf_datasets_fixture_path): 20 | data_processor = SquadQuestionAnsweringDataProcessor(args.tokenizer_wrapper) 21 | raw_dataset = get_raw_dataset( 22 | path=get_hf_datasets_fixture_path("squad_qa_test_fixture"), 23 | split="validation", 24 | ) 25 | return raw_dataset.map(data_processor) 26 | 27 | 28 | @pytest.mark.parametrize( 29 | ["index", "expected_question"], 30 | [ 31 | (0, "Which NFL team represented the AFC at Super Bowl 50?"), 32 | (1, "Which NFL team represented the NFC at Super Bowl 50?"), 33 | (2, "Where did Super Bowl 50 take place?"), 34 | ], 35 | ) 36 | def test_data_processor(processed_dev_dataset, args, index, expected_question): 37 | if args.is_tokenizer_uncased: 38 | expected_question = expected_question.lower() 39 | processed_instance = processed_dev_dataset[index] 40 | assert ( 41 | decode_question(processed_instance, args.tokenizer_wrapper.tokenizer) 42 | == expected_question 43 | ) 44 | 45 | 46 | def decode_question(instance: Dict, tokenizer: PreTrainedTokenizerBase) -> str: 47 | field_value = tokenizer.decode(instance["question"]) 48 | return field_value.lstrip() 49 | -------------------------------------------------------------------------------- /tests/data/data_processors/test_data_processor.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List 2 | 3 | import datasets 4 | import pytest 5 | 6 | from trapper.common.utils import is_equal 7 | from trapper.data import DataProcessor, IndexedInstance 8 | 9 | 10 | class MockTokenizer: 11 | @property 12 | def model_max_length(self): 13 | return 1 14 | 15 | @staticmethod 16 | def convert_tokens_to_ids(text: str) -> List[int]: 17 | return [int(tok.split("token")[-1]) for tok in text] 18 | 19 | @staticmethod 20 | def tokenize(text: str) -> List[str]: 21 | return text.split() 22 | 23 | 24 | class MockTokenizerWrapper: 25 | def __init__(self): 26 | self._tokenizer = MockTokenizer() 27 | 28 | @property 29 | def tokenizer(self): 30 | return self._tokenizer 31 | 32 | 33 | class MockDataProcessor(DataProcessor): 34 | def text_to_instance( 35 | self, ind: int, info1: str, info2: str = None 36 | ) -> IndexedInstance: 37 | info1_tokenized = self.tokenizer.tokenize(info1) 38 | info2_tokenized = self.tokenizer.tokenize(info2) 39 | return { 40 | "index": ind, 41 | "info1": self._tokenizer.convert_tokens_to_ids(info1_tokenized), 42 | "info2": self._tokenizer.convert_tokens_to_ids(info2_tokenized), 43 | } 44 | 45 | def process(self, instance_dict: Dict[str, Any]) -> IndexedInstance: 46 | return self.text_to_instance( 47 | ind=instance_dict["id"], 48 | info1=instance_dict["info1"], 49 | info2=instance_dict["info2_with_suffix"], 50 | ) 51 | 52 | 53 | @pytest.fixture 54 | def dummy_dataset(): 55 | return datasets.Dataset.from_dict( 56 | { 57 | "id": [0, 1], 58 | "info1": ["token1 token2", "token3 token4"], 59 | "info2_with_suffix": ["token4 token5", "token6"], 60 | } 61 | ) 62 | 63 | 64 | @pytest.fixture 65 | def mock_processor(): 66 | mock_tokenizer = MockTokenizerWrapper() 67 | return MockDataProcessor(mock_tokenizer) # type: ignore 68 | 69 | 70 | @pytest.mark.parametrize( 71 | ["index", "expected_instance"], 72 | [ 73 | (0, {"index": 0, "info1": [1, 2], "info2": [4, 5]}), 74 | (1, {"index": 1, "info1": [3, 4], "info2": [6]}), 75 | ], 76 | ) 77 | def test_data_processor(dummy_dataset, mock_processor, index, expected_instance): 78 | actual_instance = mock_processor(dummy_dataset[index]) 79 | assert is_equal(expected_instance, actual_instance) 80 | -------------------------------------------------------------------------------- /tests/data/test_dataset_reader.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from datasets import DownloadConfig 3 | 4 | from trapper.data import DatasetReader 5 | 6 | 7 | @pytest.fixture 8 | def dataset_reader(get_hf_datasets_fixture_path): 9 | return DatasetReader( 10 | path=get_hf_datasets_fixture_path("squad_qa_test_fixture"), 11 | download_config=DownloadConfig(local_files_only=True), 12 | ) 13 | 14 | 15 | def test_read(dataset_reader): 16 | assert len(dataset_reader.read()) == 2 # dict with two splits 17 | assert len(dataset_reader.read("train")) == 5 18 | assert len(dataset_reader.read("validation")) == 6 19 | assert len(dataset_reader.read("all")) == 11 # splits are combined 20 | with pytest.raises(ValueError): 21 | dataset_reader.read("test") 22 | -------------------------------------------------------------------------------- /tests/data/test_label_mapper.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from trapper.common.utils import is_equal 4 | from trapper.data.label_mapper import LabelMapper 5 | 6 | DUMMY_LABELS = ("label1", "label2", "label3") 7 | 8 | 9 | class MockLabelMapperWithClassVariable(LabelMapper): 10 | _LABELS = DUMMY_LABELS 11 | 12 | 13 | class MockLabelMapperWithConstructor(LabelMapper): 14 | def __init__(self, label_to_id_map=None, ignored_labels=None): 15 | super().__init__(label_to_id_map, ignored_labels=ignored_labels) 16 | 17 | 18 | @pytest.mark.parametrize( 19 | "start_id", 20 | [0, 7], 21 | ) 22 | def test_mapper_created_from_labels(start_id): 23 | dummy_ids = (start_id, start_id + 1, start_id + 2) 24 | dummy_label_to_id_map = { 25 | "label1": dummy_ids[0], 26 | "label2": dummy_ids[1], 27 | "label3": dummy_ids[2], 28 | } 29 | dummy_id_to_label_map = { 30 | id_: label for label, id_ in dummy_label_to_id_map.items() 31 | } 32 | 33 | mapper_from_label_to_id_map = MockLabelMapperWithConstructor( 34 | dummy_label_to_id_map 35 | ) 36 | mapper_from_class_variable_labels = ( 37 | MockLabelMapperWithClassVariable.from_labels(start_id=start_id) 38 | ) 39 | mapper_from_labels = MockLabelMapperWithConstructor.from_labels( 40 | labels=DUMMY_LABELS, start_id=start_id 41 | ) 42 | 43 | for mapper in ( 44 | mapper_from_label_to_id_map, 45 | mapper_from_class_variable_labels, 46 | mapper_from_labels, 47 | ): 48 | assert mapper.labels == DUMMY_LABELS 49 | assert mapper.ids == dummy_ids 50 | assert is_equal(mapper.label_to_id_map, dummy_label_to_id_map) 51 | assert is_equal(mapper.id_to_label_map, dummy_id_to_label_map) 52 | for label, id_ in dummy_label_to_id_map.items(): 53 | assert mapper.get_id(label) == id_ 54 | assert mapper.get_label(id_) == label 55 | 56 | 57 | IGNORED_LABELS = ("O1", "O2") 58 | 59 | 60 | class MockLabelMapperWithIgnoredLabelsFromClassVariable(LabelMapper): 61 | _LABELS = DUMMY_LABELS 62 | _IGNORED_LABELS = IGNORED_LABELS 63 | 64 | 65 | class MockLabelMapperWithIgnoredLabelsFromConstructor(LabelMapper): 66 | _LABELS = DUMMY_LABELS 67 | 68 | def __init__(self, label_to_id_map=None, ignored_labels=None): 69 | super().__init__(label_to_id_map, ignored_labels=ignored_labels) 70 | 71 | 72 | class MockLabelMapperWithoutIgnoredLabels(LabelMapper): 73 | _LABELS = DUMMY_LABELS 74 | 75 | 76 | def test_mappers_for_ignored_labels(): 77 | mapper_with_ignored_labels_from_class_variable = ( 78 | MockLabelMapperWithIgnoredLabelsFromClassVariable.from_labels() 79 | ) 80 | mapper_with_ignored_labels_from_constructor = ( 81 | MockLabelMapperWithIgnoredLabelsFromConstructor( 82 | label_to_id_map={}, ignored_labels=list(IGNORED_LABELS) 83 | ) 84 | ) 85 | for mapper in ( 86 | mapper_with_ignored_labels_from_class_variable, 87 | mapper_with_ignored_labels_from_constructor, 88 | ): 89 | assert mapper.ignored_labels == IGNORED_LABELS 90 | 91 | mapper_without_ignored_labels = ( 92 | MockLabelMapperWithoutIgnoredLabels.from_labels() 93 | ) 94 | assert not mapper_without_ignored_labels.ignored_labels 95 | -------------------------------------------------------------------------------- /tests/data/test_tokenizer_wrapper.py: -------------------------------------------------------------------------------- 1 | """ 2 | Test if the tokenizer wrapper handles correctly the general special tokens such 3 | as BOS and EOS as well as the task-specific special tokens such as . 4 | """ 5 | import pytest 6 | from transformers import PreTrainedTokenizerBase 7 | 8 | from trapper.common.constants import ANSWER_TOKEN, CONTEXT_TOKEN 9 | from trapper.data import TokenizerWrapper 10 | 11 | 12 | @pytest.fixture 13 | def bert_tokenizer_with_context_token(): 14 | class TokenizerWrapperWithContextToken(TokenizerWrapper): 15 | _TASK_SPECIFIC_SPECIAL_TOKENS = [CONTEXT_TOKEN] 16 | 17 | return TokenizerWrapperWithContextToken.from_pretrained("bert-base-uncased") 18 | 19 | 20 | @pytest.fixture 21 | def gpt2_tokenizer_with_context_and_answer_tokens(): 22 | class TokenizerWrapperWithContextAndAnswerTokens(TokenizerWrapper): 23 | _TASK_SPECIFIC_SPECIAL_TOKENS = [CONTEXT_TOKEN, ANSWER_TOKEN] 24 | 25 | return TokenizerWrapperWithContextAndAnswerTokens.from_pretrained("gpt2") 26 | 27 | 28 | def test_bert_tokenizer(bert_tokenizer_with_context_token): 29 | assert ( 30 | bert_tokenizer_with_context_token.num_added_special_tokens == 1 31 | ) # CONTEXT_TOKEN 32 | tokenizer = bert_tokenizer_with_context_token.tokenizer 33 | assert tokenizer.bos_token == tokenizer.cls_token == "[CLS]" 34 | assert tokenizer.bos_token_id == tokenizer.cls_token_id == 101 35 | assert tokenizer.eos_token == tokenizer.sep_token == "[SEP]" 36 | assert tokenizer.sep_token_id == tokenizer.eos_token_id == 102 37 | assert_special_tokens_are_preserved(tokenizer, CONTEXT_TOKEN) 38 | assert_all_common_special_tokens_are_present(tokenizer) 39 | 40 | 41 | def test_gpt2_tokenizer(gpt2_tokenizer_with_context_and_answer_tokens): 42 | assert ( 43 | gpt2_tokenizer_with_context_and_answer_tokens.num_added_special_tokens == 4 44 | ) # PAD_TOKEN, MASK_TOKEN, CONTEXT_TOKEN, ANSWER_TOKEN. 45 | tokenizer = gpt2_tokenizer_with_context_and_answer_tokens.tokenizer 46 | assert tokenizer.bos_token == tokenizer.cls_token == "<|endoftext|>" 47 | assert tokenizer.bos_token_id == tokenizer.cls_token_id == 50256 48 | assert tokenizer.eos_token == tokenizer.sep_token == "<|endoftext|>" 49 | assert tokenizer.eos_token_id == tokenizer.sep_token_id == 50256 50 | for token in [CONTEXT_TOKEN, ANSWER_TOKEN]: 51 | assert_special_tokens_are_preserved(tokenizer, token) 52 | assert_all_common_special_tokens_are_present(tokenizer) 53 | 54 | 55 | def assert_special_tokens_are_preserved( 56 | tokenizer: PreTrainedTokenizerBase, token: str 57 | ): 58 | assert len(tokenizer.tokenize(token)) == 1 59 | 60 | 61 | def assert_all_common_special_tokens_are_present( 62 | tokenizer: PreTrainedTokenizerBase, 63 | ): 64 | COMMON_TOKENS = ( 65 | "bos_token", 66 | "eos_token", 67 | "cls_token", 68 | "sep_token", 69 | "pad_token", 70 | "mask_token", 71 | "unk_token", 72 | ) 73 | assert all(hasattr(tokenizer, key) for key in COMMON_TOKENS) 74 | -------------------------------------------------------------------------------- /tests/metrics/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/obss/trapper/bbc82450097bd48466d2c47cfbb3bb194319410a/tests/metrics/__init__.py -------------------------------------------------------------------------------- /tests/metrics/test_metric_handler.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import numpy as np 4 | import pytest 5 | from transformers import EvalPrediction 6 | 7 | from trapper import FIXTURES_ROOT 8 | from trapper.common.io import pickle_load 9 | from trapper.common.utils import is_equal 10 | from trapper.data import TokenizerWrapper 11 | from trapper.metrics import MetricInputHandler 12 | 13 | METRIC_FIXTURES = FIXTURES_ROOT / "metrics" 14 | 15 | 16 | class MockTokenizerWrapper(TokenizerWrapper): 17 | pass 18 | 19 | 20 | class MockMetricInputHandler(MetricInputHandler): 21 | def __init__(self, tokenizer_wrappper: TokenizerWrapper): 22 | super(MockMetricInputHandler, self).__init__() 23 | self._tokenizer_wrapper = tokenizer_wrappper 24 | 25 | @property 26 | def tokenizer(self): 27 | return self._tokenizer_wrapper.tokenizer 28 | 29 | def __call__( 30 | self, 31 | eval_pred: EvalPrediction, 32 | ) -> EvalPrediction: 33 | predictions = self.tokenizer.batch_decode(eval_pred.predictions.argmax(-1)) 34 | label_ids = self.tokenizer.batch_decode(eval_pred.label_ids) 35 | predictions, label_ids = np.array(predictions), np.array(label_ids) 36 | return EvalPrediction(predictions=predictions, label_ids=label_ids) 37 | 38 | 39 | @pytest.fixture(scope="function") 40 | def mock_metric_input_handler(): 41 | mock_tokenizer_wrapper = MockTokenizerWrapper.from_pretrained( 42 | "bert-base-uncased" 43 | ) 44 | return MockMetricInputHandler(tokenizer_wrappper=mock_tokenizer_wrapper) 45 | 46 | 47 | @pytest.fixture(scope="function") 48 | def eval_pred(): 49 | predictions_pkl = METRIC_FIXTURES / "predictions.pkl" 50 | label_ids_pkl = METRIC_FIXTURES / "label_ids.pkl" 51 | return EvalPrediction( 52 | predictions=pickle_load(predictions_pkl), 53 | label_ids=pickle_load(label_ids_pkl), 54 | ) 55 | 56 | 57 | @pytest.fixture(scope="function") 58 | def actual_predictions(): 59 | return [ 60 | "[unused104] [unused250] [unused207] [unused109] [unused65] [unused49] [unused147] [unused48]", 61 | "[unused91] [unused241] [unused4] [unused82] [unused237] [unused200] [unused162] [unused227]", 62 | "[unused184] [unused37] [unused30] [unused219] [unused197] [unused1] [unused53] [unused92]", 63 | ] 64 | 65 | 66 | @pytest.fixture(scope="function") 67 | def actual_references(): 68 | return [ 69 | "[unused65] [unused120] [unused212] [unused52] [unused126] [unused149] [unused4] [unused228]", 70 | "[unused66] [unused22] [unused122] [unused229] [unused181] [unused56] [unused239] [unused73]", 71 | "[unused75] [unused88] [unused39] [unused182] [unused199] [unused26] [unused152] [unused116]", 72 | ] 73 | 74 | 75 | def test_metric_handler( 76 | eval_pred, mock_metric_input_handler, actual_predictions, actual_references 77 | ): 78 | eval_pred = mock_metric_input_handler(eval_pred) 79 | predictions = eval_pred.predictions.tolist() 80 | references = eval_pred.label_ids.tolist() 81 | 82 | assert is_equal(predictions, actual_predictions) 83 | assert is_equal(references, actual_references) 84 | -------------------------------------------------------------------------------- /tests/pipelines/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/obss/trapper/bbc82450097bd48466d2c47cfbb3bb194319410a/tests/pipelines/__init__.py -------------------------------------------------------------------------------- /tests/pipelines/test_functional.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from trapper.pipelines import create_pipeline_from_checkpoint 4 | 5 | 6 | def test_create_pipeline_from_hf_hub(tmp_path): 7 | hf_hub_model = "albert-base-v1" 8 | hf_hub_model_nonexist = hf_hub_model + "abcdef" 9 | with pytest.raises(FileNotFoundError): 10 | # Must be raising an exception since assumed config path 11 | # (under `tmp_path`) does not exist 12 | create_pipeline_from_checkpoint(tmp_path, experiment_config_path=None) 13 | 14 | with pytest.raises(ValueError): 15 | # Must be raising an exception as: 16 | # "If a model is given in HF-hub, `experiment_config.json` must be included in 17 | # the model hub repository." 18 | create_pipeline_from_checkpoint(hf_hub_model) 19 | 20 | with pytest.raises(ValueError): 21 | # Must be raising an exception as: 22 | # "Input path must be an existing directory or an existing repository at huggingface model hub." 23 | create_pipeline_from_checkpoint(hf_hub_model_nonexist) 24 | -------------------------------------------------------------------------------- /tests/pipelines/test_squad_pipeline.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | import pytest 4 | from deepdiff import DeepDiff 5 | from transformers import set_seed 6 | 7 | from trapper.common.constants import SpanTuple 8 | from trapper.common.params import Params 9 | from trapper.pipelines import PipelineMixin 10 | 11 | 12 | @pytest.fixture(scope="module") 13 | def roberta_squad_pipeline_params(): 14 | params = { 15 | "type": "squad-question-answering", 16 | "pretrained_model_name_or_path": "smallbenchnlp/roberta-small", 17 | "tokenizer_wrapper": {"type": "question-answering"}, 18 | "data_processor": {"type": "squad-question-answering"}, 19 | "data_adapter": {"type": "question-answering"}, 20 | "data_collator": {"type": "default"}, 21 | "model_wrapper": {"type": "question_answering"}, 22 | } 23 | return Params(params) 24 | 25 | 26 | @pytest.fixture(scope="module") 27 | def roberta_squad_pipeline(roberta_squad_pipeline_params): 28 | set_seed(100) 29 | return PipelineMixin.from_params(roberta_squad_pipeline_params) 30 | 31 | 32 | @pytest.fixture(scope="module") 33 | def roberta_squad_pipeline_sample_input(): 34 | return [ 35 | { 36 | "context": 'Super Bowl 50 was an American football game to determine the champion of the National Football League (NFL) for the 2015 season. The American Football Conference (AFC) champion Denver Broncos defeated the National Football Conference (NFC) champion Carolina Panthers 24–10 to earn their third Super Bowl title. The game was played on February 7, 2016, at Levi\'s Stadium in the San Francisco Bay Area at Santa Clara, California. As this was the 50th Super Bowl, the league emphasized the "golden anniversary" with various gold-themed initiatives, as well as temporarily suspending the tradition of naming each Super Bowl game with Roman numerals (under which the game would have been known as "Super Bowl L"), so that the logo could prominently feature the Arabic numerals 50.', 37 | "question": "Which NFL team represented the AFC at Super Bowl 50?", 38 | "id": "0", 39 | } 40 | ] 41 | 42 | 43 | @pytest.fixture(scope="module") 44 | def roberta_squad_pipeline_expected_output(): 45 | return { 46 | "score": 0.0002498578105587512, 47 | "answer": SpanTuple( 48 | text="Broncos defeated the National Football Conference (", start=184 49 | ), 50 | } 51 | 52 | 53 | def test_roberta_squad_pipeline_execution( 54 | roberta_squad_pipeline, 55 | roberta_squad_pipeline_sample_input, 56 | roberta_squad_pipeline_expected_output, 57 | ): 58 | actual_output = roberta_squad_pipeline(roberta_squad_pipeline_sample_input) 59 | diff = DeepDiff( 60 | roberta_squad_pipeline_expected_output, actual_output, significant_digits=3 61 | ) 62 | assert ( 63 | diff == {} 64 | ), f"Actual and Desired Dicts are not Almost Equal:\n {json.dumps(diff, indent=2)}" 65 | -------------------------------------------------------------------------------- /tests/training/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/obss/trapper/bbc82450097bd48466d2c47cfbb3bb194319410a/tests/training/__init__.py -------------------------------------------------------------------------------- /tests/training/conftest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from trapper import FIXTURES_ROOT 4 | from trapper.common import Params 5 | 6 | # noinspection PyUnresolvedReferences 7 | # pylint: disable=unused-import 8 | from trapper.common.testing_utils.pytest_fixtures import ( 9 | temp_output_dir, 10 | temp_result_dir, 11 | ) 12 | 13 | HF_DATASETS_FIXTURES_PATH = FIXTURES_ROOT / "hf_datasets" 14 | 15 | 16 | @pytest.fixture(scope="package") 17 | def get_hf_datasets_fixture_path_from_root(): 18 | def _get_hf_datasets_fixture_path(dataset: str) -> str: 19 | return str(HF_DATASETS_FIXTURES_PATH / dataset) 20 | 21 | return _get_hf_datasets_fixture_path 22 | -------------------------------------------------------------------------------- /tests/training/test_basic_trainer_for_question_answering.py: -------------------------------------------------------------------------------- 1 | import datasets 2 | import pytest 3 | from transformers import DistilBertForQuestionAnswering, DistilBertTokenizerFast 4 | 5 | from trapper.common import Params 6 | from trapper.data.data_collator import DataCollator 7 | from trapper.training import TransformerTrainer, TransformerTrainingArguments 8 | from trapper.training.optimizers import HuggingfaceAdamWOptimizer 9 | from trapper.training.train import run_experiment_using_trainer 10 | 11 | 12 | @pytest.fixture(scope="module") 13 | def trainer_params( 14 | temp_output_dir, temp_result_dir, get_hf_datasets_fixture_path_from_root 15 | ): 16 | params_dict = { 17 | "pretrained_model_name_or_path": "distilbert-base-uncased", 18 | "train_split_name": "train", 19 | "dev_split_name": "validation", 20 | "tokenizer_wrapper": {"type": "question-answering"}, 21 | "dataset_loader": { 22 | "dataset_reader": { 23 | "path": get_hf_datasets_fixture_path_from_root( 24 | "squad_qa_test_fixture" 25 | ) 26 | }, 27 | "data_processor": {"type": "squad-question-answering"}, 28 | "data_adapter": {"type": "question-answering"}, 29 | }, 30 | "data_collator": {}, 31 | "model_wrapper": {"type": "question_answering"}, 32 | "compute_metrics": {"metric_params": ["squad"]}, 33 | "metric_input_handler": {"type": "question-answering"}, 34 | "metric_output_handler": {"type": "default"}, 35 | "args": { 36 | "type": "default", 37 | "output_dir": temp_output_dir + "/checkpoints", 38 | "result_dir": temp_result_dir, 39 | "num_train_epochs": 3, 40 | "per_device_train_batch_size": 3, 41 | "per_device_eval_batch_size": 2, 42 | "logging_dir": temp_output_dir + "/logs", 43 | "no_cuda": True, 44 | "logging_steps": 2, 45 | "evaluation_strategy": "steps", 46 | "save_steps": 3, 47 | "label_names": ["start_positions", "end_positions"], 48 | "lr_scheduler_type": "linear", 49 | "warmup_steps": 2, 50 | "do_train": True, 51 | "do_eval": True, 52 | "save_total_limit": 1, 53 | "metric_for_best_model": "eval_loss", 54 | "greater_is_better": False, 55 | }, 56 | "optimizer": { 57 | "type": "huggingface_adamw", 58 | "weight_decay": 0.01, 59 | "parameter_groups": [ 60 | [ 61 | ["bias", "LayerNorm\\\\.weight", "layer_norm\\\\.weight"], 62 | {"weight_decay": 0}, 63 | ] 64 | ], 65 | "lr": 5e-5, 66 | "eps": 1e-8, 67 | }, 68 | } 69 | return Params(params_dict) 70 | 71 | 72 | @pytest.fixture(scope="module") 73 | def trainer(trainer_params) -> TransformerTrainer: 74 | return TransformerTrainer.from_params(trainer_params) 75 | 76 | 77 | def test_trainer_fields(trainer): 78 | assert isinstance(trainer, TransformerTrainer) 79 | assert isinstance(trainer.model, DistilBertForQuestionAnswering) 80 | assert isinstance(trainer.args, TransformerTrainingArguments) 81 | assert isinstance(trainer.data_collator, DataCollator) 82 | assert isinstance(trainer.train_dataset, datasets.Dataset) 83 | assert isinstance(trainer.eval_dataset, datasets.Dataset) 84 | assert isinstance(trainer.tokenizer, DistilBertTokenizerFast) 85 | assert isinstance(trainer.optimizer, HuggingfaceAdamWOptimizer) 86 | 87 | 88 | def test_trainer_can_train(trainer): 89 | run_experiment_using_trainer(trainer) 90 | -------------------------------------------------------------------------------- /trapper/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Exposes trapper subpackages so that they can be accessed through `trapper` 3 | namespace in other projects. 4 | """ 5 | from pathlib import Path 6 | 7 | import trapper.common 8 | import trapper.data 9 | import trapper.metrics 10 | import trapper.models 11 | import trapper.pipelines 12 | import trapper.training 13 | from trapper.version import VERSION as __version__ # noqa 14 | 15 | PROJECT_ROOT = Path(__file__).parent.parent.resolve() 16 | TESTS_ROOT = PROJECT_ROOT / "tests" 17 | FIXTURES_ROOT = PROJECT_ROOT / "test_fixtures" 18 | -------------------------------------------------------------------------------- /trapper/__main__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | sys.path.insert( 5 | 0, os.path.dirname(os.path.abspath(os.path.join(__file__, os.pardir))) 6 | ) 7 | 8 | 9 | def run(): 10 | from trapper.commands import main # noqa 11 | 12 | main(prog="trapper") 13 | 14 | 15 | if __name__ == "__main__": 16 | run() 17 | -------------------------------------------------------------------------------- /trapper/common/__init__.py: -------------------------------------------------------------------------------- 1 | from trapper.common.lazy import Lazy 2 | from trapper.common.params import Params 3 | from trapper.common.registrable import Registrable 4 | -------------------------------------------------------------------------------- /trapper/common/constants.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module contains constants like the textual representations of the special 3 | tokens commonly used while training the transformer models on NLP tasks. 4 | Moreover, some NamedTuple and type definitions are supplied for convenience while 5 | working on tasks dealing with spans. 6 | """ 7 | import sys 8 | from pathlib import Path 9 | from typing import NamedTuple, Union 10 | 11 | if sys.version_info >= (3, 8): 12 | from typing import TypedDict # pylint: disable=no-name-in-module 13 | else: 14 | from typing_extensions import TypedDict 15 | 16 | IGNORED_LABEL_ID = -100 # automatically ignored by PyTorch loss functions 17 | # Common special tokens 18 | CLS_TOKEN = "[CLS]" 19 | SEP_TOKEN = "[SEP]" 20 | BOS_TOKEN = "" 21 | EOS_TOKEN = "" 22 | PAD_TOKEN = "" 23 | MASK_TOKEN = "" 24 | UNK_TOKEN = "" 25 | # Task-specific special tokens 26 | CONTEXT_TOKEN = "" 27 | ANSWER_TOKEN = "" 28 | 29 | Pathlike = Union[Path, str] 30 | 31 | 32 | class SpanTuple(NamedTuple): 33 | text: str 34 | start: int 35 | 36 | def to_dict(self): 37 | return dict(self._asdict()) 38 | 39 | 40 | class SpanDict(TypedDict): 41 | text: str 42 | start: int 43 | 44 | 45 | class PositionTuple(NamedTuple): 46 | start: int 47 | end: int 48 | 49 | def to_dict(self): 50 | return dict(self._asdict()) 51 | 52 | 53 | class PositionDict(TypedDict): 54 | start: int 55 | end: int 56 | -------------------------------------------------------------------------------- /trapper/common/io.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import pickle 4 | from typing import Any, Dict, Union 5 | 6 | 7 | def json_load(fp: str) -> Union[Dict, None]: 8 | """Try loading json and return parsed dictionary. Returns None if file does not exists, 9 | or in case of a serialization error.""" 10 | try: 11 | with open(fp, "r") as jf: 12 | _json = json.load(jf) 13 | except FileNotFoundError: 14 | return 15 | except json.JSONDecodeError: 16 | return 17 | else: 18 | return _json 19 | 20 | 21 | def json_save(obj: Dict, fp: str, overwrite: bool = True) -> None: 22 | """Saves a dictionary as json file to given fp.""" 23 | if os.path.exists(fp) and not overwrite: 24 | raise ValueError( 25 | f"Path {fp} already exists. To overwrite, use `overwrite=True`." 26 | ) 27 | 28 | with open(fp, "w") as jf: 29 | json.dump(obj, jf) 30 | 31 | 32 | def pickle_load(fp: str) -> Any: 33 | try: 34 | with open(fp, "rb") as pkl: 35 | _obj = pickle.load(pkl) 36 | except FileNotFoundError: 37 | return 38 | else: 39 | return _obj 40 | 41 | 42 | def pickle_save(obj: Dict, fp: str, overwrite: bool = True) -> None: 43 | """Saves a dictionary as json file to given fp.""" 44 | if os.path.exists(fp) and not overwrite: 45 | raise ValueError( 46 | f"Path {fp} already exists. To overwrite, use overwrite=True." 47 | ) 48 | 49 | with open(fp, "wb") as pkl: 50 | pickle.dump(obj, pkl) 51 | -------------------------------------------------------------------------------- /trapper/common/lazy.py: -------------------------------------------------------------------------------- 1 | """ 2 | `Lazy` class is adapted from `allennlp` for use when constructing objects using 3 | `FromParams`, when an argument to a constructor has a sequential dependency with 4 | another argument to the same constructor. See :py:class:`allennlp.common.lazy.Lazy` 5 | for further details. 6 | """ 7 | from allennlp.common.lazy import Lazy as _Lazy 8 | 9 | Lazy = _Lazy 10 | -------------------------------------------------------------------------------- /trapper/common/notebook_utils/__init__.py: -------------------------------------------------------------------------------- 1 | from trapper.common.notebook_utils.file_transfer import download_from_url 2 | from trapper.common.notebook_utils.prepare_data import prepare_data 3 | -------------------------------------------------------------------------------- /trapper/common/notebook_utils/file_transfer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import urllib.request 3 | from pathlib import Path 4 | from typing import Optional 5 | 6 | 7 | def download_from_url(url: str, destination: Optional[str] = None) -> None: 8 | """ 9 | Utility function to download data from a specified url. 10 | 11 | Args: 12 | url: Source url of data to be downloaded. 13 | destination: Destination where the downloaded data is placed. If None, 14 | base name of the url is used, i.e if url="a/b.txt", it will be 15 | downloaded to "./b.txt". 16 | """ 17 | if destination is None: 18 | destination = os.path.basename(url) 19 | 20 | Path(destination).parent.mkdir(parents=True, exist_ok=True) 21 | 22 | if not os.path.exists(destination): 23 | urllib.request.urlretrieve( 24 | url, 25 | destination, 26 | ) 27 | -------------------------------------------------------------------------------- /trapper/common/notebook_utils/prepare_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from trapper.common.notebook_utils.file_transfer import download_from_url 4 | 5 | FIXTURES_PATH = "squad_qa_test_fixture" 6 | SQUAD_QA_FIXTURES = { 7 | "dev.json": "https://raw.githubusercontent.com/obss/trapper/main/test_fixtures/hf_datasets/squad_qa_test_fixture/dev.json", 8 | "train.json": "https://raw.githubusercontent.com/obss/trapper/main/test_fixtures/hf_datasets/squad_qa_test_fixture/train.json", 9 | "squad_qa_test_fixture.py": "https://raw.githubusercontent.com/obss/trapper/main/test_fixtures/hf_datasets/squad_qa_test_fixture/squad_qa_test_fixture.py", 10 | } 11 | EXPERIMENT_CONFIG = "https://raw.githubusercontent.com/obss/trapper/main/examples/question_answering/experiment.jsonnet" 12 | 13 | 14 | def download_fixture_data(): 15 | for file, url in SQUAD_QA_FIXTURES.items(): 16 | destination = os.path.join(FIXTURES_PATH, file) 17 | download_from_url(url, destination) 18 | 19 | 20 | def fetch_experiment_config(): 21 | download_from_url(EXPERIMENT_CONFIG) 22 | 23 | 24 | def prepare_data(): 25 | download_fixture_data() 26 | fetch_experiment_config() 27 | -------------------------------------------------------------------------------- /trapper/common/params.py: -------------------------------------------------------------------------------- 1 | from allennlp.common import Params as _Params 2 | 3 | from trapper.common.utils import append_parent_docstr 4 | 5 | 6 | @append_parent_docstr 7 | class Params(_Params): 8 | """ 9 | This class adapts the `Params` class from `allennlp`. 10 | """ 11 | -------------------------------------------------------------------------------- /trapper/common/plugins.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Open Business Software Solutions, the AllenNLP library authors. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | # Plugin discovery and import functions. 16 | 17 | trapper supports registering classes from custom modules/packages etc written by the 18 | user. Parts of this file is adapted from the AllenNLP library at 19 | https://github.com/allenai/allennlp. 20 | """ 21 | 22 | import importlib 23 | import os 24 | import sys 25 | from typing import Iterable, Set 26 | 27 | from allennlp.common.plugins import discover_file_plugins 28 | from allennlp.common.util import push_python_path 29 | from transformers.utils import logging 30 | 31 | logger = logging.get_logger(__name__) 32 | 33 | LOCAL_PLUGINS_FILENAME = ".trapper_plugins" 34 | """ 35 | Local plugin files should have this name. 36 | """ 37 | 38 | 39 | def discover_plugins() -> Iterable[str]: 40 | """ 41 | Returns an iterable of the plugins found in the local plugin file. 42 | """ 43 | plugins: Set[str] = set() 44 | if os.path.isfile(LOCAL_PLUGINS_FILENAME): 45 | with push_python_path("."): 46 | for plugin in discover_file_plugins(LOCAL_PLUGINS_FILENAME): 47 | if plugin in plugins: 48 | continue 49 | yield plugin 50 | plugins.add(plugin) 51 | 52 | 53 | def import_plugins() -> None: 54 | """ 55 | Imports the plugins found with `discover_plugins()` i.e. the custom 56 | registrable components written by the user. 57 | """ 58 | # For a presumed Python issue that makes the spawned processes unable 59 | # to find modules in the current directory. 60 | cwd = os.getcwd() 61 | if cwd not in sys.path: 62 | sys.path.append(cwd) 63 | 64 | for module_name in discover_plugins(): 65 | try: 66 | importlib.import_module(module_name) 67 | logger.info("Plugin %s available", module_name) 68 | except ModuleNotFoundError as e: 69 | logger.error(f"Plugin {module_name} could not be loaded: {e}") 70 | -------------------------------------------------------------------------------- /trapper/common/registrable.py: -------------------------------------------------------------------------------- 1 | """ 2 | `trapper.common.registrable.Registrable` is a "mixin" for allowing registering base 3 | classes and their subclasses so that they can be created by using their 4 | registered name and constructor arguments. 5 | """ 6 | 7 | from collections import defaultdict 8 | from typing import ClassVar, DefaultDict 9 | 10 | from allennlp.common.registrable import Registrable as _Registrable 11 | from allennlp.common.registrable import _SubclassRegistry 12 | 13 | from trapper.common.utils import append_parent_docstr 14 | 15 | 16 | @append_parent_docstr 17 | class Registrable(_Registrable): 18 | """ 19 | This class is created to get the registry system from the `allennlp` library 20 | without the built-in classes registered in `allennlp`. To create a fresh, 21 | independent registry, we simply extend the `allennlp`'s Registrable class 22 | and override the class variable `_registry`, which is the actual internal 23 | registry object. 24 | """ 25 | 26 | _registry: ClassVar[DefaultDict[type, _SubclassRegistry]] = defaultdict(dict) 27 | -------------------------------------------------------------------------------- /trapper/common/testing_utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/obss/trapper/bbc82450097bd48466d2c47cfbb3bb194319410a/trapper/common/testing_utils/__init__.py -------------------------------------------------------------------------------- /trapper/common/testing_utils/hf_datasets_caching.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | from pathlib import Path 4 | 5 | from trapper.common.constants import Pathlike 6 | from trapper.common.testing_utils.shell_utils import shell, validate_and_exit 7 | 8 | 9 | def get_hf_cache_dir() -> Path: 10 | return Path( 11 | os.environ.get("HF_DATASETS_CACHE") or Path.home() / ".cache/huggingface" 12 | ) 13 | 14 | 15 | def remove_hf_datasets_fixtures_cache(hf_datasets_fixtures_path: Pathlike) -> None: 16 | hf_datasets_fixtures_path = Path(hf_datasets_fixtures_path) 17 | hf_cache_dir = get_hf_cache_dir() 18 | hf_cached_datasets_dir = hf_cache_dir / "datasets" 19 | hf_cached_dataset_modules_dir = ( 20 | hf_cache_dir / "modules/datasets_modules/datasets" 21 | ) 22 | 23 | for fixture_dataset in hf_datasets_fixtures_path.glob("*"): 24 | # Remove from the original fixture directory 25 | try: 26 | os.remove(fixture_dataset / "dataset_infos.json") 27 | for f in fixture_dataset.glob("*.lock"): 28 | os.remove(f) 29 | except: 30 | pass 31 | 32 | # Remove from the global HuggingFace dataset cache 33 | for p in hf_cached_datasets_dir.glob(f"*{fixture_dataset.name}*"): 34 | if os.path.isfile(p): 35 | os.remove(p) 36 | shutil.rmtree(p, ignore_errors=True) 37 | 38 | # Remove from the global HuggingFace datasets modules cache 39 | shutil.rmtree( 40 | hf_cached_dataset_modules_dir / f"{fixture_dataset.name}", 41 | ignore_errors=True, 42 | ) 43 | 44 | 45 | def cache_hf_datasets_fixtures(hf_datasets_fixtures_path: Pathlike) -> None: 46 | hf_datasets_fixtures_path = str(hf_datasets_fixtures_path) 47 | commands = {} 48 | for d in os.listdir(hf_datasets_fixtures_path): 49 | commands[f"cache_{d}"] = shell( 50 | f"datasets-cli test {hf_datasets_fixtures_path}/{d} --save_infos --all_configs" 51 | ) 52 | validate_and_exit(**commands) 53 | 54 | 55 | def renew_hf_datasets_fixtures_cache(hf_datasets_fixtures_path: Pathlike) -> None: 56 | remove_hf_datasets_fixtures_cache(hf_datasets_fixtures_path) 57 | cache_hf_datasets_fixtures(hf_datasets_fixtures_path) 58 | -------------------------------------------------------------------------------- /trapper/common/testing_utils/pytest_fixtures/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities and helpers for writing tests. You can import the fixture modules into 3 | the `conftest.py` file under the appropriate test directory inside your test 4 | folder. E.g., you can import `trapper.common.pytest_fixtures.data` inside your 5 | `tests/data/conftest.py` file assuming that `tests/data` is the package 6 | containing the tests related to the custom data processing classes such as data 7 | processors and collators. 8 | """ 9 | from trapper.common.testing_utils.pytest_fixtures.data import ( 10 | create_data_collator_args, 11 | create_data_processor_args, 12 | get_raw_dataset, 13 | make_data_collator, 14 | make_sequential_sampler, 15 | ) 16 | from trapper.common.testing_utils.pytest_fixtures.training import ( 17 | temp_output_dir, 18 | temp_result_dir, 19 | ) 20 | -------------------------------------------------------------------------------- /trapper/common/testing_utils/pytest_fixtures/data.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Dict, Optional, Type, Union 3 | 4 | import pytest 5 | from datasets import DownloadConfig 6 | from torch.utils.data import SequentialSampler 7 | from transformers.trainer_pt_utils import SequentialDistributedSampler 8 | 9 | from trapper.data import DataCollator, DatasetReader, TokenizerWrapper 10 | from trapper.data.dataset_reader import TrapperDataset, TrapperDatasetDict 11 | from trapper.models.auto_wrappers import _TASK_TO_INPUT_FIELDS 12 | 13 | 14 | @dataclass(frozen=True) 15 | class DatasetReaderIdentifier: 16 | path: str 17 | name: str 18 | 19 | 20 | @dataclass(frozen=True) 21 | class RawDatasetIdentifier(DatasetReaderIdentifier): 22 | split: str 23 | 24 | 25 | @pytest.fixture(scope="session") 26 | def get_raw_dataset(): 27 | cached_readers: Dict[DatasetReaderIdentifier, DatasetReader] = {} 28 | cached_datasets: Dict[ 29 | RawDatasetIdentifier, Union[TrapperDataset, TrapperDatasetDict] 30 | ] = {} 31 | 32 | def _get_raw_dataset( 33 | path: str, name: Optional[str] = None, split: Optional[str] = None 34 | ) -> Union[TrapperDataset, TrapperDatasetDict]: 35 | """ 36 | Returns the specified dataset split for testing purposes. 37 | 38 | Args: 39 | path (): a local path to processing script or the directory containing 40 | the script, or a dataset identifier in the HF Datasets Hub 41 | name (): dataset configuration name, if available. 42 | split (): one of "train", "validation" or "dev". If `None`, will 43 | return a dict with all available splits. 44 | """ 45 | reader_identifier = DatasetReaderIdentifier(path=path, name=name) 46 | if reader_identifier in cached_readers: 47 | reader = cached_readers[reader_identifier] 48 | else: 49 | reader = DatasetReader( 50 | path=path, 51 | name=name, 52 | download_config=DownloadConfig(local_files_only=True), 53 | ) 54 | cached_readers[reader_identifier] = reader 55 | 56 | dataset_identifier = RawDatasetIdentifier(path=path, name=name, split=split) 57 | if dataset_identifier in cached_datasets: 58 | dataset = cached_datasets[dataset_identifier] 59 | else: 60 | dataset = reader.read(split) 61 | cached_datasets[dataset_identifier] = dataset 62 | 63 | return dataset 64 | 65 | return _get_raw_dataset 66 | 67 | 68 | @dataclass 69 | class DataProcessorArguments: 70 | model_max_sequence_length: int = None 71 | tokenizer_factory: Type[TokenizerWrapper] = TokenizerWrapper 72 | tokenizer_model_name: str = "roberta-base" 73 | tokenizer_kwargs: Dict = None 74 | 75 | def __post_init__(self): 76 | if "uncased" in self.tokenizer_model_name: 77 | self.is_tokenizer_uncased = True 78 | else: 79 | self.is_tokenizer_uncased = False 80 | tokenizer_kwargs = self.tokenizer_kwargs or {} 81 | self.tokenizer_wrapper = self.tokenizer_factory.from_pretrained( 82 | self.tokenizer_model_name, **tokenizer_kwargs 83 | ) 84 | del self.tokenizer_factory, self.tokenizer_kwargs 85 | 86 | 87 | @pytest.fixture(scope="session") 88 | def create_data_processor_args(): 89 | def _create_data_processor_args( 90 | model_max_sequence_length: int = None, 91 | tokenizer_factory: Type[TokenizerWrapper] = TokenizerWrapper, 92 | tokenizer_model_name: str = "roberta-base", 93 | **tokenizer_kwargs, 94 | ) -> DataProcessorArguments: 95 | return DataProcessorArguments( 96 | model_max_sequence_length=model_max_sequence_length, 97 | tokenizer_factory=tokenizer_factory, 98 | tokenizer_model_name=tokenizer_model_name, 99 | tokenizer_kwargs=tokenizer_kwargs, 100 | ) 101 | 102 | return _create_data_processor_args 103 | 104 | 105 | @dataclass 106 | class DataCollatorArguments(DataProcessorArguments): 107 | is_distributed: bool = False 108 | train_batch_size: int = 2 109 | task_type: str = "question_answering" 110 | validation_batch_size: int = 1 111 | 112 | def __post_init__(self): 113 | super().__post_init__() 114 | self.model_forward_params = _TASK_TO_INPUT_FIELDS[self.task_type] 115 | 116 | 117 | @pytest.fixture(scope="session") 118 | def create_data_collator_args(): 119 | def _create_data_collator_args( 120 | train_batch_size: int, 121 | validation_batch_size: int, 122 | model_max_sequence_length: int = None, 123 | tokenizer_factory: Type[TokenizerWrapper] = TokenizerWrapper, 124 | tokenizer_model_name: str = "roberta-base", 125 | task_type: str = "question_answering", 126 | is_distributed: bool = False, 127 | **tokenizer_kwargs, 128 | ) -> DataProcessorArguments: 129 | return DataCollatorArguments( 130 | model_max_sequence_length=model_max_sequence_length, 131 | tokenizer_factory=tokenizer_factory, 132 | tokenizer_kwargs=tokenizer_kwargs, 133 | train_batch_size=train_batch_size, 134 | validation_batch_size=validation_batch_size, 135 | tokenizer_model_name=tokenizer_model_name, 136 | task_type=task_type, 137 | is_distributed=is_distributed, 138 | ) 139 | 140 | return _create_data_collator_args 141 | 142 | 143 | @pytest.fixture(scope="session") 144 | def make_data_collator(): 145 | def _make_data_collator(args: DataCollatorArguments): 146 | return DataCollator(args.tokenizer_wrapper, args.model_forward_params) 147 | 148 | return _make_data_collator 149 | 150 | 151 | @pytest.fixture(scope="session") 152 | def make_sequential_sampler(): 153 | def _make_sequential_sampler(is_distributed: bool, dataset: TrapperDataset): 154 | if is_distributed: 155 | return SequentialDistributedSampler(dataset) 156 | return SequentialSampler(dataset) # type: ignore[arg-type] 157 | 158 | return _make_sequential_sampler 159 | -------------------------------------------------------------------------------- /trapper/common/testing_utils/pytest_fixtures/training.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | 4 | @pytest.fixture(scope="module") 5 | def temp_output_dir(tmpdir_factory): 6 | return str(tmpdir_factory.mktemp("outputs")) 7 | 8 | 9 | @pytest.fixture(scope="module") 10 | def temp_result_dir(tmpdir_factory): 11 | return str(tmpdir_factory.mktemp("results")) 12 | -------------------------------------------------------------------------------- /trapper/common/testing_utils/shell_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import sys 4 | 5 | 6 | def shell(command, exit_status=0): 7 | """ 8 | Run command through shell and return exit status if exit status of command run 9 | match with given exit status. 10 | 11 | Args: 12 | command: (str) Command string which runs through system shell. 13 | exit_status: (int) Expected exit status of given command run. 14 | 15 | Returns: actual_exit_status 16 | 17 | """ 18 | actual_exit_status = os.system(command) 19 | if actual_exit_status == exit_status: 20 | return 0 21 | return actual_exit_status 22 | 23 | 24 | def validate_and_exit(expected_out_status=0, **kwargs): 25 | if all([arg == expected_out_status for arg in kwargs.values()]): 26 | # Expected status, OK 27 | sys.exit(0) 28 | else: 29 | # Failure 30 | print_console_centered("Summary Results") 31 | fail_count = 0 32 | for component, exit_status in kwargs.items(): 33 | if exit_status != expected_out_status: 34 | print(f"{component} failed.") 35 | fail_count += 1 36 | print_console_centered( 37 | f"{len(kwargs)-fail_count} success, {fail_count} failure" 38 | ) 39 | sys.exit(1) 40 | 41 | 42 | def print_console_centered(text: str, fill_char="="): 43 | w, _ = shutil.get_terminal_size((80, 20)) 44 | print(f" {text} ".center(w, fill_char)) 45 | -------------------------------------------------------------------------------- /trapper/common/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Various utilities for working on data, docstrings etc while using trapper. 3 | """ 4 | from typing import Callable, Dict, List, Type, Union 5 | 6 | from deepdiff import DeepDiff 7 | 8 | from trapper.common.constants import SpanDict, SpanTuple 9 | 10 | 11 | def convert_spandict_to_spantuple(span: SpanDict) -> SpanTuple: 12 | return SpanTuple(text=span["text"], start=span["start"]) 13 | 14 | 15 | def get_docstr(callable_: Union[Type, Callable]) -> str: 16 | """Returns the docstring of the argument or empty string if it does not 17 | have any docstring""" 18 | cls_doc = getattr(callable_, "__doc__", None) 19 | return "" if cls_doc is None else cls_doc 20 | 21 | 22 | def append_parent_docstr(cls: Type = None, parent_id: int = 0): 23 | """ 24 | A decorator that appends the docstring of the decorated class' first parent 25 | into the decorated class' docstring. 26 | 27 | Args: 28 | cls : decorated class 29 | parent_id : the order of the parent in class definition, 30 | starting from 0. (default=0) 31 | """ 32 | 33 | def cls_wrapper(_cls: Type) -> Type: 34 | first_parent = _cls.__bases__[parent_id] 35 | _cls.__doc__ = get_docstr(_cls) + get_docstr(first_parent) 36 | return _cls 37 | 38 | if cls is None: 39 | return cls_wrapper 40 | return cls_wrapper(cls) 41 | 42 | 43 | def append_callable_docstr( 44 | cls: Type = None, callable_: Union[Type, Callable] = None 45 | ): 46 | """ 47 | A decorator that appends the docstring of a callable into the decorated class' 48 | docstring. 49 | 50 | Args: 51 | cls (): decorated class 52 | callable_ (): The class or function whose docstring is appended 53 | 54 | Returns: 55 | 56 | """ 57 | 58 | def cls_wrapper(_cls: Type) -> Type: 59 | _cls.__doc__ = get_docstr(_cls) + get_docstr(callable_) 60 | return _cls 61 | 62 | if cls is None: 63 | return cls_wrapper 64 | return cls_wrapper(cls) 65 | 66 | 67 | def add_property(inst, name_to_method: Dict[str, Callable]): 68 | """Dynamically add new properties to an instance by creating a new class 69 | for the instance that has the additional properties""" 70 | cls = type(inst) 71 | # Avoid creating a new class for the inst if it was already done before 72 | if not hasattr(cls, "__perinstance"): 73 | cls = type(cls.__name__, (cls,), {}) 74 | cls.__perinstance = True 75 | inst.__class__ = cls 76 | 77 | for name, method in name_to_method.items(): 78 | setattr(cls, name, property(method)) 79 | 80 | 81 | def is_equal(x: Union[Dict, List], y: Union[Dict, List]) -> bool: 82 | """Checks equality of two nested container type e.g. list or dict""" 83 | return not DeepDiff(x, y) 84 | -------------------------------------------------------------------------------- /trapper/data/__init__.py: -------------------------------------------------------------------------------- 1 | from trapper.data.data_adapters import DataAdapter, DataAdapterForQuestionAnswering 2 | from trapper.data.data_collator import DataCollator, InputBatch, InputBatchTensor 3 | from trapper.data.data_processors import DataProcessor, SquadDataProcessor 4 | from trapper.data.data_processors.data_processor import IndexedInstance 5 | from trapper.data.dataset_loader import DatasetLoader 6 | from trapper.data.dataset_reader import DatasetReader 7 | from trapper.data.label_mapper import LabelMapper 8 | from trapper.data.tokenizers import TokenizerWrapper 9 | -------------------------------------------------------------------------------- /trapper/data/data_adapters/__init__.py: -------------------------------------------------------------------------------- 1 | from trapper.data.data_adapters.data_adapter import DataAdapter 2 | from trapper.data.data_adapters.question_answering_adapter import ( 3 | DataAdapterForQuestionAnswering, 4 | ) 5 | -------------------------------------------------------------------------------- /trapper/data/data_adapters/data_adapter.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Optional 3 | 4 | from transformers import PreTrainedTokenizerBase 5 | 6 | from trapper.common import Registrable 7 | from trapper.data.data_processors.data_processor import IndexedInstance 8 | from trapper.data.label_mapper import LabelMapper 9 | from trapper.data.tokenizers import TokenizerWrapper 10 | 11 | 12 | class DataAdapter(ABC, Registrable): 13 | """ 14 | This callable class is responsible from converting the data instances that 15 | are already tokenized and indexed into a format suitable for feeding into a 16 | transformer model. Typically, it receives its inputs from a `DataProcessor` 17 | and adapts the input fields by renaming them to the names accepted by the 18 | models e.g. "input_ids", "token_type_ids" etc. Moreover, it also handles 19 | the insertion of the special tokens signaling the beginning or ending of a 20 | sequence such as `[CLS]`, `[SEP]` etc. To extend this class, you need to 21 | implement the `__call__` method as suitable for your task. See 22 | `DataAdapterForQuestionAnswering` for an example. 23 | 24 | Args: 25 | tokenizer_wrapper (): Required to access the ids of special tokens such 26 | as BOS, EOS etc 27 | label_mapper (): Only used in some tasks that require mapping between 28 | categorical labels and integer ids such as token classification. 29 | """ 30 | 31 | def __init__( 32 | self, 33 | tokenizer_wrapper: TokenizerWrapper, 34 | label_mapper: Optional[LabelMapper] = None, 35 | ): 36 | tokenizer = tokenizer_wrapper.tokenizer 37 | self._bos_token_id: int = tokenizer.bos_token_id 38 | self._eos_token_id: int = tokenizer.eos_token_id 39 | self._tokenizer: PreTrainedTokenizerBase = tokenizer 40 | self._label_mapper = label_mapper 41 | 42 | @property 43 | def tokenizer(self) -> PreTrainedTokenizerBase: 44 | return self._tokenizer 45 | 46 | @abstractmethod 47 | def __call__(self, instance: IndexedInstance) -> IndexedInstance: 48 | """ 49 | Takes a raw `IndexedInstance`, performs some processing on it, and returns 50 | an `IndexedInstance` again. Typically, you should return an instance having 51 | the fields specified in 52 | `trapper.models.auto_wrappers._TASK_TO_INPUT_FIELDS[task]` (except for 53 | "attention_mask" since it will be generated by the DataCollator for you) 54 | where `task` is your choice of modeling the downstream task you want to 55 | solve using `transformers`. Each task in the `_TASK_TO_INPUT_FIELDS` 56 | corresponds to an `AutoModel...` class implemented in `transformers`. 57 | 58 | You may look at 59 | `DataAdapterForQuestionAnswering.__call__` for an example. 60 | 61 | Args: 62 | instance (): 63 | """ 64 | raise NotImplementedError 65 | -------------------------------------------------------------------------------- /trapper/data/data_adapters/question_answering_adapter.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from trapper.data.data_adapters.data_adapter import DataAdapter 4 | from trapper.data.data_processors import IndexedInstance 5 | 6 | 7 | @DataAdapter.register("question-answering") 8 | class DataAdapterForQuestionAnswering(DataAdapter): 9 | """ 10 | `DataAdapterForQuestionAnswering` can be used in SQuAD style question 11 | answering tasks that involves a context, question and answer. 12 | 13 | Args: 14 | tokenizer_wrapper (): Required to access the ids of BOS and EOS tokens 15 | """ 16 | 17 | CONTEXT_TOKEN_TYPE_ID = 0 18 | QUESTION_TOKEN_TYPE_ID = 1 19 | 20 | def __call__(self, raw_instance: IndexedInstance) -> IndexedInstance: 21 | """ 22 | Create a sequence with the following fields: 23 | input_ids: ...context_toks... ...question_toks... 24 | token_type_ids: 0 for context tokens, 1 for question tokens. 25 | """ 26 | instance = self._build_context(raw_instance) 27 | self._append_separator_token(instance) 28 | self._append_question_tokens(instance=instance, raw_instance=raw_instance) 29 | self._append_ending_token(instance) 30 | return instance 31 | 32 | def _build_context(self, raw_instance: IndexedInstance) -> IndexedInstance: 33 | context_tokens: List[int] = raw_instance["context"] 34 | input_ids = [self._bos_token_id] + context_tokens 35 | token_type_ids = self._context_token_type_ids(context_tokens) 36 | instance = {"input_ids": input_ids, "token_type_ids": token_type_ids} 37 | self._handle_answer_span(instance, raw_instance) 38 | return instance 39 | 40 | def _append_separator_token(self, instance: IndexedInstance): 41 | self._extend_token_ids( 42 | instance=instance, 43 | token_type_id=self.CONTEXT_TOKEN_TYPE_ID, 44 | input_ids=[self._eos_token_id], 45 | ) 46 | 47 | def _append_question_tokens( 48 | self, instance: IndexedInstance, raw_instance: IndexedInstance 49 | ): 50 | self._extend_token_ids( 51 | instance=instance, 52 | token_type_id=self.QUESTION_TOKEN_TYPE_ID, 53 | input_ids=raw_instance["question"], 54 | ) 55 | 56 | def _append_ending_token(self, instance: IndexedInstance): 57 | self._extend_token_ids( 58 | instance=instance, 59 | token_type_id=self.QUESTION_TOKEN_TYPE_ID, 60 | input_ids=[self._eos_token_id], 61 | ) 62 | 63 | @staticmethod 64 | def _extend_token_ids( 65 | instance: IndexedInstance, token_type_id: int, input_ids: List[int] 66 | ): 67 | instance["input_ids"].extend(input_ids) 68 | token_type_ids = [token_type_id] * len(input_ids) 69 | instance["token_type_ids"].extend(token_type_ids) 70 | 71 | def _context_token_type_ids(self, context_tokens: List[int]) -> List[int]: 72 | # handle segment encoding of the tokens inside the context 73 | token_type_ids = [ 74 | self.CONTEXT_TOKEN_TYPE_ID for i in range(len(context_tokens)) 75 | ] 76 | token_type_ids.insert(0, self.CONTEXT_TOKEN_TYPE_ID) # bos 77 | return token_type_ids 78 | 79 | @staticmethod 80 | def _handle_answer_span( 81 | instance: IndexedInstance, raw_instance: IndexedInstance 82 | ): 83 | if "answer_position_tokenized" in raw_instance: 84 | answer_position = raw_instance["answer_position_tokenized"] 85 | # Account for the extra BOS token in the beginning of the context 86 | instance["start_positions"] = answer_position["start"] + 1 87 | instance["end_positions"] = answer_position["end"] + 1 88 | -------------------------------------------------------------------------------- /trapper/data/data_collator.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from typing import Dict, Iterable, List, Optional, Tuple, Union 3 | 4 | import torch 5 | from torch import Tensor 6 | from transformers import PreTrainedTokenizerBase 7 | 8 | from trapper.common import Registrable 9 | from trapper.common.constants import IGNORED_LABEL_ID 10 | from trapper.data.data_processors.data_processor import IndexedInstance 11 | from trapper.data.tokenizers import TokenizerWrapper 12 | 13 | InputBatch = Dict[str, List[Union[int, List[int]]]] 14 | InputBatchTensor = Dict[str, Tensor] 15 | 16 | 17 | class DataCollator(Registrable): 18 | """ 19 | This class takes a batch of `IndexedInstance`s, typically generated using a 20 | dataset reader, and returns an `InputBatchTensor`. It is responsible from 21 | padding the required instances, collect them into a batch and convert it into a 22 | `InputBatchTensor`. Moreover, it also creates the attention_mask and inserts 23 | it into the batch if required. It is used as a callable just like the 24 | `DataCollator` in the `transformers` library. 25 | 26 | Args: 27 | tokenizer (): 28 | model_forward_params (): 29 | """ 30 | 31 | default_implementation = "default" 32 | 33 | def __init__( 34 | self, 35 | tokenizer_wrapper: TokenizerWrapper, 36 | model_forward_params: Tuple[str, ...], 37 | ): 38 | self._tokenizer: PreTrainedTokenizerBase = tokenizer_wrapper.tokenizer 39 | self._model_forward_params: Tuple[str, ...] = model_forward_params 40 | 41 | def __call__( 42 | self, 43 | instances: Iterable[IndexedInstance], 44 | should_eliminate_model_incompatible_keys: bool = True, 45 | ) -> InputBatchTensor: 46 | """Prepare the dataset for training and evaluation""" 47 | batch = self.build_model_inputs( 48 | instances, 49 | should_eliminate_model_incompatible_keys=should_eliminate_model_incompatible_keys, 50 | ) 51 | self.pad(batch) 52 | return self._convert_to_tensor(batch) 53 | 54 | def build_model_inputs( 55 | self, 56 | instances: Iterable[IndexedInstance], 57 | return_attention_mask: Optional[bool] = None, 58 | should_eliminate_model_incompatible_keys: bool = True, 59 | ) -> InputBatch: 60 | return_attention_mask = ( 61 | return_attention_mask or "attention_mask" in self._model_forward_params 62 | ) 63 | batch = self._create_empty_batch() 64 | for instance in instances: 65 | if should_eliminate_model_incompatible_keys: 66 | self._eliminate_model_incompatible_keys(instance) 67 | if return_attention_mask: 68 | self._add_attention_mask(instance) 69 | self._add_instance(batch, instance) 70 | 71 | self._eliminate_empty_inputs(batch) 72 | return batch 73 | 74 | @staticmethod 75 | def _create_empty_batch() -> InputBatch: 76 | return defaultdict(list) 77 | 78 | def _eliminate_model_incompatible_keys(self, instance: IndexedInstance): 79 | incompatible_keys = [ 80 | key for key in instance.keys() if key not in self._model_forward_params 81 | ] 82 | for key in incompatible_keys: 83 | del instance[key] 84 | 85 | @staticmethod 86 | def _add_attention_mask( 87 | instance: IndexedInstance, 88 | ): 89 | if "attention_mask" not in instance: 90 | instance["attention_mask"] = [1] * len(instance["input_ids"]) 91 | 92 | @staticmethod 93 | def _add_instance(batch: InputBatch, instance: IndexedInstance): 94 | for field_name, encodings in instance.items(): 95 | batch[field_name].append(encodings) 96 | 97 | @staticmethod 98 | def _eliminate_empty_inputs(batch: InputBatch): 99 | incompatible_keys = [key for key, val in batch.items() if len(val) == 0] 100 | for key in incompatible_keys: 101 | del batch[key] 102 | 103 | def pad( 104 | self, 105 | batch: InputBatch, 106 | max_length: int = None, 107 | padding_side: str = "right", 108 | ): 109 | for feature_key, feature_values in batch.items(): 110 | if not isinstance(feature_values[0], int): # use str keys 111 | max_seq_len = max(len(ids) for ids in feature_values) 112 | padded_len = max_seq_len if max_length is None else max_length 113 | pad_id = self._pad_id(feature_key) 114 | self._pad_encodings( 115 | encodings=feature_values, 116 | pad_id=pad_id, 117 | padded_len=padded_len, 118 | padding_side=padding_side, 119 | ) 120 | 121 | def _pad_id(self, padded_field: str) -> int: 122 | if padded_field == "input_ids": 123 | pad_id = self._tokenizer.pad_token_id 124 | elif padded_field == "token_type_ids": 125 | pad_id = self._tokenizer.pad_token_type_id 126 | elif padded_field == "labels": 127 | pad_id = IGNORED_LABEL_ID 128 | elif padded_field == "attention_mask": 129 | pad_id = 0 130 | elif padded_field == "special_tokens_mask": 131 | pad_id = 1 132 | else: 133 | raise ValueError(f"{padded_field} is not a valid field for padding") 134 | return pad_id 135 | 136 | @staticmethod 137 | def _pad_encodings( 138 | encodings: List[List[int]], 139 | pad_id: int, 140 | padded_len: int, 141 | padding_side: str = "right", 142 | ): 143 | for i, encoded_inputs in enumerate(encodings): 144 | difference = padded_len - len(encoded_inputs) 145 | if difference == 0: 146 | continue 147 | pad_values = [pad_id] * difference 148 | if padding_side == "right": 149 | encoded_inputs.extend(pad_values) 150 | elif padding_side == "left": 151 | encodings[i] = pad_values + encoded_inputs 152 | 153 | @staticmethod 154 | def _convert_to_tensor(batch: InputBatch) -> InputBatchTensor: 155 | return { 156 | input_key: torch.tensor(encodings) 157 | for input_key, encodings in batch.items() 158 | } 159 | 160 | 161 | DataCollator.register("default")(DataCollator) 162 | -------------------------------------------------------------------------------- /trapper/data/data_processors/__init__.py: -------------------------------------------------------------------------------- 1 | from trapper.data.data_processors.data_processor import ( 2 | DataProcessor, 3 | IndexedInstance, 4 | ) 5 | from trapper.data.data_processors.squad import SquadDataProcessor 6 | -------------------------------------------------------------------------------- /trapper/data/data_processors/squad/__init__.py: -------------------------------------------------------------------------------- 1 | from trapper.data.data_processors.squad.question_answering_processor import ( 2 | SquadQuestionAnsweringDataProcessor, 3 | ) 4 | from trapper.data.data_processors.squad.squad_processor import SquadDataProcessor 5 | -------------------------------------------------------------------------------- /trapper/data/data_processors/squad/question_answering_processor.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Any, Dict 3 | 4 | from trapper.common.constants import SpanTuple 5 | from trapper.common.utils import convert_spandict_to_spantuple 6 | from trapper.data.data_processors import DataProcessor 7 | from trapper.data.data_processors.data_processor import ( 8 | ImproperDataInstanceError, 9 | IndexedInstance, 10 | ) 11 | from trapper.data.data_processors.squad.squad_processor import SquadDataProcessor 12 | 13 | logger = logging.getLogger(__file__) 14 | 15 | 16 | @DataProcessor.register("squad-question-answering") 17 | class SquadQuestionAnsweringDataProcessor(SquadDataProcessor): 18 | NUM_EXTRA_SPECIAL_TOKENS_IN_SEQUENCE = 3 # context question 19 | MAX_SEQUENCE_LEN = 512 20 | 21 | def process(self, instance_dict: Dict[str, Any]) -> IndexedInstance: 22 | qa_id = instance_dict["id"] 23 | context = instance_dict["context"] 24 | question = instance_dict["question"] 25 | if self._is_input_too_long(context, question): 26 | return self.filtered_instance() 27 | # Rename SQuAD answer_start as start for trapper tuple conversion. 28 | answers = instance_dict["answers"] 29 | first_answer = convert_spandict_to_spantuple( 30 | {"start": answers["answer_start"][0], "text": answers["text"][0]} 31 | ) 32 | try: 33 | return self.text_to_instance( 34 | context=context, 35 | question=question, 36 | id_=qa_id, 37 | answer=first_answer, 38 | ) 39 | except ImproperDataInstanceError: 40 | return self.filtered_instance() 41 | 42 | @staticmethod 43 | def filtered_instance() -> IndexedInstance: 44 | return { 45 | "answer": [-1], 46 | "answer_position_tokenized": {"start": -1, "end": -1}, 47 | "context": [-1], 48 | "qa_id": "", 49 | "question": [-1], 50 | "__discard_sample": True, 51 | } 52 | 53 | def text_to_instance( 54 | self, context: str, question: str, id_: str, answer: SpanTuple = None 55 | ) -> IndexedInstance: 56 | tokenized_context = self._tokenizer.tokenize(context) 57 | tokenized_question = self._tokenizer.tokenize(question) 58 | self._chop_excess_context_tokens(tokenized_context, tokenized_question) 59 | 60 | instance = { 61 | "context": self._tokenizer.convert_tokens_to_ids(tokenized_context), 62 | "question": self._tokenizer.convert_tokens_to_ids(tokenized_question), 63 | } 64 | 65 | if answer is not None: 66 | answer = self._join_whitespace_prefix(context, answer) 67 | indexed_answer = self._indexed_field( 68 | context, instance["context"], field=answer, field_type="answer" 69 | ) 70 | instance.update(indexed_answer) 71 | 72 | instance["qa_id"] = id_ 73 | return instance 74 | 75 | def _is_input_too_long(self, context: str, question: str) -> bool: 76 | context_tokens = self.tokenizer.tokenize(context) 77 | question_tokens = self.tokenizer.tokenize(question) 78 | return ( 79 | len(context_tokens) 80 | + len(question_tokens) 81 | + self.NUM_EXTRA_SPECIAL_TOKENS_IN_SEQUENCE 82 | > self.MAX_SEQUENCE_LEN 83 | ) 84 | -------------------------------------------------------------------------------- /trapper/data/data_processors/squad/squad_processor.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta 2 | from typing import Dict, List, Union 3 | 4 | from trapper.common.constants import PositionDict, PositionTuple, SpanTuple 5 | from trapper.data.data_processors.data_processor import ( 6 | DataProcessor, 7 | ImproperDataInstanceError, 8 | ) 9 | 10 | 11 | class SquadDataProcessor(DataProcessor, metaclass=ABCMeta): 12 | """ 13 | Provides utility methods that can be used in SQuAD style tasks involving a 14 | context and information fields inside the context e.g. answers in the case of 15 | question answering. 16 | """ 17 | 18 | @staticmethod 19 | def _join_whitespace_prefix(context: str, field: SpanTuple) -> SpanTuple: 20 | """ 21 | Prepend the whitespace prefix if it exists. Some tokenizers like 22 | `roberta` and `gpt2` does not ignore the whitespaces, which leading to 23 | the difference in the token ids between directly tokenizing the field or 24 | tokenizing it with the context. 25 | Args: 26 | context (str): context e.g. paragraph or document. 27 | field (SpanTuple): a special segment of the context such as 28 | `answer` or `clue` in the Question Generation task. 29 | 30 | Returns: 31 | """ 32 | start = getattr(field, "start", -1) 33 | if start != -1 and context[start - 1] == " ": 34 | return SpanTuple(text=" " + field.text, start=start - 1) 35 | return field 36 | 37 | def _tokenized_field_position( 38 | self, 39 | context: str, 40 | context_token_ids: List[int], 41 | field_token_ids: List[int], 42 | field_start_ind: int, 43 | ) -> PositionTuple: 44 | tokenized_prefix = self._tokenizer.tokenize(context[:field_start_ind]) 45 | prefix_ids = self._tokenizer.convert_tokens_to_ids(tokenized_prefix) 46 | return self._get_position(context_token_ids, field_token_ids, prefix_ids) 47 | 48 | @staticmethod 49 | def _get_position(context_ids, field_ids, field_prefix_ids) -> PositionTuple: 50 | """Returns the start and end indices of the field (e.g. answer or 51 | clue) span in the paragraph (context)""" 52 | diff_ind = min(len(field_prefix_ids), len(context_ids)) 53 | # Find the first token where the context_ids and field_prefix ids differ 54 | for i, (context_id, field_prefix_id) in enumerate( 55 | zip(context_ids, field_prefix_ids) 56 | ): 57 | if context_id != field_prefix_id: 58 | diff_ind = i 59 | break 60 | return PositionTuple( 61 | diff_ind, min(diff_ind + len(field_ids), len(context_ids)) 62 | ) 63 | 64 | def _indexed_field( 65 | self, 66 | context: str, 67 | context_token_ids: List[int], 68 | field: SpanTuple, 69 | field_type: str, 70 | ) -> Dict[str, Union[List[int], PositionDict]]: 71 | field_tokens = self._tokenizer.tokenize(field.text) 72 | field_token_ids = self._tokenizer.convert_tokens_to_ids(field_tokens) 73 | indexed_field = {field_type: field_token_ids} 74 | if field.start is not None: 75 | field_position = self._tokenized_field_position( 76 | context, context_token_ids, field_token_ids, field.start 77 | ) 78 | if field_position.start + len(field_tokens) > len(context_token_ids): 79 | raise ImproperDataInstanceError( 80 | f"Indexed {field_type} position is out of the bound. Check the " 81 | f"input field lengths!" 82 | ) 83 | indexed_field[ 84 | f"{field_type}_position_tokenized" 85 | ] = field_position.to_dict() 86 | return indexed_field 87 | 88 | def _chop_excess_context_tokens( 89 | self, tokenized_context: List, *other_tokenized_subsequences: List 90 | ): 91 | subsequences = [tokenized_context, *other_tokenized_subsequences] 92 | seq_len = self._total_seq_len(*subsequences) 93 | if seq_len > self.model_max_sequence_length: 94 | self._chop_excess_tokens(tokenized_context, seq_len) 95 | seq_len = self._total_seq_len(*subsequences) 96 | assert seq_len <= self.model_max_sequence_length 97 | -------------------------------------------------------------------------------- /trapper/data/dataset_loader.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from pathlib import Path 3 | from typing import Union 4 | 5 | import datasets 6 | 7 | from trapper.common import Registrable 8 | from trapper.data import DataAdapter 9 | from trapper.data.data_processors.data_processor import DataProcessor 10 | from trapper.data.dataset_reader import DatasetReader 11 | 12 | logger = logging.getLogger(__file__) 13 | 14 | 15 | class DatasetLoader(Registrable): 16 | """ 17 | This class is responsible for reading and pre-processing a dataset. This 18 | involves reading the raw data instances, extracting the task-related fields, 19 | tokenizing the instances, converting the fields into a format accepted by the 20 | transformer models as well as taking care of the special tokens. All these 21 | tasks are performed sequentially by three components in a pipelined manner. 22 | 23 | Args: 24 | dataset_reader (): Reads the raw dataset. 25 | data_processor (): Handles pre-processing i.e. tokenization and adding 26 | the special tokens. 27 | data_adapter (): Converts the instance into a `IndexedInstance` suitable 28 | for directly feeding to the models. 29 | """ 30 | 31 | default_implementation = "default" 32 | 33 | def __init__( 34 | self, 35 | dataset_reader: DatasetReader, 36 | data_processor: DataProcessor, 37 | data_adapter: DataAdapter, 38 | ): 39 | self._dataset_reader = dataset_reader 40 | self._data_processor = data_processor 41 | self._data_adapter = data_adapter 42 | 43 | @property 44 | def dataset_reader(self): 45 | return self._dataset_reader 46 | 47 | @dataset_reader.setter 48 | def dataset_reader(self, value: DatasetReader): 49 | if isinstance(value, DatasetReader): 50 | self._dataset_reader = value 51 | else: 52 | raise ValueError(f"The value must be an instance of a {DatasetReader}") 53 | 54 | @property 55 | def data_processor(self): 56 | return self._data_processor 57 | 58 | @data_processor.setter 59 | def data_processor(self, value: DataProcessor): 60 | if isinstance(value, DataProcessor): 61 | self._data_processor = value 62 | else: 63 | raise ValueError( 64 | f"The value must be an instance of a " 65 | f"class derived from {DataProcessor}" 66 | ) 67 | 68 | @property 69 | def data_adapter(self): 70 | return self._data_adapter 71 | 72 | @data_adapter.setter 73 | def data_adapter(self, value: DataAdapter): 74 | if isinstance(value, DataAdapter): 75 | self._data_adapter = value 76 | else: 77 | raise ValueError( 78 | f"The value must be an instance of a " 79 | f"class derived from {DataAdapter}" 80 | ) 81 | 82 | def load(self, split_name: Union[Path, str]) -> datasets.Dataset: 83 | """ 84 | Reads the specified split of the dataset, process the instances and 85 | covert them into a format suitable for feeding to a model. 86 | 87 | Args: 88 | split_name (): one of "train", "validation" or "test" 89 | 90 | Returns: 91 | a processed split from the dataset, which can be passed to 92 | `TransformerTrainer` 93 | """ 94 | raw_data = self.dataset_reader.read(split_name) 95 | return ( 96 | raw_data.map(self.data_processor) 97 | .filter(lambda x: not x["__discard_sample"]) 98 | .remove_columns("__discard_sample") 99 | .map(self.data_adapter) 100 | ) 101 | 102 | 103 | DatasetLoader.register("default")(DatasetLoader) 104 | -------------------------------------------------------------------------------- /trapper/data/label_mapper.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from copy import copy 4 | from typing import Dict, Optional, Sequence, Tuple 5 | 6 | from trapper.common import Registrable 7 | 8 | 9 | class LabelMapper(Registrable): 10 | """ 11 | Used in tasks that require mapping between categorical labels and integer 12 | ids such as token classification. The list of labels can be provided in the 13 | constructor as well as the class variable `_LABELS`. The reason for the 14 | latter is to document the labels and enable writing the labels only once for 15 | some tasks that have a large number of labels, both of which aim to 16 | reduce the possibility of errors while specifying the labels. 17 | 18 | Optional class variables: 19 | 20 | _LABELS (Tuple[str]): An optional class variable that may be used for setting 21 | the list of labels in `LabelMapper.from_labels` method. Note 22 | that the order you provide will be preserved while assigning ids to the 23 | labels. 24 | _IGNORED_LABELS (Tuple[str]): An optional class variable that may be used for 25 | setting the list of ignored labels. 26 | 27 | Args: 28 | label_to_id_map (): The mapping from string labels to integer ids 29 | ignored_labels (): The list of labels to be stored in `ignored_labels` 30 | attribute. If left as None, the class variable `_IGNORED_LABELS` will 31 | be used instead. 32 | """ 33 | 34 | default_implementation = "from_labels" 35 | _LABELS: Tuple[str] = None 36 | _IGNORED_LABELS: Tuple[str] = () 37 | 38 | def __init__( 39 | self, 40 | label_to_id_map: Optional[Dict[str, int]] = None, 41 | ignored_labels: Optional[Sequence[str]] = None, 42 | ): 43 | # We need to make `label_to_id_map` optional with default of None, 44 | # since otherwise allennlp tries to invoke __init__ although we register 45 | # a classmethod as a default constructor and demand it via the "type" 46 | # parameter inside the from_params method or a config file. 47 | if label_to_id_map is None: 48 | raise ValueError("`label_to_id_map` can not be None!") 49 | 50 | if ignored_labels is None: 51 | ignored_labels = self._IGNORED_LABELS 52 | self._ignored_labels = tuple(ignored_labels) 53 | 54 | self._label_to_id_map = label_to_id_map 55 | self._id_to_label = { 56 | id_: label for label, id_ in self._label_to_id_map.items() 57 | } 58 | 59 | @classmethod 60 | def from_labels( 61 | cls, 62 | labels: Optional[Sequence[str]] = None, 63 | start_id: Optional[int] = 0, 64 | ignored_labels: Optional[Sequence[str]] = None, 65 | ) -> LabelMapper: 66 | """ 67 | Create a `LabelMapper` from a list of labels. The indices will be 68 | the enumeration of labels starting from `start_ind`. 69 | 70 | Args: 71 | labels (): The list of labels. Note that the order you provide will 72 | be preserved while assigning ids to the labels. If left as None, 73 | the class variable `_LABELS` will be used instead. 74 | start_id (): The start value for enumeration of label ids. By 75 | default, we start from 0 and increment. 76 | ignored_labels (): The list of labels to be stored in `ignored_labels` 77 | attribute. If left as None, the class variable `_IGNORED_LABELS` 78 | will be used instead. 79 | """ 80 | if labels is None: 81 | labels = cls._LABELS 82 | ids = tuple(range(start_id, start_id + len(labels))) 83 | label_to_id = {label: id_ for label, id_ in zip(labels, ids)} 84 | if ignored_labels is None: 85 | ignored_labels = cls._IGNORED_LABELS 86 | return cls(label_to_id, ignored_labels=ignored_labels) 87 | 88 | @property 89 | def labels(self) -> Tuple[str]: 90 | return tuple(self._label_to_id_map.keys()) 91 | 92 | @property 93 | def ids(self) -> Tuple[int]: 94 | return tuple(self._label_to_id_map.values()) 95 | 96 | @property 97 | def label_to_id_map(self) -> Dict[str, int]: 98 | return copy(self._label_to_id_map) 99 | 100 | @property 101 | def id_to_label_map(self) -> Dict[int, str]: 102 | """This method may be used by the pipeline object for inference.""" 103 | return copy(self._id_to_label) 104 | 105 | def get_id(self, label: str) -> int: 106 | return self._label_to_id_map[label] 107 | 108 | def get_label(self, id_: int) -> str: 109 | return self._id_to_label[id_] 110 | 111 | @property 112 | def ignored_labels(self): 113 | return self._ignored_labels 114 | 115 | 116 | LabelMapper.register("from_label_to_id_map")(LabelMapper) 117 | LabelMapper.register("from_labels", constructor="from_labels")(LabelMapper) 118 | -------------------------------------------------------------------------------- /trapper/data/tokenizers/__init__.py: -------------------------------------------------------------------------------- 1 | from trapper.data.tokenizers.squad import QuestionAnsweringTokenizerWrapper 2 | from trapper.data.tokenizers.tokenizer_wrapper import TokenizerWrapper 3 | -------------------------------------------------------------------------------- /trapper/data/tokenizers/squad.py: -------------------------------------------------------------------------------- 1 | from trapper.common.constants import CONTEXT_TOKEN 2 | from trapper.data.tokenizers.tokenizer_wrapper import TokenizerWrapper 3 | 4 | 5 | @TokenizerWrapper.register("question-answering", constructor="from_pretrained") 6 | class QuestionAnsweringTokenizerWrapper(TokenizerWrapper): 7 | """ 8 | This tokenizer can be used in SQuAD style question answering tasks that 9 | utilizes a context, question and answer. 10 | """ 11 | 12 | _TASK_SPECIFIC_SPECIAL_TOKENS = [CONTEXT_TOKEN] 13 | -------------------------------------------------------------------------------- /trapper/data/tokenizers/tokenizer_wrapper.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import os 4 | from copy import deepcopy 5 | from typing import Dict, List, Optional, Tuple, Union 6 | 7 | from transformers import AutoTokenizer, PreTrainedTokenizerBase 8 | 9 | from trapper.common import Registrable 10 | from trapper.common.constants import ( 11 | BOS_TOKEN, 12 | CLS_TOKEN, 13 | EOS_TOKEN, 14 | MASK_TOKEN, 15 | PAD_TOKEN, 16 | SEP_TOKEN, 17 | UNK_TOKEN, 18 | ) 19 | 20 | 21 | class TokenizerWrapper(Registrable): 22 | """ 23 | The base tokenizer class for trapper that acts as a factory which returns a 24 | `PreTrainedTokenizerBase` instance after adding the task-specific tokens to it. 25 | Internally, it uses `transformers.AutoTokenizer` for creating the pretrained 26 | tokenizer objects. In addition to the tokenizer, the wrapper object also holds 27 | the maximum sequence length accepted by the model of that tokenizer. This class 28 | also handles the differences between the special start and end of sequence 29 | tokens in different models. By utilizing a tokenizer wrapped this class, you can 30 | use `tokenizer.bos_token` to access the start token without thinking which model 31 | you are working with. Otherwise, you would have to use `tokenizer.cls_token` 32 | when you are working with BERT, whereas `tokenizer.bos_token` if you are working 33 | with GPT2 for example. We fill the missing value from the (cls_token, bos_token) 34 | and (eos_token, sep_token) token pairs by saving the other's value if the 35 | pretrained tokenizer does not have only one of them. If neither were present, 36 | they get recorded with separate values. For instance, sep_token is saved with 37 | the value of eos_token in the GPT2 tokenizer since it has only eos_token 38 | normally. This is done to make the BOS-CLS and EOS-SEP tokens interchangeable. 39 | Finally, pad_token, mask_token and unk_token values are also set if they 40 | were not already present. 41 | 42 | You may need to override the `_TASK_SPECIFIC_SPECIAL_TOKENS` class variable to 43 | specify the extra special tokens needed for your task. 44 | 45 | Class variables that can be overridden: 46 | 47 | _TASK_SPECIFIC_SPECIAL_TOKENS (List[str]): A list of extra special tokens that 48 | is needed for the task at hand. E.g. `CONTEXT` token for SQuAD style question 49 | answering tasks that utilizes a context. You can look at 50 | `QuestionAnsweringTokenizerWrapper` for that example. 51 | 52 | Args: 53 | pretrained_tokenizer (): The pretrained tokenizer to be wrapped 54 | """ 55 | 56 | default_implementation = "from_pretrained" 57 | _BOS_TOKEN_KEYS = ("bos_token", "cls_token") 58 | _EOS_TOKEN_KEYS = ("eos_token", "sep_token") 59 | _SPECIAL_TOKENS_DICT: Dict[str, str] = { 60 | "bos_token": BOS_TOKEN, 61 | "eos_token": EOS_TOKEN, 62 | "cls_token": CLS_TOKEN, 63 | "sep_token": SEP_TOKEN, 64 | "pad_token": PAD_TOKEN, 65 | "mask_token": MASK_TOKEN, 66 | "unk_token": UNK_TOKEN, 67 | } 68 | _TASK_SPECIFIC_SPECIAL_TOKENS: List[str] = [] 69 | 70 | def __init__( 71 | self, pretrained_tokenizer: Optional[PreTrainedTokenizerBase] = None 72 | ): 73 | # We need to make `pretrained_tokenizer` optional with default of None, 74 | # since otherwise allennlp tries to invoke __init__ although we 75 | # register a classmethod as a default constructor and demand it via the 76 | # "type" parameter inside the from_params method or a config file. 77 | if pretrained_tokenizer is None: 78 | raise ValueError("`pretrained_tokenizer` can not be None!") 79 | self._pretrained_tokenizer = pretrained_tokenizer 80 | self._num_added_special_tokens = self._add_task_specific_tokens() 81 | 82 | @property 83 | def tokenizer(self) -> PreTrainedTokenizerBase: 84 | return self._pretrained_tokenizer 85 | 86 | @property 87 | def num_added_special_tokens(self) -> int: 88 | return self._num_added_special_tokens 89 | 90 | @classmethod 91 | def from_pretrained( 92 | cls, 93 | pretrained_model_name_or_path: Union[str, os.PathLike], 94 | *inputs, 95 | **kwargs, 96 | ) -> TokenizerWrapper: 97 | pretrained_tokenizer = AutoTokenizer.from_pretrained( 98 | pretrained_model_name_or_path, *inputs, **kwargs 99 | ) 100 | return cls(pretrained_tokenizer) 101 | 102 | def _add_task_specific_tokens(self) -> int: 103 | tokenizer = self._pretrained_tokenizer 104 | special_tokens_dict = { 105 | "additional_special_tokens": deepcopy( 106 | self._TASK_SPECIFIC_SPECIAL_TOKENS 107 | ) 108 | } 109 | for tok_name, tok_value in self._SPECIAL_TOKENS_DICT.items(): 110 | if getattr(tokenizer, tok_name) is None: 111 | for alternative_pair in ( 112 | self._BOS_TOKEN_KEYS, 113 | self._EOS_TOKEN_KEYS, 114 | ): 115 | if tok_name in alternative_pair: 116 | tok_value = self._find_alternative_token_value( 117 | tok_name, tok_value, alternative_pair 118 | ) 119 | break 120 | special_tokens_dict[tok_name] = tok_value 121 | num_added_special_tokens = tokenizer.add_special_tokens(special_tokens_dict) 122 | return num_added_special_tokens 123 | 124 | def _find_alternative_token_value( 125 | self, 126 | original_token_name: str, 127 | original_token_value: str, 128 | alternative_pair: Tuple[str, str], 129 | ) -> str: 130 | if original_token_name == alternative_pair[0]: 131 | alternative_token_name = alternative_pair[1] 132 | else: 133 | alternative_token_name = alternative_pair[0] 134 | 135 | alternative_token_value = getattr( 136 | self._pretrained_tokenizer, alternative_token_name 137 | ) 138 | return alternative_token_value or original_token_value 139 | 140 | 141 | TokenizerWrapper.register("from_pretrained", constructor="from_pretrained")( 142 | TokenizerWrapper 143 | ) 144 | -------------------------------------------------------------------------------- /trapper/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from trapper.metrics.input_handlers import MetricInputHandler 2 | from trapper.metrics.jury import JuryMetric 3 | from trapper.metrics.metric import Metric 4 | from trapper.metrics.output_handlers import MetricOutputHandler 5 | -------------------------------------------------------------------------------- /trapper/metrics/input_handlers/__init__.py: -------------------------------------------------------------------------------- 1 | from trapper.metrics.input_handlers.input_handler import MetricInputHandler 2 | from trapper.metrics.input_handlers.language_generation_input_handler import ( 3 | MetricInputHandlerForLanguageGeneration, 4 | ) 5 | from trapper.metrics.input_handlers.question_answering_input_handler import ( 6 | MetricInputHandlerForQuestionAnswering, 7 | ) 8 | from trapper.metrics.input_handlers.token_classification_input_handler import ( 9 | MetricInputHandlerForTokenClassification, 10 | ) 11 | -------------------------------------------------------------------------------- /trapper/metrics/input_handlers/input_handler.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import datasets 4 | from transformers import EvalPrediction 5 | 6 | from trapper.common import Registrable 7 | from trapper.data import IndexedInstance 8 | 9 | logger = logging.getLogger(__file__) 10 | 11 | 12 | class MetricInputHandler(Registrable): 13 | """ 14 | This callable class is responsible for processing evaluation output 15 | :py:class:`transformers.EvalPrediction` used in 16 | :py:class:`trapper.training.TransformerTrainer`. It is used to convert to 17 | suitable evaluation format for the specified metrics before metric computation. 18 | If your task needs additional information for the conversion, then override 19 | `self._extract_metadata()`. See 20 | `MetricInputHandlerForQuestionAnswering` for an example. 21 | """ 22 | 23 | default_implementation = "default" 24 | 25 | def extract_metadata(self, dataset: datasets.Dataset) -> None: 26 | """ 27 | This method applies `self._extract_metadata()` to each instance of the dataset. 28 | Do not override this method in child class, instead 29 | override `self._extract_metadata()`. 30 | 31 | Note: 32 | This method is only called once in trainer for each dataset. By default, 33 | only eval_dataset is called. 34 | 35 | Args: 36 | dataset: datasets.Dataset object 37 | 38 | Returns: None 39 | """ 40 | if self._extract_metadata(dataset[0]) is not None: 41 | raise TypeError( 42 | "`_extract_metadata` method is designed to be read-only, and hence must return None." 43 | ) 44 | dataset.map(self._extract_metadata) 45 | 46 | def _extract_metadata(self, instance: IndexedInstance) -> None: 47 | """ 48 | Child class may implement this method for metadata extraction from an instance. 49 | It is designed to be read-only, i.e do not manipulate instance in any 50 | way. You can store the additional content as attributes or class variables to use 51 | them later in `__call__()`. It should return None for efficiency purposes. 52 | 53 | Args: 54 | instance: Current instance processed 55 | 56 | Returns: None 57 | """ 58 | return None 59 | 60 | def preprocess(self, eval_pred: EvalPrediction) -> EvalPrediction: 61 | processsed_predictions = eval_pred.predictions.argmax(-1) 62 | processed_label_ids = eval_pred.label_ids 63 | return EvalPrediction( 64 | predictions=processsed_predictions, label_ids=processed_label_ids 65 | ) 66 | 67 | def __call__( 68 | self, 69 | eval_pred: EvalPrediction, 70 | ) -> EvalPrediction: 71 | """ 72 | This method is called before metric computation, the default behavior is set 73 | in this method as returning predictions and label_ids unchanged except 74 | `argmax()` is applied to predictions. However, this behaviour is likely to be 75 | changed in some tasks, such as question-answering, etc. 76 | 77 | Args: 78 | eval_pred: EvalPrediction object returned by model. 79 | 80 | Returns: Processed EvalPrediction. 81 | """ 82 | return self.preprocess(eval_pred) 83 | 84 | 85 | MetricInputHandler.register("default")(MetricInputHandler) 86 | -------------------------------------------------------------------------------- /trapper/metrics/input_handlers/language_generation_input_handler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from transformers import EvalPrediction 3 | 4 | from trapper.data.tokenizers import TokenizerWrapper 5 | from trapper.metrics.input_handlers import MetricInputHandler 6 | 7 | 8 | @MetricInputHandler.register("language-generation") 9 | class MetricInputHandlerForLanguageGeneration(MetricInputHandler): 10 | """ 11 | `MetricInputHandlerForLanguageGeneration` provides the conversion from token ids 12 | to decoded strings for predictions and labels and prepares them for the metric 13 | computation. 14 | 15 | Args: 16 | tokenizer_wrapper (): Required to convert token ids to strings. 17 | """ 18 | 19 | _contexts = list() 20 | 21 | def __init__( 22 | self, 23 | tokenizer_wrapper: TokenizerWrapper, 24 | ): 25 | super(MetricInputHandlerForLanguageGeneration, self).__init__() 26 | self._tokenizer_wrapper = tokenizer_wrapper 27 | 28 | @property 29 | def tokenizer(self): 30 | return self._tokenizer_wrapper.tokenizer 31 | 32 | def preprocess(self, eval_pred: EvalPrediction) -> EvalPrediction: 33 | if isinstance(eval_pred.predictions, tuple): 34 | eval_pred = EvalPrediction( 35 | # Models like T5 returns a tuple of 36 | # (logits, encoder_last_hidden_state) instead of only the logits 37 | predictions=eval_pred.predictions[0], 38 | label_ids=eval_pred.label_ids, 39 | ) 40 | eval_pred = super(MetricInputHandlerForLanguageGeneration, self).preprocess( 41 | eval_pred 42 | ) 43 | 44 | # https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/examples/pytorch/translation/run_translation.py#L540 45 | references = np.where( 46 | eval_pred.label_ids != -100, 47 | eval_pred.label_ids, 48 | self.tokenizer.pad_token_id, 49 | ) 50 | 51 | # Batch decode is intentionally avoided as jury metrics expect 52 | # list of list of string for language-generation metrics. 53 | predictions = np.array( 54 | [ 55 | [self.tokenizer.decode(pred, skip_special_tokens=True)] 56 | for pred in eval_pred.predictions 57 | ] 58 | ) 59 | references = np.array( 60 | [ 61 | [self.tokenizer.decode(ref, skip_special_tokens=True)] 62 | for ref in references 63 | ] 64 | ) 65 | 66 | return EvalPrediction(predictions=predictions, label_ids=references) 67 | -------------------------------------------------------------------------------- /trapper/metrics/input_handlers/question_answering_input_handler.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple 2 | 3 | import numpy as np 4 | from transformers import EvalPrediction 5 | 6 | from trapper.data import IndexedInstance 7 | from trapper.data.tokenizers import TokenizerWrapper 8 | from trapper.metrics.input_handlers import MetricInputHandler 9 | 10 | 11 | @MetricInputHandler.register("question-answering") 12 | class MetricInputHandlerForQuestionAnswering(MetricInputHandler): 13 | """ 14 | `MetricInputHandlerForQuestionAnswering` provides the conversion of predictions 15 | and labels which are the beginning and the end indices to actual answers 16 | extracted from the context. Since this conversion also requires context, this 17 | class also overrides `_extract_metadata()` to store context information from 18 | dataset instances. 19 | 20 | Args: 21 | tokenizer_wrapper (): Required to convert token ids to strings. 22 | """ 23 | 24 | _contexts = list() 25 | 26 | def __init__( 27 | self, 28 | tokenizer_wrapper: TokenizerWrapper, 29 | ): 30 | super(MetricInputHandlerForQuestionAnswering, self).__init__() 31 | self._tokenizer_wrapper = tokenizer_wrapper 32 | 33 | @property 34 | def tokenizer(self): 35 | return self._tokenizer_wrapper.tokenizer 36 | 37 | def _extract_metadata(self, instance: IndexedInstance) -> None: 38 | context = instance["context"] 39 | self._contexts.append(context) 40 | 41 | def _decode_answer(self, context: List[int], start, end) -> str: 42 | answer = context[start - 1 : end - 1] 43 | return self.tokenizer.decode(answer).lstrip() 44 | 45 | def preprocess(self, eval_pred: EvalPrediction) -> EvalPrediction: 46 | predictions, references = eval_pred.predictions, eval_pred.label_ids 47 | predicted_starts, predicted_ends = predictions[0].argmax(-1), predictions[ 48 | 1 49 | ].argmax(-1) 50 | references_starts, references_ends = references[0], references[1] 51 | n_samples = predictions[0].shape[0] 52 | 53 | predicted_answers = [] 54 | reference_answers = [] 55 | 56 | for i in range(n_samples): 57 | context = self._contexts[i] 58 | predicted_answer = self._decode_answer( 59 | context, predicted_starts[i], predicted_ends[i] 60 | ) 61 | reference_answer = self._decode_answer( 62 | context, references_starts[i], references_ends[i] 63 | ) 64 | predicted_answers.append(predicted_answer) 65 | reference_answers.append(reference_answer) 66 | 67 | predictions = np.array(predicted_answers) 68 | references = np.array(reference_answers) 69 | processed_eval_pred = EvalPrediction( 70 | predictions=predictions, label_ids=references 71 | ) 72 | return processed_eval_pred 73 | -------------------------------------------------------------------------------- /trapper/metrics/input_handlers/token_classification_input_handler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from transformers import EvalPrediction 3 | 4 | from trapper.common.constants import IGNORED_LABEL_ID 5 | from trapper.data import LabelMapper 6 | from trapper.metrics.input_handlers import MetricInputHandler 7 | 8 | 9 | @MetricInputHandler.register("token-classification") 10 | class MetricInputHandlerForTokenClassification(MetricInputHandler): 11 | """ 12 | `MetricInputHandlerForTokenClassification` provides the conversion of predictions 13 | and labels from ids to labels by using a `LabelMapper`. 14 | 15 | Args: 16 | label_mapper (): Required to convert ids to matching labels. 17 | """ 18 | 19 | def __init__( 20 | self, 21 | label_mapper: LabelMapper, 22 | ): 23 | super(MetricInputHandlerForTokenClassification, self).__init__() 24 | self._label_mapper = label_mapper 25 | 26 | @property 27 | def label_mapper(self): 28 | return self._label_mapper 29 | 30 | def _id_to_label(self, id_: int) -> str: 31 | return self.label_mapper.get_label(id_) 32 | 33 | def preprocess(self, eval_pred: EvalPrediction) -> EvalPrediction: 34 | predictions, references = eval_pred.predictions, eval_pred.label_ids 35 | all_predicted_ids = np.argmax(predictions, axis=2) 36 | all_label_ids = references 37 | actual_predictions = [] 38 | actual_labels = [] 39 | for predicted_ids, label_ids in zip(all_predicted_ids, all_label_ids): 40 | actual_prediction = [] 41 | actual_label = [] 42 | for (p, l) in zip(predicted_ids, label_ids): 43 | if l != IGNORED_LABEL_ID: 44 | actual_prediction.append(self._id_to_label(p)) 45 | actual_label.append(self._id_to_label(l)) 46 | 47 | actual_predictions.append(actual_prediction) 48 | actual_labels.append(actual_label) 49 | 50 | processed_eval_pred = EvalPrediction( 51 | predictions=np.array(actual_predictions), 52 | label_ids=np.array(actual_labels), 53 | ) 54 | return processed_eval_pred 55 | -------------------------------------------------------------------------------- /trapper/metrics/jury.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Optional, Union 2 | 3 | import jury 4 | from allennlp.common import Params 5 | from transformers import EvalPrediction 6 | 7 | from trapper.metrics.metric import Metric, MetricParam 8 | 9 | 10 | @Metric.register("default") 11 | class JuryMetric(Metric): 12 | def __init__( 13 | self, 14 | metric_params: Union[MetricParam, List[MetricParam]], 15 | json_normalize: Optional[bool] = True, 16 | **kwargs, 17 | ): 18 | super().__init__(**kwargs) 19 | self._metric_params = self._convert_metric_params_to_dict(metric_params) 20 | self.json_normalize = json_normalize 21 | 22 | @property 23 | def metric_params(self): 24 | return self._metric_params 25 | 26 | def __call__(self, eval_pred: EvalPrediction) -> Dict[str, Any]: 27 | if self._metric_params is None: 28 | return {} 29 | jury_scorer = jury.Jury(self._metric_params, run_concurrent=False) 30 | 31 | processed_eval_pred = self.input_handler(eval_pred) 32 | 33 | score = jury_scorer( 34 | predictions=processed_eval_pred.predictions.tolist(), 35 | references=processed_eval_pred.label_ids.tolist(), 36 | ) 37 | score = self.output_handler(score) 38 | 39 | if self.json_normalize: 40 | return self.normalize(score) 41 | 42 | return score 43 | 44 | def _convert_metric_params_to_dict( 45 | self, metric_params: Union[MetricParam, List[MetricParam]] 46 | ) -> Dict: 47 | converted_metric_params = metric_params 48 | if isinstance(metric_params, Params): 49 | converted_metric_params = metric_params.params 50 | elif isinstance(metric_params, list): 51 | converted_metric_params = [] 52 | for param in metric_params: 53 | if isinstance(param, Params): 54 | metric_param = param.params 55 | else: 56 | metric_param = param 57 | converted_metric_params.append(metric_param) 58 | return converted_metric_params 59 | 60 | @staticmethod 61 | def normalize(score: Dict) -> Dict: 62 | extended_results = {} 63 | for key, value in score.items(): 64 | if isinstance(value, dict): 65 | for name, val in value.items(): 66 | extended_results[f"{key}_{name}"] = val 67 | else: 68 | extended_results[key] = value 69 | return extended_results 70 | -------------------------------------------------------------------------------- /trapper/metrics/metric.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta, abstractmethod 2 | from typing import Any, Dict, Optional, Union 3 | 4 | from transformers import EvalPrediction 5 | 6 | from trapper.common import Registrable 7 | from trapper.metrics.input_handlers.input_handler import MetricInputHandler 8 | from trapper.metrics.output_handlers import MetricOutputHandler 9 | 10 | MetricParam = Union[str, Dict[str, Any]] 11 | 12 | 13 | class Metric(Registrable, metaclass=ABCMeta): 14 | """ 15 | Base `Registrable` class that is used to register the metrics needed for 16 | evaluating the models. The subclasses should be implemented as callables 17 | that accepts a `transformers.EvalPrediction` in their `__call__` method and 18 | compute score for that prediction. 19 | 20 | Args: 21 | input_handler (): 22 | """ 23 | 24 | default_implementation = "default" 25 | 26 | def __init__( 27 | self, 28 | input_handler: Optional[MetricInputHandler] = None, 29 | output_handler: Optional[MetricOutputHandler] = None, 30 | ): 31 | self._input_handler = input_handler or MetricInputHandler() 32 | self._output_handler = output_handler or MetricOutputHandler() 33 | 34 | @property 35 | def input_handler(self): 36 | return self._input_handler 37 | 38 | @property 39 | def output_handler(self): 40 | return self._output_handler 41 | 42 | @abstractmethod 43 | def __call__(self, eval_pred: EvalPrediction) -> Dict[str, Any]: 44 | pass 45 | -------------------------------------------------------------------------------- /trapper/metrics/output_handlers/__init__.py: -------------------------------------------------------------------------------- 1 | from trapper.metrics.output_handlers.output_handler import MetricOutputHandler 2 | from trapper.metrics.output_handlers.token_classification_output_handler import ( 3 | MetricOutputHandlerForTokenClassification, 4 | ) 5 | -------------------------------------------------------------------------------- /trapper/metrics/output_handlers/output_handler.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | from trapper.common import Registrable 4 | 5 | 6 | class MetricOutputHandler(Registrable): 7 | """ 8 | This callable class is responsible for manipulating the resulting score object 9 | returned after metric computation. This base class reflects the default 10 | behavior as returning to result as is. See 11 | `MetricOutputHandlerForTokenClassification` for an example. 12 | """ 13 | 14 | default_implementation = "default" 15 | 16 | def __call__(self, score: Dict) -> Dict: 17 | """ 18 | This method is called after metric computation, the default behavior is set 19 | in this method as directly returning the score as is. Intended behavior of 20 | this method is to provide an interface to a user to manipulate the score object. 21 | 22 | Args: 23 | score: Output of metric computation by `Metric`. 24 | 25 | Returns: Post-processed score 26 | """ 27 | return score 28 | 29 | 30 | MetricOutputHandler.register("default")(MetricOutputHandler) 31 | -------------------------------------------------------------------------------- /trapper/metrics/output_handlers/token_classification_output_handler.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | from trapper.metrics.output_handlers.output_handler import MetricOutputHandler 4 | 5 | 6 | @MetricOutputHandler.register("token-classification") 7 | class MetricOutputHandlerForTokenClassification(MetricOutputHandler): 8 | def __init__(self, overall_only: bool = False): 9 | super(MetricOutputHandlerForTokenClassification, self).__init__() 10 | self.overall_only = overall_only 11 | 12 | def __call__(self, score: Dict) -> Dict: 13 | if not self.overall_only: 14 | return score 15 | return { 16 | "precision": score["overall_precision"], 17 | "recall": score["overall_recall"], 18 | "f1": score["overall_f1"], 19 | "accuracy": score["overall_accuracy"], 20 | } 21 | -------------------------------------------------------------------------------- /trapper/models/__init__.py: -------------------------------------------------------------------------------- 1 | from trapper.models.auto_wrappers import ( 2 | ModelWrapperForCausalLM, 3 | ModelWrapperForMaskedLM, 4 | ModelWrapperForMultipleChoice, 5 | ModelWrapperForNextSentencePrediction, 6 | ModelWrapperForQuestionAnswering, 7 | ModelWrapperForSeq2SeqLM, 8 | ModelWrapperForSequenceClassification, 9 | ModelWrapperForTokenClassification, 10 | ) 11 | from trapper.models.model_wrapper import ModelWrapper 12 | -------------------------------------------------------------------------------- /trapper/models/auto_wrappers.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=protected-access 2 | """ 3 | This module contains the wrapped task-specific `auto` classes from the 4 | `Transformers` library. 5 | """ 6 | from collections import OrderedDict 7 | from typing import Any, Type 8 | 9 | from transformers import ( 10 | AutoModelForCausalLM, 11 | AutoModelForMaskedLM, 12 | AutoModelForMultipleChoice, 13 | AutoModelForNextSentencePrediction, 14 | AutoModelForQuestionAnswering, 15 | AutoModelForSeq2SeqLM, 16 | AutoModelForSequenceClassification, 17 | AutoModelForTokenClassification, 18 | ) 19 | 20 | from trapper.models.model_wrapper import ModelWrapper 21 | 22 | _COMMON_INPUT_FIELDS = ["input_ids", "attention_mask"] 23 | _TASK_TO_INPUT_FIELDS = OrderedDict( 24 | [ 25 | # Input fields are required by the data collators and task-specific 26 | # ModelWrapper subclasses. 27 | ( 28 | "question_answering", 29 | ( 30 | *_COMMON_INPUT_FIELDS, 31 | "token_type_ids", 32 | "start_positions", 33 | "end_positions", 34 | ), 35 | ), 36 | ("token_classification", (*_COMMON_INPUT_FIELDS, "labels")), 37 | ("causal_lm", (*_COMMON_INPUT_FIELDS, "token_type_ids", "labels")), 38 | ("masked_lm", (*_COMMON_INPUT_FIELDS, "token_type_ids", "labels")), 39 | ("seq2seq_lm", (*_COMMON_INPUT_FIELDS, "labels")), 40 | ( 41 | "sequence_classification", 42 | (*_COMMON_INPUT_FIELDS, "token_type_ids", "labels"), 43 | ), 44 | ("multiple_choice", (*_COMMON_INPUT_FIELDS, "token_type_ids", "labels")), 45 | ( 46 | "next_sentence_prediction", 47 | (*_COMMON_INPUT_FIELDS, "token_type_ids", "labels"), 48 | ), 49 | ] 50 | ) 51 | 52 | 53 | def _create_and_register_transformer_subclass( 54 | auto_cls: Type, task: str 55 | ) -> Type[ModelWrapper]: 56 | """ 57 | Dynamically creates a ModelWrapper subclass by wrapping an `AutoModelFor...` 58 | factory from the Transformers library. Then, the subclass is 59 | registered to the framework with the `task` argument and returned. 60 | Args: 61 | auto_cls (Type): an `auto` class from `Transformers` 62 | task (str): registered name of the subclass 63 | 64 | Returns: 65 | A registered task-specific `ModelWrapper` subclass 66 | """ 67 | cls = _create_transformer_subclass(auto_cls, task) 68 | ModelWrapper.register(task, constructor="from_pretrained")(cls) 69 | cls._TASK_SPECIFIC_AUTO_CLASS = auto_cls 70 | cls._TASK_SPECIFIC_FORWARD_PARAMS = _TASK_TO_INPUT_FIELDS[task] 71 | return cls 72 | 73 | 74 | def _create_transformer_subclass(auto_cls: Type, task: str) -> Type[ModelWrapper]: 75 | """ 76 | Dynamically creates a ModelWrapper subclass by wrapping an `AutoModelFor...` 77 | factory from the Transformers library. 78 | Args: 79 | auto_cls (Type): an `auto` class from `Transformers` 80 | task (str): the task name inserted to the docstring 81 | 82 | Returns: 83 | A task-specific `ModelWrapper` subclass 84 | """ 85 | auto_cls_name = auto_cls.__name__ 86 | subcls_name = auto_cls_name.replace("AutoModel", "ModelWrapper") 87 | attr_dict = {"__doc__": _get_transformer_subclass_doc(auto_cls_name, task)} 88 | cls: Any = type(subcls_name, (ModelWrapper,), attr_dict) 89 | return cls 90 | 91 | 92 | def _get_transformer_subclass_doc(auto_model_name: str, task: str): 93 | return f""" 94 | Wrapper for `transformers.{auto_model_name}`. Registered as the `ModelWrapper` 95 | factory for `{task}` style tasks. 96 | """ 97 | 98 | 99 | # Below, we try to add and register as much auto classes as we can from the 100 | # `transformers` library. The original auto classes from `transformers` are found in 101 | # `src/transformers/models/auto/modeling_auto.py` file. 102 | # -------------------------------------------------------------------------- 103 | 104 | # The classes that have been tested are below 105 | 106 | # The model wrapper factory that yields a wrapped base model with a question 107 | # answering head 108 | ModelWrapperForQuestionAnswering = _create_and_register_transformer_subclass( 109 | AutoModelForQuestionAnswering, "question_answering" 110 | ) 111 | 112 | # The model wrapper factory that yields a wrapped base model with a token 113 | # classification head 114 | ModelWrapperForTokenClassification = _create_and_register_transformer_subclass( 115 | AutoModelForTokenClassification, "token_classification" 116 | ) 117 | # -------------------------------------------------------------------------- 118 | 119 | # Experimental classes that have not been tested yet are below. Note that some of 120 | # them may be removed in the future. 121 | 122 | # The model wrapper factory that yields a wrapped base model with a causal language 123 | # modeling head 124 | ModelWrapperForCausalLM = _create_and_register_transformer_subclass( 125 | AutoModelForCausalLM, "causal_lm" 126 | ) 127 | 128 | # The model wrapper factory that yields a wrapped base model with a masked language 129 | # modeling head 130 | ModelWrapperForMaskedLM = _create_and_register_transformer_subclass( 131 | AutoModelForMaskedLM, "masked_lm" 132 | ) 133 | 134 | # The model wrapper factory that yields a wrapped base model with a seq-to-seq 135 | # language modeling head 136 | ModelWrapperForSeq2SeqLM = _create_and_register_transformer_subclass( 137 | AutoModelForSeq2SeqLM, "seq2seq_lm" 138 | ) 139 | 140 | # The model wrapper factory that yields a wrapped base model with a sequence 141 | # classification head 142 | ModelWrapperForSequenceClassification = _create_and_register_transformer_subclass( 143 | AutoModelForSequenceClassification, "sequence_classification" 144 | ) 145 | 146 | # The model wrapper factory that yields a wrapped base model with a multiple choice 147 | # head 148 | ModelWrapperForMultipleChoice = _create_and_register_transformer_subclass( 149 | AutoModelForMultipleChoice, "multiple_choice" 150 | ) 151 | 152 | # The model wrapper factory that yields a wrapped base model with a next sentence 153 | # prediction head 154 | ModelWrapperForNextSentencePrediction = _create_and_register_transformer_subclass( 155 | AutoModelForNextSentencePrediction, "next_sentence_prediction" 156 | ) 157 | -------------------------------------------------------------------------------- /trapper/pipelines/__init__.py: -------------------------------------------------------------------------------- 1 | from trapper.pipelines.functional import ( 2 | create_pipeline_from_checkpoint, 3 | create_pipeline_from_params, 4 | ) 5 | from trapper.pipelines.pipeline import PipelineMixin 6 | from trapper.pipelines.question_answering_pipeline import ( 7 | SquadQuestionAnsweringPipeline, 8 | ) 9 | -------------------------------------------------------------------------------- /trapper/pipelines/functional.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from pathlib import Path 3 | from typing import Any, Dict, Optional, Union 4 | 5 | from huggingface_hub.file_download import hf_hub_download 6 | from huggingface_hub.hf_api import HfApi 7 | from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError 8 | 9 | from trapper.common.params import Params 10 | from trapper.pipelines.pipeline import PIPELINE_CONFIG_ARGS, PipelineMixin 11 | 12 | DEFAULT_CFG_NAME = "experiment_config.json" 13 | 14 | 15 | def _read_pipeline_params( 16 | config_path: str, 17 | params_overrides: Union[str, Dict[str, Any], None] = None, 18 | ) -> Params: 19 | if not (config_path.endswith(".json") or config_path.endswith(".jsonnet")): 20 | raise ValueError( 21 | "Illegal file format. Please provide a json or jsonnet file!" 22 | ) 23 | _validate_params_overrides(params_overrides) 24 | params = Params.from_file( 25 | params_file=config_path, 26 | params_overrides=params_overrides, 27 | ) 28 | return params 29 | 30 | 31 | def _validate_checkpoint_dir(path: Union[str, Path]) -> None: 32 | path = Path(path) 33 | if not path.is_dir(): 34 | raise ValueError("Input path must be an existing directory") 35 | 36 | 37 | def _validate_params_overrides( 38 | params_overrides: Union[str, Dict[str, Any]] 39 | ) -> None: 40 | if params_overrides is None: 41 | return 42 | elif isinstance(params_overrides, dict): 43 | if ( 44 | "type" in params_overrides 45 | or "pretrained_model_name_or_path" in params_overrides 46 | ): 47 | raise ValueError( 48 | "'type' and 'pretrained_model_name_or_path are not allowed " 49 | "to be used in 'params_overrides'." 50 | ) 51 | 52 | 53 | def create_pipeline_from_params( 54 | params, pipeline_type: Optional[str] = "default", **kwargs 55 | ) -> PipelineMixin: 56 | data_components = params.get("dataset_loader").params 57 | params.update(data_components) 58 | params = Params({k: v for k, v in params.items() if k in PIPELINE_CONFIG_ARGS}) 59 | params.update( 60 | { 61 | "type": pipeline_type, 62 | **kwargs, 63 | } 64 | ) 65 | return PipelineMixin.from_params(params) 66 | 67 | 68 | def repo_exists(repo_id: str) -> bool: 69 | hf_api = HfApi() 70 | try: 71 | hf_api.repo_info(repo_id) 72 | except RepositoryNotFoundError: 73 | return False 74 | return True 75 | 76 | 77 | def _sanitize_checkpoint( 78 | checkpoint_path: Union[str, Path], 79 | experiment_config_path: Union[str, Path, None], 80 | use_auth_token: Union[bool, str, None], 81 | ) -> str: 82 | checkpoint_path = Path(checkpoint_path) 83 | if checkpoint_path.is_dir(): # Try local checkpoint 84 | if experiment_config_path is None: 85 | warnings.warn( 86 | "`experiment_config_path` is not given and assumed to be located under `checkpoint_path`." 87 | ) 88 | experiment_config_path = str(checkpoint_path / DEFAULT_CFG_NAME) 89 | elif repo_exists(checkpoint_path.as_posix()): # Try HF Model-hub 90 | if experiment_config_path is None: 91 | try: 92 | experiment_config_path = hf_hub_download( 93 | checkpoint_path.as_posix(), 94 | DEFAULT_CFG_NAME, 95 | use_auth_token=use_auth_token, 96 | ) 97 | except EntryNotFoundError: 98 | raise ValueError( 99 | "If a model is given in HF-hub, `experiment_config.json` must be included in " 100 | "the model hub repository." 101 | ) 102 | else: 103 | raise ValueError( 104 | "Input path must be an existing directory or an existing " 105 | "repository at huggingface model hub." 106 | ) 107 | return experiment_config_path 108 | 109 | 110 | def create_pipeline_from_checkpoint( 111 | checkpoint_path: Union[str, Path], 112 | experiment_config_path: Union[str, Path] = None, 113 | params_overrides: Union[str, Dict[str, Any]] = None, 114 | pipeline_type: Optional[str] = "default", 115 | use_auth_token: Union[str, bool, None] = None, 116 | **kwargs 117 | ) -> PipelineMixin: 118 | if experiment_config_path is None and params_overrides is not None: 119 | params = Params(params_overrides) 120 | else: 121 | experiment_config_path = _sanitize_checkpoint( 122 | checkpoint_path, experiment_config_path, use_auth_token=use_auth_token 123 | ) 124 | params = _read_pipeline_params(experiment_config_path, params_overrides) 125 | params.update({"pretrained_model_name_or_path": checkpoint_path}) 126 | return create_pipeline_from_params( 127 | params, pipeline_type=pipeline_type, **kwargs 128 | ) 129 | -------------------------------------------------------------------------------- /trapper/pipelines/pipeline.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Open Business Software Solutions, the HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from typing import Optional 15 | 16 | from transformers import Pipeline as _Pipeline 17 | 18 | from trapper.common import Lazy, Registrable 19 | from trapper.common.plugins import import_plugins 20 | from trapper.common.utils import append_parent_docstr 21 | from trapper.data import ( 22 | DataAdapter, 23 | DataCollator, 24 | DataProcessor, 25 | LabelMapper, 26 | TokenizerWrapper, 27 | ) 28 | from trapper.models import ModelWrapper 29 | 30 | PIPELINE_CONFIG_ARGS = [ 31 | "pretrained_model_name_or_path", 32 | "model_wrapper", 33 | "tokenizer_wrapper", 34 | "data_processor", 35 | "data_adapter", 36 | "data_collator", 37 | "args_parser", 38 | "model_max_sequence_length", 39 | "label_mapper", 40 | "feature_extractor", 41 | "modelcard", 42 | "framework", 43 | "task", 44 | "device", 45 | "binary_output", 46 | "use_auth_token", 47 | ] 48 | 49 | 50 | @append_parent_docstr(parent_id=0) 51 | class PipelineMixin(_Pipeline, Registrable): 52 | """ 53 | A Mixin class for constructing pipelines that utilize data components of trapper. 54 | This class' precedence in multiple inheritance should be higher, i.e. inheritance 55 | order of PipelineMixin should be low. 56 | 57 | Note: 58 | In theory and practice this class can be used as a base class to create a 59 | custom concrete class; however, this class is designed as a mixin to be used 60 | with transformers' pipeline classes, and should never be used solely as it 61 | is not a concrete class. 62 | 63 | Although not recommended, it can be used like a base class that extends 64 | transformers Pipeline class. In this case, this class must implement 65 | the abstract and required methods. 66 | 67 | Examples: 68 | from transformers.pipelines import QuestionAnsweringPipeline 69 | 70 | class CustomQAPipeline(PipelineMixin, QuestionAnsweringPipeline): 71 | ... 72 | """ 73 | 74 | default_implementation = "default" 75 | 76 | def __init__( 77 | self, 78 | data_processor: DataProcessor, 79 | data_adapter: DataAdapter, 80 | data_collator: Optional[DataCollator] = None, 81 | **kwargs 82 | ): 83 | super(PipelineMixin, self).__init__(**kwargs) 84 | self._data_processor = data_processor 85 | self._data_adapter = data_adapter 86 | self._data_collator = data_collator 87 | 88 | @property 89 | def data_processor(self): 90 | return self._data_processor 91 | 92 | @property 93 | def data_adapter(self): 94 | return self._data_adapter 95 | 96 | @property 97 | def data_collator(self): 98 | return self._data_collator 99 | 100 | @classmethod 101 | def from_partial_objects( 102 | cls, 103 | pretrained_model_name_or_path: str, 104 | model_wrapper: Lazy[ModelWrapper], 105 | tokenizer_wrapper: Lazy[TokenizerWrapper], 106 | data_processor: Lazy[DataProcessor], 107 | data_adapter: Lazy[DataAdapter], 108 | data_collator: Optional[Lazy[DataCollator]] = None, 109 | label_mapper: Optional[Lazy[LabelMapper]] = None, 110 | model_max_sequence_length: Optional[int] = None, 111 | framework: Optional[str] = "pt", 112 | task: str = "", 113 | device: int = -1, 114 | binary_output: bool = False, 115 | use_auth_token: bool = None, 116 | **kwargs 117 | ) -> "PipelineMixin": 118 | 119 | # To find the registrable components from the user-defined packages 120 | import_plugins() 121 | 122 | model_wrapper_ = model_wrapper.construct( 123 | pretrained_model_name_or_path=pretrained_model_name_or_path, 124 | use_auth_token=use_auth_token, 125 | ) 126 | model_forward_params = model_wrapper_.forward_params 127 | 128 | if label_mapper: 129 | label_mapper_ = label_mapper.construct() 130 | else: 131 | label_mapper_ = None 132 | 133 | tokenizer_wrapper_ = tokenizer_wrapper.construct( 134 | pretrained_model_name_or_path=pretrained_model_name_or_path, 135 | use_auth_token=use_auth_token, 136 | ) 137 | 138 | data_processor_ = data_processor.construct( 139 | tokenizer_wrapper=tokenizer_wrapper_, 140 | label_mapper=label_mapper_, 141 | model_max_sequence_length=model_max_sequence_length, 142 | ) 143 | 144 | if data_collator: 145 | data_collator_ = data_collator.construct( 146 | tokenizer_wrapper=tokenizer_wrapper_, 147 | model_forward_params=model_forward_params, 148 | ) 149 | else: 150 | data_collator_ = None 151 | 152 | data_adapter_ = data_adapter.construct( 153 | tokenizer_wrapper=tokenizer_wrapper_, label_mapper=label_mapper 154 | ) 155 | 156 | return cls( 157 | model=model_wrapper_.model, 158 | tokenizer=tokenizer_wrapper_.tokenizer, 159 | data_processor=data_processor_, 160 | data_adapter=data_adapter_, 161 | data_collator=data_collator_, 162 | framework=framework, 163 | task=task, 164 | device=device, 165 | binary_output=binary_output, 166 | **kwargs 167 | ) 168 | 169 | 170 | PipelineMixin.register("default", constructor="from_partial_objects")(PipelineMixin) 171 | -------------------------------------------------------------------------------- /trapper/training/__init__.py: -------------------------------------------------------------------------------- 1 | from trapper.training.optimizers import Optimizer 2 | from trapper.training.trainer import TransformerTrainer 3 | from trapper.training.training_args import TransformerTrainingArguments 4 | -------------------------------------------------------------------------------- /trapper/training/callbacks.py: -------------------------------------------------------------------------------- 1 | from transformers import TrainerCallback as _TrainerCallback 2 | 3 | from trapper.common import Registrable 4 | 5 | 6 | class TrainerCallback(_TrainerCallback, Registrable): 7 | """ 8 | The base class that implements `Registrable` for 9 | `transformers.TrainerCallback`. To add a custom callback, just subclass 10 | this and register the new class with a name. 11 | """ 12 | -------------------------------------------------------------------------------- /trapper/training/optimizers.py: -------------------------------------------------------------------------------- 1 | from allennlp.training import optimizers as _allennlp_optimizers 2 | 3 | from trapper.common import Registrable 4 | from trapper.common.utils import append_parent_docstr 5 | 6 | 7 | @append_parent_docstr(parent_id=1) 8 | class Optimizer(Registrable, _allennlp_optimizers.Optimizer): 9 | """ 10 | The base `Registrable` optimizer class that replaces the one from the 11 | `allennlp` library. 12 | """ 13 | 14 | 15 | @Optimizer.register("adam") 16 | class AdamOptimizer(Optimizer, _allennlp_optimizers.AdamOptimizer): 17 | """ 18 | `Adam` optimizer registered with name "adam". 19 | """ 20 | 21 | 22 | @Optimizer.register("sparse_adam") 23 | class SparseAdamOptimizer(Optimizer, _allennlp_optimizers.SparseAdamOptimizer): 24 | """ 25 | `SparseAdam` optimizer registered with name "sparse_adam". 26 | """ 27 | 28 | 29 | @Optimizer.register("adamax") 30 | class AdamaxOptimizer(Optimizer, _allennlp_optimizers.AdamaxOptimizer): 31 | """ 32 | `Adamax` optimizer registered with name "adamax". 33 | """ 34 | 35 | 36 | @Optimizer.register("adamw") 37 | class AdamWOptimizer(Optimizer, _allennlp_optimizers.AdamWOptimizer): 38 | """ 39 | `AdamW` optimizer registered with name "adamw". 40 | """ 41 | 42 | 43 | @Optimizer.register("huggingface_adamw") 44 | class HuggingfaceAdamWOptimizer( 45 | Optimizer, _allennlp_optimizers.HuggingfaceAdamWOptimizer 46 | ): 47 | """ 48 | `HuggingfaceAdamW` optimizer registered with name "huggingface_adamw". 49 | """ 50 | 51 | 52 | @Optimizer.register("adagrad") 53 | class AdagradOptimizer(Optimizer, _allennlp_optimizers.AdagradOptimizer): 54 | """ 55 | `Adagrad` optimizer registered with name "adagrad". 56 | """ 57 | 58 | 59 | @Optimizer.register("adadelta") 60 | class AdadeltaOptimizer(Optimizer, _allennlp_optimizers.AdadeltaOptimizer): 61 | """ 62 | `Adadelta` optimizer registered with name "adadelta". 63 | """ 64 | 65 | 66 | @Optimizer.register("sgd") 67 | class SgdOptimizer(Optimizer, _allennlp_optimizers.SgdOptimizer): 68 | """ 69 | `Sgd` optimizer registered with name "sgd". 70 | """ 71 | 72 | 73 | @Optimizer.register("rmsprop") 74 | class RmsPropOptimizer(Optimizer, _allennlp_optimizers.RmsPropOptimizer): 75 | """ 76 | `RmsProp` optimizer registered with name "rmsprop". 77 | """ 78 | 79 | 80 | @Optimizer.register("averaged_sgd") 81 | class AveragedSgdOptimizer(Optimizer, _allennlp_optimizers.AveragedSgdOptimizer): 82 | """ 83 | `AveragedSgd` optimizer registered with name "averaged_sgd". 84 | """ 85 | 86 | 87 | @append_parent_docstr(parent_id=1) 88 | @Optimizer.register("dense_sparse_adam") 89 | class DenseSparseAdamOptimizer( 90 | Optimizer, 91 | _allennlp_optimizers.DenseSparseAdam, 92 | ): 93 | """ 94 | `DenseSparseAdam` optimizer registered with name "dense_sparse_adam". 95 | """ 96 | -------------------------------------------------------------------------------- /trapper/training/training_args.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import Optional 3 | 4 | from transformers import Seq2SeqTrainingArguments 5 | from transformers.training_args import TrainingArguments as _TrainingArguments 6 | from transformers.utils import logging 7 | 8 | from trapper.common import Registrable 9 | from trapper.common.utils import append_parent_docstr 10 | 11 | logger = logging.get_logger(__name__) 12 | 13 | 14 | @append_parent_docstr(parent_id=0) 15 | @dataclass 16 | class TransformerTrainingArguments(_TrainingArguments, Registrable): 17 | """ 18 | Wraps the `TrainingArguments` class from the `Transformers` library. 19 | """ 20 | 21 | result_dir: Optional[str] = field( 22 | default=None, 23 | metadata={ 24 | "help": "The directory to save the metrics for the final model " 25 | "and the trainer state at the end of the training." 26 | }, 27 | ) 28 | 29 | def __post_init__(self): 30 | if self.report_to is None: 31 | logger.info( 32 | "Transformers v4.5.1 defaults `--report_to` to 'all', " 33 | "so we change it to 'tensorboard'." 34 | ) 35 | self.report_to = ["tensorboard"] 36 | super().__post_init__() 37 | 38 | 39 | TransformerTrainingArguments.register("default")(TransformerTrainingArguments) 40 | 41 | 42 | @TransformerTrainingArguments.register("seq2seq") 43 | class Seq2SeqTransformerTrainingArguments( 44 | TransformerTrainingArguments, Seq2SeqTrainingArguments 45 | ): 46 | pass 47 | -------------------------------------------------------------------------------- /trapper/version.py: -------------------------------------------------------------------------------- 1 | _MAJOR = "0" 2 | _MINOR = "0" 3 | _PATCH = "13" 4 | 5 | VERSION = "{0}.{1}.{2}".format(_MAJOR, _MINOR, _PATCH) 6 | --------------------------------------------------------------------------------