├── process_datasets ├── __init__.py ├── utils │ ├── __init__.py │ └── general.py ├── strategies │ ├── paragraph │ │ ├── __init__.py │ │ └── base.py │ ├── __init__.py │ └── strategy.py ├── loaders │ ├── __init__.py │ ├── loader.py │ ├── disk.py │ ├── dataset.py │ └── json.py └── __main__.py ├── transformers_framework ├── __init__.py ├── architectures │ ├── __init__.py │ ├── roberta │ │ ├── __init__.py │ │ └── modeling_config.py │ ├── modeling_output.py │ └── modeling_head.py ├── models │ ├── base │ │ ├── __init__.py │ │ ├── mlm.py │ │ ├── base.py │ │ ├── classification.py │ │ └── as2.py │ ├── joint │ │ ├── __init__.py │ │ ├── as2 │ │ │ ├── __init__.py │ │ │ ├── roberta.py │ │ │ └── base.py │ │ ├── mlm │ │ │ ├── __init__.py │ │ │ ├── roberta.py │ │ │ └── base.py │ │ ├── mlm_as2 │ │ │ ├── __init__.py │ │ │ ├── roberta.py │ │ │ └── base.py │ │ └── fact_checking │ │ │ ├── roberta.py │ │ │ └── base.py │ └── __init__.py ├── transformations │ ├── __init__.py │ ├── transformation.py │ └── conversion_transformation.py ├── adapters │ ├── map_adapters │ │ ├── __init__.py │ │ ├── arrow │ │ │ ├── __init__.py │ │ │ ├── pairwise_adapter.py │ │ │ ├── arrow_adapter.py │ │ │ └── jointwise_adapter.py │ │ └── map_adapter.py │ ├── __init__.py │ └── transformer_adapter.py ├── datamodules │ ├── __init__.py │ └── transformers_datamodule.py ├── utilities │ ├── .DS_Store │ ├── datamodules.py │ ├── __init__.py │ ├── structures.py │ ├── classes.py │ ├── processors.py │ ├── tokenization.py │ └── functional.py ├── samplers │ ├── __init__.py │ └── keys_sampler.py └── __main__.py ├── .gitattributes ├── datasets └── scores_as2 │ ├── scores_roberta_base_asnq │ ├── dataset_dict.json │ ├── .DS_Store │ ├── test │ │ ├── dataset.arrow │ │ ├── state.json │ │ └── dataset_info.json │ ├── train │ │ ├── dataset.arrow │ │ ├── state.json │ │ └── dataset_info.json │ └── validation │ │ ├── dataset.arrow │ │ ├── state.json │ │ └── dataset_info.json │ ├── scores_roberta_base_trecqa │ ├── dataset_dict.json │ ├── .DS_Store │ ├── test │ │ ├── dataset.arrow │ │ ├── state.json │ │ └── dataset_info.json │ ├── train │ │ ├── dataset.arrow │ │ ├── state.json │ │ └── dataset_info.json │ └── validation │ │ ├── dataset.arrow │ │ ├── state.json │ │ └── dataset_info.json │ ├── scores_roberta_base_wikiqa │ ├── dataset_dict.json │ ├── .DS_Store │ ├── test │ │ ├── dataset.arrow │ │ ├── state.json │ │ └── dataset_info.json │ ├── train │ │ ├── dataset.arrow │ │ ├── state.json │ │ └── dataset_info.json │ └── validation │ │ ├── dataset.arrow │ │ ├── state.json │ │ └── dataset_info.json │ └── .DS_Store ├── requirements.txt ├── CODE_OF_CONDUCT.md ├── transformers_experiments ├── pretraining │ ├── run_roberta_base_joint_pretraining_ae_k.sh │ └── run_roberta_base_joint_pretraining_ie_k.sh └── finetuning │ ├── fever │ ├── run_roberta_base_joint_ae_1.sh │ └── run_roberta_base_joint_ie_1.sh │ ├── asnq │ ├── run_roberta_base_joint_ae_k_no_shuf.sh │ ├── run_roberta_base_joint_ie_k_no_shuf.sh │ ├── run_roberta_base_joint_ae_k.sh │ ├── run_roberta_base_joint_ie_k.sh │ ├── run_roberta_base_joint_ae_k_best.sh │ └── run_roberta_base_joint_ie_k_best.sh │ ├── trecqa │ ├── run_roberta_base_joint_ae_k_no_shuf.sh │ ├── run_roberta_base_joint_ie_k_no_shuf.sh │ ├── run_roberta_base_joint_ae_k.sh │ ├── run_roberta_base_joint_ie_k.sh │ ├── run_roberta_base_joint_ae_k_best.sh │ └── run_roberta_base_joint_ie_k_best.sh │ └── wikiqa │ ├── run_roberta_base_joint_ae_k_no_shuf.sh │ ├── run_roberta_base_joint_ie_k_no_shuf.sh │ ├── run_roberta_base_joint_ae_k.sh │ ├── run_roberta_base_joint_ie_k.sh │ ├── run_roberta_base_joint_ae_k_best.sh │ └── run_roberta_base_joint_ie_k_best.sh ├── setup.cfg ├── transformers_utilities └── datasets │ ├── merge_datasets.py │ ├── create_trecqa_dataset.py │ ├── create_wikiqa_dataset.py │ ├── create_fever_dataset.py │ └── create_asnq_dataset.py ├── Makefile └── CONTRIBUTING.md /process_datasets/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /process_datasets/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /transformers_framework/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /transformers_framework/architectures/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /transformers_framework/models/base/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /transformers_framework/models/joint/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /process_datasets/strategies/paragraph/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /transformers_framework/models/joint/as2/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /transformers_framework/models/joint/mlm/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /transformers_framework/transformations/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /transformers_framework/adapters/map_adapters/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /transformers_framework/architectures/roberta/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /transformers_framework/models/joint/mlm_as2/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | *.arrow filter=lfs diff=lfs merge=lfs -text 2 | -------------------------------------------------------------------------------- /transformers_framework/adapters/map_adapters/arrow/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /datasets/scores_as2/scores_roberta_base_asnq/dataset_dict.json: -------------------------------------------------------------------------------- 1 | {"splits": ["train", "validation", "test"]} -------------------------------------------------------------------------------- /datasets/scores_as2/scores_roberta_base_trecqa/dataset_dict.json: -------------------------------------------------------------------------------- 1 | {"splits": ["train", "validation", "test"]} -------------------------------------------------------------------------------- /datasets/scores_as2/scores_roberta_base_wikiqa/dataset_dict.json: -------------------------------------------------------------------------------- 1 | {"splits": ["test", "validation", "train"]} -------------------------------------------------------------------------------- /datasets/scores_as2/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/wqa-multi-sentence-inference/HEAD/datasets/scores_as2/.DS_Store -------------------------------------------------------------------------------- /process_datasets/strategies/__init__.py: -------------------------------------------------------------------------------- 1 | from process_datasets.strategies.paragraph.sentence_sentence import Sentence2SentenceStrategy # noqa: F401 2 | -------------------------------------------------------------------------------- /transformers_framework/datamodules/__init__.py: -------------------------------------------------------------------------------- 1 | from transformers_framework.datamodules.transformers_datamodule import TransformersDataModule # noqa: F401 2 | -------------------------------------------------------------------------------- /transformers_framework/utilities/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/wqa-multi-sentence-inference/HEAD/transformers_framework/utilities/.DS_Store -------------------------------------------------------------------------------- /datasets/scores_as2/scores_roberta_base_asnq/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/wqa-multi-sentence-inference/HEAD/datasets/scores_as2/scores_roberta_base_asnq/.DS_Store -------------------------------------------------------------------------------- /datasets/scores_as2/scores_roberta_base_trecqa/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/wqa-multi-sentence-inference/HEAD/datasets/scores_as2/scores_roberta_base_trecqa/.DS_Store -------------------------------------------------------------------------------- /datasets/scores_as2/scores_roberta_base_wikiqa/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/wqa-multi-sentence-inference/HEAD/datasets/scores_as2/scores_roberta_base_wikiqa/.DS_Store -------------------------------------------------------------------------------- /datasets/scores_as2/scores_roberta_base_asnq/test/dataset.arrow: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:600cc561a75e71874c83d60565babd8ea1051f2196d213124afecad234ed8eee 3 | size 3735136 4 | -------------------------------------------------------------------------------- /datasets/scores_as2/scores_roberta_base_trecqa/test/dataset.arrow: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:afff656a0334390f8956e7b32382eac3d12c12d52bdb91dc8d9ac5b5f6a69275 3 | size 12416 4 | -------------------------------------------------------------------------------- /datasets/scores_as2/scores_roberta_base_wikiqa/test/dataset.arrow: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:5b3fa3bc6ea7146eaa296f2d3bcd69f45a6759c2160622e9635d96037d1e9228 3 | size 20280 4 | -------------------------------------------------------------------------------- /datasets/scores_as2/scores_roberta_base_asnq/train/dataset.arrow: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:27669ac869b480d45919d2ddcaa4f8569becdfc9c83afb835796b7fd043b8e0c 3 | size 163251120 4 | -------------------------------------------------------------------------------- /datasets/scores_as2/scores_roberta_base_asnq/validation/dataset.arrow: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:0aae8f751c0cce505b99b3285828a2a13c0f3da6c19c7726907138276e2645ec 3 | size 3717264 4 | -------------------------------------------------------------------------------- /datasets/scores_as2/scores_roberta_base_trecqa/train/dataset.arrow: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:0b1f929534a9c936663f0635b6f8a66410fa56da1254c548058248ae552c26d9 3 | size 432856 4 | -------------------------------------------------------------------------------- /datasets/scores_as2/scores_roberta_base_trecqa/validation/dataset.arrow: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:608ea2e48dc7564005a5e8263e5ec3ee050fbbc8f68d36ee6313d6cee26b1893 3 | size 11624 4 | -------------------------------------------------------------------------------- /datasets/scores_as2/scores_roberta_base_wikiqa/train/dataset.arrow: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:8b20dd7c3d6cc201129f7a8a37a0c73c1af805e379606e8ed0d1b2907c1e8220 3 | size 171960 4 | -------------------------------------------------------------------------------- /datasets/scores_as2/scores_roberta_base_wikiqa/validation/dataset.arrow: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:86214dcdc5472bff2a24f95b0d88c355d30e59b0d5b4bfb01f56db3aa7872182 3 | size 10104 4 | -------------------------------------------------------------------------------- /transformers_framework/samplers/__init__.py: -------------------------------------------------------------------------------- 1 | from transformers_framework.samplers.keys_sampler import DistributedKeysSampler # noqa: F401 2 | from transformers_framework.samplers.keys_sampler import KeysSampler # noqa: F401 3 | -------------------------------------------------------------------------------- /process_datasets/loaders/__init__.py: -------------------------------------------------------------------------------- 1 | from process_datasets.loaders.dataset import DatasetLoader # noqa: F401 2 | from process_datasets.loaders.disk import DiskLoader # noqa: F401 3 | from process_datasets.loaders.json import JsonLoader # noqa: F401 4 | -------------------------------------------------------------------------------- /transformers_framework/adapters/__init__.py: -------------------------------------------------------------------------------- 1 | from transformers_framework.adapters.map_adapters.arrow.jointwise_adapter import JointwiseArrowAdapter # noqa: F401 2 | from transformers_framework.adapters.map_adapters.arrow.pairwise_adapter import PairwiseArrowAdapter # noqa: F401 3 | -------------------------------------------------------------------------------- /transformers_framework/utilities/datamodules.py: -------------------------------------------------------------------------------- 1 | from pytorch_lightning.trainer.states import TrainerFn 2 | 3 | 4 | STAGES_TO_NAMES = { 5 | TrainerFn.FITTING: 'train', 6 | TrainerFn.VALIDATING: 'valid', 7 | TrainerFn.TESTING: 'test', 8 | TrainerFn.PREDICTING: 'predict', 9 | } 10 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.13.1 2 | torchmetrics>=0.7.2,<0.8 3 | rich>=10.14 4 | scipy>=1.7.3 5 | torchvision>=0.11.3 6 | sentencepiece 7 | matplotlib 8 | tokenizers>=0.11.6 9 | datasets>=2.1.0,<3 10 | transformers>=4.17,<5 11 | transformers-lightning>=0.7.8,<8 12 | pytorch-lightning>=1.5.10,<1.6 13 | blingfire>=0.1.7 14 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | ## Code of Conduct 2 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 3 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 4 | opensource-codeofconduct@amazon.com with any additional questions or comments. 5 | -------------------------------------------------------------------------------- /datasets/scores_as2/scores_roberta_base_asnq/test/state.json: -------------------------------------------------------------------------------- 1 | { 2 | "_data_files": [ 3 | { 4 | "filename": "dataset.arrow" 5 | } 6 | ], 7 | "_fingerprint": "c2efeb7ebddb4718", 8 | "_format_columns": null, 9 | "_format_kwargs": {}, 10 | "_format_type": null, 11 | "_indexes": {}, 12 | "_output_all_columns": false, 13 | "_split": null 14 | } -------------------------------------------------------------------------------- /datasets/scores_as2/scores_roberta_base_asnq/train/state.json: -------------------------------------------------------------------------------- 1 | { 2 | "_data_files": [ 3 | { 4 | "filename": "dataset.arrow" 5 | } 6 | ], 7 | "_fingerprint": "a33b72ac5b0ffcd5", 8 | "_format_columns": null, 9 | "_format_kwargs": {}, 10 | "_format_type": null, 11 | "_indexes": {}, 12 | "_output_all_columns": false, 13 | "_split": null 14 | } -------------------------------------------------------------------------------- /datasets/scores_as2/scores_roberta_base_trecqa/test/state.json: -------------------------------------------------------------------------------- 1 | { 2 | "_data_files": [ 3 | { 4 | "filename": "dataset.arrow" 5 | } 6 | ], 7 | "_fingerprint": "34b70bde23ffad8b", 8 | "_format_columns": null, 9 | "_format_kwargs": {}, 10 | "_format_type": null, 11 | "_indexes": {}, 12 | "_output_all_columns": false, 13 | "_split": null 14 | } -------------------------------------------------------------------------------- /datasets/scores_as2/scores_roberta_base_trecqa/train/state.json: -------------------------------------------------------------------------------- 1 | { 2 | "_data_files": [ 3 | { 4 | "filename": "dataset.arrow" 5 | } 6 | ], 7 | "_fingerprint": "9b5ebbbfdece10f2", 8 | "_format_columns": null, 9 | "_format_kwargs": {}, 10 | "_format_type": null, 11 | "_indexes": {}, 12 | "_output_all_columns": false, 13 | "_split": null 14 | } -------------------------------------------------------------------------------- /datasets/scores_as2/scores_roberta_base_wikiqa/test/state.json: -------------------------------------------------------------------------------- 1 | { 2 | "_data_files": [ 3 | { 4 | "filename": "dataset.arrow" 5 | } 6 | ], 7 | "_fingerprint": "d4780e5c8123ad6f", 8 | "_format_columns": null, 9 | "_format_kwargs": {}, 10 | "_format_type": null, 11 | "_indexes": {}, 12 | "_output_all_columns": false, 13 | "_split": null 14 | } -------------------------------------------------------------------------------- /datasets/scores_as2/scores_roberta_base_wikiqa/train/state.json: -------------------------------------------------------------------------------- 1 | { 2 | "_data_files": [ 3 | { 4 | "filename": "dataset.arrow" 5 | } 6 | ], 7 | "_fingerprint": "d0825c436dad8485", 8 | "_format_columns": null, 9 | "_format_kwargs": {}, 10 | "_format_type": null, 11 | "_indexes": {}, 12 | "_output_all_columns": false, 13 | "_split": null 14 | } -------------------------------------------------------------------------------- /datasets/scores_as2/scores_roberta_base_asnq/validation/state.json: -------------------------------------------------------------------------------- 1 | { 2 | "_data_files": [ 3 | { 4 | "filename": "dataset.arrow" 5 | } 6 | ], 7 | "_fingerprint": "02d07652be489029", 8 | "_format_columns": null, 9 | "_format_kwargs": {}, 10 | "_format_type": null, 11 | "_indexes": {}, 12 | "_output_all_columns": false, 13 | "_split": null 14 | } -------------------------------------------------------------------------------- /datasets/scores_as2/scores_roberta_base_trecqa/validation/state.json: -------------------------------------------------------------------------------- 1 | { 2 | "_data_files": [ 3 | { 4 | "filename": "dataset.arrow" 5 | } 6 | ], 7 | "_fingerprint": "a171f78ad864f721", 8 | "_format_columns": null, 9 | "_format_kwargs": {}, 10 | "_format_type": null, 11 | "_indexes": {}, 12 | "_output_all_columns": false, 13 | "_split": null 14 | } -------------------------------------------------------------------------------- /datasets/scores_as2/scores_roberta_base_wikiqa/validation/state.json: -------------------------------------------------------------------------------- 1 | { 2 | "_data_files": [ 3 | { 4 | "filename": "dataset.arrow" 5 | } 6 | ], 7 | "_fingerprint": "3cb067964b89f490", 8 | "_format_columns": null, 9 | "_format_kwargs": {}, 10 | "_format_type": null, 11 | "_indexes": {}, 12 | "_output_all_columns": false, 13 | "_split": null 14 | } -------------------------------------------------------------------------------- /transformers_framework/models/__init__.py: -------------------------------------------------------------------------------- 1 | from transformers_framework.models.joint.as2.roberta import RobertaJointAS2 # noqa: F401 2 | from transformers_framework.models.joint.fact_checking.roberta import RobertaJointFactChecking # noqa: F401 3 | from transformers_framework.models.joint.mlm.roberta import RobertaJointMLM # noqa: F401 4 | from transformers_framework.models.joint.mlm_as2.roberta import RobertaJointMLMAndClassification # noqa: F401 5 | -------------------------------------------------------------------------------- /transformers_framework/utilities/__init__.py: -------------------------------------------------------------------------------- 1 | from transformers_framework.utilities.classes import * # noqa: F401, F403 2 | from transformers_framework.utilities.datamodules import * # noqa: F401, F403 3 | from transformers_framework.utilities.functional import * # noqa: F401, F403 4 | from transformers_framework.utilities.processors import * # noqa: F401, F403 5 | from transformers_framework.utilities.structures import * # noqa: F401, F403 6 | from transformers_framework.utilities.tokenization import * # noqa: F401, F403 7 | -------------------------------------------------------------------------------- /process_datasets/loaders/loader.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser, Namespace 2 | 3 | from datasets import Dataset 4 | 5 | 6 | class Loader: 7 | 8 | dataset: Dataset 9 | 10 | def __init__(self, hparams: Namespace): 11 | self.hparams = hparams 12 | 13 | def __call__(self) -> Dataset: 14 | r""" Return dataset for input data. """ 15 | return self.dataset 16 | 17 | def add_loader_specific_args(parser: ArgumentParser): 18 | parser.add_argument('--keep_in_memory', action="store_true", help="Whether to keep in memory input dataset.") 19 | parser.add_argument('--split', default=None, required=False, type=str, help="Split to be loaded") 20 | -------------------------------------------------------------------------------- /datasets/scores_as2/scores_roberta_base_asnq/test/dataset_info.json: -------------------------------------------------------------------------------- 1 | { 2 | "builder_name": null, 3 | "citation": "", 4 | "config_name": null, 5 | "dataset_size": null, 6 | "description": "", 7 | "download_checksums": null, 8 | "download_size": null, 9 | "features": { 10 | "scores_roberta_base": { 11 | "feature": { 12 | "dtype": "float64", 13 | "id": null, 14 | "_type": "Value" 15 | }, 16 | "length": -1, 17 | "id": null, 18 | "_type": "Sequence" 19 | } 20 | }, 21 | "homepage": "", 22 | "license": "", 23 | "post_processed": null, 24 | "post_processing_size": null, 25 | "size_in_bytes": null, 26 | "splits": null, 27 | "supervised_keys": null, 28 | "task_templates": null, 29 | "version": null 30 | } -------------------------------------------------------------------------------- /datasets/scores_as2/scores_roberta_base_asnq/train/dataset_info.json: -------------------------------------------------------------------------------- 1 | { 2 | "builder_name": null, 3 | "citation": "", 4 | "config_name": null, 5 | "dataset_size": null, 6 | "description": "", 7 | "download_checksums": null, 8 | "download_size": null, 9 | "features": { 10 | "scores_roberta_base": { 11 | "feature": { 12 | "dtype": "float64", 13 | "id": null, 14 | "_type": "Value" 15 | }, 16 | "length": -1, 17 | "id": null, 18 | "_type": "Sequence" 19 | } 20 | }, 21 | "homepage": "", 22 | "license": "", 23 | "post_processed": null, 24 | "post_processing_size": null, 25 | "size_in_bytes": null, 26 | "splits": null, 27 | "supervised_keys": null, 28 | "task_templates": null, 29 | "version": null 30 | } -------------------------------------------------------------------------------- /datasets/scores_as2/scores_roberta_base_trecqa/test/dataset_info.json: -------------------------------------------------------------------------------- 1 | { 2 | "builder_name": null, 3 | "citation": "", 4 | "config_name": null, 5 | "dataset_size": null, 6 | "description": "", 7 | "download_checksums": null, 8 | "download_size": null, 9 | "features": { 10 | "scores_roberta_base": { 11 | "feature": { 12 | "dtype": "float64", 13 | "id": null, 14 | "_type": "Value" 15 | }, 16 | "length": -1, 17 | "id": null, 18 | "_type": "Sequence" 19 | } 20 | }, 21 | "homepage": "", 22 | "license": "", 23 | "post_processed": null, 24 | "post_processing_size": null, 25 | "size_in_bytes": null, 26 | "splits": null, 27 | "supervised_keys": null, 28 | "task_templates": null, 29 | "version": null 30 | } -------------------------------------------------------------------------------- /datasets/scores_as2/scores_roberta_base_trecqa/train/dataset_info.json: -------------------------------------------------------------------------------- 1 | { 2 | "builder_name": null, 3 | "citation": "", 4 | "config_name": null, 5 | "dataset_size": null, 6 | "description": "", 7 | "download_checksums": null, 8 | "download_size": null, 9 | "features": { 10 | "scores_roberta_base": { 11 | "feature": { 12 | "dtype": "float64", 13 | "id": null, 14 | "_type": "Value" 15 | }, 16 | "length": -1, 17 | "id": null, 18 | "_type": "Sequence" 19 | } 20 | }, 21 | "homepage": "", 22 | "license": "", 23 | "post_processed": null, 24 | "post_processing_size": null, 25 | "size_in_bytes": null, 26 | "splits": null, 27 | "supervised_keys": null, 28 | "task_templates": null, 29 | "version": null 30 | } -------------------------------------------------------------------------------- /datasets/scores_as2/scores_roberta_base_wikiqa/test/dataset_info.json: -------------------------------------------------------------------------------- 1 | { 2 | "builder_name": null, 3 | "citation": "", 4 | "config_name": null, 5 | "dataset_size": null, 6 | "description": "", 7 | "download_checksums": null, 8 | "download_size": null, 9 | "features": { 10 | "scores_roberta_base": { 11 | "feature": { 12 | "dtype": "float64", 13 | "id": null, 14 | "_type": "Value" 15 | }, 16 | "length": -1, 17 | "id": null, 18 | "_type": "Sequence" 19 | } 20 | }, 21 | "homepage": "", 22 | "license": "", 23 | "post_processed": null, 24 | "post_processing_size": null, 25 | "size_in_bytes": null, 26 | "splits": null, 27 | "supervised_keys": null, 28 | "task_templates": null, 29 | "version": null 30 | } -------------------------------------------------------------------------------- /datasets/scores_as2/scores_roberta_base_wikiqa/train/dataset_info.json: -------------------------------------------------------------------------------- 1 | { 2 | "builder_name": null, 3 | "citation": "", 4 | "config_name": null, 5 | "dataset_size": null, 6 | "description": "", 7 | "download_checksums": null, 8 | "download_size": null, 9 | "features": { 10 | "scores_roberta_base": { 11 | "feature": { 12 | "dtype": "float64", 13 | "id": null, 14 | "_type": "Value" 15 | }, 16 | "length": -1, 17 | "id": null, 18 | "_type": "Sequence" 19 | } 20 | }, 21 | "homepage": "", 22 | "license": "", 23 | "post_processed": null, 24 | "post_processing_size": null, 25 | "size_in_bytes": null, 26 | "splits": null, 27 | "supervised_keys": null, 28 | "task_templates": null, 29 | "version": null 30 | } -------------------------------------------------------------------------------- /datasets/scores_as2/scores_roberta_base_asnq/validation/dataset_info.json: -------------------------------------------------------------------------------- 1 | { 2 | "builder_name": null, 3 | "citation": "", 4 | "config_name": null, 5 | "dataset_size": null, 6 | "description": "", 7 | "download_checksums": null, 8 | "download_size": null, 9 | "features": { 10 | "scores_roberta_base": { 11 | "feature": { 12 | "dtype": "float64", 13 | "id": null, 14 | "_type": "Value" 15 | }, 16 | "length": -1, 17 | "id": null, 18 | "_type": "Sequence" 19 | } 20 | }, 21 | "homepage": "", 22 | "license": "", 23 | "post_processed": null, 24 | "post_processing_size": null, 25 | "size_in_bytes": null, 26 | "splits": null, 27 | "supervised_keys": null, 28 | "task_templates": null, 29 | "version": null 30 | } -------------------------------------------------------------------------------- /datasets/scores_as2/scores_roberta_base_trecqa/validation/dataset_info.json: -------------------------------------------------------------------------------- 1 | { 2 | "builder_name": null, 3 | "citation": "", 4 | "config_name": null, 5 | "dataset_size": null, 6 | "description": "", 7 | "download_checksums": null, 8 | "download_size": null, 9 | "features": { 10 | "scores_roberta_base": { 11 | "feature": { 12 | "dtype": "float64", 13 | "id": null, 14 | "_type": "Value" 15 | }, 16 | "length": -1, 17 | "id": null, 18 | "_type": "Sequence" 19 | } 20 | }, 21 | "homepage": "", 22 | "license": "", 23 | "post_processed": null, 24 | "post_processing_size": null, 25 | "size_in_bytes": null, 26 | "splits": null, 27 | "supervised_keys": null, 28 | "task_templates": null, 29 | "version": null 30 | } -------------------------------------------------------------------------------- /datasets/scores_as2/scores_roberta_base_wikiqa/validation/dataset_info.json: -------------------------------------------------------------------------------- 1 | { 2 | "builder_name": null, 3 | "citation": "", 4 | "config_name": null, 5 | "dataset_size": null, 6 | "description": "", 7 | "download_checksums": null, 8 | "download_size": null, 9 | "features": { 10 | "scores_roberta_base": { 11 | "feature": { 12 | "dtype": "float64", 13 | "id": null, 14 | "_type": "Value" 15 | }, 16 | "length": -1, 17 | "id": null, 18 | "_type": "Sequence" 19 | } 20 | }, 21 | "homepage": "", 22 | "license": "", 23 | "post_processed": null, 24 | "post_processing_size": null, 25 | "size_in_bytes": null, 26 | "splits": null, 27 | "supervised_keys": null, 28 | "task_templates": null, 29 | "version": null 30 | } -------------------------------------------------------------------------------- /transformers_framework/adapters/map_adapters/map_adapter.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from typing import Dict 3 | 4 | from transformers_framework.adapters.transformer_adapter import TransformersAdapter 5 | 6 | 7 | class MapAdapter(TransformersAdapter): 8 | r""" 9 | MapAdapters provide a map-like interface to retrieve data. Each subclass should 10 | override the __getitem__ and __len__ and __iter__ methods. 11 | """ 12 | 13 | @abstractmethod 14 | def __getitem__(self, idx) -> Dict: 15 | r""" 16 | This function should use the arguments in `hyperparameters` to 17 | return a map over the (parsed) lines. This is the right place to return indexable data. 18 | 19 | >>> return self.data[idx] 20 | """ 21 | 22 | @abstractmethod 23 | def __len__(self): 24 | r""" Returns the number of examples in the source dataset. """ 25 | -------------------------------------------------------------------------------- /process_datasets/loaders/disk.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from argparse import ArgumentParser, Namespace 4 | 5 | import datasets 6 | 7 | from process_datasets.loaders.loader import Loader 8 | 9 | 10 | class DiskLoader(Loader): 11 | r""" Load a dataset from disk. """ 12 | 13 | def __init__(self, hparams: Namespace): 14 | super().__init__(hparams) 15 | assert os.path.isdir(hparams.input_folder), "Input folder does not exist" 16 | 17 | logging.info(f"Loading input dataset from disk") 18 | dataset = datasets.load_from_disk(hparams.input_folder, keep_in_memory=hparams.keep_in_memory) 19 | self.dataset = dataset if hparams.split is None else dataset[hparams.split] 20 | 21 | def add_loader_specific_args(parser: ArgumentParser): 22 | super(DiskLoader, DiskLoader).add_loader_specific_args(parser) 23 | parser.add_argument('--input_folder', type=str, required=True) 24 | -------------------------------------------------------------------------------- /transformers_experiments/pretraining/run_roberta_base_joint_pretraining_ae_k.sh: -------------------------------------------------------------------------------- 1 | # change datasets path and output folder as needed. We suggest to run this experiment on a single P4 2 | python -m transformers_framework \ 3 | --model RobertaJointMLMAndClassification \ 4 | --devices 8 \ 5 | --accelerator gpu --strategy deepspeed_stage_2 \ 6 | --precision 16 \ 7 | --pre_trained_model roberta-base \ 8 | --name roberta-base-joint-ie-k \ 9 | --output_dir outputs/joint-pretraining \ 10 | \ 11 | --adapter JointwiseArrowAdapter \ 12 | --batch_size 64 \ 13 | --train_filepath /path/to/datasets \ 14 | --field_names premise consequence \ 15 | --label_name label \ 16 | \ 17 | --log_every_n_steps 100 \ 18 | --accumulate_grad_batches 8 \ 19 | --max_sequence_length 64 \ 20 | -k 5 --selection random --separated \ 21 | --learning_rate 1e-04 \ 22 | --max_steps 100000 \ 23 | --weight_decay 0.01 \ 24 | --num_warmup_steps 10000 \ 25 | --num_workers 8 \ 26 | --head_type AE_k \ 27 | --seed 1337 28 | -------------------------------------------------------------------------------- /transformers_experiments/pretraining/run_roberta_base_joint_pretraining_ie_k.sh: -------------------------------------------------------------------------------- 1 | # change datasets path and output folder as needed. We suggest to run this experiment on a single P4 2 | python -m transformers_framework \ 3 | --model RobertaJointMLMAndClassification \ 4 | --devices 8 \ 5 | --accelerator gpu --strategy deepspeed_stage_2 \ 6 | --precision 16 \ 7 | --pre_trained_model roberta-base \ 8 | --name roberta-base-joint-ie-k \ 9 | --output_dir outputs/joint-pretraining \ 10 | \ 11 | --adapter JointwiseArrowAdapter \ 12 | --batch_size 64 \ 13 | --train_filepath /path/to/datasets \ 14 | --field_names premise consequence \ 15 | --label_name label \ 16 | \ 17 | --log_every_n_steps 100 \ 18 | --accumulate_grad_batches 8 \ 19 | --max_sequence_length 64 \ 20 | -k 5 --selection random --separated \ 21 | --learning_rate 1e-04 \ 22 | --max_steps 100000 \ 23 | --weight_decay 0.01 \ 24 | --num_warmup_steps 10000 \ 25 | --num_workers 8 \ 26 | --head_type IE_k \ 27 | --seed 1337 28 | -------------------------------------------------------------------------------- /process_datasets/loaders/dataset.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from argparse import ArgumentParser, Namespace 3 | 4 | import datasets 5 | 6 | from process_datasets.loaders.loader import Loader 7 | 8 | 9 | class DatasetLoader(Loader): 10 | r""" Load a dataset from the datasets library. """ 11 | 12 | def __init__(self, hparams: Namespace): 13 | super().__init__(hparams) 14 | 15 | if len(hparams.name) > 1: 16 | hparams.name, hparams.config = hparams.name 17 | else: 18 | hparams.config = None 19 | hparams.name = hparams.name[0] 20 | 21 | logging.info(f"Loading input dataset {hparams.name} with config {hparams.config}") 22 | dataset = datasets.load_dataset(hparams.name, hparams.config, keep_in_memory=hparams.keep_in_memory) 23 | self.dataset = dataset if hparams.split is None else dataset[hparams.split] 24 | 25 | def add_loader_specific_args(parser: ArgumentParser): 26 | super(DatasetLoader, DatasetLoader).add_loader_specific_args(parser) 27 | parser.add_argument('--name', type=str, required=True, nargs='+') 28 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [flake8] 2 | ignore = 3 | E203, # not pep8, black adds whitespace before ':' 4 | W503, # not pep8, black adds line break before binary operator 5 | E301, # (E301) Expected 1 blank line, found 0 6 | E302, # (E302) Expected 2 blank lines, found 0 7 | E303, # (E303) Too many blank lines (3) 8 | E305, # (E305) Expected 2 blank lines after end of function or class; Sometimes I use 3 to break up a file 9 | E306, # (E306) expected 1 blank line before a nested definition [pep257] 10 | E309, # (E309) expected 1 blank line after class declaration [pep257] 11 | W293, # Blank line contains whitespace. When I put my cursor inside a function, it should be at the same 12 | # indentation (or one out) from the previous block in that function, not at the left side of the file. 13 | # C901, # Function is too complex. I'll judge the appropriate level of complexity thank you very much 14 | max_line_length = 120 15 | exclude = scripts/,experiments/ 16 | 17 | [isort] 18 | line_length = 120 19 | multi_line_output = 3 20 | include_trailing_comma = true 21 | lines_after_imports = 2 22 | -------------------------------------------------------------------------------- /transformers_framework/architectures/roberta/modeling_config.py: -------------------------------------------------------------------------------- 1 | from transformers.models.roberta.configuration_roberta import RobertaConfig 2 | 3 | 4 | MONOLITHIC_HEAD_TYPES = ("IE_1", "IE_k", "AE_1", "AE_k") 5 | 6 | 7 | class RobertaJointConfig(RobertaConfig): 8 | r""" 9 | The :class:`~RobertaJointConfig` class directly inherits :class:`~transformers.RobertaConfig`. It reuses the 10 | same defaults. Please check the parent class for more information. 11 | 12 | Args: 13 | k (:obj:`int`, `optional`, defaults to 5): 14 | Number of candidates to consider for each query. 15 | sentence_msl (:obj:`int`, `optional`, defaults to 64): 16 | Max length of each query or candidate. 17 | head_type (:obj:`str`, `optional`, defaults to 'IE_1'): 18 | The classification head type. 19 | """ 20 | 21 | def __init__(self, k: int = 5, sentence_msl: int = 64, head_type: str = "IE_1", **kwargs): 22 | super().__init__(is_decoder=False, **kwargs) 23 | assert head_type in MONOLITHIC_HEAD_TYPES 24 | self.k = k 25 | self.sentence_msl = sentence_msl 26 | self.head_type = head_type 27 | -------------------------------------------------------------------------------- /process_datasets/strategies/strategy.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from argparse import ArgumentParser, Namespace 3 | from typing import Any, Dict, List 4 | 5 | from process_datasets.utils.general import dict2list, list2dict 6 | 7 | 8 | class Strategy(ABC): 9 | r"""Given a stream of input text, creates pretraining examples.""" 10 | 11 | def __init__(self, hparams: Namespace): 12 | super().__init__() 13 | self.hparams = hparams 14 | 15 | def __call__(self, batch: Dict[Any, List]) -> Dict[Any, List]: 16 | r""" Receive a batch of documents and return processed version. 17 | """ 18 | batch = dict2list(batch) 19 | batch = self.process_batch(batch) 20 | batch = list2dict(batch) 21 | return batch 22 | 23 | @abstractmethod 24 | def process_batch(self, batch: List[Dict]) -> List[Dict]: 25 | r""" Process a list of batches. """ 26 | 27 | @staticmethod 28 | def add_arguments_to_argparse(parser: ArgumentParser): 29 | r""" Add strategy specific parameters to the cmd argument parser. """ 30 | parser.add_argument( 31 | '--field', type=str, required=False, default='text', help="Field names in the dataset to consider." 32 | ) 33 | -------------------------------------------------------------------------------- /transformers_framework/models/base/mlm.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | 3 | from torchmetrics import Accuracy 4 | from transformers_lightning.language_modeling import MaskedLanguageModeling 5 | 6 | from transformers_framework.models.base.base import BaseModel 7 | 8 | 9 | class BaseModelMLM(BaseModel): 10 | 11 | def __init__(self, hyperparameters): 12 | super().__init__(hyperparameters) 13 | 14 | self.mlm = MaskedLanguageModeling( 15 | self.tokenizer, 16 | probability=hyperparameters.probability, 17 | probability_masked=hyperparameters.probability_masked, 18 | probability_replaced=hyperparameters.probability_replaced, 19 | ) 20 | 21 | self.train_mlm_acc = Accuracy() 22 | self.valid_mlm_acc = Accuracy() 23 | self.test_mlm_acc = Accuracy() 24 | 25 | def add_model_specific_args(parser: ArgumentParser): 26 | super(BaseModelMLM, BaseModelMLM).add_model_specific_args(parser) 27 | # mlm specific arguments 28 | parser.add_argument('--probability', type=float, default=0.15) 29 | parser.add_argument('--probability_masked', type=float, default=0.80) 30 | parser.add_argument('--probability_replaced', type=float, default=0.10) 31 | -------------------------------------------------------------------------------- /transformers_experiments/finetuning/fever/run_roberta_base_joint_ae_1.sh: -------------------------------------------------------------------------------- 1 | python -m transformers_framework \ 2 | --model RobertaJointFactChecking \ 3 | --devices 2 --accelerator gpu --strategy ddp \ 4 | --pre_trained_model \ 5 | --name roberta-base-joint-fever-AE-1 \ 6 | --output_dir outputs/joint-fever \ 7 | \ 8 | --adapter JointwiseArrowAdapter \ 9 | --batch_size 32 --val_batch_size 128 --test_batch_size 128 \ 10 | --train_filepath --train_split train \ 11 | --valid_filepath --valid_split validation \ 12 | --field_names claim evidence \ 13 | --label_name label \ 14 | --key_name key \ 15 | \ 16 | --accumulate_grad_batches 1 \ 17 | --max_sequence_length 64 \ 18 | -k 5 --selection all --separated --force_load_dataset_in_memory --reduce_labels \ 19 | --learning_rate 1e-05 \ 20 | --max_epochs 15 \ 21 | --early_stopping \ 22 | --patience 8 \ 23 | --weight_decay 0.0 \ 24 | --num_warmup_steps 1000 \ 25 | --monitor validation/accuracy \ 26 | --val_check_interval 0.5 \ 27 | --num_workers 8 \ 28 | --shuffle_candidates --reload_dataloaders_every_n_epoch 1 \ 29 | --num_labels 3 \ 30 | --head_type AE_1 \ 31 | -------------------------------------------------------------------------------- /transformers_experiments/finetuning/fever/run_roberta_base_joint_ie_1.sh: -------------------------------------------------------------------------------- 1 | python -m transformers_framework \ 2 | --model RobertaJointFactChecking \ 3 | --devices 2 --accelerator gpu --strategy ddp \ 4 | --pre_trained_model \ 5 | --name roberta-base-joint-fever-IE-1 \ 6 | --output_dir outputs/joint-fever \ 7 | \ 8 | --adapter JointwiseArrowAdapter \ 9 | --batch_size 32 --val_batch_size 128 --test_batch_size 128 \ 10 | --train_filepath --train_split train \ 11 | --valid_filepath --valid_split validation \ 12 | --field_names claim evidence \ 13 | --label_name label \ 14 | --key_name key \ 15 | \ 16 | --accumulate_grad_batches 1 \ 17 | --max_sequence_length 64 \ 18 | -k 5 --selection all --separated --force_load_dataset_in_memory --reduce_labels \ 19 | --learning_rate 1e-05 \ 20 | --max_epochs 15 \ 21 | --early_stopping \ 22 | --patience 8 \ 23 | --weight_decay 0.0 \ 24 | --num_warmup_steps 1000 \ 25 | --monitor validation/accuracy \ 26 | --val_check_interval 0.5 \ 27 | --num_workers 8 \ 28 | --shuffle_candidates --reload_dataloaders_every_n_epoch 1 \ 29 | --num_labels 3 \ 30 | --head_type IE_1 \ 31 | -------------------------------------------------------------------------------- /transformers_experiments/finetuning/asnq/run_roberta_base_joint_ae_k_no_shuf.sh: -------------------------------------------------------------------------------- 1 | python -m transformers_framework \ 2 | --model RobertaJointAS2 \ 3 | --devices 8 --accelerator gpu --strategy ddp \ 4 | --precision 16 \ 5 | --pre_trained_model \ 6 | --name roberta-base-joint-asnq-AE-k-no-shuf \ 7 | --output_dir outputs/joint-asnq \ 8 | \ 9 | --adapter JointwiseArrowAdapter \ 10 | --batch_size 128 --val_batch_size 128 --test_batch_size 128 \ 11 | --train_filepath --train_split train \ 12 | --valid_filepath --valid_split validation \ 13 | --test_filepath --test_split test \ 14 | --field_names question answer \ 15 | --label_name label \ 16 | --key_name key \ 17 | \ 18 | --accumulate_grad_batches 2 \ 19 | --max_sequence_length 64 \ 20 | -k 5 --selection all --force_load_dataset_in_memory --separated \ 21 | --learning_rate 1e-05 \ 22 | --max_epochs 6 \ 23 | --early_stopping \ 24 | --patience 8 \ 25 | --weight_decay 0.0 \ 26 | --num_warmup_steps 5000 \ 27 | --monitor validation/map \ 28 | --val_check_interval 0.5 \ 29 | --num_workers 8 \ 30 | --head_type AE_k \ 31 | --seed 1337 32 | -------------------------------------------------------------------------------- /transformers_experiments/finetuning/asnq/run_roberta_base_joint_ie_k_no_shuf.sh: -------------------------------------------------------------------------------- 1 | python -m transformers_framework \ 2 | --model RobertaJointAS2 \ 3 | --devices 8 --accelerator gpu --strategy ddp \ 4 | --precision 16 \ 5 | --pre_trained_model \ 6 | --name roberta-base-joint-asnq-IE-k-no-shuf \ 7 | --output_dir outputs/joint-asnq \ 8 | \ 9 | --adapter JointwiseArrowAdapter \ 10 | --batch_size 128 --val_batch_size 128 --test_batch_size 128 \ 11 | --train_filepath --train_split train \ 12 | --valid_filepath --valid_split validation \ 13 | --test_filepath --test_split test \ 14 | --field_names question answer \ 15 | --label_name label \ 16 | --key_name key \ 17 | \ 18 | --accumulate_grad_batches 2 \ 19 | --max_sequence_length 64 \ 20 | -k 5 --selection all --force_load_dataset_in_memory --separated \ 21 | --learning_rate 1e-05 \ 22 | --max_epochs 6 \ 23 | --early_stopping \ 24 | --patience 8 \ 25 | --weight_decay 0.0 \ 26 | --num_warmup_steps 5000 \ 27 | --monitor validation/map \ 28 | --val_check_interval 0.5 \ 29 | --num_workers 8 \ 30 | --head_type IE_k \ 31 | --seed 1337 32 | -------------------------------------------------------------------------------- /process_datasets/loaders/json.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from argparse import ArgumentParser, Namespace 4 | 5 | from datasets import load_dataset 6 | 7 | from process_datasets.loaders.loader import Loader 8 | 9 | 10 | class JsonLoader(Loader): 11 | r""" Load a dataset from a bunch of JSONL files. """ 12 | 13 | def __init__(self, hparams: Namespace): 14 | super().__init__(hparams) 15 | assert os.path.isdir(hparams.input_folder), "Input folder does not exist" 16 | 17 | logging.info(f"Loading dataset from local json files in folder {hparams.input_folder}") 18 | filepaths = [ 19 | os.path.join(hparams.input_folder, f) 20 | for f in os.listdir(hparams.input_folder) 21 | if f.endswith('.json') or f.endswith('.jsonl') 22 | ] 23 | dataset = load_dataset( 24 | 'json', data_files=filepaths, keep_in_memory=hparams.keep_in_memory 25 | ) 26 | self.dataset = dataset if hparams.split is None else dataset[hparams.split] 27 | 28 | def add_loader_specific_args(parser: ArgumentParser): 29 | super(JsonLoader, JsonLoader).add_loader_specific_args(parser) 30 | parser.add_argument('-i', '--input_folder', type=str, required=True, help="Input data folder.") 31 | -------------------------------------------------------------------------------- /transformers_experiments/finetuning/trecqa/run_roberta_base_joint_ae_k_no_shuf.sh: -------------------------------------------------------------------------------- 1 | python -m transformers_framework \ 2 | --model RobertaJointAS2 \ 3 | --devices 2 --accelerator gpu --strategy ddp \ 4 | --precision 16 \ 5 | --pre_trained_model \ 6 | --name roberta-base-joint-trecqa-AE-k-no-shuf \ 7 | --output_dir outputs/joint-trecqa \ 8 | \ 9 | --adapter JointwiseArrowAdapter \ 10 | --batch_size 32 --val_batch_size 128 --test_batch_size 128 \ 11 | --train_filepath --train_split train \ 12 | --valid_filepath --valid_split validation \ 13 | --test_filepath --test_split test \ 14 | --field_names question answer \ 15 | --label_name label \ 16 | --key_name key \ 17 | \ 18 | --accumulate_grad_batches 2 \ 19 | --max_sequence_length 64 \ 20 | -k 5 --selection all --force_load_dataset_in_memory --separated \ 21 | --learning_rate 1e-05 \ 22 | --max_epochs 40 \ 23 | --early_stopping \ 24 | --patience 8 \ 25 | --weight_decay 0.0 \ 26 | --num_warmup_steps 100 \ 27 | --monitor validation/map \ 28 | --val_check_interval 1.0 \ 29 | --num_workers 8 \ 30 | --head_type AE_k \ 31 | --seed 1337 32 | -------------------------------------------------------------------------------- /transformers_experiments/finetuning/trecqa/run_roberta_base_joint_ie_k_no_shuf.sh: -------------------------------------------------------------------------------- 1 | python -m transformers_framework \ 2 | --model RobertaJointAS2 \ 3 | --devices 2 --accelerator gpu --strategy ddp \ 4 | --precision 16 \ 5 | --pre_trained_model \ 6 | --name roberta-base-joint-trecqa-IE-k-no-shuf \ 7 | --output_dir outputs/joint-trecqa \ 8 | \ 9 | --adapter JointwiseArrowAdapter \ 10 | --batch_size 32 --val_batch_size 128 --test_batch_size 128 \ 11 | --train_filepath --train_split train \ 12 | --valid_filepath --valid_split validation \ 13 | --test_filepath --test_split test \ 14 | --field_names question answer \ 15 | --label_name label \ 16 | --key_name key \ 17 | \ 18 | --accumulate_grad_batches 2 \ 19 | --max_sequence_length 64 \ 20 | -k 5 --selection all --force_load_dataset_in_memory --separated \ 21 | --learning_rate 1e-05 \ 22 | --max_epochs 40 \ 23 | --early_stopping \ 24 | --patience 8 \ 25 | --weight_decay 0.0 \ 26 | --num_warmup_steps 100 \ 27 | --monitor validation/map \ 28 | --val_check_interval 1.0 \ 29 | --num_workers 8 \ 30 | --head_type IE_k \ 31 | --seed 1337 32 | -------------------------------------------------------------------------------- /transformers_experiments/finetuning/wikiqa/run_roberta_base_joint_ae_k_no_shuf.sh: -------------------------------------------------------------------------------- 1 | python -m transformers_framework \ 2 | --model RobertaJointAS2 \ 3 | --devices 2 --accelerator gpu --strategy ddp \ 4 | --precision 16 \ 5 | --pre_trained_model \ 6 | --name roberta-base-joint-wikiqa-AE-k-no-shuf \ 7 | --output_dir outputs/joint-wikiqa \ 8 | \ 9 | --adapter JointwiseArrowAdapter \ 10 | --batch_size 32 --val_batch_size 128 --test_batch_size 128 \ 11 | --train_filepath --train_split train \ 12 | --valid_filepath --valid_split validation \ 13 | --test_filepath --test_split test \ 14 | --field_names question answer \ 15 | --label_name label \ 16 | --key_name key \ 17 | \ 18 | --accumulate_grad_batches 2 \ 19 | --max_sequence_length 64 \ 20 | -k 5 --selection all --force_load_dataset_in_memory --separated \ 21 | --learning_rate 1e-05 \ 22 | --max_epochs 40 \ 23 | --early_stopping \ 24 | --patience 8 \ 25 | --weight_decay 0.0 \ 26 | --num_warmup_steps 100 \ 27 | --monitor validation/map \ 28 | --val_check_interval 1.0 \ 29 | --num_workers 8 \ 30 | --head_type AE_k \ 31 | --seed 1337 32 | -------------------------------------------------------------------------------- /transformers_experiments/finetuning/wikiqa/run_roberta_base_joint_ie_k_no_shuf.sh: -------------------------------------------------------------------------------- 1 | python -m transformers_framework \ 2 | --model RobertaJointAS2 \ 3 | --devices 2 --accelerator gpu --strategy ddp \ 4 | --precision 16 \ 5 | --pre_trained_model \ 6 | --name roberta-base-joint-wikiqa-IE-k-no-shuf \ 7 | --output_dir outputs/joint-wikiqa \ 8 | \ 9 | --adapter JointwiseArrowAdapter \ 10 | --batch_size 32 --val_batch_size 128 --test_batch_size 128 \ 11 | --train_filepath --train_split train \ 12 | --valid_filepath --valid_split validation \ 13 | --test_filepath --test_split test \ 14 | --field_names question answer \ 15 | --label_name label \ 16 | --key_name key \ 17 | \ 18 | --accumulate_grad_batches 2 \ 19 | --max_sequence_length 64 \ 20 | -k 5 --selection all --force_load_dataset_in_memory --separated \ 21 | --learning_rate 1e-05 \ 22 | --max_epochs 40 \ 23 | --early_stopping \ 24 | --patience 8 \ 25 | --weight_decay 0.0 \ 26 | --num_warmup_steps 100 \ 27 | --monitor validation/map \ 28 | --val_check_interval 1.0 \ 29 | --num_workers 8 \ 30 | --head_type IE_k \ 31 | --seed 1337 32 | -------------------------------------------------------------------------------- /transformers_framework/models/joint/mlm/roberta.py: -------------------------------------------------------------------------------- 1 | from transformers import RobertaTokenizerFast 2 | 3 | from transformers_framework.architectures.roberta.modeling_config import RobertaJointConfig 4 | from transformers_framework.architectures.roberta.modeling_joint_roberta import JointRobertaForMaskedLM 5 | from transformers_framework.models.joint.mlm.base import BaseJointMLM 6 | 7 | 8 | class RobertaJointMLM(BaseJointMLM): 9 | 10 | def setup_config(self) -> RobertaJointConfig: 11 | return RobertaJointConfig.from_pretrained( 12 | self.hyperparameters.pre_trained_config, 13 | k=self.hyperparameters.k, 14 | sentence_msl=self.hyperparameters.max_sequence_length, 15 | ) 16 | 17 | def setup_model(self, config: RobertaJointConfig) -> JointRobertaForMaskedLM: 18 | if self.hyperparameters.pre_trained_model is None: 19 | return JointRobertaForMaskedLM(config) 20 | else: 21 | return JointRobertaForMaskedLM.from_pretrained( 22 | self.hyperparameters.pre_trained_model, config=config, ignore_mismatched_sizes=True 23 | ) 24 | 25 | def setup_tokenizer(self) -> RobertaTokenizerFast: 26 | return RobertaTokenizerFast.from_pretrained(self.hyperparameters.pre_trained_tokenizer) 27 | -------------------------------------------------------------------------------- /transformers_utilities/datasets/merge_datasets.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | import os 3 | 4 | from datasets import Dataset, DatasetDict, concatenate_datasets, load_from_disk 5 | 6 | 7 | def main(args): 8 | assert all(os.path.isdir(filepath) for filepath in args.input), ( 9 | "could not find some of the input datasets" 10 | ) 11 | 12 | datas = [load_from_disk(filepath) for filepath in args.input] 13 | 14 | assert all(isinstance(d, Dataset) for d in datas) or all(isinstance(d, DatasetDict) for d in datas), ( 15 | "datasets must be all DatasetDict or all Dataset" 16 | ) 17 | 18 | if all(isinstance(d, DatasetDict) for d in datas): 19 | res = DatasetDict({ 20 | split: concatenate_datasets([d[split] for d in datas], axis=1) for split in datas[0].keys() 21 | }) 22 | else: 23 | res = concatenate_datasets(datas, axis=1) 24 | 25 | res.save_to_disk(args.output) 26 | 27 | 28 | if __name__ == '__main__': 29 | parser = ArgumentParser() 30 | parser.add_argument('--input', type=str, nargs='+', required=True, help="Input datasets to merge") 31 | parser.add_argument('--output', type=str, required=True, help="Output folder for resulting dataset") 32 | args = parser.parse_args() 33 | main(args) 34 | -------------------------------------------------------------------------------- /transformers_experiments/finetuning/asnq/run_roberta_base_joint_ae_k.sh: -------------------------------------------------------------------------------- 1 | python -m transformers_framework \ 2 | --model RobertaJointAS2 \ 3 | --devices 8 --accelerator gpu --strategy ddp \ 4 | --precision 16 \ 5 | --pre_trained_model \ 6 | --name roberta-base-joint-asnq-AE-k \ 7 | --output_dir outputs/joint-asnq \ 8 | \ 9 | --adapter JointwiseArrowAdapter \ 10 | --batch_size 128 --val_batch_size 128 --test_batch_size 128 \ 11 | --train_filepath --train_split train \ 12 | --valid_filepath --valid_split validation \ 13 | --test_filepath --test_split test \ 14 | --field_names question answer \ 15 | --label_name label \ 16 | --key_name key \ 17 | \ 18 | --accumulate_grad_batches 2 \ 19 | --max_sequence_length 64 \ 20 | -k 5 --selection all --force_load_dataset_in_memory --separated \ 21 | --learning_rate 1e-05 \ 22 | --max_epochs 6 \ 23 | --early_stopping \ 24 | --patience 8 \ 25 | --weight_decay 0.0 \ 26 | --num_warmup_steps 5000 \ 27 | --monitor validation/map \ 28 | --val_check_interval 0.5 \ 29 | --num_workers 8 \ 30 | --shuffle_candidates --reload_dataloaders_every_n_epoch 1 \ 31 | --head_type AE_k \ 32 | --seed 1337 33 | -------------------------------------------------------------------------------- /transformers_experiments/finetuning/asnq/run_roberta_base_joint_ie_k.sh: -------------------------------------------------------------------------------- 1 | python -m transformers_framework \ 2 | --model RobertaJointAS2 \ 3 | --devices 8 --accelerator gpu --strategy ddp \ 4 | --precision 16 \ 5 | --pre_trained_model \ 6 | --name roberta-base-joint-asnq-IE-k \ 7 | --output_dir outputs/joint-asnq \ 8 | \ 9 | --adapter JointwiseArrowAdapter \ 10 | --batch_size 128 --val_batch_size 128 --test_batch_size 128 \ 11 | --train_filepath --train_split train \ 12 | --valid_filepath --valid_split validation \ 13 | --test_filepath --test_split test \ 14 | --field_names question answer \ 15 | --label_name label \ 16 | --key_name key \ 17 | \ 18 | --accumulate_grad_batches 2 \ 19 | --max_sequence_length 64 \ 20 | -k 5 --selection all --force_load_dataset_in_memory --separated \ 21 | --learning_rate 1e-05 \ 22 | --max_epochs 6 \ 23 | --early_stopping \ 24 | --patience 8 \ 25 | --weight_decay 0.0 \ 26 | --num_warmup_steps 5000 \ 27 | --monitor validation/map \ 28 | --val_check_interval 0.5 \ 29 | --num_workers 8 \ 30 | --shuffle_candidates --reload_dataloaders_every_n_epoch 1 \ 31 | --head_type IE_k \ 32 | --seed 1337 33 | -------------------------------------------------------------------------------- /transformers_framework/utilities/structures.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Sequence, Union 3 | 4 | 5 | @dataclass 6 | class Sample: 7 | pass # just a super class 8 | 9 | 10 | @dataclass 11 | class DataSample(Sample): 12 | data: Sequence[str] # textual data inputs 13 | key: Union[int, None] # an unique identifier of the input example 14 | label: Union[int, None] # label as integer 15 | score: Union[float, None] # score assigned to the textual example 16 | 17 | 18 | @dataclass 19 | class PairwiseSample(Sample): 20 | key: Union[int, None] # an unique identifier of the input example 21 | first: str # first sentence as string 22 | second: str # second sentence as string 23 | label: Union[int, None] # label as integer 24 | score: Union[float, None] # score assigned to the textual example 25 | 26 | 27 | @dataclass 28 | class JointwiseSample(Sample): 29 | key: Union[int, None] # an unique identifier of the input example 30 | first: str # first sentence as string 31 | seconds: Sequence[str] # second sentences as strings 32 | label: Union[Sequence[int], None] # labels as integers 33 | score: Union[Sequence[float], None] # scores assigned to the textual examples 34 | valid: Union[Sequence[float], None] # valid positions (not padded) 35 | -------------------------------------------------------------------------------- /transformers_framework/utilities/classes.py: -------------------------------------------------------------------------------- 1 | from argparse import Namespace 2 | 3 | 4 | class ExtendedNamespace(Namespace): 5 | r""" Simple object for storing attributes. 6 | 7 | Implements equality by attribute names and values, and provides a simple 8 | string representation. 9 | 10 | This version is enhanced with dictionary capabilities. 11 | """ 12 | 13 | def __init__(self, **kwargs): 14 | for name in kwargs: 15 | setattr(self, name, kwargs[name]) 16 | 17 | @classmethod 18 | def from_namespace(cls, other_namespace: Namespace): 19 | new = cls() 20 | new.__dict__ = other_namespace.__dict__ 21 | return new 22 | 23 | def __eq__(self, other): 24 | if not isinstance(other, Namespace): 25 | return NotImplemented 26 | return vars(self) == vars(other) 27 | 28 | def __contains__(self, key): 29 | return key in self.__dict__ 30 | 31 | def __getitem__(self, key): 32 | return self.__dict__[key] 33 | 34 | def __setitem__(self, key, value): 35 | self.__dict__[key] = value 36 | 37 | def __delitem__(self, key): 38 | del self.__dict__[key] 39 | 40 | def __iter__(self): 41 | yield from self.__dict__.items() 42 | 43 | def __len__(self): 44 | return len(self.__dict__) 45 | -------------------------------------------------------------------------------- /transformers_experiments/finetuning/trecqa/run_roberta_base_joint_ae_k.sh: -------------------------------------------------------------------------------- 1 | python -m transformers_framework \ 2 | --model RobertaJointAS2 \ 3 | --devices 2 --accelerator gpu --strategy ddp \ 4 | --precision 16 \ 5 | --pre_trained_model \ 6 | --name roberta-base-joint-trecqa-AE-k \ 7 | --output_dir outputs/joint-trecqa \ 8 | \ 9 | --adapter JointwiseArrowAdapter \ 10 | --batch_size 32 --val_batch_size 128 --test_batch_size 128 \ 11 | --train_filepath --train_split train \ 12 | --valid_filepath --valid_split validation \ 13 | --test_filepath --test_split test \ 14 | --field_names question answer \ 15 | --label_name label \ 16 | --key_name key \ 17 | \ 18 | --accumulate_grad_batches 2 \ 19 | --max_sequence_length 64 \ 20 | -k 5 --selection all --force_load_dataset_in_memory --separated \ 21 | --learning_rate 1e-05 \ 22 | --max_epochs 40 \ 23 | --early_stopping \ 24 | --patience 8 \ 25 | --weight_decay 0.0 \ 26 | --num_warmup_steps 100 \ 27 | --monitor validation/map \ 28 | --val_check_interval 1.0 \ 29 | --num_workers 8 \ 30 | --shuffle_candidates --reload_dataloaders_every_n_epoch 1 \ 31 | --head_type AE_k \ 32 | --seed 1337 33 | -------------------------------------------------------------------------------- /transformers_experiments/finetuning/trecqa/run_roberta_base_joint_ie_k.sh: -------------------------------------------------------------------------------- 1 | python -m transformers_framework \ 2 | --model RobertaJointAS2 \ 3 | --devices 2 --accelerator gpu --strategy ddp \ 4 | --precision 16 \ 5 | --pre_trained_model \ 6 | --name roberta-base-joint-trecqa-IE-k \ 7 | --output_dir outputs/joint-trecqa \ 8 | \ 9 | --adapter JointwiseArrowAdapter \ 10 | --batch_size 32 --val_batch_size 128 --test_batch_size 128 \ 11 | --train_filepath --train_split train \ 12 | --valid_filepath --valid_split validation \ 13 | --test_filepath --test_split test \ 14 | --field_names question answer \ 15 | --label_name label \ 16 | --key_name key \ 17 | \ 18 | --accumulate_grad_batches 2 \ 19 | --max_sequence_length 64 \ 20 | -k 5 --selection all --force_load_dataset_in_memory --separated \ 21 | --learning_rate 1e-05 \ 22 | --max_epochs 40 \ 23 | --early_stopping \ 24 | --patience 8 \ 25 | --weight_decay 0.0 \ 26 | --num_warmup_steps 100 \ 27 | --monitor validation/map \ 28 | --val_check_interval 1.0 \ 29 | --num_workers 8 \ 30 | --shuffle_candidates --reload_dataloaders_every_n_epoch 1 \ 31 | --head_type IE_k \ 32 | --seed 1337 33 | -------------------------------------------------------------------------------- /transformers_experiments/finetuning/wikiqa/run_roberta_base_joint_ae_k.sh: -------------------------------------------------------------------------------- 1 | python -m transformers_framework \ 2 | --model RobertaJointAS2 \ 3 | --devices 2 --accelerator gpu --strategy ddp \ 4 | --precision 16 \ 5 | --pre_trained_model \ 6 | --name roberta-base-joint-wikiqa-AE-k \ 7 | --output_dir outputs/joint-wikiqa \ 8 | \ 9 | --adapter JointwiseArrowAdapter \ 10 | --batch_size 32 --val_batch_size 128 --test_batch_size 128 \ 11 | --train_filepath --train_split train \ 12 | --valid_filepath --valid_split validation \ 13 | --test_filepath --test_split test \ 14 | --field_names question answer \ 15 | --label_name label \ 16 | --key_name key \ 17 | \ 18 | --accumulate_grad_batches 2 \ 19 | --max_sequence_length 64 \ 20 | -k 5 --selection all --force_load_dataset_in_memory --separated \ 21 | --learning_rate 1e-05 \ 22 | --max_epochs 40 \ 23 | --early_stopping \ 24 | --patience 8 \ 25 | --weight_decay 0.0 \ 26 | --num_warmup_steps 100 \ 27 | --monitor validation/map \ 28 | --val_check_interval 1.0 \ 29 | --num_workers 8 \ 30 | --shuffle_candidates --reload_dataloaders_every_n_epoch 1 \ 31 | --head_type AE_k \ 32 | --seed 1337 33 | -------------------------------------------------------------------------------- /transformers_experiments/finetuning/wikiqa/run_roberta_base_joint_ie_k.sh: -------------------------------------------------------------------------------- 1 | python -m transformers_framework \ 2 | --model RobertaJointAS2 \ 3 | --devices 2 --accelerator gpu --strategy ddp \ 4 | --precision 16 \ 5 | --pre_trained_model \ 6 | --name roberta-base-joint-wikiqa-IE-k \ 7 | --output_dir outputs/joint-wikiqa \ 8 | \ 9 | --adapter JointwiseArrowAdapter \ 10 | --batch_size 32 --val_batch_size 128 --test_batch_size 128 \ 11 | --train_filepath --train_split train \ 12 | --valid_filepath --valid_split validation \ 13 | --test_filepath --test_split test \ 14 | --field_names question answer \ 15 | --label_name label \ 16 | --key_name key \ 17 | \ 18 | --accumulate_grad_batches 2 \ 19 | --max_sequence_length 64 \ 20 | -k 5 --selection all --force_load_dataset_in_memory --separated \ 21 | --learning_rate 1e-05 \ 22 | --max_epochs 40 \ 23 | --early_stopping \ 24 | --patience 8 \ 25 | --weight_decay 0.0 \ 26 | --num_warmup_steps 100 \ 27 | --monitor validation/map \ 28 | --val_check_interval 1.0 \ 29 | --num_workers 8 \ 30 | --shuffle_candidates --reload_dataloaders_every_n_epoch 1 \ 31 | --head_type IE_k \ 32 | --seed 1337 33 | -------------------------------------------------------------------------------- /transformers_experiments/finetuning/asnq/run_roberta_base_joint_ae_k_best.sh: -------------------------------------------------------------------------------- 1 | python -m transformers_framework \ 2 | --model RobertaJointAS2 \ 3 | --devices 8 --accelerator gpu --strategy ddp \ 4 | --precision 16 \ 5 | --pre_trained_model \ 6 | --name roberta-base-joint-asnq-AE-k-best \ 7 | --output_dir outputs/joint-asnq \ 8 | \ 9 | --adapter JointwiseArrowAdapter \ 10 | --batch_size 128 --val_batch_size 128 --test_batch_size 128 \ 11 | --train_filepath --train_split train \ 12 | --valid_filepath --valid_split validation \ 13 | --test_filepath --test_split test \ 14 | --field_names question answer \ 15 | --label_name label \ 16 | --key_name key \ 17 | --score_name scores_roberta_base \ 18 | \ 19 | --accumulate_grad_batches 2 \ 20 | --max_sequence_length 64 \ 21 | -k 5 --selection best --separated \ 22 | --learning_rate 1e-05 \ 23 | --max_epochs 6 \ 24 | --early_stopping \ 25 | --patience 8 \ 26 | --weight_decay 0.0 \ 27 | --num_warmup_steps 5000 \ 28 | --monitor validation/map \ 29 | --val_check_interval 0.5 \ 30 | --num_workers 8 \ 31 | --shuffle_candidates --reload_dataloaders_every_n_epoch 1 \ 32 | --head_type AE_k \ 33 | --seed 1337 34 | -------------------------------------------------------------------------------- /transformers_experiments/finetuning/asnq/run_roberta_base_joint_ie_k_best.sh: -------------------------------------------------------------------------------- 1 | python -m transformers_framework \ 2 | --model RobertaJointAS2 \ 3 | --devices 8 --accelerator gpu --strategy ddp \ 4 | --precision 16 \ 5 | --pre_trained_model \ 6 | --name roberta-base-joint-asnq-IE-k-best \ 7 | --output_dir outputs/joint-asnq \ 8 | \ 9 | --adapter JointwiseArrowAdapter \ 10 | --batch_size 128 --val_batch_size 128 --test_batch_size 128 \ 11 | --train_filepath --train_split train \ 12 | --valid_filepath --valid_split validation \ 13 | --test_filepath --test_split test \ 14 | --field_names question answer \ 15 | --label_name label \ 16 | --key_name key \ 17 | --score_name scores_roberta_base \ 18 | \ 19 | --accumulate_grad_batches 2 \ 20 | --max_sequence_length 64 \ 21 | -k 5 --selection best --separated \ 22 | --learning_rate 1e-05 \ 23 | --max_epochs 6 \ 24 | --early_stopping \ 25 | --patience 8 \ 26 | --weight_decay 0.0 \ 27 | --num_warmup_steps 5000 \ 28 | --monitor validation/map \ 29 | --val_check_interval 0.5 \ 30 | --num_workers 8 \ 31 | --shuffle_candidates --reload_dataloaders_every_n_epoch 1 \ 32 | --head_type IE_k \ 33 | --seed 1337 34 | -------------------------------------------------------------------------------- /transformers_experiments/finetuning/trecqa/run_roberta_base_joint_ae_k_best.sh: -------------------------------------------------------------------------------- 1 | python -m transformers_framework \ 2 | --model RobertaJointAS2 \ 3 | --devices 2 --accelerator gpu --strategy ddp \ 4 | --precision 16 \ 5 | --pre_trained_model \ 6 | --name roberta-base-joint-trecqa-AE-k-best \ 7 | --output_dir outputs/joint-trecqa \ 8 | \ 9 | --adapter JointwiseArrowAdapter \ 10 | --batch_size 32 --val_batch_size 128 --test_batch_size 128 \ 11 | --train_filepath --train_split train \ 12 | --valid_filepath --valid_split validation \ 13 | --test_filepath --test_split test \ 14 | --field_names question answer \ 15 | --label_name label \ 16 | --key_name key \ 17 | --score_name scores_roberta_base \ 18 | \ 19 | --accumulate_grad_batches 2 \ 20 | --max_sequence_length 64 \ 21 | -k 5 --selection best --separated \ 22 | --learning_rate 1e-05 \ 23 | --max_epochs 40 \ 24 | --early_stopping \ 25 | --patience 8 \ 26 | --weight_decay 0.0 \ 27 | --num_warmup_steps 100 \ 28 | --monitor validation/map \ 29 | --val_check_interval 1.0 \ 30 | --num_workers 8 \ 31 | --shuffle_candidates --reload_dataloaders_every_n_epoch 1 \ 32 | --head_type AE_k \ 33 | --seed 1337 34 | -------------------------------------------------------------------------------- /transformers_experiments/finetuning/trecqa/run_roberta_base_joint_ie_k_best.sh: -------------------------------------------------------------------------------- 1 | python -m transformers_framework \ 2 | --model RobertaJointAS2 \ 3 | --devices 2 --accelerator gpu --strategy ddp \ 4 | --precision 16 \ 5 | --pre_trained_model \ 6 | --name roberta-base-joint-trecqa-IE-k-best \ 7 | --output_dir outputs/joint-trecqa \ 8 | \ 9 | --adapter JointwiseArrowAdapter \ 10 | --batch_size 32 --val_batch_size 128 --test_batch_size 128 \ 11 | --train_filepath --train_split train \ 12 | --valid_filepath --valid_split validation \ 13 | --test_filepath --test_split test \ 14 | --field_names question answer \ 15 | --label_name label \ 16 | --key_name key \ 17 | --score_name scores_roberta_base \ 18 | \ 19 | --accumulate_grad_batches 2 \ 20 | --max_sequence_length 64 \ 21 | -k 5 --selection best --separated \ 22 | --learning_rate 1e-05 \ 23 | --max_epochs 40 \ 24 | --early_stopping \ 25 | --patience 8 \ 26 | --weight_decay 0.0 \ 27 | --num_warmup_steps 100 \ 28 | --monitor validation/map \ 29 | --val_check_interval 1.0 \ 30 | --num_workers 8 \ 31 | --shuffle_candidates --reload_dataloaders_every_n_epoch 1 \ 32 | --head_type IE_k \ 33 | --seed 1337 34 | -------------------------------------------------------------------------------- /transformers_experiments/finetuning/wikiqa/run_roberta_base_joint_ae_k_best.sh: -------------------------------------------------------------------------------- 1 | python -m transformers_framework \ 2 | --model RobertaJointAS2 \ 3 | --devices 2 --accelerator gpu --strategy ddp \ 4 | --precision 16 \ 5 | --pre_trained_model \ 6 | --name roberta-base-joint-wikiqa-AE-k-best \ 7 | --output_dir outputs/joint-wikiqa \ 8 | \ 9 | --adapter JointwiseArrowAdapter \ 10 | --batch_size 32 --val_batch_size 128 --test_batch_size 128 \ 11 | --train_filepath --train_split train \ 12 | --valid_filepath --valid_split validation \ 13 | --test_filepath --test_split test \ 14 | --field_names question answer \ 15 | --label_name label \ 16 | --key_name key \ 17 | --score_name scores_roberta_base \ 18 | \ 19 | --accumulate_grad_batches 2 \ 20 | --max_sequence_length 64 \ 21 | -k 5 --selection best --separated \ 22 | --learning_rate 1e-05 \ 23 | --max_epochs 40 \ 24 | --early_stopping \ 25 | --patience 8 \ 26 | --weight_decay 0.0 \ 27 | --num_warmup_steps 100 \ 28 | --monitor validation/map \ 29 | --val_check_interval 1.0 \ 30 | --num_workers 8 \ 31 | --shuffle_candidates --reload_dataloaders_every_n_epoch 1 \ 32 | --head_type AE_k \ 33 | --seed 1337 34 | -------------------------------------------------------------------------------- /transformers_experiments/finetuning/wikiqa/run_roberta_base_joint_ie_k_best.sh: -------------------------------------------------------------------------------- 1 | python -m transformers_framework \ 2 | --model RobertaJointAS2 \ 3 | --devices 2 --accelerator gpu --strategy ddp \ 4 | --precision 16 \ 5 | --pre_trained_model \ 6 | --name roberta-base-joint-wikiqa-IE-k-best \ 7 | --output_dir outputs/joint-wikiqa \ 8 | \ 9 | --adapter JointwiseArrowAdapter \ 10 | --batch_size 32 --val_batch_size 128 --test_batch_size 128 \ 11 | --train_filepath --train_split train \ 12 | --valid_filepath --valid_split validation \ 13 | --test_filepath --test_split test \ 14 | --field_names question answer \ 15 | --label_name label \ 16 | --key_name key \ 17 | --score_name scores_roberta_base \ 18 | \ 19 | --accumulate_grad_batches 2 \ 20 | --max_sequence_length 64 \ 21 | -k 5 --selection best --separated \ 22 | --learning_rate 1e-05 \ 23 | --max_epochs 40 \ 24 | --early_stopping \ 25 | --patience 8 \ 26 | --weight_decay 0.0 \ 27 | --num_warmup_steps 100 \ 28 | --monitor validation/map \ 29 | --val_check_interval 1.0 \ 30 | --num_workers 8 \ 31 | --shuffle_candidates --reload_dataloaders_every_n_epoch 1 \ 32 | --head_type IE_k \ 33 | --seed 1337 34 | -------------------------------------------------------------------------------- /transformers_framework/models/joint/as2/roberta.py: -------------------------------------------------------------------------------- 1 | from transformers import RobertaTokenizerFast 2 | 3 | from transformers_framework.architectures.roberta.modeling_config import RobertaJointConfig 4 | from transformers_framework.architectures.roberta.modeling_joint_roberta import JointRobertaForSequenceClassification 5 | from transformers_framework.models.joint.as2.base import BaseJointAS2 6 | 7 | 8 | class RobertaJointAS2(BaseJointAS2): 9 | 10 | def setup_config(self) -> RobertaJointConfig: 11 | return RobertaJointConfig.from_pretrained( 12 | self.hyperparameters.pre_trained_config, 13 | k=self.hyperparameters.k, 14 | sentence_msl=self.hyperparameters.max_sequence_length, 15 | num_labels=self.hyperparameters.num_labels, 16 | head_type=self.hyperparameters.head_type, 17 | ) 18 | 19 | def setup_model(self, config: RobertaJointConfig) -> JointRobertaForSequenceClassification: 20 | if self.hyperparameters.pre_trained_model is None: 21 | return JointRobertaForSequenceClassification(config) 22 | else: 23 | return JointRobertaForSequenceClassification.from_pretrained( 24 | self.hyperparameters.pre_trained_model, config=config, ignore_mismatched_sizes=True 25 | ) 26 | 27 | def setup_tokenizer(self) -> RobertaTokenizerFast: 28 | return RobertaTokenizerFast.from_pretrained(self.hyperparameters.pre_trained_tokenizer) 29 | -------------------------------------------------------------------------------- /transformers_framework/models/joint/fact_checking/roberta.py: -------------------------------------------------------------------------------- 1 | from transformers import RobertaTokenizerFast 2 | 3 | from transformers_framework.architectures.roberta.modeling_config import RobertaJointConfig 4 | from transformers_framework.architectures.roberta.modeling_joint_roberta import JointRobertaForSequenceClassification 5 | from transformers_framework.models.joint.fact_checking.base import BaseJointFactChecking 6 | 7 | 8 | class RobertaJointFactChecking(BaseJointFactChecking): 9 | 10 | def setup_config(self) -> RobertaJointConfig: 11 | return RobertaJointConfig.from_pretrained( 12 | self.hyperparameters.pre_trained_config, 13 | k=self.hyperparameters.k, 14 | sentence_msl=self.hyperparameters.max_sequence_length, 15 | num_labels=self.hyperparameters.num_labels, 16 | head_type=self.hyperparameters.head_type, 17 | ) 18 | 19 | def setup_model(self, config: RobertaJointConfig) -> JointRobertaForSequenceClassification: 20 | if self.hyperparameters.pre_trained_model is None: 21 | return JointRobertaForSequenceClassification(config) 22 | else: 23 | return JointRobertaForSequenceClassification.from_pretrained( 24 | self.hyperparameters.pre_trained_model, config=config, ignore_mismatched_sizes=True 25 | ) 26 | 27 | def setup_tokenizer(self) -> RobertaTokenizerFast: 28 | return RobertaTokenizerFast.from_pretrained(self.hyperparameters.pre_trained_tokenizer) 29 | -------------------------------------------------------------------------------- /transformers_framework/models/joint/mlm_as2/roberta.py: -------------------------------------------------------------------------------- 1 | from transformers import RobertaTokenizerFast 2 | 3 | from transformers_framework.architectures.roberta.modeling_config import RobertaJointConfig 4 | from transformers_framework.architectures.roberta.modeling_joint_roberta import ( 5 | JointRobertaForMaskedLMAndSequenceClassification, 6 | ) 7 | from transformers_framework.models.joint.mlm_as2.base import BaseJointMLMAndAS2 8 | 9 | 10 | class RobertaJointMLMAndClassification(BaseJointMLMAndAS2): 11 | 12 | def setup_config(self) -> RobertaJointConfig: 13 | return RobertaJointConfig.from_pretrained( 14 | self.hyperparameters.pre_trained_config, 15 | k=self.hyperparameters.k, 16 | sentence_msl=self.hyperparameters.max_sequence_length, 17 | num_labels=self.hyperparameters.num_labels, 18 | head_type=self.hyperparameters.head_type, 19 | ) 20 | 21 | def setup_model(self, config: RobertaJointConfig) -> JointRobertaForMaskedLMAndSequenceClassification: 22 | if self.hyperparameters.pre_trained_model is None: 23 | return JointRobertaForMaskedLMAndSequenceClassification(config) 24 | else: 25 | return JointRobertaForMaskedLMAndSequenceClassification.from_pretrained( 26 | self.hyperparameters.pre_trained_model, config=config, ignore_mismatched_sizes=True 27 | ) 28 | 29 | def setup_tokenizer(self) -> RobertaTokenizerFast: 30 | return RobertaTokenizerFast.from_pretrained(self.hyperparameters.pre_trained_tokenizer) 31 | -------------------------------------------------------------------------------- /transformers_framework/adapters/transformer_adapter.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from argparse import ArgumentParser, Namespace 3 | 4 | from transformers import PreTrainedTokenizer 5 | from transformers_lightning.adapters import SuperAdapter 6 | 7 | from transformers_framework.transformations.transformation import Transformation, TransformationsConcatenation 8 | 9 | 10 | class TransformersAdapter(SuperAdapter): 11 | 12 | def __init__( 13 | self, 14 | hyperparameters: Namespace, 15 | tokenizer: PreTrainedTokenizer, 16 | stage_name: str, 17 | seed: int = 0, 18 | ) -> None: 19 | super().__init__(hyperparameters) 20 | 21 | self.tokenizer = tokenizer 22 | self.stage_name = stage_name 23 | self.seed = seed 24 | self.transformations = self.__get_transformations__() 25 | 26 | def __get_transformations__(self) -> Transformation: 27 | return TransformationsConcatenation(self.hyperparameters) # empty transformations 28 | 29 | @abstractmethod 30 | def is_active(self) -> bool: 31 | r""" Return True or False based on whether this adapter could return data or not. """ 32 | 33 | @staticmethod 34 | def add_adapter_instance_specific_args(parser: ArgumentParser, stage_name: str): 35 | r""" In the case many adapters are used, it could be useful 36 | to organize the arguments of every adapter using a different prefix. 37 | Put here all the arguments that are not shared by every adapter instance, for 38 | example the path of the data on the disk. """ 39 | -------------------------------------------------------------------------------- /transformers_framework/architectures/modeling_output.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Optional 3 | 4 | import torch 5 | from transformers.modeling_outputs import BaseModelOutput 6 | 7 | 8 | @dataclass 9 | class MaskedLMOutput(BaseModelOutput): 10 | r""" 11 | Base class for masked language modeling outputs. 12 | 13 | Args: 14 | masked_lm_loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, 15 | `optional`, returned when :obj:`masked_lm_labels` is provided): Masked language modeling (MLM) loss. 16 | masked_lm_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`): 17 | Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). 18 | """ 19 | 20 | masked_lm_loss: Optional[torch.FloatTensor] = None 21 | masked_lm_logits: torch.FloatTensor = None 22 | 23 | 24 | @dataclass 25 | class SequenceClassificationOutput(BaseModelOutput): 26 | r""" 27 | Base class for sequence classification outputs. 28 | 29 | Args: 30 | seq_class_loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, 31 | `optional`, returned when :obj:`seq_class_labels` is provided): 32 | Sequence classification loss. 33 | seq_class_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, hidden_size)`): 34 | Sequence classification logits (scores for each input example taken before SoftMax). 35 | """ 36 | 37 | seq_class_loss: Optional[torch.FloatTensor] = None 38 | seq_class_logits: Optional[torch.FloatTensor] = None 39 | 40 | 41 | @dataclass 42 | class MaskedLMAndSequenceClassificationOutput( 43 | MaskedLMOutput, SequenceClassificationOutput 44 | ): 45 | r""" Masked language modeling + sequence classification. """ 46 | -------------------------------------------------------------------------------- /transformers_framework/transformations/transformation.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from argparse import Namespace 3 | from typing import Generator, Sequence, Union 4 | 5 | from transformers_framework.utilities.structures import Sample 6 | 7 | 8 | class Transformation(ABC): 9 | r""" Transformation class encodes a transformation from samples to other samples. """ 10 | 11 | def __init__(self, hyperparameters: Namespace): 12 | self.hyperparameters = hyperparameters 13 | 14 | @abstractmethod 15 | def __call__( 16 | self, 17 | samples: Union[Generator[Sample, None, None], Sample] 18 | ) -> Union[Generator[Sample, None, None], Sample]: 19 | r""" Apply here the trasformation on a single element or on all elements. """ 20 | 21 | 22 | class TransformationsConcatenation(Transformation): 23 | r""" Concatenate a sequence of transformations and return elements after parsing with every transformation 24 | in the same order provided in the __init__. """ 25 | 26 | def __init__(self, hyperparameters: Namespace, *transformations: Sequence[Transformation]): 27 | super().__init__(hyperparameters) 28 | self.transformations = list(transformations) 29 | 30 | def __call__( 31 | self, 32 | samples: Union[Generator[Sample, None, None], Sample] 33 | ) -> Union[Generator[Sample, None, None], Sample]: 34 | for transformation in self.transformations: 35 | samples = transformation(samples) 36 | return samples 37 | 38 | def __str__(self) -> str: 39 | return ( 40 | f"" 42 | ) 43 | 44 | def append_transformation(self, transformation: Transformation): 45 | r""" Append transformation to the list. """ 46 | self.transformations.append(transformation) 47 | -------------------------------------------------------------------------------- /transformers_utilities/datasets/create_trecqa_dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | import gzip 5 | from argparse import ArgumentParser 6 | from typing import Dict, List 7 | 8 | from datasets import Dataset, DatasetDict 9 | 10 | 11 | logging.getLogger().setLevel(logging.INFO) 12 | 13 | 14 | def to_bool(string): 15 | return string.lower() in ('yes', 'pos', 'positive', '1', 'correct') 16 | 17 | 18 | def to_int(string): 19 | return int(to_bool(string)) 20 | 21 | 22 | def load_dataset(filepath: str) -> List[Dict]: 23 | r""" Load a JSONL from disk. """ 24 | with (gzip.open(filepath) if filepath.endswith('gz') else open(filepath)) as fi: 25 | return [json.loads(line) for line in fi] 26 | 27 | 28 | def main(args): 29 | r""" Create TREC-QA dataset. """ 30 | assert not os.path.exists(args.output_folder) 31 | 32 | logging.info("Loading data") 33 | trecqa = { 34 | 'train': load_dataset(os.path.join(args.input_folder, 'train-all.jsonl.gz')), 35 | 'validation': load_dataset(os.path.join(args.input_folder, 'dev-filtered.jsonl')), 36 | 'test': load_dataset(os.path.join(args.input_folder, 'test-filtered.jsonl')), 37 | } 38 | 39 | res = {} 40 | for split, data in trecqa.items(): 41 | dataset = dict(question=[], answer=[], label=[], key=[]) 42 | for i, sample in enumerate(sorted(data, key=lambda a: a['question'])): 43 | labels = [s['label'] for s in sample["candidates"]] 44 | answers = [s['sentence'] for s in sample["candidates"]] 45 | dataset['question'].append(sample['question']) 46 | dataset['answer'].append(answers) 47 | dataset['key'].append(i) 48 | dataset['label'].append(labels) 49 | 50 | res[split] = Dataset.from_dict(dataset) 51 | 52 | logging.info("Saving results") 53 | res = DatasetDict(res) 54 | res.save_to_disk(args.output_folder) 55 | 56 | 57 | if __name__ == "__main__": 58 | parser = ArgumentParser() 59 | parser.add_argument('--input_folder', type=str, required=True) 60 | parser.add_argument('--output_folder', type=str, required=True) 61 | args = parser.parse_args() 62 | main(args) 63 | -------------------------------------------------------------------------------- /transformers_utilities/datasets/create_wikiqa_dataset.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from argparse import ArgumentParser 4 | 5 | from datasets import load_dataset, Dataset, DatasetDict 6 | from tqdm import tqdm 7 | 8 | 9 | logging.getLogger().setLevel(logging.INFO) 10 | 11 | 12 | def main(args): 13 | r""" Create WikiQA dataset. """ 14 | assert not os.path.exists(args.output_folder) 15 | 16 | logging.info("Loading data") 17 | wikiqa = load_dataset('wiki_qa') 18 | 19 | res = {} 20 | for split in wikiqa.keys(): 21 | dataset = wikiqa[split] 22 | 23 | questions_to_answers = {} 24 | for example in tqdm(dataset, total=len(dataset), desc=f"Processing split {split}..."): 25 | if example['question'] not in questions_to_answers: 26 | questions_to_answers[example['question']] = { 27 | 'answer': [], 28 | 'label': [], 29 | 'key': example['question_id'] 30 | } 31 | 32 | questions_to_answers[example['question']]['answer'].append(example['answer']) 33 | questions_to_answers[example['question']]['label'].append(example['label']) 34 | 35 | if split in ('validation', 'test'): # cleaning all+ and all- 36 | questions_to_answers = { 37 | k: v for k, v in questions_to_answers.items() if (0 < sum(v['label']) < len(v['label'])) 38 | } 39 | 40 | dataset = dict(question=[], answer=[], label=[], key=[]) 41 | for question in sorted(list(questions_to_answers.keys())): 42 | values = questions_to_answers[question] 43 | dataset['question'].append(question) 44 | dataset['answer'].append(values['answer']) 45 | dataset['key'].append(values['key']) 46 | dataset['label'].append(values['label']) 47 | 48 | res[split] = Dataset.from_dict(dataset) 49 | 50 | logging.info("Saving results") 51 | res = DatasetDict(res) 52 | res.save_to_disk(args.output_folder) 53 | 54 | 55 | if __name__ == "__main__": 56 | parser = ArgumentParser() 57 | parser.add_argument('--output_folder', type=str, required=True) 58 | args = parser.parse_args() 59 | main(args) 60 | -------------------------------------------------------------------------------- /transformers_utilities/datasets/create_fever_dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | from argparse import ArgumentParser 5 | from typing import Dict 6 | from datasets import Dataset, DatasetDict 7 | 8 | from tqdm import tqdm 9 | 10 | 11 | logging.getLogger().setLevel(logging.INFO) 12 | 13 | 14 | label_to_id = { 15 | "SUPPORTS": 0, 16 | "NOT ENOUGH INFO": 1, 17 | "REFUTES": 2, 18 | } 19 | 20 | fake_evicence = ("no doc", -1, "", 0) 21 | 22 | 23 | def get_dataset_from_file(filepath: str) -> Dict: 24 | r""" Build a single split of the dataset. """ 25 | with open(filepath) as fi: 26 | data = [json.loads(line) for line in fi] 27 | 28 | res = dict(claim=[], evidence=[], label=[], key=[], doc=[]) 29 | 30 | for example in tqdm(data, desc="Processing..."): 31 | label = example.get('label', None) 32 | if label is not None: 33 | label = label_to_id[label] 34 | 35 | if not example['evidence']: 36 | logging.warn(f"No evidence for claim id {example['id']}") 37 | example['evidence'] = [fake_evicence] 38 | 39 | evicences = [evidence[0] + ". " + evidence[2] for evidence in example['evidence']] 40 | doc_names = [evidence[0] for evidence in example['evidence']] 41 | 42 | res['claim'].append(example['claim']) 43 | res['evidence'].append(evicences) 44 | res['label'].append(label) 45 | res['doc'].append(doc_names) 46 | res['key'].append(example['id']) 47 | 48 | return Dataset.from_dict(res) 49 | 50 | 51 | def main(args): 52 | r""" Create FEVER dataset. """ 53 | 54 | assert os.path.isfile(args.train_file) 55 | assert os.path.isfile(args.dev_file) 56 | assert os.path.isfile(args.test_file) 57 | 58 | res = DatasetDict( 59 | train=get_dataset_from_file(args.train_file), 60 | validation=get_dataset_from_file(args.dev_file), 61 | test=get_dataset_from_file(args.test_file), 62 | ) 63 | res.save_to_disk(args.output_folder) 64 | 65 | 66 | if __name__ == "__main__": 67 | parser = ArgumentParser() 68 | parser.add_argument('--train_file', type=str, required=True) 69 | parser.add_argument('--dev_file', type=str, required=True) 70 | parser.add_argument('--test_file', type=str, required=True) 71 | parser.add_argument('--output_folder', type=str, required=True) 72 | args = parser.parse_args() 73 | main(args) 74 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | # Code check and install requirements 2 | check: 3 | # isort and flake 4 | isort . 5 | flake8 . 6 | 7 | env: 8 | pip install -r requirements.txt 9 | 10 | wikiqa: 11 | rm -rf datasets/wikiqa /tmp/wikiqa_tmp 12 | python transformers_utilities/datasets/create_wikiqa_dataset.py --output_folder /tmp/wikiqa_tmp 13 | python transformers_utilities/datasets/merge_datasets.py \ 14 | --input /tmp/wikiqa_tmp datasets/scores_as2/scores_roberta_base_wikiqa \ 15 | --output datasets/wikiqa 16 | rm -r /tmp/wikiqa_tmp 17 | 18 | trecqa: 19 | rm -rf datasets/trecqa /tmp/lexdecomp-master /tmp/trecqa.zip /tmp/trecqa_tmp 20 | wget https://github.com/mcrisc/lexdecomp/archive/refs/heads/master.zip -O /tmp/trecqa.zip 21 | unzip /tmp/trecqa.zip -d /tmp 22 | python transformers_utilities/datasets/create_trecqa_dataset.py \ 23 | --input_folder /tmp/lexdecomp-master/trec-qa \ 24 | --output_folder /tmp/trecqa_tmp 25 | python transformers_utilities/datasets/merge_datasets.py \ 26 | --input /tmp/trecqa_tmp datasets/scores_as2/scores_roberta_base_trecqa \ 27 | --output datasets/trecqa 28 | rm -r /tmp/lexdecomp-master /tmp/trecqa.zip /tmp/trecqa_tmp 29 | 30 | asnq: 31 | rm -rf datasets/asnq /tmp/asnq.tar /tmp/data /tmp/wqa-cascade-transformers-master /tmp/asnq_tmp 32 | wget https://d3t7erp6ge410c.cloudfront.net/tanda-aaai-2020/data/asnq.tar -O /tmp/asnq.tar 33 | tar xvf /tmp/asnq.tar -C /tmp 34 | wget https://github.com/alexa/wqa-cascade-transformers/archive/refs/heads/master.zip -O /tmp/cascade.zip 35 | unzip /tmp/cascade.zip -d /tmp 36 | python transformers_utilities/datasets/create_asnq_dataset.py \ 37 | --input_folder /tmp/data/asnq \ 38 | --output /tmp/asnq_tmp \ 39 | --dev_filter /tmp/wqa-cascade-transformers-master/acl2020cascade/data/unique.dev \ 40 | --test_filter /tmp/wqa-cascade-transformers-master/acl2020cascade/data/unique.test 41 | python transformers_utilities/datasets/merge_datasets.py \ 42 | --input /tmp/asnq_tmp datasets/scores_as2/scores_roberta_base_asnq \ 43 | --output datasets/asnq 44 | rm -r /tmp/cascade.zip /tmp/asnq.tar /tmp/data /tmp/wqa-cascade-transformers-master /tmp/asnq_tmp 45 | 46 | fever: 47 | rm -rf datasets/fever /tmp/kgat.zip /tmp/KernelGAT 48 | wget https://thunlp.oss-cn-qingdao.aliyuncs.com/KernelGAT/FEVER/KernelGAT.zip -O /tmp/kgat.zip 49 | unzip /tmp/kgat.zip -d /tmp 50 | python transformers_utilities/datasets/create_fever_dataset.py \ 51 | --train_file /tmp/KernelGAT/data/bert_train.json \ 52 | --dev_file /tmp/KernelGAT/data/bert_dev.json \ 53 | --test_file /tmp/KernelGAT/data/bert_test.json \ 54 | --output datasets/fever 55 | rm -r /tmp/kgat.zip /tmp/KernelGAT 56 | -------------------------------------------------------------------------------- /transformers_framework/models/base/base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from argparse import ArgumentParser 3 | 4 | from pytorch_lightning.utilities import rank_zero_warn 5 | from transformers import PreTrainedTokenizer 6 | from transformers.configuration_utils import PretrainedConfig 7 | from transformers.modeling_utils import PreTrainedModel 8 | from transformers_lightning.models import TransformersModel 9 | 10 | 11 | class BaseModel(TransformersModel, ABC): 12 | 13 | config_class: PretrainedConfig 14 | model_class: PreTrainedModel 15 | tokenizer_class: PreTrainedTokenizer 16 | 17 | def __init__(self, hyperparameters): 18 | super().__init__(hyperparameters) 19 | self.save_hyperparameters(hyperparameters) 20 | 21 | if self.hyperparameters.pre_trained_config is None: 22 | self.hyperparameters.pre_trained_config = self.hyperparameters.pre_trained_model 23 | rank_zero_warn('Found None `pre_trained_config`, setting equal to `pre_trained_model`') 24 | 25 | if self.hyperparameters.pre_trained_tokenizer is None: 26 | self.hyperparameters.pre_trained_tokenizer = self.hyperparameters.pre_trained_model 27 | rank_zero_warn('Found None `pre_trained_tokenizer`, setting equal to `pre_trained_model`') 28 | 29 | assert self.hyperparameters.pre_trained_config is not None, ( 30 | "Cannot instantiate model without a pre_trained_config." 31 | ) 32 | 33 | self.config = self.setup_config() 34 | self.model = self.setup_model(self.config) 35 | self.tokenizer = self.setup_tokenizer() 36 | 37 | @abstractmethod 38 | def setup_config(self) -> PretrainedConfig: 39 | r""" Return the config instance. """ 40 | 41 | @abstractmethod 42 | def setup_model(self, config: PretrainedConfig) -> PreTrainedModel: 43 | r""" Return the model instance. """ 44 | 45 | @abstractmethod 46 | def setup_tokenizer(self) -> PreTrainedTokenizer: 47 | r""" Return the tokenizer instance. """ 48 | 49 | @staticmethod 50 | def add_model_specific_args(parser: ArgumentParser): 51 | super(BaseModel, BaseModel).add_model_specific_args(parser) 52 | # add pre_trained model, tokenizer and config arguments. default config and tokenizer to model if missing 53 | parser.add_argument('--pre_trained_model', type=str, required=False, default=None) 54 | parser.add_argument('--pre_trained_tokenizer', type=str, required=False, default=None) 55 | parser.add_argument('--pre_trained_config', type=str, required=False, default=None) 56 | parser.add_argument('--num_labels', type=int, required=False, default=2) 57 | -------------------------------------------------------------------------------- /transformers_framework/utilities/processors.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | from transformers import PreTrainedTokenizer 4 | 5 | from transformers_framework.utilities.structures import DataSample, JointwiseSample, PairwiseSample 6 | from transformers_framework.utilities.tokenization import encode_many_sequences, encode_pair, encode_sequences 7 | 8 | 9 | def data_processor( 10 | sample: DataSample, 11 | tokenizer: PreTrainedTokenizer, 12 | max_sequence_length: int, 13 | chain: bool, 14 | ): 15 | r""" Generinc encoder. Encodes data by concatenating on the msl axis. """ 16 | encoded = encode_sequences(sample.data, tokenizer, max_sequence_length, chain=chain) 17 | 18 | if sample.label is not None: 19 | encoded['labels'] = sample.label 20 | if sample.key is not None: 21 | encoded['keys'] = sample.key 22 | 23 | return encoded 24 | 25 | 26 | def pair_processor( 27 | sample: PairwiseSample, 28 | tokenizer: PreTrainedTokenizer, 29 | separated: bool, 30 | max_sequence_length: int, 31 | chain: bool, 32 | truncation: Union[int, str] = True, 33 | padding: str = "max_length", 34 | return_overflowing_tokens: bool = False, 35 | return_offsets_mapping: bool = False, 36 | stride: int = 0, 37 | allow_null_second: bool = False, 38 | ): 39 | r""" Encode a pair for pairwise training. """ 40 | # TODO: remove comment 41 | # assert sample.first is not None and (sample.second is not None or allow_null_second is True) 42 | 43 | if separated: 44 | encoded = encode_sequences([sample.first, sample.second], tokenizer, max_sequence_length, chain=chain) 45 | else: 46 | encoded = encode_pair( 47 | sample.first, 48 | sample.second, 49 | tokenizer, 50 | max_sequence_length, 51 | truncation=truncation, 52 | padding=padding, 53 | return_overflowing_tokens=return_overflowing_tokens, 54 | return_offsets_mapping=return_offsets_mapping, 55 | stride=stride, 56 | allow_null_second=allow_null_second, 57 | ) 58 | 59 | if sample.label is not None: 60 | encoded['labels'] = sample.label 61 | if sample.key is not None: 62 | encoded['keys'] = sample.key 63 | 64 | return encoded 65 | 66 | 67 | def joint_processor( 68 | sample: JointwiseSample, 69 | tokenizer: PreTrainedTokenizer, 70 | separated: bool, 71 | max_sequence_length: int, 72 | reduce_labels: bool, 73 | ): 74 | r""" Encoding for Jointwise models. 75 | Encodes a sequence on sentences with internal padding by concatenating along the msl axis. 76 | """ 77 | if separated: 78 | encoded = encode_sequences([sample.first] + sample.seconds, tokenizer, max_sequence_length, chain=True) 79 | else: 80 | encoded = encode_many_sequences([sample.first] + sample.seconds, tokenizer, max_sequence_length) 81 | 82 | if sample.label is not None: 83 | encoded['labels'] = sample.label[0] if reduce_labels else sample.label 84 | if sample.key is not None: 85 | encoded['keys'] = sample.key 86 | if sample.valid is not None: 87 | encoded['valid'] = sample.valid 88 | 89 | return encoded 90 | -------------------------------------------------------------------------------- /process_datasets/__main__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import random 3 | from argparse import ArgumentParser 4 | 5 | from datasets import DatasetDict 6 | from psutil import cpu_count 7 | from transformers_lightning.utils import get_classes_from_module 8 | 9 | from process_datasets import loaders, strategies 10 | from process_datasets.loaders.loader import Loader 11 | from process_datasets.strategies.strategy import Strategy 12 | 13 | 14 | ALL_STRATEGIES = get_classes_from_module(strategies, parent=Strategy) 15 | ALL_LOADERS = get_classes_from_module(loaders, parent=Loader) 16 | 17 | 18 | def main(hparams): 19 | 20 | logging.info("Setting seed...") 21 | random.seed(hparams.seed) 22 | 23 | # Create instances 24 | logging.info("Creating loader and strategy instances...") 25 | loader = ALL_LOADERS[hparams.loader](hparams) 26 | strategy = ALL_STRATEGIES[hparams.strategy](hparams) 27 | 28 | logging.info("Starting pipeline...") 29 | 30 | # Data loading 31 | dataset = loader() 32 | 33 | # Data processing 34 | kwargs = dict(num_proc=hparams.num_proc) if hparams.num_proc > 0 else dict() 35 | if isinstance(dataset, DatasetDict): 36 | dataset = DatasetDict(**{ 37 | k: dataset[k].map( 38 | strategy, 39 | batched=True, 40 | batch_size=hparams.batch_size, 41 | remove_columns=dataset[k].column_names, 42 | keep_in_memory=False, 43 | **kwargs, 44 | ) for k in dataset 45 | }) 46 | else: 47 | dataset = dataset.map( 48 | strategy, 49 | batched=True, 50 | batch_size=hparams.batch_size, 51 | remove_columns=dataset.column_names, 52 | keep_in_memory=False, 53 | **kwargs, 54 | ) 55 | 56 | # Data writing 57 | logging.info("Writing to disk") 58 | dataset.save_to_disk(hparams.output_folder) 59 | 60 | logging.info("Done!") 61 | 62 | 63 | if __name__ == "__main__": 64 | parser = ArgumentParser(f"Create pretraining datasets") 65 | parser.add_argument( 66 | '--loader', type=str, required=True, choices=ALL_LOADERS, help="Loader class to load data" 67 | ) 68 | parser.add_argument( 69 | '--strategy', 70 | type=str, 71 | required=False, 72 | default='Sentence2SentenceStrategy', 73 | choices=ALL_STRATEGIES.keys(), 74 | help="Strategy to use to create the dataset", 75 | ) 76 | parser.add_argument('--output_folder', type=str, required=True, help="Output folder") 77 | parser.add_argument('--seed', default=1337, required=False, type=int, help="Seed for reproducibility.") 78 | parser.add_argument( 79 | '--batch_size', default=10000, type=int, required=False, help="How many input examples to process together." 80 | ) 81 | parser.add_argument('--num_proc', type=int, default=cpu_count(), required=False, help="How many process to use.") 82 | # add strategy parameters 83 | tmp_hparams, _ = parser.parse_known_args() 84 | 85 | loader_class = ALL_LOADERS[tmp_hparams.loader] 86 | strategy_class = ALL_STRATEGIES[tmp_hparams.strategy] 87 | 88 | loader_class.add_loader_specific_args(parser) 89 | strategy_class.add_arguments_to_argparse(parser) 90 | 91 | hparams = parser.parse_args() 92 | main(hparams) 93 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing Guidelines 2 | 3 | Thank you for your interest in contributing to our project. Whether it's a bug report, new feature, correction, or additional 4 | documentation, we greatly value feedback and contributions from our community. 5 | 6 | Please read through this document before submitting any issues or pull requests to ensure we have all the necessary 7 | information to effectively respond to your bug report or contribution. 8 | 9 | 10 | ## Reporting Bugs/Feature Requests 11 | 12 | We welcome you to use the GitHub issue tracker to report bugs or suggest features. 13 | 14 | When filing an issue, please check existing open, or recently closed, issues to make sure somebody else hasn't already 15 | reported the issue. Please try to include as much information as you can. Details like these are incredibly useful: 16 | 17 | * A reproducible test case or series of steps 18 | * The version of our code being used 19 | * Any modifications you've made relevant to the bug 20 | * Anything unusual about your environment or deployment 21 | 22 | 23 | ## Contributing via Pull Requests 24 | Contributions via pull requests are much appreciated. Before sending us a pull request, please ensure that: 25 | 26 | 1. You are working against the latest source on the *main* branch. 27 | 2. You check existing open, and recently merged, pull requests to make sure someone else hasn't addressed the problem already. 28 | 3. You open an issue to discuss any significant work - we would hate for your time to be wasted. 29 | 30 | To send us a pull request, please: 31 | 32 | 1. Fork the repository. 33 | 2. Modify the source; please focus on the specific change you are contributing. If you also reformat all the code, it will be hard for us to focus on your change. 34 | 3. Ensure local tests pass. 35 | 4. Commit to your fork using clear commit messages. 36 | 5. Send us a pull request, answering any default questions in the pull request interface. 37 | 6. Pay attention to any automated CI failures reported in the pull request, and stay involved in the conversation. 38 | 39 | GitHub provides additional document on [forking a repository](https://help.github.com/articles/fork-a-repo/) and 40 | [creating a pull request](https://help.github.com/articles/creating-a-pull-request/). 41 | 42 | 43 | ## Finding contributions to work on 44 | Looking at the existing issues is a great way to find something to contribute on. As our projects, by default, use the default GitHub issue labels (enhancement/bug/duplicate/help wanted/invalid/question/wontfix), looking at any 'help wanted' issues is a great place to start. 45 | 46 | 47 | ## Code of Conduct 48 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 49 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 50 | opensource-codeofconduct@amazon.com with any additional questions or comments. 51 | 52 | 53 | ## Security issue notifications 54 | If you discover a potential security issue in this project we ask that you notify AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/). Please do **not** create a public github issue. 55 | 56 | 57 | ## Licensing 58 | 59 | See the [LICENSE](LICENSE) file for our project's licensing. We will ask you to confirm the licensing of your contribution. 60 | -------------------------------------------------------------------------------- /transformers_utilities/datasets/create_asnq_dataset.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from argparse import ArgumentParser 4 | from typing import Dict, List, Tuple 5 | import csv 6 | 7 | from datasets import load_dataset, Dataset, DatasetDict 8 | from tqdm import tqdm 9 | 10 | 11 | logging.getLogger().setLevel(logging.INFO) 12 | 13 | 14 | def load_filter(filepath: str) -> List[str]: 15 | r""" Load a linewise file into an array of strings. """ 16 | with open(filepath) as fi: 17 | return set(line.strip() for line in fi) 18 | 19 | 20 | def load_dataset(filepath: str) -> List[Tuple]: 21 | r""" Load ASNQ dataset in TSV from disk. """ 22 | with open(filepath) as fi: 23 | reader = csv.reader(fi, delimiter='\t', quoting=csv.QUOTE_NONE) 24 | yield from reader 25 | 26 | 27 | def process_split(dataset: List[Tuple], question_filter: List[str] = None) -> Dict: 28 | r""" Process a single split and filter on questions that are in filter. """ 29 | questions_to_answers = {} 30 | for question, candidate, label in tqdm(dataset, desc=f"Processing..."): 31 | if question not in questions_to_answers: 32 | questions_to_answers[question] = { 33 | 'answer': [], 34 | 'label': [], 35 | 'key': len(questions_to_answers) 36 | } 37 | 38 | questions_to_answers[question]['answer'].append(candidate) 39 | questions_to_answers[question]['label'].append(int(label.strip() == '4')) 40 | 41 | if question_filter is not None: 42 | questions_to_answers = { 43 | k: v for k, v in questions_to_answers.items() if k in question_filter 44 | } 45 | 46 | dataset = dict(question=[], answer=[], label=[], key=[]) 47 | for question in sorted(list(questions_to_answers.keys())): 48 | values = questions_to_answers[question] 49 | dataset['question'].append(question) 50 | dataset['answer'].append(values['answer']) 51 | dataset['key'].append(values['key']) 52 | dataset['label'].append(values['label']) 53 | 54 | return Dataset.from_dict(dataset) 55 | 56 | 57 | def main(args): 58 | r""" Create ASNQ dataset. """ 59 | assert not os.path.exists(args.output_folder) 60 | 61 | logging.info("Loading data") 62 | asnq = { 63 | 'train': load_dataset(os.path.join(args.input_folder, 'train.tsv')), 64 | 'validation': load_dataset(os.path.join(args.input_folder, 'dev.tsv')), 65 | 'test': load_dataset(os.path.join(args.input_folder, 'dev.tsv')) 66 | } 67 | 68 | logging.info("Loading filters") 69 | filters = { 70 | 'dev': load_filter(args.dev_filter), 71 | 'test': load_filter(args.test_filter) 72 | } 73 | 74 | res = DatasetDict( 75 | train=process_split(asnq['train']), 76 | validation=process_split(asnq['validation'], question_filter=filters['dev']), 77 | test=process_split(asnq['test'], question_filter=filters['test']), 78 | ) 79 | 80 | logging.info("Saving results") 81 | res.save_to_disk(args.output_folder) 82 | 83 | 84 | if __name__ == "__main__": 85 | parser = ArgumentParser() 86 | parser.add_argument('--input_folder', type=str, required=True) 87 | parser.add_argument('--output_folder', type=str, required=True) 88 | parser.add_argument('--dev_filter', type=str, required=True) 89 | parser.add_argument('--test_filter', type=str, required=True) 90 | args = parser.parse_args() 91 | main(args) 92 | -------------------------------------------------------------------------------- /transformers_framework/models/joint/mlm/base.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | 3 | from transformers_lightning.language_modeling import IGNORE_IDX 4 | 5 | from transformers_framework.architectures.roberta.modeling_config import MONOLITHIC_HEAD_TYPES 6 | from transformers_framework.models.base.mlm import BaseModelMLM 7 | from transformers_framework.utilities.functional import index_multi_tensors 8 | 9 | 10 | class BaseJointMLM(BaseModelMLM): 11 | 12 | def training_step(self, batch, *args): 13 | r""" 14 | Start by masking tokens some tokens. 15 | """ 16 | input_ids, attention_mask = batch["input_ids"], batch["attention_mask"] 17 | input_ids, labels = self.mlm(input_ids) 18 | 19 | # tokens_type_ids are automatically created by the model based on the config 20 | results = self(input_ids=input_ids, attention_mask=attention_mask, labels=labels) 21 | predictions, labels = index_multi_tensors( 22 | results.seq_class_logits.argmax(dim=-1), labels, positions=labels != IGNORE_IDX 23 | ) 24 | 25 | # logs metrics for each training_step, and the average across the epoch, to the progress bar and logger 26 | self.log('training/loss', results.seq_class_loss, on_epoch=True, prog_bar=True) 27 | self.log('training/accuracy', self.train_mlm_acc(predictions, labels), on_epoch=True) 28 | 29 | return results.seq_class_loss 30 | 31 | def validation_step(self, batch, *args): 32 | r""" 33 | Start by masking tokens some tokens. 34 | """ 35 | input_ids, attention_mask = batch["input_ids"], batch["attention_mask"] 36 | input_ids, labels = self.mlm(input_ids) 37 | 38 | # tokens_type_ids are automatically created by the model based on the config 39 | results = self(input_ids=input_ids, attention_mask=attention_mask, labels=labels) 40 | predictions, labels = index_multi_tensors( 41 | results.seq_class_logits.argmax(dim=-1), labels, positions=labels != IGNORE_IDX 42 | ) 43 | 44 | # logs metrics for each training_step, and the average across the epoch, to the progress bar and logger 45 | self.log('validation/loss', results.seq_class_loss, on_epoch=True, prog_bar=True) 46 | self.log('validation/accuracy', self.valid_mlm_acc(predictions, labels), on_epoch=True) 47 | 48 | def test_step(self, batch, *args): 49 | r""" 50 | Start by masking tokens some tokens. 51 | """ 52 | input_ids, attention_mask = batch["input_ids"], batch["attention_mask"] 53 | input_ids, labels = self.mlm(input_ids) 54 | 55 | # tokens_type_ids are automatically created by the model based on the config 56 | results = self(input_ids=input_ids, attention_mask=attention_mask, labels=labels) 57 | predictions, labels = index_multi_tensors( 58 | results.seq_class_logits.argmax(dim=-1), labels, positions=labels != IGNORE_IDX 59 | ) 60 | 61 | # logs metrics for each training_step, and the average across the epoch, to the progress bar and logger 62 | self.log('test/loss', results.seq_class_loss, on_epoch=True, prog_bar=True) 63 | self.log('test/accuracy', self.valid_mlm_acc(predictions, labels), on_epoch=True) 64 | 65 | @staticmethod 66 | def add_model_specific_args(parser: ArgumentParser): 67 | super(BaseJointMLM, BaseJointMLM).add_model_specific_args(parser) 68 | parser.set_defaults(max_sequence_length=64) 69 | parser.add_argument('--head_type', type=str, required=True, choices=MONOLITHIC_HEAD_TYPES) 70 | -------------------------------------------------------------------------------- /transformers_framework/adapters/map_adapters/arrow/pairwise_adapter.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser, Namespace 2 | from typing import Dict 3 | 4 | from pytorch_lightning.trainer.states import TrainerFn 5 | from transformers import PreTrainedTokenizer 6 | 7 | from transformers_framework.adapters.map_adapters.arrow.arrow_adapter import ArrowAdapter 8 | from transformers_framework.transformations.conversion_transformation import DictSample2PairwiseSampleTransformation 9 | from transformers_framework.transformations.transformation import Transformation, TransformationsConcatenation 10 | from transformers_framework.utilities.processors import pair_processor 11 | from transformers_framework.utilities.structures import PairwiseSample 12 | 13 | 14 | class PairwiseArrowAdapter(ArrowAdapter): 15 | r""" Pairwise version of Arrow File readers, which implements filtering on scores and limits. """ 16 | 17 | def __init__( 18 | self, 19 | hyperparameters: Namespace, 20 | tokenizer: PreTrainedTokenizer, 21 | stage: TrainerFn, 22 | seed: int = 0, 23 | ) -> None: 24 | super().__init__(hyperparameters, tokenizer, stage, seed=seed) 25 | 26 | # arguments checks 27 | assert not hyperparameters.chain or hyperparameters.separated, ( 28 | "`chain` requires `separated`" 29 | ) 30 | assert type(self) != PairwiseArrowAdapter or len(hyperparameters.field_names) == 2, ( 31 | "`field_names` must have length 2" 32 | ) 33 | assert type(self) != PairwiseArrowAdapter or ( 34 | hyperparameters.separated is False or hyperparameters.allow_null_second is False 35 | ), "`allow_null_second` not allowed with `separated`" 36 | 37 | def __get_transformations__(self) -> Transformation: 38 | return TransformationsConcatenation( 39 | self.hyperparameters, 40 | DictSample2PairwiseSampleTransformation( 41 | self.hyperparameters, 42 | first_field=self.hyperparameters.field_names[0], 43 | second_field=self.hyperparameters.field_names[1], 44 | key_field=self.hyperparameters.key_name, 45 | label_field=self.hyperparameters.label_name, 46 | score_field=self.hyperparameters.score_name, 47 | ), # Dict -> PairwiseSample 48 | ) 49 | 50 | def preprocess_line(self, sample: PairwiseSample) -> Dict: 51 | r""" 52 | Process a line. The structure of each line is exactly 53 | the same returned by the __iter__ method. Here you should do data preparation 54 | for the actual model being trained. This is a good place to do batch tokenization, 55 | padding and so on. 56 | """ 57 | return pair_processor( 58 | sample, 59 | tokenizer=self.tokenizer, 60 | separated=self.hyperparameters.separated, 61 | max_sequence_length=self.hyperparameters.max_sequence_length, 62 | chain=self.hyperparameters.chain, 63 | allow_null_second=self.hyperparameters.allow_null_second, 64 | ) 65 | 66 | @staticmethod 67 | def add_adapter_specific_args(parser: ArgumentParser): 68 | super(PairwiseArrowAdapter, PairwiseArrowAdapter).add_adapter_specific_args(parser) 69 | parser.add_argument( 70 | '--separated', action="store_true", help="Candidates are separated between question and answer" 71 | ) 72 | parser.add_argument( 73 | '--score_name', 74 | required=False, 75 | default=None, 76 | help="Name of the score field" 77 | ) 78 | parser.add_argument('--allow_null_second', action="store_true", help='Allow second field to be None.') 79 | -------------------------------------------------------------------------------- /transformers_framework/utilities/tokenization.py: -------------------------------------------------------------------------------- 1 | from itertools import chain as iterchain 2 | from typing import Dict, Sequence, Union 3 | 4 | from transformers import PreTrainedTokenizer 5 | 6 | from transformers_framework.architectures.modeling_tokenizer import ExtendedTokenizerFast 7 | 8 | 9 | def encode_sequences( 10 | sequences: Sequence[str], 11 | tokenizer: PreTrainedTokenizer, 12 | max_sequence_length: int, 13 | chain: bool = True, 14 | extended_token_type_ids: int = None, 15 | ) -> Dict: 16 | r""" Encode a sequence of sentences as 17 | [CLS] question [SEP] [CLS] candidate_1 [SEP] [CLS] candidate_2 [SEP] ... 18 | or 19 | question candidate_1 candidate_2 ... 20 | such that the total length is equal to `max_sequence_length` * len(sequences) 21 | 22 | If chain, chain every sequence before returning. 23 | """ 24 | encoded = tokenizer( 25 | sequences, 26 | add_special_tokens=True, 27 | return_attention_mask=True, 28 | return_token_type_ids=True, 29 | max_length=max_sequence_length, 30 | truncation=True, 31 | padding='max_length', 32 | ) 33 | 34 | if extended_token_type_ids is not None: 35 | encoded['token_type_ids'] = [ 36 | [min(i, extended_token_type_ids)] * len(ids) 37 | for i, ids in enumerate(encoded['input_ids']) 38 | ] 39 | 40 | if chain: 41 | encoded = {k: list(iterchain(*v)) for k, v in encoded.items()} 42 | 43 | return encoded 44 | 45 | 46 | def encode_many_sequences( 47 | sequences: Sequence[str], 48 | tokenizer: ExtendedTokenizerFast, 49 | max_sequence_length: int, 50 | extended_token_type_ids: int = None, 51 | ) -> Dict: 52 | r""" Encode a list of sequences as 53 | [CLS] first [SEP] second1 [SEP] second2 [SEP] ... [SEP] 54 | or 55 | first candidate1 candidate2 ... 56 | such that the total length is equal to `max_sequence_length`. 57 | """ 58 | assert isinstance(tokenizer, ExtendedTokenizerFast), ( 59 | "Cannot use `encode_many_sequences` without ExtendedTokenizer" 60 | ) 61 | 62 | encoded = tokenizer.encode_many( 63 | sequences, 64 | add_special_tokens=True, 65 | return_attention_mask=True, 66 | return_token_type_ids=True, 67 | max_length=max_sequence_length, 68 | truncation='longest_first', 69 | padding="max_length", 70 | extended_token_type_ids=extended_token_type_ids, 71 | ) 72 | return encoded 73 | 74 | 75 | def encode_pair( 76 | first: str, 77 | second: str, 78 | tokenizer: PreTrainedTokenizer, 79 | max_sequence_length: int, 80 | truncation: Union[int, str] = True, 81 | padding: str = "max_length", 82 | return_overflowing_tokens: bool = False, 83 | return_offsets_mapping: bool = False, 84 | stride: int = 0, 85 | allow_null_second: bool = False, 86 | ) -> Dict: 87 | r""" Encode a first-second pair as 88 | [CLS] first [SEP] second [SEP] 89 | or 90 | first second 91 | such that the total length is equal to `max_sequence_length`. 92 | """ 93 | tok_args = dict( 94 | add_special_tokens=True, 95 | return_attention_mask=True, 96 | return_token_type_ids=True, 97 | max_length=max_sequence_length, 98 | truncation=truncation, 99 | padding=padding, 100 | return_overflowing_tokens=return_overflowing_tokens, 101 | return_offsets_mapping=return_offsets_mapping, 102 | stride=stride, 103 | ) 104 | 105 | if second is None and allow_null_second is True: 106 | encoded = tokenizer(first, **tok_args) 107 | else: 108 | encoded = tokenizer(first, second, **tok_args) 109 | return encoded 110 | -------------------------------------------------------------------------------- /transformers_framework/models/joint/fact_checking/base.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | 3 | import torch 4 | 5 | from transformers_framework.architectures.roberta.modeling_config import MONOLITHIC_HEAD_TYPES 6 | from transformers_framework.models.base.classification import BaseModelClassification 7 | 8 | 9 | class BaseJointFactChecking(BaseModelClassification): 10 | 11 | def training_step(self, batch, *args): 12 | r""" Just compute the loss and log it. """ 13 | input_ids, attention_mask, labels = batch["input_ids"], batch["attention_mask"], batch['labels'] 14 | results = self(input_ids=input_ids, attention_mask=attention_mask, labels=labels) 15 | preds = results.seq_class_logits.argmax(dim=-1) 16 | 17 | # logs metrics for each training_step, and the average across the epoch, to the progress bar and logger 18 | self.log('training/loss', results.seq_class_loss, on_epoch=True, prog_bar=True) 19 | self.log('training/accuracy', self.train_acc(preds, labels), on_epoch=True, prog_bar=True) 20 | for class_id, value in enumerate(self.train_f1(preds, labels)): 21 | self.log(f'training/f1_class_{class_id}', value, on_epoch=True, prog_bar=False) 22 | return results.seq_class_loss 23 | 24 | def validation_step(self, batch, *args): 25 | r""" 26 | Compute predictions and log retrieval results. 27 | """ 28 | input_ids, attention_mask, labels = batch["input_ids"], batch["attention_mask"], batch['labels'] 29 | results = self(input_ids=input_ids, attention_mask=attention_mask, labels=labels) 30 | preds = results.seq_class_logits.argmax(dim=-1) 31 | 32 | self.log('validation/loss', results.seq_class_loss, on_epoch=True, prog_bar=True) 33 | self.log('validation/accuracy', self.valid_acc(preds, labels), on_epoch=True, prog_bar=True) 34 | for class_id, value in enumerate(self.valid_f1(preds, labels)): 35 | self.log(f'validation/f1_class_{class_id}', value, on_epoch=True, prog_bar=False) 36 | 37 | def test_step(self, batch, *args): 38 | r""" 39 | Compute predictions and log retrieval results. 40 | """ 41 | input_ids, attention_mask, labels = batch["input_ids"], batch["attention_mask"], batch['labels'] 42 | results = self(input_ids=input_ids, attention_mask=attention_mask, labels=labels) 43 | preds = results.seq_class_logits.argmax(dim=-1) 44 | 45 | self.log('test/loss', results.seq_class_loss, on_epoch=True, prog_bar=True) 46 | self.log('test/accuracy', self.test_acc(preds, labels), on_epoch=True, prog_bar=True) 47 | for class_id, value in enumerate(self.test_f1(preds, labels)): 48 | self.log(f'test/f1_class_{class_id}', value, on_epoch=True, prog_bar=False) 49 | 50 | def predict_step(self, batch, *args): 51 | r""" 52 | Compute predictions. 53 | """ 54 | input_ids, attention_mask, keys = batch["input_ids"], batch["attention_mask"], batch['keys'] 55 | results = self(input_ids=input_ids, attention_mask=attention_mask) 56 | preds = results.seq_class_logits.argmax(dim=-1) 57 | return {'preds': preds, 'keys': keys} 58 | 59 | def predict_epoch_end(self, predictions): 60 | r""" Receive a list of predictions and return a dict to write to files. """ 61 | preds = torch.cat([o['preds'] for o in predictions], dim=0) 62 | keys = torch.cat([o['keys'] for o in predictions], dim=0) 63 | assert preds.shape == keys.shape 64 | return {'preds': preds.flatten(), 'keys': keys.flatten()} 65 | 66 | @staticmethod 67 | def add_model_specific_args(parser: ArgumentParser): 68 | super(BaseJointFactChecking, BaseJointFactChecking).add_model_specific_args(parser) 69 | parser.set_defaults(max_sequence_length=64) 70 | parser.add_argument('--head_type', type=str, required=True, choices=MONOLITHIC_HEAD_TYPES) 71 | -------------------------------------------------------------------------------- /process_datasets/utils/general.py: -------------------------------------------------------------------------------- 1 | import re 2 | from typing import Any, Dict, Generator, Iterable, List 3 | 4 | from blingfire import text_to_sentences 5 | 6 | 7 | cleaner = re.compile(r"\s+") 8 | 9 | 10 | def clean_sentences(sentences: List[str], min_sentence_length: int = 1) -> Generator[str, None, None]: 11 | r""" Check that sentences are long enough and non empty. """ 12 | for sentence in sentences: 13 | sentence = sentence.strip() 14 | if len(sentence) >= min_sentence_length: 15 | yield sentence 16 | 17 | 18 | def clean_paragraphs( 19 | paragraphs: List[str], 20 | min_paragraph_length: int = 1, 21 | min_sentences_per_paragraph: int = 1, 22 | min_sentence_length: int = 1, 23 | ) -> Generator[str, None, None]: 24 | r""" () is remainder after link in it was filtered out. """ 25 | for paragraph in paragraphs: 26 | paragraphs = cleaner.sub(" ", paragraph.strip()).replace("()", "") 27 | if len(paragraph) >= min_paragraph_length: 28 | paragraph = list( 29 | clean_sentences( 30 | text_to_sentences(paragraph).split("\n"), min_sentence_length=min_sentence_length 31 | ) 32 | ) 33 | if len(paragraph) >= min_sentences_per_paragraph: 34 | yield paragraph 35 | 36 | 37 | def clean_documents( 38 | documents: Iterable[Dict], 39 | paragraph_separator: str = "\n\n", 40 | min_sentence_length: int = 1, 41 | min_sentences_per_paragraph: int = 1, 42 | min_paragraph_length: int = 1, 43 | min_paragraphs_per_document: int = 1, 44 | min_document_length: int = 1, 45 | ) -> List[List[List[str]]]: 46 | r""" Clean every document by splitting it in paragraphs and then by splitting each paragraph in sentences. """ 47 | 48 | for document in documents: 49 | document = document.strip() 50 | if len(document) >= min_document_length: 51 | # generic filter on min length and special chars at the paragraph level 52 | document = list( 53 | clean_paragraphs( 54 | re.split(paragraph_separator, document), 55 | min_paragraph_length=min_paragraph_length, 56 | min_sentence_length=min_sentence_length, 57 | min_sentences_per_paragraph=min_sentences_per_paragraph, 58 | ) 59 | ) 60 | if len(document) >= min_paragraphs_per_document: 61 | yield document 62 | 63 | 64 | def check_dict(dictionary: Dict): 65 | return all(v is not None for v in dictionary.values()) 66 | 67 | 68 | def cumsum_limit(values: Iterable[int], limit: int) -> int: 69 | r""" Return the position of the element which first violates the limit in the cumulative sum. """ 70 | count = 0 71 | for i, v in enumerate(values): 72 | count += v 73 | if count > limit: 74 | return i 75 | return i 76 | 77 | 78 | def dict2list(data: Dict[Any, List]) -> List[Dict]: 79 | r""" Convert a dict or lists to a list of dicts. """ 80 | values = list(data.values()) 81 | assert all(isinstance(v, list) for v in values) 82 | assert all(len(v) == len(values[0]) for v in values) 83 | 84 | if not data or all(len(v) == 0 for v in values): 85 | return [] 86 | 87 | keys = data.keys() 88 | res = [ 89 | {a: b for a, b in zip(keys, values)} 90 | for values in zip(*[data[key] for key in keys]) 91 | ] 92 | return res 93 | 94 | 95 | def list2dict(data: List[Dict]) -> Dict[Any, List]: 96 | r""" Convert a list of dicts to a dict of lists. """ 97 | if not data: 98 | return {} 99 | 100 | assert all(isinstance(d, dict) for d in data) 101 | keys = data[0].keys() 102 | assert all(d.keys() == keys for d in data) 103 | 104 | res = {k: [d[k] for d in data] for k in keys} 105 | return res 106 | -------------------------------------------------------------------------------- /transformers_framework/architectures/modeling_head.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Tuple 3 | 4 | import torch 5 | from torch import nn 6 | from transformers_lightning.language_modeling import IGNORE_IDX 7 | 8 | from transformers_framework.architectures.modeling_output import SequenceClassificationOutput 9 | 10 | 11 | class ClassificationHead(nn.Module): 12 | r""" Head for sentence-level classification tasks. """ 13 | 14 | def __init__(self, config, hidden_size: int = None, num_labels: int = None): 15 | super().__init__() 16 | classifier_dropout = ( 17 | config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob 18 | ) 19 | hidden_size = hidden_size if hidden_size is not None else config.hidden_size 20 | num_labels = num_labels if num_labels is not None else config.num_labels 21 | 22 | self.dropout = nn.Dropout(classifier_dropout) 23 | self.dense = nn.Linear(hidden_size, hidden_size) 24 | self.out_proj = nn.Linear(hidden_size, num_labels) 25 | 26 | def forward(self, features): 27 | features = self.dropout(features) 28 | features = self.dense(features) 29 | features = torch.tanh(features) 30 | features = self.dropout(features) 31 | features = self.out_proj(features) 32 | return features 33 | 34 | 35 | class JointClassificationHead(nn.Module): 36 | 37 | def __init__(self, config): 38 | super().__init__() 39 | hidden_size = config.hidden_size * 2 if config.head_type == "AE_k" else config.hidden_size 40 | 41 | self.config = config 42 | self.cls_positions = [self.config.sentence_msl * i for i in range(self.config.k + 1)] 43 | self.classifier = ClassificationHead(config, hidden_size=hidden_size) 44 | self.loss_fct = nn.CrossEntropyLoss(ignore_index=IGNORE_IDX) 45 | 46 | def forward(self, hidden_state: torch.Tensor, labels: torch.Tensor) -> SequenceClassificationOutput: 47 | r""" 48 | Args: 49 | hidden_state: the last hidden_state of some model with shape (batch_size, (k + 1) * seq_len, hidden_size) 50 | labels: the labels for every candidate with shape (batch_size) or (batch_size, k) 51 | 52 | Return: 53 | the loss and the logits of shape (batch_size, num_labels) or (batch_size, k, num_labels). 54 | """ 55 | 56 | if self.config.head_type == "IE_1": 57 | hidden_state = hidden_state[:, self.cls_positions[0], :] 58 | elif self.config.head_type == "AE_1": 59 | hidden_state = hidden_state[:, self.cls_positions, :].sum(dim=1) 60 | elif self.config.head_type == "IE_k": 61 | hidden_state = hidden_state[:, self.cls_positions[1:], :] 62 | else: # "AE_k" 63 | question_hidden_states = hidden_state[:, [self.cls_positions[0]], :] 64 | candidates_hidden_states = hidden_state[:, self.cls_positions[1:], :] 65 | question_hidden_states = question_hidden_states.expand_as(candidates_hidden_states) 66 | hidden_state = torch.cat([question_hidden_states, candidates_hidden_states], dim=2) 67 | 68 | logits = self.classifier(hidden_state) 69 | 70 | loss = None 71 | if labels is not None: 72 | assert 1 <= labels.dim() <= 2, "labels must be of shape (batch_size) or (batch_size, k)" 73 | 74 | if self.config.head_type == "IE_1": 75 | assert labels.dim() == 1, "IE_1 classification head needs labels of shape (batch_size)" 76 | elif self.config.head_type == "AE_1": 77 | assert labels.dim() == 1, "AE_1 classification head needs labels of shape (batch_size)" 78 | elif self.config.head_type == "IE_k": 79 | assert labels.dim() == 2, "IE_k classification head needs labels of shape (batch_size, k)" 80 | else: # "AE_k" 81 | assert labels.dim() == 2, "AE_k classification head needs labels of shape (batch_size, k)" 82 | 83 | loss = self.loss_fct(logits.view(-1, self.config.num_labels), labels.flatten()) 84 | 85 | return SequenceClassificationOutput(seq_class_loss=loss, seq_class_logits=logits) 86 | -------------------------------------------------------------------------------- /transformers_framework/transformations/conversion_transformation.py: -------------------------------------------------------------------------------- 1 | from argparse import Namespace 2 | from types import GeneratorType 3 | from typing import Dict, Generator, Union 4 | 5 | from transformers_framework.transformations.transformation import Transformation 6 | from transformers_framework.utilities.functional import apply_to_generator 7 | from transformers_framework.utilities.structures import JointwiseSample, PairwiseSample 8 | 9 | 10 | # Dict -> PairwiseSample 11 | class DictSample2PairwiseSampleTransformation(Transformation): 12 | r""" Transforms DictSample to DataSample. """ 13 | 14 | def __init__( 15 | self, 16 | hyperparameters: Namespace, 17 | first_field: str, 18 | second_field: str, 19 | key_field: str = 'key', 20 | label_field: str = 'label', 21 | score_field: str = 'score', 22 | ): 23 | super().__init__(hyperparameters) 24 | self.first_field = first_field 25 | self.second_field = second_field 26 | self.key_field = key_field 27 | self.label_field = label_field 28 | self.score_field = score_field 29 | 30 | def dict_sample_to_pairwise_sample(self, sample: Dict) -> PairwiseSample: 31 | r""" Transforming Dict instances to PairwiseSample. """ 32 | assert isinstance(sample, Dict), f"input must be of type Dict, found {sample.__class__.__name__}" 33 | return PairwiseSample( 34 | first=" ".join(sample[f] for f in self.first_field.split(":")), 35 | second=" ".join(sample[f] for f in self.second_field.split(":")), 36 | key=sample[self.key_field] if self.key_field in sample else None, 37 | label=sample[self.label_field] if self.label_field in sample else None, 38 | score=sample[self.score_field] if self.score_field in sample else None, 39 | ) 40 | 41 | def __call__( 42 | self, 43 | samples: Union[Generator[Dict, None, None], Dict] 44 | ) -> Union[Generator[PairwiseSample, None, None], PairwiseSample]: 45 | r""" Transforming Dict instances to PairwiseSample. """ 46 | if isinstance(samples, GeneratorType): 47 | return apply_to_generator(samples, self.dict_sample_to_pairwise_sample) 48 | else: 49 | return self.dict_sample_to_pairwise_sample(samples) 50 | 51 | 52 | # Dict -> JointwiseSample 53 | class DictSample2JointwiseSampleTransformation(Transformation): 54 | r""" Transforms DictSample to JointwiseSample. """ 55 | 56 | def __init__( 57 | self, 58 | hyperparameters: Namespace, 59 | first_field: str, 60 | seconds_field: str, 61 | key_field: str = 'key', 62 | label_field: str = 'label', 63 | score_field: str = 'score', 64 | ): 65 | super().__init__(hyperparameters) 66 | 67 | assert ":" not in first_field 68 | assert ":" not in seconds_field 69 | 70 | self.first_field = first_field 71 | self.seconds_field = seconds_field 72 | self.key_field = key_field 73 | self.label_field = label_field 74 | self.score_field = score_field 75 | 76 | def dict_sample_to_jointwise_sample(self, sample: Dict) -> JointwiseSample: 77 | r""" Transform a Dict instance to JointwiseSample. """ 78 | assert isinstance(sample, Dict), f"input must be of type Dict, found {sample.__class__.__name__}" 79 | return JointwiseSample( 80 | first=sample[self.first_field], 81 | seconds=sample[self.seconds_field], 82 | key=sample[self.key_field] if self.key_field in sample else None, 83 | label=sample[self.label_field] if self.label_field in sample else None, 84 | score=sample[self.score_field] if self.score_field in sample else None, 85 | valid=[True] * len(sample[self.seconds_field]), 86 | ) 87 | 88 | def __call__( 89 | self, 90 | samples: Union[Generator[Dict, None, None], Dict] 91 | ) -> Union[Generator[JointwiseSample, None, None], JointwiseSample]: 92 | r""" Transform a Dict instance to JointwiseSample. """ 93 | if isinstance(samples, GeneratorType): 94 | return apply_to_generator(samples, self.dict_sample_to_jointwise_sample) 95 | else: 96 | return self.dict_sample_to_jointwise_sample(samples) 97 | -------------------------------------------------------------------------------- /transformers_framework/models/joint/as2/base.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | 3 | import torch 4 | from transformers_lightning.language_modeling import IGNORE_IDX 5 | 6 | from transformers_framework.architectures.roberta.modeling_config import MONOLITHIC_HEAD_TYPES 7 | from transformers_framework.models.base.as2 import BaseModelAS2 8 | from transformers_framework.utilities.functional import index_multi_tensors 9 | 10 | 11 | class BaseJointAS2(BaseModelAS2): 12 | 13 | def training_step(self, batch, *args): 14 | r""" Just compute the loss and log it. """ 15 | input_ids, attention_mask, labels = batch["input_ids"], batch["attention_mask"], batch['labels'] 16 | results = self(input_ids=input_ids, attention_mask=attention_mask, labels=labels) 17 | 18 | preds = results.seq_class_logits.argmax(dim=-1) 19 | preds, labels = index_multi_tensors(preds, labels, positions=labels != IGNORE_IDX) 20 | 21 | # logs metrics for each training_step, and the average across the epoch, to the progress bar and logger 22 | self.log('training/loss', results.seq_class_loss, on_epoch=True, prog_bar=True) 23 | self.log('training/accuracy', self.train_acc(preds, labels), on_epoch=True, prog_bar=True) 24 | 25 | return results.seq_class_loss 26 | 27 | def validation_step(self, batch, *args): 28 | r""" 29 | Compute predictions and log retrieval results. 30 | """ 31 | input_ids, attention_mask, labels, keys, valid = ( 32 | batch["input_ids"], batch["attention_mask"], batch['labels'], batch['keys'], batch['valid'] 33 | ) 34 | results = self(input_ids=input_ids, attention_mask=attention_mask, labels=labels) 35 | 36 | keys = keys.unsqueeze(-1).expand_as(labels) 37 | logits, labels, keys = index_multi_tensors(results.seq_class_logits, labels, keys, positions=valid) 38 | 39 | preds = logits.argmax(dim=-1) 40 | scores = torch.softmax(logits, dim=-1)[:, -1] 41 | 42 | self.valid_map.update(preds=scores, target=labels, indexes=keys) 43 | self.valid_mrr.update(preds=scores, target=labels, indexes=keys) 44 | self.valid_p1.update(preds=scores, target=labels, indexes=keys) 45 | self.valid_hr5.update(preds=scores, target=labels, indexes=keys) 46 | self.valid_ndgc.update(preds=scores, target=labels, indexes=keys) 47 | 48 | self.log('validation/loss', results.seq_class_loss, on_epoch=True, prog_bar=True) 49 | self.log('validation/accuracy', self.valid_acc(preds, labels), on_epoch=True, prog_bar=True) 50 | 51 | def test_step(self, batch, *args): 52 | r""" 53 | Compute predictions and log retrieval results. 54 | """ 55 | input_ids, attention_mask, labels, keys, valid = ( 56 | batch["input_ids"], batch["attention_mask"], batch['labels'], batch['keys'], batch['valid'] 57 | ) 58 | results = self(input_ids=input_ids, attention_mask=attention_mask, labels=labels) 59 | 60 | keys = keys.unsqueeze(-1).expand_as(labels) 61 | logits, labels, keys = index_multi_tensors(results.seq_class_logits, labels, keys, positions=valid) 62 | 63 | preds = logits.argmax(dim=-1) 64 | scores = torch.softmax(logits, dim=-1)[:, -1] 65 | 66 | self.test_map.update(preds=scores, target=labels, indexes=keys) 67 | self.test_mrr.update(preds=scores, target=labels, indexes=keys) 68 | self.test_p1.update(preds=scores, target=labels, indexes=keys) 69 | self.test_hr5.update(preds=scores, target=labels, indexes=keys) 70 | self.test_ndgc.update(preds=scores, target=labels, indexes=keys) 71 | 72 | self.log('test/loss', results.seq_class_loss, on_epoch=True, prog_bar=True) 73 | self.log('test/accuracy', self.test_acc(preds, labels), on_epoch=True, prog_bar=True) 74 | 75 | def predict_step(self, batch, *args): 76 | r""" Like test step but without metrics. """ 77 | input_ids, attention_mask, keys, valid = ( 78 | batch["input_ids"], batch["attention_mask"], batch['keys'], batch['valid'] 79 | ) 80 | results = self(input_ids=input_ids, attention_mask=attention_mask) 81 | 82 | keys = keys.unsqueeze(-1).expand_as(valid) 83 | logits, keys = index_multi_tensors(results.seq_class_logits, keys, positions=valid) 84 | scores = torch.softmax(logits, dim=-1)[..., -1] 85 | 86 | return {'scores': scores, 'keys': keys} 87 | 88 | @staticmethod 89 | def add_model_specific_args(parser: ArgumentParser): 90 | super(BaseJointAS2, BaseJointAS2).add_model_specific_args(parser) 91 | parser.set_defaults(max_sequence_length=64) 92 | parser.add_argument('--head_type', type=str, required=True, choices=MONOLITHIC_HEAD_TYPES) 93 | -------------------------------------------------------------------------------- /transformers_framework/models/base/classification.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchmetrics.classification import Accuracy, F1Score 3 | 4 | from transformers_framework.models.base.base import BaseModel 5 | from transformers_framework.utilities import shrink_batch 6 | 7 | 8 | class BaseModelClassification(BaseModel): 9 | 10 | def __init__(self, hyperparameters): 11 | super().__init__(hyperparameters) 12 | 13 | self.train_acc = Accuracy(num_classes=self.hyperparameters.num_labels) 14 | self.train_f1 = F1Score(num_classes=self.hyperparameters.num_labels, average=None) 15 | 16 | self.valid_acc = Accuracy(num_classes=self.hyperparameters.num_labels) 17 | self.valid_f1 = F1Score(num_classes=self.hyperparameters.num_labels, average=None) 18 | 19 | self.test_acc = Accuracy(num_classes=self.hyperparameters.num_labels) 20 | self.test_f1 = F1Score(num_classes=self.hyperparameters.num_labels, average=None) 21 | 22 | def training_step(self, batch, *args): 23 | r""" 24 | Just compute the loss and log it. 25 | """ 26 | input_ids, attention_mask, labels = batch["input_ids"], batch["attention_mask"], batch['labels'] 27 | token_type_ids = batch.get('token_type_ids', None) 28 | 29 | input_ids, attention_mask, token_type_ids = shrink_batch( 30 | input_ids, attention_mask, token_type_ids, pad_token_id=self.tokenizer.pad_token_id 31 | ) 32 | 33 | results = self( 34 | input_ids=input_ids, 35 | attention_mask=attention_mask, 36 | token_type_ids=token_type_ids, 37 | labels=labels, 38 | ) 39 | preds = results.logits.argmax(dim=-1) 40 | 41 | # logs metrics for each training_step, and the average across the epoch, to the progress bar and logger 42 | self.log('training/loss', results.loss, on_epoch=True, prog_bar=True) 43 | self.log('training/accuracy', self.train_acc(preds, labels), on_epoch=True, prog_bar=True) 44 | for class_id, value in enumerate(self.train_f1(preds, labels)): 45 | self.log(f'training/f1_class_{class_id}', value, on_epoch=True, prog_bar=False) 46 | 47 | return results.loss 48 | 49 | def validation_step(self, batch, *args): 50 | r""" 51 | Compute predictions and log retrieval results. 52 | """ 53 | input_ids, attention_mask, labels = batch["input_ids"], batch["attention_mask"], batch['labels'] 54 | token_type_ids = batch.get('token_type_ids', None) 55 | 56 | input_ids, attention_mask, token_type_ids = shrink_batch( 57 | input_ids, attention_mask, token_type_ids, pad_token_id=self.tokenizer.pad_token_id 58 | ) 59 | 60 | results = self( 61 | input_ids=input_ids, 62 | attention_mask=attention_mask, 63 | token_type_ids=token_type_ids, 64 | labels=labels, 65 | ) 66 | preds = results.logits.argmax(dim=-1) 67 | 68 | self.log('validation/accuracy', self.valid_acc(preds, labels), on_epoch=True, prog_bar=True) 69 | for class_id, value in enumerate(self.valid_f1(preds, labels)): 70 | self.log(f'validation/f1_class_{class_id}', value, on_epoch=True, prog_bar=False) 71 | 72 | def test_step(self, batch, *args): 73 | r""" 74 | Compute predictions and log retrieval results. 75 | """ 76 | input_ids, attention_mask, labels = batch["input_ids"], batch["attention_mask"], batch['labels'] 77 | token_type_ids = batch.get('token_type_ids', None) 78 | 79 | input_ids, attention_mask, token_type_ids = shrink_batch( 80 | input_ids, attention_mask, token_type_ids, pad_token_id=self.tokenizer.pad_token_id 81 | ) 82 | 83 | results = self( 84 | input_ids=input_ids, 85 | attention_mask=attention_mask, 86 | token_type_ids=token_type_ids, 87 | labels=labels, 88 | ) 89 | preds = results.logits.argmax(dim=-1) 90 | 91 | self.log('test/accuracy', self.test_acc(preds, labels), on_epoch=True, prog_bar=True) 92 | for class_id, value in enumerate(self.test_f1(preds, labels)): 93 | self.log(f'test/f1_class_{class_id}', value, on_epoch=True, prog_bar=False) 94 | 95 | def predict_step(self, batch, *args): 96 | r""" 97 | Compute predictions. 98 | """ 99 | input_ids, attention_mask, keys = batch["input_ids"], batch["attention_mask"], batch['keys'] 100 | token_type_ids = batch.get('token_type_ids', None) 101 | 102 | input_ids, attention_mask, token_type_ids = shrink_batch( 103 | input_ids, attention_mask, token_type_ids, pad_token_id=self.tokenizer.pad_token_id 104 | ) 105 | 106 | results = self( 107 | input_ids=input_ids, 108 | attention_mask=attention_mask, 109 | token_type_ids=token_type_ids, 110 | ) 111 | preds = results.logits.argmax(dim=-1) 112 | return {'preds': preds, 'keys': keys} 113 | 114 | def predict_epoch_end(self, predictions): 115 | r""" Receive a list of predictions and return a dict to write to files. """ 116 | preds = torch.cat([o['preds'] for o in predictions], dim=0) 117 | keys = torch.cat([o['keys'] for o in predictions], dim=0) 118 | 119 | assert preds.shape == keys.shape 120 | return {'preds': preds.flatten(), 'keys': keys.flatten()} 121 | -------------------------------------------------------------------------------- /transformers_framework/models/joint/mlm_as2/base.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | from typing import Any, List 3 | 4 | from torchmetrics.classification.accuracy import Accuracy 5 | from torchmetrics.retrieval import ( 6 | RetrievalHitRate, 7 | RetrievalMAP, 8 | RetrievalMRR, 9 | RetrievalNormalizedDCG, 10 | RetrievalPrecision, 11 | ) 12 | from transformers_lightning.language_modeling.masked_language_modeling import IGNORE_IDX 13 | 14 | from transformers_framework.architectures.roberta.modeling_config import MONOLITHIC_HEAD_TYPES 15 | from transformers_framework.models.base.mlm import BaseModelMLM 16 | from transformers_framework.utilities import index_multi_tensors 17 | 18 | 19 | class BaseJointMLMAndAS2(BaseModelMLM): 20 | 21 | def __init__(self, hyperparameters): 22 | super().__init__(hyperparameters) 23 | 24 | self.train_acc = Accuracy() 25 | 26 | self.valid_acc = Accuracy() 27 | self.valid_map = RetrievalMAP() 28 | self.valid_mrr = RetrievalMRR() 29 | self.valid_p1 = RetrievalPrecision(k=1) 30 | self.valid_hr5 = RetrievalHitRate(k=5) 31 | self.valid_ndgc = RetrievalNormalizedDCG() 32 | 33 | self.valid_acc = Accuracy() 34 | self.valid_map = RetrievalMAP() 35 | self.valid_mrr = RetrievalMRR() 36 | self.valid_p1 = RetrievalPrecision(k=1) 37 | self.valid_hr5 = RetrievalHitRate(k=5) 38 | self.valid_ndgc = RetrievalNormalizedDCG() 39 | 40 | def training_step(self, batch, *args): 41 | r""" 42 | Start by masking tokens some tokens. 43 | """ 44 | input_ids, attention_mask, labels, valid = ( 45 | batch["input_ids"], batch["attention_mask"], batch["labels"], batch["valid"] 46 | ) 47 | input_ids, mlm_labels = self.mlm(input_ids) 48 | results = self(input_ids=input_ids, attention_mask=attention_mask, labels=mlm_labels, class_labels=labels) 49 | 50 | # MLM part 51 | mlm_predictions = results.masked_lm_logits.argmax(dim=-1) 52 | mlm_predictions, mlm_labels = index_multi_tensors( 53 | mlm_predictions, mlm_labels, positions=mlm_labels != IGNORE_IDX 54 | ) 55 | 56 | loss = results.masked_lm_loss + results.seq_class_loss 57 | 58 | # logs metrics for each training_step, and the average across the epoch, to the progress bar and logger 59 | self.log('training/loss', loss, on_epoch=True, prog_bar=True) 60 | 61 | # MLM part 62 | self.log('training/mlm_loss', results.masked_lm_loss, on_epoch=True, prog_bar=True) 63 | self.log('training/mlm_accuracy', self.train_mlm_acc(mlm_predictions, mlm_labels), on_epoch=True) 64 | 65 | # Class part 66 | class_preds = results.seq_class_logits.argmax(dim=-1) 67 | class_preds, labels = index_multi_tensors(class_preds, labels, positions=valid) 68 | 69 | self.log('training/classification_loss', results.seq_class_loss, on_epoch=True) 70 | self.log('training/classification_accuracy', self.train_acc(class_preds, labels), on_epoch=True, prog_bar=True) 71 | 72 | return loss 73 | 74 | def validation_step(self, batch, *args, **kwargs): 75 | r""" 76 | Start by masking tokens some tokens. 77 | """ 78 | input_ids, attention_mask, labels, keys, valid = ( 79 | batch["input_ids"], batch["attention_mask"], batch['labels'], batch['keys'], batch['valid'] 80 | ) 81 | input_ids, mlm_labels = self.mlm(input_ids) 82 | results = self(input_ids=input_ids, attention_mask=attention_mask, labels=mlm_labels, class_labels=labels) 83 | 84 | # MLM part 85 | mlm_predictions = results.masked_lm_logits.argmax(dim=-1) 86 | mlm_predictions, mlm_labels = index_multi_tensors( 87 | mlm_predictions, mlm_labels, positions=mlm_labels != IGNORE_IDX 88 | ) 89 | 90 | loss = results.masked_lm_loss + results.seq_class_loss 91 | 92 | # logs metrics for each training_step, and the average across the epoch, to the progress bar and logger 93 | self.log('validation/loss', loss, on_epoch=True, prog_bar=True) 94 | 95 | # MLM part 96 | self.log('validation/mlm_loss', results.masked_lm_loss, on_epoch=True, prog_bar=True) 97 | self.log('validation/mlm_accuracy', self.train_mlm_acc(mlm_predictions, mlm_labels), on_epoch=True) 98 | 99 | # Class part 100 | keys = keys.unsqueeze(-1).expand_as(labels) 101 | logits, labels, keys = index_multi_tensors(results.seq_class_logits, labels, keys, positions=valid) 102 | 103 | preds = logits.argmax(dim=-1) 104 | scores = logits.softmax(dim=-1)[:, -1] 105 | 106 | self.valid_map.update(preds=scores, target=labels, indexes=keys) 107 | self.valid_mrr.update(preds=scores, target=labels, indexes=keys) 108 | self.valid_p1.update(preds=scores, target=labels, indexes=keys) 109 | self.valid_hr5.update(preds=scores, target=labels, indexes=keys) 110 | self.valid_ndgc.update(preds=scores, target=labels, indexes=keys) 111 | 112 | self.log('validation/classification_loss', results.seq_class_loss, on_epoch=True) 113 | self.log('validation/classification_accuracy', self.valid_acc(preds, labels), on_epoch=True, prog_bar=True) 114 | 115 | def validation_epoch_end(self, outputs: List[Any]) -> None: 116 | r""" Just log metrics. """ 117 | self.log('validation/map', self.valid_map, on_epoch=True) 118 | self.log('validation/mrr', self.valid_mrr, on_epoch=True) 119 | self.log('validation/p1', self.valid_p1, on_epoch=True) 120 | self.log('validation/hr5', self.valid_hr5, on_epoch=True) 121 | self.log('validation/ndcg', self.valid_ndgc, on_epoch=True) 122 | 123 | @staticmethod 124 | def add_model_specific_args(parser: ArgumentParser): 125 | super(BaseJointMLMAndAS2, BaseJointMLMAndAS2).add_model_specific_args(parser) 126 | parser.set_defaults(max_sequence_length=64) 127 | parser.add_argument('--head_type', type=str, required=True, choices=MONOLITHIC_HEAD_TYPES) 128 | -------------------------------------------------------------------------------- /transformers_framework/samplers/keys_sampler.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Dict, Iterator, Optional 3 | 4 | import torch 5 | from torch.utils.data import Dataset 6 | from torch.utils.data.distributed import DistributedSampler 7 | from torch.utils.data.sampler import Sampler, T_co 8 | 9 | 10 | def get_keys_to_indexes(dataset: Dataset) -> Dict: 11 | r""" Returns a dict mapping keys to tuples (key, index). """ 12 | keys_to_indexes = {} 13 | 14 | for i in range(len(dataset)): 15 | sample = dataset._get_sample(i) 16 | if sample.key not in keys_to_indexes: 17 | keys_to_indexes[sample.key] = [] 18 | keys_to_indexes[sample.key].append((sample.key, i)) 19 | 20 | return keys_to_indexes 21 | 22 | 23 | def get_indexes( 24 | dataset: Dataset, 25 | generator: torch.Generator, 26 | shuffle: bool = True, 27 | ): 28 | r""" Simply iterate over the keys and provide all the data in a shuffled-at-the-key-level order. """ 29 | 30 | keys_to_indexes = get_keys_to_indexes(dataset) 31 | if shuffle is True: 32 | permutation = torch.randperm(len(keys_to_indexes), generator=generator).tolist() 33 | else: 34 | permutation = torch.arange(len(keys_to_indexes)).tolist() 35 | keys_indexes = [x for i in permutation for x in keys_to_indexes[i]] 36 | 37 | # simply yield every index 38 | for triple in keys_indexes: 39 | yield triple[1] 40 | 41 | 42 | class KeysSampler(Sampler[T_co]): 43 | r"""Samples elements randomly by providing elements with the same key sequentially. 44 | 45 | Args: 46 | dataset (Dataset): dataset to sample from 47 | generator (Generator): Generator used in sampling. 48 | """ 49 | 50 | def __init__( 51 | self, 52 | dataset: Dataset, 53 | generator: torch.Generator = None, 54 | shuffle: bool = True, 55 | ) -> None: 56 | self.dataset = dataset 57 | self.generator = generator 58 | self.shuffle = shuffle 59 | 60 | def __iter__(self) -> Iterator[T_co]: 61 | if self.generator is None: 62 | generator = torch.Generator() 63 | generator.manual_seed(int(torch.empty((), dtype=torch.int64).random_().item())) 64 | else: 65 | generator = self.generator 66 | 67 | yield from get_indexes( 68 | dataset=self.dataset, 69 | generator=generator, 70 | shuffle=self.shuffle, 71 | ) 72 | 73 | def __len__(self): 74 | return len(self.dataset) 75 | 76 | 77 | class DistributedKeysSampler(DistributedSampler[T_co]): 78 | r"""Sampler that restricts data loading to a subset of the indexes returned by KeysSampler. 79 | 80 | It is especially useful in conjunction with 81 | :class:`torch.nn.parallel.DistributedDataParallel`. In such a case, each 82 | process can pass a :class:`~torch.utils.data.DistributedSampler` instance as a 83 | :class:`~torch.utils.data.DataLoader` sampler, and load a subset of the 84 | original dataset that is exclusive to it. 85 | 86 | .. note:: 87 | Dataset is assumed to be of constant size. 88 | 89 | Args: 90 | dataset: Dataset used for sampling. 91 | num_replicas (int, optional): Number of processes participating in 92 | distributed training. By default, :attr:`world_size` is retrieved from the 93 | current distributed group. 94 | rank (int, optional): Rank of the current process within :attr:`num_replicas`. 95 | By default, :attr:`rank` is retrieved from the current distributed 96 | group. 97 | shuffle (bool, optional): If ``True`` (default), sampler will shuffle the 98 | indices. 99 | generator (Generator): Generator used in sampling. 100 | drop_last (bool, optional): if ``True``, then the sampler will drop the 101 | tail of the data to make it evenly divisible across the number of 102 | replicas. If ``False``, the sampler will add extra indices to make 103 | the data evenly divisible across the replicas. Default: ``False``. 104 | 105 | .. warning:: 106 | In distributed mode, calling the :meth:`set_epoch` method at 107 | the beginning of each epoch **before** creating the :class:`DataLoader` iterator 108 | is necessary to make shuffling work properly across multiple epochs. Otherwise, 109 | the same ordering will be always used. 110 | 111 | Example:: 112 | >>> sampler = DistributedSampler(dataset) if is_distributed else None 113 | >>> loader = DataLoader(dataset, shuffle=(sampler is None), 114 | ... sampler=sampler) 115 | >>> for epoch in range(start_epoch, n_epochs): 116 | ... if is_distributed: 117 | ... sampler.set_epoch(epoch) 118 | ... train(loader) 119 | """ 120 | 121 | def __init__( 122 | self, 123 | dataset: Dataset, 124 | num_replicas: Optional[int] = None, 125 | rank: Optional[int] = None, 126 | shuffle: bool = True, 127 | seed: int = 0, 128 | drop_last: bool = False, 129 | ) -> None: 130 | super().__init__( 131 | dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle, seed=seed, drop_last=drop_last 132 | ) 133 | 134 | def __iter__(self) -> Iterator[T_co]: 135 | 136 | # deterministically shuffle based on epoch and seed 137 | generator = torch.Generator() 138 | generator.manual_seed(self.seed + self.epoch) 139 | 140 | # shuffle at the key level, not at the index one 141 | indices = get_indexes( 142 | dataset=self.dataset, 143 | generator=generator, 144 | shuffle=self.shuffle, 145 | ) 146 | indices = list(indices) 147 | 148 | if not self.drop_last: 149 | # add extra samples to make it evenly divisible 150 | padding_size = self.total_size - len(indices) 151 | if padding_size <= len(indices): 152 | indices += indices[:padding_size] 153 | else: 154 | indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size] 155 | else: 156 | # remove tail of data to make it evenly divisible. 157 | indices = indices[:self.total_size] 158 | assert len(indices) == self.total_size 159 | 160 | # subsample 161 | indices = indices[self.rank * self.num_samples:(self.rank + 1) * self.num_samples] 162 | assert len(indices) == self.num_samples 163 | 164 | return iter(indices) 165 | 166 | def __len__(self) -> int: 167 | return self.num_samples 168 | 169 | def set_epoch(self, epoch: int) -> None: 170 | r""" 171 | Sets the epoch for this sampler. When :attr:`shuffle=True`, this ensures all replicas 172 | use a different random ordering for each epoch. Otherwise, the next iteration of this 173 | sampler will yield the same ordering. 174 | Args: 175 | epoch (int): Epoch number. 176 | """ 177 | self.epoch = epoch 178 | -------------------------------------------------------------------------------- /transformers_framework/adapters/map_adapters/arrow/arrow_adapter.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser, Namespace 2 | from json import JSONDecodeError 3 | from typing import Dict 4 | 5 | from datasets import DatasetDict, concatenate_datasets, load_dataset, load_from_disk 6 | from pytorch_lightning import _logger as logger 7 | from pytorch_lightning.utilities import rank_zero_warn 8 | from transformers import PreTrainedTokenizer 9 | 10 | from transformers_framework.adapters.map_adapters.map_adapter import MapAdapter 11 | from transformers_framework.utilities.processors import data_processor 12 | from transformers_framework.utilities.structures import DataSample 13 | 14 | 15 | def load_dataset_from_disk(path: str, keep_in_memory: bool = False, split: str = None) -> DatasetDict: 16 | r""" Load both Dataset's dumps and json folders transparently from disk. """ 17 | try: 18 | res = load_from_disk(path, keep_in_memory=keep_in_memory) 19 | if split is not None and split != "-": 20 | res = res[split] 21 | except FileNotFoundError: 22 | try: 23 | res = load_dataset('json', data_dir=path, keep_in_memory=keep_in_memory)['train'] 24 | if split is not None and split != "-": 25 | rank_zero_warn( 26 | "Jsonl dataset does not require a split, just use `--splits -`." 27 | " For this run I will set `--splits -` for you." 28 | ) 29 | except JSONDecodeError: 30 | logger.error( 31 | f"Could not load dataset from {path}. " 32 | f"Make sure this path is a valid folder containing jsonl files or a dataset dump." 33 | ) 34 | exit(1) 35 | return res 36 | 37 | 38 | class ArrowAdapter(MapAdapter): 39 | r""" Superclass of Arrow File readers, which implements filtering on scores and limits. """ 40 | 41 | def __init__( 42 | self, 43 | hyperparameters: Namespace, 44 | tokenizer: PreTrainedTokenizer, 45 | stage_name: str, 46 | seed: int = 0, 47 | ) -> None: 48 | super().__init__(hyperparameters, tokenizer, stage_name, seed=seed) 49 | 50 | if self.is_active(): 51 | self.data = self.load_data() 52 | 53 | def load_data(self): 54 | r""" Load data from disk first parsing input parameters. This method should be protected by `is_active`. """ 55 | filepaths = self.hyperparameters[f'{self.stage_name}_filepaths'] 56 | splits = self.hyperparameters[f'{self.stage_name}_splits'] 57 | 58 | if len(splits) == 1 and len(filepaths) > 1: 59 | splits = splits * len(filepaths) 60 | 61 | assert len(splits) == len(filepaths), ( 62 | "You must provide a single split for every dataset or a split for every dataset" 63 | ) 64 | 65 | rank_zero_warn(f"Loading datasets from disk{' and concatenating' if len(filepaths) > 1 else ''}...") 66 | data = concatenate_datasets([ 67 | load_dataset_from_disk(filepath, keep_in_memory=self.hyperparameters.keep_in_memory, split=split) 68 | for split, filepath in zip(splits, filepaths) 69 | ]) 70 | 71 | for field in self.hyperparameters.field_names: 72 | for f in field.split(":"): 73 | assert f in data.column_names, ( 74 | f"column {f} was not found among available dataset's columns {data.column_names}" 75 | ) 76 | return data 77 | 78 | def is_active(self) -> bool: 79 | return self.hyperparameters[f'{self.stage_name}_filepaths'] is not None 80 | 81 | def check(self, sample: Dict, idx: int): 82 | if any(v is None for v in sample.values()): 83 | if not hasattr(self, "already_logged_warning"): 84 | self.already_logged_warning = True 85 | rank_zero_warn( 86 | f"Sample {sample} with id {idx} seems incomplete. Will not log warning like this anymore." 87 | ) 88 | 89 | def __getitem__(self, idx) -> Dict: 90 | r""" Get dict of data at a given position. """ 91 | sample = self.data[idx] 92 | if self.hyperparameters.generate_key is True and ('key' not in sample or sample['key'] is None): 93 | sample['key'] = idx 94 | self.check(sample, idx) # check sample is valid and raise warning the first time 95 | return self.transformations(sample) 96 | 97 | def __len__(self) -> int: 98 | return len(self.data) 99 | 100 | def __iter__(self): 101 | for idx, sample in enumerate(self.data): 102 | self.check(sample, idx) # check sample is valid and raise warning 103 | yield self.transformations(sample) 104 | 105 | def preprocess_line(self, sample: DataSample) -> Dict: 106 | r""" 107 | Process a line. The structure of each line is exactly 108 | the same returned by the __iter__ method. Here you should do data preparation 109 | for the actual model being trained. This is a good place to do tokenization, 110 | padding and so on. 111 | """ 112 | return data_processor( 113 | sample, 114 | tokenizer=self.tokenizer, 115 | max_sequence_length=self.hyperparameters.max_sequence_length, 116 | chain=self.hyperparameters.chain, 117 | ) 118 | 119 | @staticmethod 120 | def add_adapter_specific_args(parser: ArgumentParser): 121 | super(ArrowAdapter, ArrowAdapter).add_adapter_specific_args(parser) 122 | parser.add_argument('--keep_in_memory', action="store_true", help="Read whole Dataset into memory.") 123 | parser.add_argument('--max_sequence_length', required=True, type=int, help="Model max sequence length") 124 | parser.add_argument( 125 | '--chain', action="store_true", help="Whether to chain sentences when encoded separately" 126 | ) 127 | parser.add_argument( 128 | '--field_names', 129 | required=True, 130 | nargs='+', 131 | help="Names of the fields of the input data to use for training. Use : to concatenate field together." 132 | ) 133 | parser.add_argument( 134 | '--key_name', 135 | required=False, 136 | default=None, 137 | help="Name of the key field" 138 | ) 139 | parser.add_argument( 140 | '--label_name', 141 | required=False, 142 | default=None, 143 | help="Name of the label field" 144 | ) 145 | parser.add_argument('--generate_key', action="store_true", help="Use dataset indexes as keys.") 146 | 147 | @staticmethod 148 | def add_adapter_instance_specific_args(parser: ArgumentParser, stage_name: str): 149 | super(ArrowAdapter, ArrowAdapter).add_adapter_instance_specific_args(parser, stage_name=stage_name) 150 | parser.add_argument( 151 | f'--{stage_name}_filepaths', 152 | type=str, 153 | required=False, 154 | default=None, 155 | nargs='+', 156 | help=f"Path to {stage_name} dataset dump", 157 | ) 158 | parser.add_argument( 159 | f'--{stage_name}_splits', 160 | type=str, 161 | required=False, 162 | default=[stage_name], 163 | nargs='+', 164 | help="The dataset split to load.", 165 | ) 166 | -------------------------------------------------------------------------------- /transformers_framework/models/base/as2.py: -------------------------------------------------------------------------------- 1 | from typing import Any, List 2 | 3 | import torch 4 | from torchmetrics.classification import Accuracy 5 | from torchmetrics.retrieval import ( 6 | RetrievalHitRate, 7 | RetrievalMAP, 8 | RetrievalMRR, 9 | RetrievalNormalizedDCG, 10 | RetrievalPrecision, 11 | ) 12 | from transformers_lightning.language_modeling import IGNORE_IDX 13 | 14 | from transformers_framework.models.base.base import BaseModel 15 | from transformers_framework.utilities import shrink_batch 16 | from transformers_framework.utilities.functional import index_multi_tensors 17 | 18 | 19 | class BaseModelAS2(BaseModel): 20 | 21 | def __init__(self, hyperparameters): 22 | super().__init__(hyperparameters) 23 | 24 | self.train_acc = Accuracy() 25 | 26 | self.valid_acc = Accuracy() 27 | self.valid_map = RetrievalMAP() 28 | self.valid_mrr = RetrievalMRR() 29 | self.valid_p1 = RetrievalPrecision(k=1) 30 | self.valid_hr5 = RetrievalHitRate(k=5) 31 | self.valid_ndgc = RetrievalNormalizedDCG() 32 | 33 | self.test_acc = Accuracy() 34 | self.test_map = RetrievalMAP() 35 | self.test_mrr = RetrievalMRR() 36 | self.test_p1 = RetrievalPrecision(k=1) 37 | self.test_hr5 = RetrievalHitRate(k=5) 38 | self.test_ndgc = RetrievalNormalizedDCG() 39 | 40 | def training_step(self, batch, *args): 41 | r""" Just compute the loss and log it. """ 42 | input_ids, attention_mask, labels = batch["input_ids"], batch["attention_mask"], batch['labels'] 43 | token_type_ids = batch.get('token_type_ids', None) 44 | 45 | input_ids, attention_mask, token_type_ids = shrink_batch( 46 | input_ids, attention_mask, token_type_ids, pad_token_id=self.tokenizer.pad_token_id 47 | ) 48 | 49 | results = self( 50 | input_ids=input_ids, 51 | attention_mask=attention_mask, 52 | token_type_ids=token_type_ids, 53 | labels=labels, 54 | ) 55 | preds = results.logits.argmax(dim=-1) 56 | preds, labels = index_multi_tensors(preds, labels, positions=labels != IGNORE_IDX) 57 | 58 | # logs metrics for each training_step, and the average across the epoch, to the progress bar and logger 59 | self.log('training/loss', results.loss, on_epoch=True, prog_bar=True) 60 | self.log('training/accuracy', self.train_acc(preds, labels), on_epoch=True, prog_bar=True) 61 | return results.loss 62 | 63 | def validation_step(self, batch, *args): 64 | r""" Compute predictions and log retrieval results. """ 65 | input_ids, attention_mask, labels, keys = ( 66 | batch["input_ids"], batch["attention_mask"], batch['labels'], batch['keys'] 67 | ) 68 | token_type_ids = batch.get('token_type_ids', None) 69 | 70 | input_ids, attention_mask, token_type_ids = shrink_batch( 71 | input_ids, attention_mask, token_type_ids, pad_token_id=self.tokenizer.pad_token_id 72 | ) 73 | 74 | results = self( 75 | input_ids=input_ids, 76 | attention_mask=attention_mask, 77 | token_type_ids=token_type_ids, 78 | labels=labels, 79 | ) 80 | preds = results.logits.argmax(dim=-1) 81 | scores = torch.softmax(results.logits, dim=-1)[..., -1] 82 | preds, scores, labels = index_multi_tensors(preds, scores, labels, positions=labels != IGNORE_IDX) 83 | 84 | self.valid_map.update(preds=scores, target=labels, indexes=keys) 85 | self.valid_mrr.update(preds=scores, target=labels, indexes=keys) 86 | self.valid_p1.update(preds=scores, target=labels, indexes=keys) 87 | self.valid_hr5.update(preds=scores, target=labels, indexes=keys) 88 | self.valid_ndgc.update(preds=scores, target=labels, indexes=keys) 89 | 90 | self.log('validation/accuracy', self.valid_acc(preds, labels), on_epoch=True, prog_bar=True) 91 | 92 | def validation_epoch_end(self, outputs: List[Any]) -> None: 93 | r""" Just log metrics. """ 94 | self.log('validation/map', self.valid_map, on_epoch=True) 95 | self.log('validation/mrr', self.valid_mrr, on_epoch=True) 96 | self.log('validation/p1', self.valid_p1, on_epoch=True) 97 | self.log('validation/hr5', self.valid_hr5, on_epoch=True) 98 | self.log('validation/ndcg', self.valid_ndgc, on_epoch=True) 99 | 100 | def test_step(self, batch, *args): 101 | r""" Compute predictions and log retrieval results. """ 102 | input_ids, attention_mask, labels, keys = ( 103 | batch["input_ids"], batch["attention_mask"], batch['labels'], batch['keys'] 104 | ) 105 | token_type_ids = batch.get('token_type_ids', None) 106 | 107 | input_ids, attention_mask, token_type_ids = shrink_batch( 108 | input_ids, attention_mask, token_type_ids, pad_token_id=self.tokenizer.pad_token_id 109 | ) 110 | 111 | results = self( 112 | input_ids=input_ids, 113 | attention_mask=attention_mask, 114 | token_type_ids=token_type_ids, 115 | labels=labels, 116 | ) 117 | preds = results.logits.argmax(dim=-1) 118 | scores = torch.softmax(results.logits, dim=-1)[..., -1] 119 | preds, scores, labels = index_multi_tensors(preds, scores, labels, positions=labels != IGNORE_IDX) 120 | 121 | self.test_map.update(preds=scores, target=labels, indexes=keys) 122 | self.test_mrr.update(preds=scores, target=labels, indexes=keys) 123 | self.test_p1.update(preds=scores, target=labels, indexes=keys) 124 | self.test_hr5.update(preds=scores, target=labels, indexes=keys) 125 | self.test_ndgc.update(preds=scores, target=labels, indexes=keys) 126 | 127 | self.log('test/accuracy', self.test_acc(preds, labels), on_epoch=True, prog_bar=True) 128 | 129 | def test_epoch_end(self, outputs: List[Any]) -> None: 130 | r""" Just log metrics. """ 131 | self.log('test/map', self.test_map, on_step=False, on_epoch=True) 132 | self.log('test/mrr', self.test_mrr, on_step=False, on_epoch=True) 133 | self.log('test/p1', self.test_p1, on_step=False, on_epoch=True) 134 | self.log('test/hr5', self.test_hr5, on_step=False, on_epoch=True) 135 | self.log('test/ndcg', self.test_ndgc, on_step=False, on_epoch=True) 136 | 137 | def predict_step(self, batch, *args): 138 | r""" Like test step but without metrics and labels. """ 139 | input_ids, attention_mask, keys = ( 140 | batch["input_ids"], batch["attention_mask"], batch['keys'] 141 | ) 142 | token_type_ids = batch.get('token_type_ids', None) 143 | 144 | input_ids, attention_mask, token_type_ids = shrink_batch( 145 | input_ids, attention_mask, token_type_ids, pad_token_id=self.tokenizer.pad_token_id 146 | ) 147 | 148 | results = self( 149 | input_ids=input_ids, 150 | attention_mask=attention_mask, 151 | token_type_ids=token_type_ids, 152 | ) 153 | scores = torch.softmax(results.logits, dim=-1)[:, 1] # take predictions on pos class 154 | return {'scores': scores, 'keys': keys} 155 | 156 | def predict_epoch_end(self, predictions): 157 | r""" Receive a list of predictions and return a List of scores to write to a file. """ 158 | scores = torch.cat([o['scores'] for o in predictions], dim=0) 159 | keys = torch.cat([o['keys'] for o in predictions], dim=0) 160 | 161 | assert scores.shape == keys.shape 162 | return {'scores': scores.flatten(), 'keys': keys.flatten()} 163 | -------------------------------------------------------------------------------- /transformers_framework/adapters/map_adapters/arrow/jointwise_adapter.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser, Namespace 2 | from typing import Dict 3 | 4 | from pytorch_lightning.trainer.states import TrainerFn 5 | from transformers import PreTrainedTokenizer 6 | 7 | from transformers_framework.adapters.map_adapters.arrow.pairwise_adapter import PairwiseArrowAdapter 8 | from transformers_framework.transformations.conversion_transformation import DictSample2JointwiseSampleTransformation 9 | from transformers_framework.transformations.filtering_transformation import ( 10 | FilterJointSampleOnKTransformation, 11 | FilterJointwiseSampleOnScoreAndLabelTransformation, 12 | selection_possibilities, 13 | ) 14 | from transformers_framework.transformations.transformation import Transformation, TransformationsConcatenation 15 | from transformers_framework.utilities import JointwiseSample 16 | from transformers_framework.utilities.processors import joint_processor 17 | 18 | 19 | class JointwiseArrowAdapter(PairwiseArrowAdapter): 20 | r""" Jointwise version of Arrow File readers, which implements filtering on scores and limits. """ 21 | 22 | def __init__( 23 | self, 24 | hyperparameters: Namespace, 25 | tokenizer: PreTrainedTokenizer, 26 | stage: TrainerFn, 27 | seed: int = 0, 28 | ) -> None: 29 | super().__init__(hyperparameters, tokenizer, stage, seed=seed) 30 | 31 | # arguments check 32 | assert type(self) != JointwiseArrowAdapter or isinstance(hyperparameters.k, int), ( 33 | f"provided `k` {hyperparameters.k} is not an integer" 34 | ) 35 | assert type(self) != JointwiseArrowAdapter or ( 36 | hyperparameters.selection is not None and hyperparameters.selection in selection_possibilities 37 | ), ( 38 | f"provided `selection` {hyperparameters.selection} is not among " 39 | f"the accepted values {selection_possibilities} or is `None`" 40 | ) 41 | assert type(self) != JointwiseArrowAdapter or ( 42 | hyperparameters.selection != 'all' or hyperparameters.force_load_dataset_in_memory is True 43 | ), "selection=`all` requires `force_load_dataset_in_memory` to be true" 44 | assert hyperparameters.min_threshold is None or isinstance(hyperparameters.min_threshold, float), ( 45 | "`min_threshold` must be a float or None" 46 | ) 47 | assert hyperparameters.max_threshold is None or isinstance(hyperparameters.max_threshold, float), ( 48 | "`max_threshold` must be a float or None" 49 | ) 50 | assert hyperparameters.max_positives is None or isinstance(hyperparameters.max_positives, int), ( 51 | "`max_positives` must be a int or None" 52 | ) 53 | assert hyperparameters.max_negatives is None or isinstance(hyperparameters.max_negatives, int), ( 54 | "`max_negatives` must be a int or None" 55 | ) 56 | assert type(self) != JointwiseArrowAdapter or hyperparameters.shuffle_candidates is False or ( 57 | isinstance(hyperparameters.reload_dataloaders_every_n_epochs, int) 58 | and hyperparameters.reload_dataloaders_every_n_epochs > 0 59 | ), "`shuffle_candidates` requires not None `reload_dataloaders_every_n_epochs` greater than 0" 60 | assert type(self) != JointwiseArrowAdapter or hyperparameters.separated is True, ( 61 | "`separated` must be True" 62 | ) 63 | assert type(self) != JointwiseArrowAdapter or len(self.hyperparameters.field_names) == 2, ( 64 | "`field_names` must have length 2" 65 | ) 66 | assert type(self) != JointwiseArrowAdapter or all( 67 | ":" not in field for field in self.hyperparameters.field_names 68 | ), "every `field_names` must not contain the `:`" 69 | 70 | def __iter__(self): 71 | for idx, sample in enumerate(self.data): 72 | self.check(sample, idx) # check sample is valid and raise warning 73 | if self.hyperparameters.selection == 'all': 74 | yield from self.transformations(sample) 75 | else: 76 | yield self.transformations(sample) 77 | 78 | def __get_transformations__(self) -> Transformation: 79 | return TransformationsConcatenation( 80 | self.hyperparameters, 81 | DictSample2JointwiseSampleTransformation( 82 | self.hyperparameters, 83 | first_field=self.hyperparameters.field_names[0], 84 | seconds_field=self.hyperparameters.field_names[1], 85 | key_field=self.hyperparameters.key_name, 86 | label_field=self.hyperparameters.label_name, 87 | score_field=self.hyperparameters.score_name, 88 | ), # Dict -> JointwiseSample 89 | FilterJointwiseSampleOnScoreAndLabelTransformation(self.hyperparameters), # Filter on score and labels 90 | FilterJointSampleOnKTransformation( # Filter on K 91 | self.hyperparameters, 92 | padding=True, 93 | seed=self.seed, 94 | ) # -> JointwiseSample 95 | ) 96 | 97 | def preprocess_line(self, sample: JointwiseSample) -> Dict: 98 | r""" 99 | Process a line. The structure of each line is exactly 100 | the same returned by the __iter__ method. Here you should do data preparation 101 | for the actual model being trained. This is a good place to do batch tokenization, 102 | padding and so on. 103 | """ 104 | return joint_processor( 105 | sample, 106 | tokenizer=self.tokenizer, 107 | separated=self.hyperparameters.separated, 108 | max_sequence_length=self.hyperparameters.max_sequence_length, 109 | reduce_labels=self.hyperparameters.reduce_labels, 110 | ) 111 | 112 | @staticmethod 113 | def add_adapter_specific_args(parser: ArgumentParser): 114 | super(JointwiseArrowAdapter, JointwiseArrowAdapter).add_adapter_specific_args(parser) 115 | parser.add_argument( 116 | '--shuffle_candidates', action="store_true", help="Shuffle candidates when using `k` to select them" 117 | ) 118 | parser.add_argument( 119 | '--min_threshold', 120 | type=float, 121 | required=False, 122 | default=None, 123 | help="Lower threshold for candidates filtering", 124 | ) 125 | parser.add_argument( 126 | '--max_threshold', 127 | type=float, 128 | required=False, 129 | default=None, 130 | help="Upper threshold for candidates filtering", 131 | ) 132 | parser.add_argument( 133 | '--max_negatives', type=int, required=False, default=None, help="Lower threshold for candidates filtering" 134 | ) 135 | parser.add_argument( 136 | '--max_positives', type=int, required=False, default=None, help="Upper threshold for candidates filtering" 137 | ) 138 | parser.add_argument( 139 | '-k', type=int, required=False, default=None, help="Number of candidates to be cumulated on a row" 140 | ) 141 | parser.add_argument( 142 | '--selection', 143 | type=str, 144 | required=False, 145 | default=None, 146 | choices=selection_possibilities, 147 | help="How to get candidates if `-k` is defined", 148 | ) 149 | parser.add_argument( 150 | '--reduce_labels', 151 | action="store_true", 152 | help="Whether to reduce label to single dim when using k. Labels must be all equal in the same group", 153 | ) 154 | parser.add_argument( 155 | '--extended_token_type_ids', 156 | type=int, 157 | default=None, 158 | help="How many extended TT ids should be generated.", 159 | ) 160 | -------------------------------------------------------------------------------- /transformers_framework/datamodules/transformers_datamodule.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser, Namespace 2 | from typing import Callable 3 | 4 | from pytorch_lightning import _logger as logger 5 | from pytorch_lightning.trainer.states import TrainerFn 6 | from pytorch_lightning.trainer.trainer import Trainer 7 | from pytorch_lightning.utilities.distributed import rank_zero_info, rank_zero_warn 8 | from transformers.tokenization_utils import PreTrainedTokenizer 9 | from transformers_lightning.adapters.super_adapter import SuperAdapter 10 | from transformers_lightning.datamodules.adapter_datamodule import AdaptersDataModule 11 | from transformers_lightning.datasets.iterable_dataset import TransformersIterableDataset 12 | from transformers_lightning.datasets.map_dataset import TransformersMapDataset 13 | from transformers_lightning.utils import collate_single_fn 14 | from transformers_lightning.utils.inspectors import get_classes_from_module 15 | 16 | from transformers_framework import adapters 17 | from transformers_framework.adapters.transformer_adapter import TransformersAdapter 18 | from transformers_framework.samplers import DistributedKeysSampler, KeysSampler 19 | from transformers_framework.utilities.datamodules import STAGES_TO_NAMES 20 | 21 | 22 | ADAPTER_CLASSES = get_classes_from_module(adapters, parent=TransformersAdapter) 23 | 24 | 25 | class TransformersDataModule(AdaptersDataModule): 26 | r""" 27 | MultiFileDataModule implements some simple methods to check whether training, val or testing is required. 28 | This shoudl not directly instantiated. 29 | """ 30 | 31 | def __init__( 32 | self, 33 | hyperparameters: Namespace, 34 | trainer: Trainer, 35 | collate_fn: Callable = collate_single_fn, 36 | tokenizer: PreTrainedTokenizer = None, 37 | ): 38 | super().__init__(hyperparameters, trainer, collate_fn=collate_fn) 39 | self.tokenizer = tokenizer 40 | self.adapter_class = ADAPTER_CLASSES[hyperparameters.adapter] 41 | 42 | self.train_adapter = self.get_adapter(TrainerFn.FITTING) 43 | self.valid_adapter = self.get_adapter(TrainerFn.VALIDATING) 44 | self.test_adapter = self.get_adapter(TrainerFn.TESTING) 45 | self.predict_adapter = self.get_adapter(TrainerFn.PREDICTING) 46 | 47 | assert not (hyperparameters.iterable is True and hyperparameters.keep_same_keys_close is True), ( 48 | "cannot use `keep_same_keys_close` with `iterable` and vice-versa." 49 | ) 50 | assert hyperparameters.keep_same_keys_close is False or hyperparameters.replace_sampler_ddp is False, ( 51 | "when using `keep_same_keys_close` you must set `replace_sampler_ddp=False`" 52 | ) 53 | 54 | # Optional, called for every GPU/machine (assigning state is OK) 55 | def setup(self, stage: str = None): 56 | r""" Load datasets only if respective file is defined. """ 57 | 58 | if stage is None: 59 | return 60 | 61 | if stage == TrainerFn.FITTING.value or stage == TrainerFn.VALIDATING.value: 62 | if self.do_train(): 63 | self.train_dataset = self.load_dataset(TrainerFn.FITTING) 64 | if self.do_validation(): 65 | self.valid_dataset = self.load_dataset(TrainerFn.VALIDATING) 66 | 67 | elif stage == TrainerFn.TESTING.value: 68 | if self.do_test(): 69 | self.test_dataset = [self.load_dataset(TrainerFn.TESTING)] 70 | 71 | elif stage == TrainerFn.PREDICTING.value: 72 | if self.do_predict(): 73 | self.predict_dataset = self.load_dataset(TrainerFn.PREDICTING) 74 | 75 | def get_adapter(self, stage: TrainerFn) -> SuperAdapter: 76 | r""" Return the adapter to use. """ 77 | return self.adapter_class( 78 | self.hyperparameters, self.tokenizer, STAGES_TO_NAMES[stage], seed=self.trainer.current_epoch, 79 | ) 80 | 81 | def load_dataset(self, stage: TrainerFn = None): 82 | r""" Load a dataset given the stage name. """ 83 | logger.info(f"Loading {stage.value} dataset...") 84 | adapter = getattr(self, f"{STAGES_TO_NAMES[stage]}_adapter") 85 | dataset_class = TransformersIterableDataset if self.hyperparameters.iterable else TransformersMapDataset 86 | 87 | # map dataset must be told not to load everything in memory 88 | kwargs = {} 89 | if not self.hyperparameters.iterable: 90 | kwargs = dict(keep_in_memory=self.hyperparameters.force_load_dataset_in_memory) 91 | 92 | dataset = dataset_class(self.hyperparameters, adapter, self.trainer, **kwargs) 93 | rank_zero_info( 94 | f"{stage.value.capitalize()} dataset has length " 95 | f"{len(dataset) if not self.hyperparameters.iterable else 'inf'}" 96 | ) 97 | return dataset 98 | 99 | def do_train(self): 100 | return self.train_adapter.is_active() 101 | 102 | def do_validation(self): 103 | return self.valid_adapter.is_active() 104 | 105 | def do_test(self): 106 | return self.test_adapter.is_active() 107 | 108 | def do_predict(self): 109 | return self.predict_adapter.is_active() 110 | 111 | def train_dataloader(self): 112 | r""" Return the training dataloader. 113 | If user requested keep_same_keys_close, we will provide a custom sampler to the dataloader. """ 114 | if not self.do_train(): 115 | return None 116 | 117 | if self.hyperparameters.keep_same_keys_close is True: 118 | # keep keys together only in training 119 | rank_zero_warn("Using custom keys sampler") 120 | sampler_cls = ( 121 | DistributedKeysSampler if self.trainer.accelerator_connector.is_distributed else KeysSampler 122 | ) 123 | sampler = sampler_cls( 124 | self.train_dataset, 125 | keep_same_keys_close=self.hyperparameters.keep_same_keys_close, 126 | shuffle=True, 127 | ) 128 | return self.default_dataloader(self.train_dataset, self.hyperparameters.batch_size, sampler=sampler) 129 | 130 | if ( 131 | self.hyperparameters.reload_dataloaders_every_n_epochs > 0 132 | and self.trainer.current_epoch > 0 133 | and self.trainer.current_epoch % self.hyperparameters.reload_dataloaders_every_n_epochs == 0 134 | ): 135 | rank_zero_warn("Reloading train dataset every epoch.") 136 | self.train_adapter = self.get_adapter(TrainerFn.FITTING) 137 | self.train_dataset = self.load_dataset(TrainerFn.FITTING) 138 | 139 | return self.default_dataloader( 140 | self.train_dataset, 141 | self.hyperparameters.batch_size, 142 | shuffle=not self.hyperparameters.iterable, 143 | prefetch_factor=self.hyperparameters.prefetch_factor, 144 | ) 145 | 146 | @classmethod 147 | def add_datamodule_specific_args(cls, parser: ArgumentParser): 148 | super(TransformersDataModule, TransformersDataModule).add_datamodule_specific_args(parser) 149 | parser.add_argument('--adapter', type=str, required=True, choices=ADAPTER_CLASSES.keys()) 150 | parser.add_argument( 151 | '--prefetch_factor', default=2, type=int, required=False, help='Number of examples to prepare in advance.' 152 | ) 153 | parser.add_argument( 154 | '--keep_same_keys_close', 155 | action="store_true", 156 | help="Keep entries with same first together when shuffling, valid only for training.", 157 | ) 158 | parser.add_argument( 159 | '--force_load_dataset_in_memory', 160 | action="store_true", 161 | help=( 162 | "Load whole dataset in memory even backed by pyarrow." 163 | " This may be usefull with transformations that change number of examples." 164 | ) 165 | ) 166 | tmp_hyperparameters, _ = parser.parse_known_args() 167 | adapter_class = ADAPTER_CLASSES[tmp_hyperparameters.adapter] 168 | for stage_name in STAGES_TO_NAMES.values(): 169 | adapter_class.add_adapter_instance_specific_args(parser, stage_name=stage_name) 170 | adapter_class.add_adapter_specific_args(parser) 171 | 172 | 173 | -------------------------------------------------------------------------------- /transformers_framework/__main__.py: -------------------------------------------------------------------------------- 1 | import os 2 | from argparse import ArgumentParser 3 | 4 | import datasets 5 | import pytorch_lightning as pl 6 | import transformers 7 | from pytorch_lightning import seed_everything 8 | from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor, ModelCheckpoint, RichModelSummary 9 | from pytorch_lightning.loggers import TensorBoardLogger 10 | from pytorch_lightning.plugins import DDPPlugin 11 | from pytorch_lightning.utilities import rank_zero_warn 12 | from transformers_lightning.callbacks import RichProgressBar, TransformersModelCheckpointCallback 13 | from transformers_lightning.defaults import DefaultConfig 14 | from transformers_lightning.utils import get_classes_from_module 15 | 16 | from transformers_framework import models 17 | from transformers_framework.datamodules import TransformersDataModule 18 | from transformers_framework.utilities import ExtendedNamespace, write_dict_to_disk 19 | 20 | 21 | def main(hyperparameters): 22 | 23 | # too much complains of the tokenizers 24 | transformers.logging.set_verbosity_error() 25 | 26 | os.environ['TOKENIZERS_PARALLELISM'] = "true" 27 | datasets.config.IN_MEMORY_MAX_SIZE = 1024 * 1024 * 1024 * hyperparameters.datasets_max_in_memory # in GB 28 | 29 | # set the random seed 30 | seed_everything(seed=hyperparameters.seed, workers=True) 31 | 32 | # instantiate PL model 33 | pl_model_class = all_models[hyperparameters.model] 34 | model = pl_model_class(hyperparameters) 35 | 36 | # default tensorboard logger 37 | tb_logger = TensorBoardLogger( 38 | save_dir=os.path.join(hyperparameters.output_dir, hyperparameters.tensorboard_dir), 39 | name=hyperparameters.name, 40 | ) 41 | loggers = [tb_logger] 42 | 43 | # save pre-trained models to 44 | save_transformers_callback = TransformersModelCheckpointCallback(hyperparameters) 45 | 46 | # and log learning rate 47 | lr_monitor_callback = LearningRateMonitor(logging_interval='step') 48 | 49 | # and normal checkpoints with 50 | checkpoints_dir = os.path.join(hyperparameters.output_dir, hyperparameters.checkpoints_dir, hyperparameters.name) 51 | checkpoint_callback_args = dict(verbose=True, dirpath=checkpoints_dir) 52 | 53 | if hyperparameters.monitor is not None: 54 | checkpoint_callback_args = dict( 55 | **checkpoint_callback_args, 56 | monitor=hyperparameters.monitor, 57 | save_last=True, 58 | mode=hyperparameters.monitor_direction, 59 | save_top_k=1, 60 | ) 61 | checkpoint_callback = ModelCheckpoint(**checkpoint_callback_args) 62 | 63 | # rich progress bar 64 | rich_progress_bar = RichProgressBar(leave=True) 65 | 66 | # modelsummary callback 67 | model_summary = RichModelSummary(max_depth=2) 68 | 69 | # all callbacks 70 | callbacks = [ 71 | save_transformers_callback, 72 | lr_monitor_callback, 73 | checkpoint_callback, 74 | rich_progress_bar, 75 | model_summary, 76 | ] 77 | 78 | # early stopping if defined 79 | if hyperparameters.early_stopping: 80 | if hyperparameters.monitor is None: 81 | raise ValueError("cannot use early_stopping without a monitored variable") 82 | 83 | early_stopping_callback = EarlyStopping( 84 | monitor=hyperparameters.monitor, 85 | patience=hyperparameters.patience, 86 | verbose=True, 87 | mode=hyperparameters.monitor_direction, 88 | ) 89 | callbacks.append(early_stopping_callback) 90 | 91 | # disable find unused parameters to improve performance 92 | kwargs = dict() 93 | if hyperparameters.strategy in ("dp", "ddp2"): 94 | rank_zero_warn("This repo is not designed to work with DataParallel. Use strategy `ddp` instead.") 95 | 96 | if hyperparameters.strategy == "ddp": 97 | kwargs['strategy'] = DDPPlugin(find_unused_parameters=False) 98 | 99 | # instantiate PL trainer 100 | trainer = pl.Trainer.from_argparse_args( 101 | hyperparameters, 102 | default_root_dir=hyperparameters.output_dir, 103 | logger=loggers, 104 | callbacks=callbacks, 105 | profiler='simple', 106 | **kwargs, 107 | ) 108 | 109 | # DataModules 110 | datamodule = TransformersDataModule(hyperparameters, trainer, tokenizer=model.tokenizer) 111 | 112 | # Train! 113 | if datamodule.do_train(): 114 | trainer.fit(model, datamodule=datamodule) 115 | 116 | # Test! 117 | if datamodule.do_test(): 118 | if datamodule.do_train() and hyperparameters.monitor is not None: 119 | rank_zero_warn( 120 | f"Going to test on best ckpt chosen over " 121 | f"{hyperparameters.monitor}: {checkpoint_callback.best_model_path}" 122 | ) 123 | trainer.test(datamodule=datamodule, ckpt_path='best') 124 | else: 125 | rank_zero_warn("Going to test on last or pretrained ckpt") 126 | trainer.test(model, datamodule=datamodule) 127 | 128 | if datamodule.do_predict(): 129 | assert hasattr(model, "predict_step") and hasattr(model, "predict_epoch_end"), ( 130 | "To do predictions, the model must implement both `predict_step` and `predict_epoch_end`" 131 | ) 132 | 133 | if trainer._accelerator_connector.is_distributed: 134 | rank_zero_warn("Predicting on more than 1 GPU may give results in different order, use keys to sort them.") 135 | 136 | predictions = trainer.predict(model, datamodule=datamodule, return_predictions=True) 137 | predictions = model.predict_epoch_end(predictions) 138 | 139 | basepath = os.path.join(hyperparameters.output_dir, hyperparameters.predictions_dir, hyperparameters.name) 140 | write_dict_to_disk(predictions, basepath, trainer=trainer) 141 | 142 | 143 | if __name__ == '__main__': 144 | 145 | # Read config for defaults and eventually override with hyperparameters from command line 146 | parser = ArgumentParser(add_help=False) 147 | 148 | # model classname 149 | all_models = get_classes_from_module(models, parent=pl.LightningModule) 150 | parser.add_argument('--model', type=str, required=True, choices=all_models.keys()) 151 | 152 | # experiment name, used both for checkpointing, pre_trained_names, logging and tensorboard 153 | parser.add_argument('--name', type=str, required=True, help='Name of the model') 154 | 155 | # various options 156 | parser.add_argument('--seed', type=int, default=1337, help='Set the random seed') 157 | parser.add_argument('--monitor', type=str, help='Value to monitor for best checkpoint', default=None) 158 | parser.add_argument( 159 | '--monitor_direction', type=str, help='Monitor value direction for best', default='max', choices=['min', 'max'] 160 | ) 161 | parser.add_argument('--early_stopping', action="store_true", help="Use early stopping") 162 | parser.add_argument( 163 | '--patience', 164 | type=int, 165 | default=5, 166 | required=False, 167 | help="Number of non-improving validations to wait before early stopping" 168 | ) 169 | parser.add_argument( 170 | '--find_unused_parameters', 171 | action="store_true", 172 | help="Whether to check for unused params at each iteration" 173 | ) 174 | parser.add_argument( 175 | '--datasets_max_in_memory', type=int, default=0, help="Datasets max in memory cache (in GB)" 176 | ) 177 | 178 | # I/O folders 179 | parser.add_argument( 180 | '--predictions_dir', type=str, default="predictions", required=False, help="Predictions folder" 181 | ) 182 | 183 | DefaultConfig.add_defaults_args(parser) 184 | 185 | # retrieving model with temporary parsered arguments 186 | tmp_params, extra = parser.parse_known_args() 187 | 188 | # get pl_model_class in advance to know which params it needs 189 | all_models[tmp_params.model].add_model_specific_args(parser) 190 | TransformersDataModule.add_datamodule_specific_args(parser) 191 | 192 | # add callback / logger specific parameters 193 | TransformersModelCheckpointCallback.add_callback_specific_args(parser) 194 | 195 | # add all the available trainer options to argparse 196 | # ie: now --devices --num_nodes ... --fast_dev_run all work in the cli 197 | parser = pl.Trainer.add_argparse_args(parser) 198 | 199 | # get NameSpace of paramters 200 | hyperparameters = parser.parse_args() 201 | hyperparameters = ExtendedNamespace.from_namespace(hyperparameters) 202 | main(hyperparameters) 203 | -------------------------------------------------------------------------------- /process_datasets/strategies/paragraph/base.py: -------------------------------------------------------------------------------- 1 | import random 2 | from abc import abstractmethod 3 | from argparse import ArgumentParser, Namespace 4 | from typing import Dict, Generator, List, Tuple 5 | 6 | from process_datasets.strategies.strategy import Strategy 7 | from process_datasets.utils.general import clean_documents 8 | 9 | 10 | class _ParagraphStrategy(Strategy): 11 | 12 | def __init__(self, hparams: Namespace): 13 | super().__init__(hparams) 14 | 15 | assert self.hparams.min_sentence_length >= 1, ( 16 | "`--min_sentence_length` must be a positive integer" 17 | ) 18 | assert self.hparams.min_paragraph_length >= 1, ( 19 | "`--min_paragraph_length` must be a positive integer" 20 | ) 21 | assert self.hparams.min_document_length >= 1, ( 22 | "`--min_document_length` must be a positive integer" 23 | ) 24 | 25 | assert self.hparams.min_sentences_per_paragraph >= 1, ( 26 | "`--min_sentences_per_paragraph` must be a positive integer" 27 | ) 28 | assert self.hparams.min_paragraphs_per_document >= 1, ( 29 | "`--min_paragraphs_per_document` must be a positive integer" 30 | ) 31 | 32 | assert self.hparams.paragraph_ratio is None or (0.0 <= self.hparams.paragraph_ratio <= 1.0), ( 33 | "`--paragraph_ratio` must be a float in [0.0, 1.0]" 34 | ) 35 | assert 0.0 <= self.hparams.document_ratio <= 1.0, ( 36 | "`--document_ratio` must be a float in [0.0, 1.0]" 37 | ) 38 | 39 | self.actual = None 40 | 41 | @abstractmethod 42 | def make(self, documents: Generator[List[List[str]], None, None]) -> Generator[Dict, None, None]: 43 | r""" Create and yield examples that will be returned. """ 44 | 45 | def process_batch(self, batch: List[Dict]) -> List[Dict]: 46 | r""" Process a batch of texts. """ 47 | documents = (b[self.hparams.field] for b in batch) 48 | documents = clean_documents( 49 | documents=documents, 50 | paragraph_separator=self.hparams.paragraph_separator, 51 | min_sentence_length=self.hparams.min_sentence_length, 52 | min_paragraph_length=self.hparams.min_paragraph_length, 53 | min_document_length=self.hparams.min_document_length, 54 | min_sentences_per_paragraph=self.hparams.min_sentences_per_paragraph, 55 | min_paragraphs_per_document=self.hparams.min_paragraphs_per_document, 56 | ) 57 | examples = list(self.make(documents)) 58 | return examples 59 | 60 | @staticmethod 61 | def add_arguments_to_argparse(parser: ArgumentParser): 62 | super(_ParagraphStrategy, _ParagraphStrategy).add_arguments_to_argparse(parser) 63 | parser.add_argument('--paragraph_separator', default='\n\n', required=False, 64 | help="Split documents into paragraphs on this characted (string)") 65 | parser.add_argument('--min_sentence_length', type=int, default=20, required=False, 66 | help="Minimum length to consider a sentence (in characters)") 67 | parser.add_argument('--min_paragraph_length', type=int, default=60, required=False, 68 | help="Minimum length to consider a paragraph (in characters)") 69 | parser.add_argument('--min_document_length', type=int, default=200, required=False, 70 | help="Minimum length to consider a document (in characters)") 71 | parser.add_argument('--min_sentences_per_paragraph', type=int, default=1, required=False, 72 | help="Minimum number of cleaned sentences per paragraph (in characters)") 73 | parser.add_argument('--min_paragraphs_per_document', type=int, default=1, required=False, 74 | help="Minimum number of cleaned paragraphs per document (in characters)") 75 | parser.add_argument( 76 | '--paragraph_ratio', 77 | type=float, 78 | default=None, 79 | help=( 80 | "How many paragraphs per documents should be used as pivot. " 81 | "None means 1 per document. A float between 0.0 and 1.0 " 82 | "means the corrisponding percentage of documents. " 83 | "With 0.0, no output pairs will be created." 84 | ) 85 | ) 86 | parser.add_argument( 87 | '--document_ratio', 88 | type=float, 89 | default=1.0, 90 | help=( 91 | "How many documents should be considered. None means all. " 92 | "A float will be used as probability to select a document. " 93 | "Negatives may anyway considered discarded documents." 94 | ) 95 | ) 96 | 97 | 98 | class _PairwiseStrategy(_ParagraphStrategy): 99 | 100 | def __init__(self, hparams: Namespace): 101 | super().__init__(hparams) 102 | 103 | if len(self.hparams.max_negatives) == 1: 104 | self.hparams.max_negatives = (self.hparams.max_negatives[0], self.hparams.max_negatives[0]) 105 | 106 | if self.hparams.max_hard_negatives is not None and len(self.hparams.max_hard_negatives) == 1: 107 | self.hparams.max_hard_negatives = (self.hparams.max_hard_negatives[0], self.hparams.max_hard_negatives[0]) 108 | 109 | assert len(self.hparams.max_negatives) == 2, ( 110 | "`--max_negatives` must be an integer or a range" 111 | ) 112 | assert self.hparams.max_hard_negatives is None or len(self.hparams.max_hard_negatives) == 2, ( 113 | "`--max_hard_negatives` must be an integer or a range" 114 | ) 115 | 116 | def get_random_max_negatives(self) -> int: 117 | r""" Get random number of negatives. """ 118 | return random.randint(*self.hparams.max_negatives) 119 | 120 | def get_random_max_hard_negatives(self) -> int: 121 | r""" Get random number of hard negatives. """ 122 | if self.hparams.max_hard_negatives is None: 123 | raise ValueError("Cannot call `get_random_max_hard_negatives` without setting `max_hard_negatives`") 124 | return random.randint(*self.hparams.max_hard_negatives) 125 | 126 | def extract_random_span(self, paragraph: List[str], length: int, return_remain: bool = False): 127 | r""" Extract a random span of sentences from a paragraph. """ 128 | position = random.randint(0, len(paragraph) - length) 129 | if return_remain: 130 | res = paragraph[:position] + paragraph[position + length:] 131 | return paragraph[position:position + length], res 132 | else: 133 | return paragraph[position:position + length] 134 | 135 | def extract_two_random_spans( 136 | self, paragraph: List[str], length_1: int, length_2: int, return_remain: bool = False 137 | ) -> Tuple: 138 | r""" Extract a random span of sentences from a paragraph. """ 139 | position_1 = random.randint(0, len(paragraph) - (length_1 + length_2)) 140 | position_2 = random.randint(position_1 + length_1, len(paragraph) - length_2) 141 | if return_remain: 142 | res = ( 143 | paragraph[:position_1] \ 144 | + paragraph[position_1 + length_1:position_2] \ 145 | + paragraph[position_2 + length_2:] 146 | ) 147 | return paragraph[position_1:position_1 + length_1], paragraph[position_2:position_2 + length_2], res 148 | else: 149 | return paragraph[position_1:position_1 + length_1], paragraph[position_2:position_2 + length_2] 150 | 151 | @staticmethod 152 | def add_arguments_to_argparse(parser: ArgumentParser): 153 | super(_PairwiseStrategy, _PairwiseStrategy).add_arguments_to_argparse(parser) 154 | parser.add_argument('--max_negatives', default=None, required=True, type=int, nargs='+', 155 | help="Max number or range of all negatives.") 156 | parser.add_argument('--max_hard_negatives', default=None, required=False, type=int, nargs='+', 157 | help="Max number or range of hard negatives (from same document).") 158 | 159 | 160 | class _SortingStrategy(_ParagraphStrategy): 161 | 162 | def __init__(self, hparams: Namespace): 163 | super().__init__(hparams) 164 | 165 | assert 0 < self.hparams.number_range[0] <= self.hparams.number_range[1], ( 166 | "`--number_range` must be a non-empty range with limits greater than 0" 167 | ) 168 | 169 | self.hparams.number_range = tuple(range(self.hparams.number_range[0], self.hparams.number_range[1] + 1)) 170 | 171 | if self.hparams.number_probs is None: 172 | self.hparams.number_probs = (1.0, ) * len(self.hparams.number_range) 173 | 174 | @staticmethod 175 | def add_arguments_to_argparse(parser: ArgumentParser): 176 | super(_SortingStrategy, _SortingStrategy).add_arguments_to_argparse(parser) 177 | parser.add_argument('--number_range', required=True, type=int, nargs=2, 178 | help="Min and max number of paragraphs to order.") 179 | parser.add_argument('--number_probs', default=None, required=False, type=float, nargs='+', 180 | help="Probabilities of paragraphs numbers.") 181 | -------------------------------------------------------------------------------- /transformers_framework/utilities/functional.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import json 3 | import os 4 | from collections.abc import Iterable 5 | from functools import partial 6 | from itertools import combinations 7 | from string import ascii_uppercase 8 | from typing import Any, Callable, Dict, Generator, List, Sequence, Union 9 | 10 | import torch 11 | from pytorch_lightning.trainer.trainer import Trainer 12 | from torch import Tensor 13 | from transformers import PreTrainedTokenizer 14 | from transformers.pipelines.base import pad_collate_fn 15 | from transformers_lightning.utils import filter_generator 16 | 17 | 18 | def bool_to_int(string): 19 | return int(string.lower().strip() in ('yes', 'pos', 'positive', '1', 'correct', 'true')) 20 | 21 | 22 | def to_int(string): 23 | return int(string.lower().strip()) 24 | 25 | 26 | def to_float(string): 27 | return float(string.strip()) 28 | 29 | 30 | def _check_types(argument: str, types=[]): 31 | r""" Parse argument in one of the given types (in order) and return converted value. """ 32 | for _type in types: 33 | try: 34 | if _type is bool: 35 | if argument.lower() not in ('true', 'false'): 36 | raise ValueError() 37 | x = (argument.lower() == 'true') 38 | else: 39 | x = _type(argument) 40 | return x 41 | except ValueError: 42 | pass 43 | raise TypeError(f"Argument {argument} is not of allowed types: {types}") 44 | 45 | 46 | def check_types(*types): 47 | r""" Parse argument in one of the given types (in order) and return converted value. """ 48 | return partial(_check_types, types=types) 49 | 50 | 51 | def split(_list: Sequence, part_length: int, drop_last: bool = False): 52 | r""" 53 | Split a list `_list` in parts of length `part_length`. 54 | Eventually drop last piece if it would have been shorter. """ 55 | assert isinstance(part_length, int) and part_length > 0 56 | assert isinstance(_list, (list, tuple)) 57 | 58 | res = [] 59 | for i in range(0, len(_list), part_length): 60 | res.append(_list[i: i + part_length]) 61 | 62 | if drop_last and len(res[-1]) < part_length: 63 | res = res[:-1] 64 | 65 | return res 66 | 67 | 68 | def l2_norm(x, y, dim: int = -1, keepdim: bool = False, normalize: bool = True): # noqa: E741 69 | r""" Computes L-Norm between two tensors on the given dimension. """ 70 | if normalize: 71 | x = x / torch.linalg.norm(x, ord=2, dim=dim, keepdim=True) 72 | y = y / torch.linalg.norm(y, ord=2, dim=dim, keepdim=True) 73 | 74 | return (x - y).pow(2).sum(dim=dim, keepdim=keepdim).sqrt() 75 | 76 | 77 | def _get_scattered_tensor(size: int, device: torch.device): 78 | indexes = list(zip(*[[x, x + 1] if x % 2 == 0 else [x, x - 1] for x in range(size)])) 79 | res = torch.zeros(size, size, dtype=torch.bool, device=device, requires_grad=False) 80 | res[indexes] = True 81 | return res 82 | 83 | 84 | cache = {} 85 | 86 | def get_scattered_tensor(size: int, device: torch.device, use_cache: bool = True): 87 | r""" Return a tensor (matrix) with the following True values: 88 | Example with size = 4 89 | 0 1 0 0 90 | 1 0 0 0 91 | 0 0 0 1 92 | 0 0 1 0 93 | """ 94 | if use_cache is False: 95 | return _get_scattered_tensor(size, device) 96 | 97 | if size not in cache: 98 | cache[size] = _get_scattered_tensor(size, device) 99 | 100 | return cache[size] 101 | 102 | 103 | def expand_logits(logits: torch.Tensor) -> torch.Tensor: 104 | probs = torch.sigmoid(logits) 105 | logits = torch.stack([1 - probs, probs], dim=-1).log() 106 | return logits 107 | 108 | 109 | def split_list(_list: List, length: int): 110 | r""" Split a list `l` in parts on length `length`. """ 111 | assert length >= 1 112 | 113 | index = 0 114 | while True: 115 | yield _list[index:index + length] 116 | index += length 117 | if index >= len(_list): 118 | break 119 | 120 | 121 | def get_rng_index(list_or_tuple) -> int: 122 | return torch.randint(0, len(list_or_tuple), size=()).item() 123 | 124 | 125 | def write_dict_to_disk(_dict: Dict, folder_path, trainer: Trainer): 126 | r""" Write some dict to disk as key-value pairs. """ 127 | os.makedirs(folder_path, exist_ok=True) 128 | for key, values in _dict.items(): 129 | if isinstance(values, Tensor): 130 | values = extract_data_from_every_tensor(values) 131 | filename = os.path.join(folder_path, f"{key}-{trainer.global_rank}.tsv") 132 | with open(filename, "w") as fo: 133 | for line in values: 134 | if isinstance(line, (list, tuple)): 135 | line = "\t".join(line) 136 | fo.write(f"{line}\n") 137 | 138 | 139 | def shrink_batch( 140 | input_ids: torch.Tensor, 141 | attention_mask: torch.Tensor, 142 | token_type_ids: torch.Tensor = None, 143 | pad_token_id: int = 0, 144 | ): 145 | r""" Remove data on the sequence length dimension in the positions where every example is padded. """ 146 | indexes = (input_ids != pad_token_id).any(dim=0) 147 | return ( 148 | input_ids[..., indexes], 149 | attention_mask[..., indexes], 150 | token_type_ids[..., indexes] if token_type_ids is not None else None, 151 | ) 152 | 153 | 154 | def shrink_batch_dict( 155 | batch: Dict, 156 | pad_token_id: int = 0, 157 | ): 158 | r""" Remove data on the sequence length dimension in the positions where every example is padded. """ 159 | indexes = (batch['input_ids'] != pad_token_id).any(dim=0) 160 | return {k: v[..., indexes] if v is not None else None for k, v in batch.items()} 161 | 162 | 163 | def pad_sequence(sequence: List, padding_value: Any, length: int): 164 | r""" Pad a sequence with values up to length. """ 165 | sequence += [padding_value] * (length - len(sequence)) 166 | return sequence 167 | 168 | 169 | def string_to_signature(string, length: int = 16): 170 | return hashlib.sha1(string.encode("utf-8")).hexdigest()[:length] 171 | 172 | 173 | def special_zip(*iterators) -> Iterable: 174 | r""" Zip allowing None iterators (which will be threated as infinite None generators. """ 175 | def inf_gen(): 176 | while True: 177 | yield None 178 | iterators = (iter(iterator) if iterator is not None else inf_gen() for iterator in iterators) 179 | yield from zip(*iterators) 180 | 181 | 182 | def none_if_all_none(iterable: Iterable) -> Union[Iterable, None]: 183 | r""" If all elements in iterable are None, return None, else return iterable. """ 184 | if all(x is None for x in iterable): 185 | return None 186 | return iterable 187 | 188 | 189 | def extract_data_from_every_tensor(tensor: Tensor): 190 | r""" Extract list of data from every kind of tensor on every device. """ 191 | return tensor.cpu().detach().tolist() 192 | 193 | 194 | def sample_from_distribution(logits: Tensor, sample_function: str = 'gumbel'): 195 | r""" 196 | Sample from generator logits either using gumbel distrib or multinomial distribution. 197 | Reimplement gumbel softmax because there is a bug in torch.nn.functional.gumbel_softmax 198 | when fp16 is used (https://github.com/pytorch/pytorch/issues/41663). 199 | Code taken from 200 | https://github.com/richarddwang/electra_pytorch/blob/9b2533e62cd1b6126feca323fb7b48480b8c2df0/pretrain.py#L318. 201 | Gumbel softmax is equal to what official ELECTRA code do, 202 | standard gumbel dist. = -ln(-ln(standard uniform dist.)) 203 | """ 204 | if sample_function == 'gumbel': 205 | loc = torch.tensor(0., device=logits.device, dtype=logits.dtype) 206 | scale = torch.tensor(1., device=logits.device, dtype=logits.dtype) 207 | gumbel_dist = torch.distributions.gumbel.Gumbel(loc, scale) 208 | return (logits + gumbel_dist.sample(logits.shape)).argmax(dim=-1) 209 | elif sample_function == 'multinomial': 210 | return torch.multinomial(torch.softmax(logits, dim=-1), 1).squeeze() 211 | else: 212 | raise ValueError("`sample_function` not valid, choose between 'gumbel' and 'multinomial'") 213 | 214 | 215 | def apply_to_generator(generator: Generator, function: Callable) -> Generator: 216 | r""" Apply a function to every element of a generator. """ 217 | yield from (function(element) for element in generator) 218 | 219 | 220 | def index_multi_tensors(*tensors: Sequence[Tensor], positions: Tensor = None): 221 | r""" Index many tensors where positions is True. """ 222 | return (ten[positions] for ten in tensors) 223 | 224 | 225 | def get_group_indexes_dict(indexes: Tensor) -> List[Tensor]: 226 | r""" 227 | Given an integer `torch.Tensor` `indexes`, return a `torch.Tensor` of indexes for each different value in 228 | `indexes`. 229 | 230 | Args: 231 | indexes: a `torch.Tensor` 232 | 233 | Return: 234 | A list of integer `torch.Tensor`s 235 | 236 | Example: 237 | >>> indexes = torch.tensor([0, 0, 0, 1, 1, 1, 1]) 238 | >>> get_group_indexes(indexes) 239 | {0: tensor([0, 1, 2]), 1: tensor([3, 4, 5, 6])} 240 | """ 241 | 242 | res: dict = {} 243 | for i, _id in enumerate(indexes): 244 | _id = _id.item() 245 | if _id in res: 246 | res[_id] += [i] 247 | else: 248 | res[_id] = [i] 249 | 250 | return {k: torch.tensor(x, dtype=torch.long) for k, x in res.items()} 251 | 252 | 253 | def safe_value_to_list(integer, length): 254 | if isinstance(integer, Iterable): 255 | integer = list(integer) 256 | assert len(integer) == length 257 | return integer 258 | else: 259 | return [integer] * length 260 | 261 | 262 | def _names_infinite_generator(prefix: str = '', postfix: str = ''): 263 | number_of_letters = 1 264 | while True: 265 | for el in combinations(ascii_uppercase, r=number_of_letters): 266 | el = prefix + ''.join(el) + postfix 267 | yield el 268 | number_of_letters += 1 269 | 270 | 271 | def names_infinite_generator(prefix: str = '', postfix: str = '', process_id: int = None, world_size: int = None): 272 | generator = _names_infinite_generator(prefix=prefix, postfix=postfix) 273 | if process_id is not None and world_size is not None: 274 | return filter_generator(generator, step=world_size, offset=process_id) 275 | return generator 276 | 277 | 278 | def collate_single_fn_with_exceptions(tokenizer: PreTrainedTokenizer, skip: List[str] = []) -> Callable: 279 | r""" 280 | Merge n dicts with identical keys creating list of value tensors. 281 | Do not convert original documents to tensors. 282 | """ 283 | # convert values to tensors 284 | pad_collate_fn_instance = pad_collate_fn(tokenizer, None) 285 | 286 | def collate_fn(data: List[Dict]): 287 | process = pad_collate_fn_instance(data) 288 | return process 289 | 290 | return collate_fn 291 | 292 | 293 | class Writer: 294 | 295 | def __init__(self, output_path: str, trainer: Trainer, chunk_size: int = 1000000): 296 | self.chunk_size = chunk_size 297 | self.output_path = output_path 298 | self.names_generator = names_infinite_generator( 299 | postfix='.jsonl', process_id=trainer.global_rank, world_size=trainer.world_size 300 | ) 301 | self.start() 302 | 303 | def start(self): 304 | self.fo = open(os.path.join(self.output_path, next(self.names_generator)), "w") 305 | self.written = 0 306 | 307 | def reset(self): 308 | self.close() 309 | self.start() 310 | 311 | def close(self): 312 | self.fo.close() 313 | 314 | def write(self, data: Dict): 315 | self.fo.write(json.dumps(data) + "\n") 316 | self.written += 1 317 | 318 | def write_lines(self, lines_iterable: Iterable): 319 | for line in lines_iterable: 320 | if self.written == self.chunk_size: 321 | self.reset() 322 | self.write(line) 323 | --------------------------------------------------------------------------------