├── assets └── adapter.png ├── adapter-bert ├── cfgs │ ├── adapter.yaml │ └── full-finetuning.yaml ├── model │ ├── adapter.py │ ├── model.py │ └── bert.py ├── config.py ├── train.py └── dataset.py ├── .gitignore ├── README.md ├── environment.yml └── LICENSE /assets/adapter.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cs-mshah/Adapter-Bert/HEAD/assets/adapter.png -------------------------------------------------------------------------------- /adapter-bert/cfgs/adapter.yaml: -------------------------------------------------------------------------------- 1 | ACCELERATOR: auto 2 | ADAPTER_BOTTLENECK: 64 3 | CFG_BASE: adapter.yaml 4 | CFG_DIR: cfgs 5 | EPOCHS: 3 6 | LEARNING_RATE: 3.0e-05 7 | MAX_SEQ_LENGTH: 128 8 | MODEL_NAME: bert-large-uncased 9 | NUM_GPUS: 1 10 | NUM_WORKERS: 12 11 | RNG_SEED: 42 12 | TASK_NAME: cola 13 | TRAINING_STRATEGY: adapter 14 | TRAIN_BATCH: 32 15 | VAL_BATCH: 32 16 | WARMUP_STEPS: 0 17 | WEIGHT_DECAY: 0.0 18 | -------------------------------------------------------------------------------- /adapter-bert/cfgs/full-finetuning.yaml: -------------------------------------------------------------------------------- 1 | ACCELERATOR: auto 2 | ADAPTER_BOTTLENECK: 64 3 | CFG_BASE: full-finetuning.yaml 4 | CFG_DIR: cfgs 5 | EPOCHS: 3 6 | LEARNING_RATE: 3.0e-05 7 | MAX_SEQ_LENGTH: 128 8 | MODEL_NAME: bert-large-uncased 9 | NUM_GPUS: 1 10 | NUM_WORKERS: 12 11 | RNG_SEED: 42 12 | TASK_NAME: cola 13 | TRAINING_STRATEGY: full-finetuning 14 | TRAIN_BATCH: 32 15 | VAL_BATCH: 32 16 | WARMUP_STEPS: 0 17 | WEIGHT_DECAY: 0.0 18 | -------------------------------------------------------------------------------- /adapter-bert/model/adapter.py: -------------------------------------------------------------------------------- 1 | from config import cfg 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class AdapterModule(nn.Module): 6 | def __init__(self, 7 | in_feature 8 | ): 9 | super().__init__() 10 | 11 | self.proj_down = nn.Linear(in_features=in_feature, out_features=cfg.ADAPTER_BOTTLENECK) 12 | self.proj_up = nn.Linear(in_features=cfg.ADAPTER_BOTTLENECK, out_features=in_feature) 13 | 14 | def forward(self, x): 15 | input = x.clone() 16 | 17 | x = self.proj_down(x) 18 | x = F.relu(x) 19 | return self.proj_up(x) + input # Skip Connection -------------------------------------------------------------------------------- /adapter-bert/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | from yacs.config import CfgNode as CN 3 | 4 | # global config object 5 | _C = CN() 6 | 7 | # ----------------configuration options------------------------ # 8 | 9 | # number of GPUS to use in the experiment 10 | _C.NUM_GPUS = 1 11 | # number of workers for doing things 12 | _C.NUM_WORKERS = 12 13 | # random seed 14 | _C.RNG_SEED = 42 15 | # configuration directory 16 | _C.CFG_DIR = 'cfgs' 17 | # base configuration yaml file 18 | _C.CFG_BASE = 'adapter.yaml' 19 | # train batch size 20 | _C.TRAIN_BATCH = 32 21 | # val batch size 22 | _C.VAL_BATCH = 32 23 | # task name 24 | _C.TASK_NAME = 'cola' 25 | # whether to use full-finetuning or adapter 26 | _C.TRAINING_STRATEGY = 'adapter' 27 | # adapter bottleneck size (8, 64, 256) 28 | _C.ADAPTER_BOTTLENECK = 64 29 | #max sequece length 30 | _C.MAX_SEQ_LENGTH = 128 31 | # model name 32 | _C.MODEL_NAME = 'bert-large-uncased' 33 | # lr (3e-5, 3e-4, 3e-3) 34 | _C.LEARNING_RATE = 3e-5 35 | # weight decay 36 | _C.WEIGHT_DECAY = 0.0 37 | # warmup steps 38 | _C.WARMUP_STEPS = 0 39 | # number of epochs (3, 20) 40 | _C.EPOCHS = 20 41 | # trainer accelerator 42 | _C.ACCELERATOR = 'auto' 43 | 44 | # ----------------default config-------------------------------- # 45 | 46 | # import the defaults as a global singleton: 47 | cfg = _C # `from config import cfg` 48 | 49 | _CFG_DEFAULT = _C.clone() 50 | _CFG_DEFAULT.freeze() 51 | 52 | def dump_cfg(config_name='cfg.yaml'): 53 | """Dumps the config to the output directory.""" 54 | cfg_file = os.path.join(_C.CFG_DIR, config_name) 55 | with open(cfg_file, "w") as f: 56 | _C.dump(stream=f) 57 | 58 | if __name__ == '__main__': 59 | dump_cfg(_C.CFG_BASE) -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # repo/wandb/lightning generated 10 | lightning_logs/ 11 | adapter-bert/adapter-bert/ 12 | wandb/ 13 | data/ 14 | 15 | # Distribution / packaging 16 | .Python 17 | build/ 18 | develop-eggs/ 19 | dist/ 20 | downloads/ 21 | eggs/ 22 | .eggs/ 23 | lib/ 24 | lib64/ 25 | parts/ 26 | sdist/ 27 | var/ 28 | wheels/ 29 | pip-wheel-metadata/ 30 | share/python-wheels/ 31 | *.egg-info/ 32 | .installed.cfg 33 | *.egg 34 | MANIFEST 35 | 36 | # PyInstaller 37 | # Usually these files are written by a python script from a template 38 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 39 | *.manifest 40 | *.spec 41 | 42 | # Installer logs 43 | pip-log.txt 44 | pip-delete-this-directory.txt 45 | 46 | # Unit test / coverage reports 47 | htmlcov/ 48 | .tox/ 49 | .nox/ 50 | .coverage 51 | .coverage.* 52 | .cache 53 | nosetests.xml 54 | coverage.xml 55 | *.cover 56 | *.py,cover 57 | .hypothesis/ 58 | .pytest_cache/ 59 | 60 | # Translations 61 | *.mo 62 | *.pot 63 | 64 | # Django stuff: 65 | *.log 66 | local_settings.py 67 | db.sqlite3 68 | db.sqlite3-journal 69 | 70 | # Flask stuff: 71 | instance/ 72 | .webassets-cache 73 | 74 | # Scrapy stuff: 75 | .scrapy 76 | 77 | # Sphinx documentation 78 | docs/_build/ 79 | 80 | # PyBuilder 81 | target/ 82 | 83 | # Jupyter Notebook 84 | .ipynb_checkpoints 85 | 86 | # IPython 87 | profile_default/ 88 | ipython_config.py 89 | 90 | # pyenv 91 | .python-version 92 | 93 | # pipenv 94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 97 | # install all needed dependencies. 98 | #Pipfile.lock 99 | 100 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 101 | __pypackages__/ 102 | 103 | # Celery stuff 104 | celerybeat-schedule 105 | celerybeat.pid 106 | 107 | # SageMath parsed files 108 | *.sage.py 109 | 110 | # Environments 111 | .env 112 | .venv 113 | env/ 114 | venv/ 115 | ENV/ 116 | env.bak/ 117 | venv.bak/ 118 | 119 | # Spyder project settings 120 | .spyderproject 121 | .spyproject 122 | 123 | # Rope project settings 124 | .ropeproject 125 | 126 | # mkdocs documentation 127 | /site 128 | 129 | # mypy 130 | .mypy_cache/ 131 | .dmypy.json 132 | dmypy.json 133 | 134 | # Pyre type checker 135 | .pyre/ 136 | -------------------------------------------------------------------------------- /adapter-bert/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import argparse 4 | from pytorch_lightning import Trainer, seed_everything 5 | from pytorch_lightning.loggers import WandbLogger 6 | from pytorch_lightning.callbacks import EarlyStopping 7 | from dataset import GLUEDataModule 8 | from model.model import GLUETransformer 9 | from config import cfg 10 | import wandb 11 | 12 | if __name__ == '__main__': 13 | parser = argparse.ArgumentParser(description='Adapter-Bert') 14 | parser.add_argument('--config', required=False, type=str, help='path to yaml config') 15 | args = parser.parse_args() 16 | if args.config: 17 | cfg.merge_from_file(args.config) 18 | 19 | os.environ['CUDA_VISIBLE_DEVICES']='4' 20 | 21 | seed_everything(cfg.RNG_SEED) 22 | os.environ["TOKENIZERS_PARALLELISM"] = 'False' 23 | wandb_logger = WandbLogger(name=f'{cfg.MODEL_NAME}-{cfg.TRAINING_STRATEGY}', 24 | project='adapter-bert') 25 | 26 | dm = GLUEDataModule(model_name_or_path=cfg.MODEL_NAME, 27 | task_name=cfg.TASK_NAME, 28 | max_seq_length=cfg.MAX_SEQ_LENGTH, 29 | train_batch_size=cfg.TRAIN_BATCH, 30 | eval_batch_size=cfg.VAL_BATCH, 31 | num_workers=cfg.NUM_WORKERS) 32 | 33 | dm.prepare_data() 34 | dm.setup("fit") 35 | 36 | warmup_steps = int(0.1 * len(dm.dataset['train'])) * cfg.EPOCHS 37 | 38 | model = GLUETransformer( 39 | model_name_or_path=cfg.MODEL_NAME, 40 | num_labels=dm.num_labels, 41 | task_name=dm.task_name, 42 | strategy=cfg.TRAINING_STRATEGY, 43 | learning_rate=cfg.LEARNING_RATE, 44 | warmup_steps=warmup_steps, 45 | weight_decay=cfg.WEIGHT_DECAY, 46 | train_batch_size=cfg.TRAIN_BATCH, 47 | eval_batch_size=cfg.VAL_BATCH, 48 | eval_splits=dm.eval_splits 49 | ) 50 | 51 | callbacks = None 52 | # callbacks = [EarlyStopping(monitor='val_loss', patience=5)] 53 | 54 | trainer = Trainer( 55 | max_epochs=cfg.EPOCHS, 56 | accelerator=cfg.ACCELERATOR, 57 | devices=cfg.NUM_GPUS if torch.cuda.is_available() else None, 58 | logger=wandb_logger, 59 | enable_checkpointing=False, 60 | deterministic=True, 61 | callbacks=callbacks 62 | ) 63 | 64 | trainer.fit(model, datamodule=dm) 65 | 66 | wandb.config.update( 67 | { 68 | "warmup_steps": warmup_steps, 69 | "random_seed": cfg.RNG_SEED, 70 | "max_sequence_length": cfg.MAX_SEQ_LENGTH, 71 | "num_gpus": cfg.NUM_GPUS, 72 | "num_workers": cfg.NUM_WORKERS, 73 | "Trainable params": model.trainable_params_count 74 | }) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Parameter-Efficient Transfer Learning for NLP 2 | 3 | [![arXiv](http://img.shields.io/badge/paper-arxiv.1902.00751-B31B1B.svg)](https://arxiv.org/abs/1902.00751) 4 | [![Conference](http://img.shields.io/badge/PMLR-2019-4b44ce.svg)](http://proceedings.mlr.press/v97/houlsby19a.html) 5 | 6 | This repository contains a paper implementation for ["Parameter-Efficient Transfer Learning for NLP"](https://arxiv.org/abs/1902.00751). The implementation uses pytorch-lightning and huggingface transformers. Currently the runs were performed for the 'CoLA' dataset from the [GLUE benchmark](https://gluebenchmark.com/tasks). The authors of this paper have proposed an adapter module for tranfer learning which is parameter efficient as compared to full finetuning. For demonstration, the BERT model has been used. Here is the overview of adapter module: 7 | 8 | ![adapter](assets/adapter.png) 9 | 10 | All experiments run can be found here: [wandb/adapter-bert](https://wandb.ai/manan-shah/adapter-bert) 11 | 12 | ## Setup 13 | 14 | ### Clone repository 15 | 16 | ``` 17 | git clone https://github.com/cs-mshah/Adapter-Bert.git 18 | ``` 19 | 20 | ### Environment 21 | 22 | ``` 23 | conda env create -f environment.yml 24 | conda activate adapter 25 | cd adapter-bert 26 | ``` 27 | 28 | ### Structure 29 | 30 | ``` 31 | . 32 | ├── adapter-bert 33 | │ ├── cfgs # configurations to run 34 | │ ├── config.py # default configuration 35 | │ ├── dataset.py # LightningDataModule for GLUE tasks 36 | │ ├── train.py # main file for training 37 | │ ├── model # model architectures 38 | │ ├── bert.py # modifications to huggingface/transformers/bert 39 | │ ├── adapter.py # adapter module 40 | │ └── model.py # LightningModule to train 41 | ├── assets # figures/outputs 42 | ├── environment.yml # environment configuration 43 | ├── .gitignore # ignore files that cannot commit to Git 44 | ├── README.md # project description 45 | ├── LICENSE # Apache 2.0 License 46 | ``` 47 | 48 | ## Training 49 | 50 | [yacs](https://github.com/rbgirshick/yacs) is used as the configuration system and [wandb](https://wandb.ai/) for logging. You can change the configuration in `config.py` as it will get imported in `train.py` for training. For training run: 51 | 52 | ``` 53 | python train.py 54 | 55 | # use saved config 56 | python train.py --config cfgs/adapter.yaml 57 | ``` 58 | 59 | ``` 60 | python train.py --help 61 | usage: train.py [-h] [--config CONFIG] 62 | 63 | Adapter-Bert 64 | 65 | optional arguments: 66 | -h, --help show this help message and exit 67 | --config CONFIG path to yaml config 68 | ``` 69 | 70 | ## References 71 | [lightining examples: text transformers](https://lightning.ai/docs/pytorch/latest/notebooks/lightning_examples/text-transformers.html) 72 | [krypticmouse/Adapter-BERT](https://lightning.ai/docs/pytorch/latest/notebooks/lightning_examples/text-transformers.html) 73 | [huggingface/transformers](https://github.com/huggingface/transformers/tree/main/src/transformers/models/bert) 74 | 75 | ## Citation 76 | 77 | Official paper citation: 78 | 79 | ``` 80 | @inproceedings{houlsby2019parameter, 81 | title = {Parameter-Efficient Transfer Learning for {NLP}}, 82 | author = {Houlsby, Neil and Giurgiu, Andrei and Jastrzebski, Stanislaw and Morrone, Bruna and De Laroussilhe, Quentin and Gesmundo, Andrea and Attariyan, Mona and Gelly, Sylvain}, 83 | booktitle = {Proceedings of the 36th International Conference on Machine Learning}, 84 | year = {2019}, 85 | } 86 | ``` -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: adapter 2 | channels: 3 | - conda-forge 4 | - https://www.idiap.ch/software/bob/conda 5 | - defaults 6 | dependencies: 7 | - _libgcc_mutex=0.1=main 8 | - _openmp_mutex=5.1=1_gnu 9 | - ca-certificates=2022.12.7=ha878542_0 10 | - certifi=2022.12.7=pyhd8ed1ab_0 11 | - gh=2.25.1=ha8f183a_0 12 | - ld_impl_linux-64=2.38=h1181459_1 13 | - libffi=3.4.2=h6a678d5_6 14 | - libgcc-ng=11.2.0=h1234567_1 15 | - libgomp=11.2.0=h1234567_1 16 | - libstdcxx-ng=11.2.0=h1234567_1 17 | - ncurses=6.4=h6a678d5_0 18 | - openssl=1.1.1t=h7f8727e_0 19 | - pip=23.0.1=py39h06a4308_0 20 | - python=3.9.16=h7a1cb2a_2 21 | - readline=8.2=h5eee18b_0 22 | - setuptools=65.6.3=py39h06a4308_0 23 | - sqlite=3.41.1=h5eee18b_0 24 | - tk=8.6.12=h1ccaba5_0 25 | - tzdata=2022g=h04d1e81_0 26 | - wheel=0.38.4=py39h06a4308_0 27 | - xz=5.2.10=h5eee18b_1 28 | - zlib=1.2.13=h5eee18b_0 29 | - pip: 30 | - aiohttp==3.8.4 31 | - aiosignal==1.3.1 32 | - anyio==3.6.2 33 | - appdirs==1.4.4 34 | - arrow==1.2.3 35 | - async-timeout==4.0.2 36 | - attrs==22.2.0 37 | - beautifulsoup4==4.12.0 38 | - blessed==1.20.0 39 | - charset-normalizer==3.1.0 40 | - click==8.1.3 41 | - cmake==3.26.0 42 | - croniter==1.3.8 43 | - datasets==2.10.1 44 | - dateutils==0.6.12 45 | - deepdiff==6.3.0 46 | - dill==0.3.6 47 | - dnspython==2.3.0 48 | - docker-pycreds==0.4.0 49 | - email-validator==1.3.1 50 | - evaluate==0.4.0 51 | - fastapi==0.88.0 52 | - filelock==3.10.1 53 | - frozenlist==1.3.3 54 | - fsspec==2023.3.0 55 | - gitdb==4.0.10 56 | - gitpython==3.1.31 57 | - h11==0.14.0 58 | - httpcore==0.16.3 59 | - httptools==0.5.0 60 | - httpx==0.23.3 61 | - huggingface-hub==0.13.3 62 | - idna==3.4 63 | - inquirer==3.1.3 64 | - itsdangerous==2.1.2 65 | - jinja2==3.1.2 66 | - joblib==1.2.0 67 | - lightning==2.0.0 68 | - lightning-cloud==0.5.32 69 | - lightning-utilities==0.8.0 70 | - lit==16.0.0 71 | - markdown-it-py==2.2.0 72 | - markupsafe==2.1.2 73 | - mdurl==0.1.2 74 | - mpmath==1.3.0 75 | - multidict==6.0.4 76 | - multiprocess==0.70.14 77 | - networkx==3.0 78 | - numpy==1.24.2 79 | - nvidia-cublas-cu11==11.10.3.66 80 | - nvidia-cuda-cupti-cu11==11.7.101 81 | - nvidia-cuda-nvrtc-cu11==11.7.99 82 | - nvidia-cuda-runtime-cu11==11.7.99 83 | - nvidia-cudnn-cu11==8.5.0.96 84 | - nvidia-cufft-cu11==10.9.0.58 85 | - nvidia-curand-cu11==10.2.10.91 86 | - nvidia-cusolver-cu11==11.4.0.1 87 | - nvidia-cusparse-cu11==11.7.4.91 88 | - nvidia-nccl-cu11==2.14.3 89 | - nvidia-nvtx-cu11==11.7.91 90 | - ordered-set==4.1.0 91 | - orjson==3.8.8 92 | - packaging==23.0 93 | - pandas==1.5.3 94 | - pathtools==0.1.2 95 | - pillow==9.4.0 96 | - protobuf==4.22.1 97 | - psutil==5.9.4 98 | - pyarrow==11.0.0 99 | - pydantic==1.10.6 100 | - pygments==2.14.0 101 | - pyjwt==2.6.0 102 | - python-dateutil==2.8.2 103 | - python-dotenv==1.0.0 104 | - python-editor==1.0.4 105 | - python-multipart==0.0.6 106 | - pytorch-lightning==2.0.0 107 | - pytz==2022.7.1 108 | - pyyaml==6.0 109 | - readchar==4.0.5 110 | - regex==2022.10.31 111 | - requests==2.28.2 112 | - responses==0.18.0 113 | - rfc3986==1.5.0 114 | - rich==13.3.2 115 | - scikit-learn==1.2.2 116 | - scipy==1.10.1 117 | - sentry-sdk==1.17.0 118 | - setproctitle==1.3.2 119 | - six==1.16.0 120 | - smmap==5.0.0 121 | - sniffio==1.3.0 122 | - soupsieve==2.4 123 | - starlette==0.22.0 124 | - starsessions==1.3.0 125 | - sympy==1.11.1 126 | - threadpoolctl==3.1.0 127 | - tokenizers==0.13.2 128 | - torch==2.0.0 129 | - torchaudio==2.0.1 130 | - torchdata==0.6.0 131 | - torchmetrics==0.11.4 132 | - torchtext==0.15.1 133 | - torchvision==0.15.1 134 | - tqdm==4.65.0 135 | - traitlets==5.9.0 136 | - transformers==4.27.2 137 | - triton==2.0.0 138 | - typing-extensions==4.5.0 139 | - ujson==5.7.0 140 | - urllib3==1.26.15 141 | - uvicorn==0.21.1 142 | - uvloop==0.17.0 143 | - wandb==0.14.0 144 | - watchfiles==0.18.1 145 | - wcwidth==0.2.6 146 | - websocket-client==1.5.1 147 | - websockets==10.4 148 | - xxhash==3.2.0 149 | - yacs==0.1.8 150 | - yarl==1.8.2 151 | -------------------------------------------------------------------------------- /adapter-bert/dataset.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | from torch.utils.data import DataLoader 3 | import datasets 4 | from pytorch_lightning import LightningDataModule 5 | from transformers import AutoTokenizer 6 | 7 | 8 | class GLUEDataModule(LightningDataModule): 9 | task_text_field_map = { 10 | "cola": ["sentence"], 11 | "sst2": ["sentence"], 12 | "mrpc": ["sentence1", "sentence2"], 13 | "qqp": ["question1", "question2"], 14 | "stsb": ["sentence1", "sentence2"], 15 | "mnli": ["premise", "hypothesis"], 16 | "qnli": ["question", "sentence"], 17 | "rte": ["sentence1", "sentence2"], 18 | "wnli": ["sentence1", "sentence2"], 19 | "ax": ["premise", "hypothesis"], 20 | } 21 | 22 | glue_task_num_labels = { 23 | "cola": 2, 24 | "sst2": 2, 25 | "mrpc": 2, 26 | "qqp": 2, 27 | "stsb": 1, 28 | "mnli": 3, 29 | "qnli": 2, 30 | "rte": 2, 31 | "wnli": 2, 32 | "ax": 3, 33 | } 34 | 35 | loader_columns = [ 36 | "datasets_idx", 37 | "input_ids", 38 | "token_type_ids", 39 | "attention_mask", 40 | "start_positions", 41 | "end_positions", 42 | "labels", 43 | ] 44 | 45 | def __init__( 46 | self, 47 | model_name_or_path: str, 48 | task_name: str = "cola", 49 | max_seq_length: int = 128, 50 | train_batch_size: int = 32, 51 | eval_batch_size: int = 32, 52 | num_workers: int = 4, 53 | **kwargs, 54 | ): 55 | super().__init__() 56 | self.model_name_or_path = model_name_or_path 57 | self.task_name = task_name 58 | self.max_seq_length = max_seq_length 59 | self.train_batch_size = train_batch_size 60 | self.eval_batch_size = eval_batch_size 61 | self.num_workers = num_workers 62 | 63 | self.text_fields = self.task_text_field_map[task_name] 64 | self.num_labels = self.glue_task_num_labels[task_name] 65 | self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True) 66 | 67 | def setup(self, stage: str): 68 | self.dataset = datasets.load_dataset("glue", self.task_name) 69 | 70 | for split in self.dataset.keys(): 71 | self.dataset[split] = self.dataset[split].map( 72 | self.convert_to_features, 73 | batched=True, 74 | remove_columns=["label"], 75 | ) 76 | self.columns = [c for c in self.dataset[split].column_names if c in self.loader_columns] 77 | self.dataset[split].set_format(type="torch", columns=self.columns) 78 | 79 | self.eval_splits = [x for x in self.dataset.keys() if "validation" in x] 80 | 81 | def prepare_data(self): 82 | datasets.load_dataset("glue", self.task_name) 83 | AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True) 84 | 85 | def train_dataloader(self): 86 | return DataLoader(self.dataset["train"], batch_size=self.train_batch_size, shuffle=True, num_workers=self.num_workers) 87 | 88 | def val_dataloader(self): 89 | if len(self.eval_splits) == 1: 90 | return DataLoader(self.dataset["validation"], batch_size=self.eval_batch_size, num_workers=self.num_workers) 91 | elif len(self.eval_splits) > 1: 92 | return [DataLoader(self.dataset[x], batch_size=self.eval_batch_size, num_workers=self.num_workers) for x in self.eval_splits] 93 | 94 | def test_dataloader(self): 95 | if len(self.eval_splits) == 1: 96 | return DataLoader(self.dataset["test"], batch_size=self.eval_batch_size, num_workers=self.num_workers) 97 | elif len(self.eval_splits) > 1: 98 | return [DataLoader(self.dataset[x], batch_size=self.eval_batch_size, num_workers=self.num_workers) for x in self.eval_splits] 99 | 100 | def convert_to_features(self, example_batch, indices=None): 101 | # Either encode single sentence or sentence pairs 102 | if len(self.text_fields) > 1: 103 | texts_or_text_pairs = list(zip(example_batch[self.text_fields[0]], example_batch[self.text_fields[1]])) 104 | else: 105 | texts_or_text_pairs = example_batch[self.text_fields[0]] 106 | 107 | # Tokenize the text/text pairs 108 | features = self.tokenizer.batch_encode_plus( 109 | texts_or_text_pairs, max_length=self.max_seq_length, padding='max_length', truncation=True 110 | ) 111 | 112 | # Rename label to labels to make it easier to pass to model forward 113 | features["labels"] = example_batch["label"] 114 | 115 | return features 116 | -------------------------------------------------------------------------------- /adapter-bert/model/model.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | from typing import Optional 3 | 4 | import torch 5 | from pytorch_lightning import LightningModule 6 | from transformers import ( 7 | AutoConfig, 8 | AutoModelForSequenceClassification, 9 | get_linear_schedule_with_warmup, 10 | ) 11 | import evaluate 12 | from model.bert import BertForSequenceClassification 13 | 14 | class GLUETransformer(LightningModule): 15 | 16 | def __init__( 17 | self, 18 | model_name_or_path: str, 19 | num_labels: int, 20 | task_name: str, 21 | strategy: str, 22 | learning_rate: float = 3e-5, 23 | adam_epsilon: float = 1e-8, 24 | warmup_steps: int = 0, 25 | weight_decay: float = 0.0, 26 | train_batch_size: int = 32, 27 | eval_batch_size: int = 32, 28 | eval_splits: Optional[list] = None, 29 | **kwargs, 30 | ): 31 | super().__init__() 32 | 33 | self.save_hyperparameters() 34 | 35 | self.config = AutoConfig.from_pretrained(model_name_or_path, num_labels=num_labels) 36 | if self.hparams.strategy == 'full-finetuning': 37 | self.model = AutoModelForSequenceClassification.from_pretrained(model_name_or_path, config=self.config) 38 | elif self.hparams.strategy == 'adapter': 39 | self.model = BertForSequenceClassification.from_pretrained(model_name_or_path, config=self.config) 40 | else: 41 | raise Exception("Unkown training strategy. Select from 'full-finetuning' or 'adapter'.") 42 | self.metric = evaluate.load( 43 | "glue", self.hparams.task_name, experiment_id=datetime.now().strftime("%d-%m-%Y_%H-%M-%S") 44 | ) 45 | self.validation_step_outputs = [] 46 | self.trainable_params_count = 0 47 | 48 | def forward(self, **inputs): 49 | return self.model(**inputs) 50 | 51 | def training_step(self, batch, batch_idx): 52 | outputs = self(**batch) 53 | loss = outputs[0] 54 | return loss 55 | 56 | def validation_step(self, batch, batch_idx, dataloader_idx=0): 57 | outputs = self(**batch) 58 | val_loss, logits = outputs[:2] 59 | 60 | if self.hparams.num_labels > 1: 61 | preds = torch.argmax(logits, axis=1) 62 | elif self.hparams.num_labels == 1: 63 | preds = logits.squeeze() 64 | 65 | labels = batch["labels"] 66 | self.validation_step_outputs.append({"loss": val_loss, "preds": preds, "labels": labels}) 67 | 68 | def on_validation_epoch_end(self) -> None: 69 | if self.hparams.task_name == "mnli": 70 | for i, output in enumerate(self.validation_step_outputs): 71 | # matched or mismatched 72 | split = self.hparams.eval_splits[i].split("_")[-1] 73 | preds = torch.cat([x["preds"] for x in output]).detach().cpu().numpy() 74 | labels = torch.cat([x["labels"] for x in output]).detach().cpu().numpy() 75 | loss = torch.stack([x["loss"] for x in output]).mean() 76 | self.log(f"val_loss_{split}", loss, prog_bar=True, logger=True, sync_dist=True) 77 | split_metrics = { 78 | f"{k}_{split}": v for k, v in self.metric.compute(predictions=preds, references=labels).items() 79 | } 80 | self.log_dict(split_metrics, prog_bar=True, logger=True, sync_dist=True) 81 | self.validation_step_outputs.clear() # free memory 82 | return loss 83 | 84 | preds = torch.cat([x["preds"] for x in self.validation_step_outputs]).detach().cpu().numpy() 85 | labels = torch.cat([x["labels"] for x in self.validation_step_outputs]).detach().cpu().numpy() 86 | loss = torch.stack([x["loss"] for x in self.validation_step_outputs]).mean() 87 | self.log("val_loss", loss, prog_bar=True, logger=True, sync_dist=True) 88 | self.log_dict(self.metric.compute(predictions=preds, references=labels), prog_bar=True, logger=True, sync_dist=True) 89 | self.validation_step_outputs.clear() # free memory 90 | return loss 91 | 92 | def configure_optimizers(self): 93 | """Prepare optimizer and schedule (linear warmup and decay)""" 94 | model = self.model 95 | optimizer_grouped_parameters = [] 96 | if self.hparams.strategy == 'full-finetuning': 97 | no_decay = ["bias", "LayerNorm.weight"] 98 | optimizer_grouped_parameters = [ 99 | { 100 | "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 101 | "weight_decay": self.hparams.weight_decay, 102 | }, 103 | { 104 | "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 105 | "weight_decay": 0.0, 106 | }, 107 | ] 108 | elif self.hparams.strategy == 'adapter': 109 | no_decay = ["adapter.proj_up.bias", "adapter.proj_down.bias", "LayerNorm"] 110 | cls_bias = ['cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias'] 111 | cls_weight = ['cls.seq_relationship.weight', 'cls.predicions.transform.dense.weight', 'cls.predictions.decoder.weight'] 112 | layers = ["adapter.proj_up.weight", "adapter.proj_down.weight"] 113 | layers.extend(cls_weight) 114 | no_decay.extend(cls_bias) 115 | 116 | optimizer_grouped_parameters = [ 117 | { 118 | "params": [p for n, p in self.model.named_parameters() if any([nd in n for nd in layers])], 119 | "weight_decay": self.hparams.weight_decay, 120 | }, 121 | { 122 | "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 123 | "weight_decay": 0.0, 124 | }, 125 | ] 126 | 127 | # count the total no. of trainable params 128 | for group in optimizer_grouped_parameters: 129 | for param in group["params"]: 130 | self.trainable_params_count += param.numel() 131 | print(f'Total Trainable params: {self.trainable_params_count}') 132 | 133 | optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=self.hparams.learning_rate, eps=self.hparams.adam_epsilon) 134 | 135 | scheduler = get_linear_schedule_with_warmup( 136 | optimizer, 137 | num_warmup_steps=self.hparams.warmup_steps, 138 | num_training_steps=self.trainer.estimated_stepping_batches, 139 | ) 140 | scheduler = {"scheduler": scheduler, "interval": "step", "frequency": 1} 141 | return [optimizer], [scheduler] -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /adapter-bert/model/bert.py: -------------------------------------------------------------------------------- 1 | """PyTorch BERT model adapted from huggingface/transformers/models/bert/modeling_bert.py""" 2 | 3 | import math 4 | import os 5 | import warnings 6 | from dataclasses import dataclass 7 | from typing import List, Optional, Tuple, Union 8 | 9 | import torch 10 | import torch.utils.checkpoint 11 | from torch import nn 12 | from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss 13 | from .adapter import AdapterModule 14 | 15 | 16 | from transformers.activations import ACT2FN 17 | from transformers.modeling_outputs import ( 18 | BaseModelOutputWithPastAndCrossAttentions, 19 | BaseModelOutputWithPoolingAndCrossAttentions, 20 | CausalLMOutputWithCrossAttentions, 21 | MaskedLMOutput, 22 | MultipleChoiceModelOutput, 23 | NextSentencePredictorOutput, 24 | QuestionAnsweringModelOutput, 25 | SequenceClassifierOutput, 26 | TokenClassifierOutput, 27 | ) 28 | from transformers.modeling_utils import PreTrainedModel 29 | from transformers.pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer 30 | from transformers.utils import ( 31 | ModelOutput, 32 | add_code_sample_docstrings, 33 | add_start_docstrings, 34 | add_start_docstrings_to_model_forward, 35 | logging, 36 | replace_return_docstrings, 37 | ) 38 | from transformers.models.bert.configuration_bert import BertConfig 39 | 40 | 41 | logger = logging.get_logger(__name__) 42 | 43 | _CHECKPOINT_FOR_DOC = "bert-base-uncased" 44 | _CONFIG_FOR_DOC = "BertConfig" 45 | 46 | # TokenClassification docstring 47 | _CHECKPOINT_FOR_TOKEN_CLASSIFICATION = "dbmdz/bert-large-cased-finetuned-conll03-english" 48 | _TOKEN_CLASS_EXPECTED_OUTPUT = ( 49 | "['O', 'I-ORG', 'I-ORG', 'I-ORG', 'O', 'O', 'O', 'O', 'O', 'I-LOC', 'O', 'I-LOC', 'I-LOC'] " 50 | ) 51 | _TOKEN_CLASS_EXPECTED_LOSS = 0.01 52 | 53 | # QuestionAnswering docstring 54 | _CHECKPOINT_FOR_QA = "deepset/bert-base-cased-squad2" 55 | _QA_EXPECTED_OUTPUT = "'a nice puppet'" 56 | _QA_EXPECTED_LOSS = 7.41 57 | _QA_TARGET_START_INDEX = 14 58 | _QA_TARGET_END_INDEX = 15 59 | 60 | # SequenceClassification docstring 61 | _CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION = "textattack/bert-base-uncased-yelp-polarity" 62 | _SEQ_CLASS_EXPECTED_OUTPUT = "'LABEL_1'" 63 | _SEQ_CLASS_EXPECTED_LOSS = 0.01 64 | 65 | 66 | BERT_PRETRAINED_MODEL_ARCHIVE_LIST = [ 67 | "bert-base-uncased", 68 | "bert-large-uncased", 69 | "bert-base-cased", 70 | "bert-large-cased", 71 | "bert-base-multilingual-uncased", 72 | "bert-base-multilingual-cased", 73 | "bert-base-chinese", 74 | "bert-base-german-cased", 75 | "bert-large-uncased-whole-word-masking", 76 | "bert-large-cased-whole-word-masking", 77 | "bert-large-uncased-whole-word-masking-finetuned-squad", 78 | "bert-large-cased-whole-word-masking-finetuned-squad", 79 | "bert-base-cased-finetuned-mrpc", 80 | "bert-base-german-dbmdz-cased", 81 | "bert-base-german-dbmdz-uncased", 82 | "cl-tohoku/bert-base-japanese", 83 | "cl-tohoku/bert-base-japanese-whole-word-masking", 84 | "cl-tohoku/bert-base-japanese-char", 85 | "cl-tohoku/bert-base-japanese-char-whole-word-masking", 86 | "TurkuNLP/bert-base-finnish-cased-v1", 87 | "TurkuNLP/bert-base-finnish-uncased-v1", 88 | "wietsedv/bert-base-dutch-cased", 89 | # See all BERT models at https://huggingface.co/models?filter=bert 90 | ] 91 | 92 | 93 | def load_tf_weights_in_bert(model, config, tf_checkpoint_path): 94 | """Load tf checkpoints in a pytorch model.""" 95 | try: 96 | import re 97 | 98 | import numpy as np 99 | import tensorflow as tf 100 | except ImportError: 101 | logger.error( 102 | "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " 103 | "https://www.tensorflow.org/install/ for installation instructions." 104 | ) 105 | raise 106 | tf_path = os.path.abspath(tf_checkpoint_path) 107 | logger.info(f"Converting TensorFlow checkpoint from {tf_path}") 108 | # Load weights from TF model 109 | init_vars = tf.train.list_variables(tf_path) 110 | names = [] 111 | arrays = [] 112 | for name, shape in init_vars: 113 | logger.info(f"Loading TF weight {name} with shape {shape}") 114 | array = tf.train.load_variable(tf_path, name) 115 | names.append(name) 116 | arrays.append(array) 117 | 118 | for name, array in zip(names, arrays): 119 | name = name.split("/") 120 | # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v 121 | # which are not required for using pretrained model 122 | if any( 123 | n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"] 124 | for n in name 125 | ): 126 | logger.info(f"Skipping {'/'.join(name)}") 127 | continue 128 | pointer = model 129 | for m_name in name: 130 | if re.fullmatch(r"[A-Za-z]+_\d+", m_name): 131 | scope_names = re.split(r"_(\d+)", m_name) 132 | else: 133 | scope_names = [m_name] 134 | if scope_names[0] == "kernel" or scope_names[0] == "gamma": 135 | pointer = getattr(pointer, "weight") 136 | elif scope_names[0] == "output_bias" or scope_names[0] == "beta": 137 | pointer = getattr(pointer, "bias") 138 | elif scope_names[0] == "output_weights": 139 | pointer = getattr(pointer, "weight") 140 | elif scope_names[0] == "squad": 141 | pointer = getattr(pointer, "classifier") 142 | else: 143 | try: 144 | pointer = getattr(pointer, scope_names[0]) 145 | except AttributeError: 146 | logger.info(f"Skipping {'/'.join(name)}") 147 | continue 148 | if len(scope_names) >= 2: 149 | num = int(scope_names[1]) 150 | pointer = pointer[num] 151 | if m_name[-11:] == "_embeddings": 152 | pointer = getattr(pointer, "weight") 153 | elif m_name == "kernel": 154 | array = np.transpose(array) 155 | try: 156 | if pointer.shape != array.shape: 157 | raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched") 158 | except AssertionError as e: 159 | e.args += (pointer.shape, array.shape) 160 | raise 161 | logger.info(f"Initialize PyTorch weight {name}") 162 | pointer.data = torch.from_numpy(array) 163 | return model 164 | 165 | 166 | class BertEmbeddings(nn.Module): 167 | """Construct the embeddings from word, position and token_type embeddings.""" 168 | 169 | def __init__(self, config): 170 | super().__init__() 171 | self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) 172 | self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) 173 | self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) 174 | 175 | # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load 176 | # any TensorFlow checkpoint file 177 | self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) 178 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 179 | # position_ids (1, len position emb) is contiguous in memory and exported when serialized 180 | self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") 181 | self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) 182 | self.register_buffer( 183 | "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False 184 | ) 185 | 186 | def forward( 187 | self, 188 | input_ids: Optional[torch.LongTensor] = None, 189 | token_type_ids: Optional[torch.LongTensor] = None, 190 | position_ids: Optional[torch.LongTensor] = None, 191 | inputs_embeds: Optional[torch.FloatTensor] = None, 192 | past_key_values_length: int = 0, 193 | ) -> torch.Tensor: 194 | if input_ids is not None: 195 | input_shape = input_ids.size() 196 | else: 197 | input_shape = inputs_embeds.size()[:-1] 198 | 199 | seq_length = input_shape[1] 200 | 201 | if position_ids is None: 202 | position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length] 203 | 204 | # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs 205 | # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves 206 | # issue #5664 207 | if token_type_ids is None: 208 | if hasattr(self, "token_type_ids"): 209 | buffered_token_type_ids = self.token_type_ids[:, :seq_length] 210 | buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length) 211 | token_type_ids = buffered_token_type_ids_expanded 212 | else: 213 | token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) 214 | 215 | if inputs_embeds is None: 216 | inputs_embeds = self.word_embeddings(input_ids) 217 | token_type_embeddings = self.token_type_embeddings(token_type_ids) 218 | 219 | embeddings = inputs_embeds + token_type_embeddings 220 | if self.position_embedding_type == "absolute": 221 | position_embeddings = self.position_embeddings(position_ids) 222 | embeddings += position_embeddings 223 | embeddings = self.LayerNorm(embeddings) 224 | embeddings = self.dropout(embeddings) 225 | return embeddings 226 | 227 | 228 | class BertSelfAttention(nn.Module): 229 | def __init__(self, config, position_embedding_type=None): 230 | super().__init__() 231 | if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): 232 | raise ValueError( 233 | f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " 234 | f"heads ({config.num_attention_heads})" 235 | ) 236 | 237 | self.num_attention_heads = config.num_attention_heads 238 | self.attention_head_size = int(config.hidden_size / config.num_attention_heads) 239 | self.all_head_size = self.num_attention_heads * self.attention_head_size 240 | 241 | self.query = nn.Linear(config.hidden_size, self.all_head_size) 242 | self.key = nn.Linear(config.hidden_size, self.all_head_size) 243 | self.value = nn.Linear(config.hidden_size, self.all_head_size) 244 | 245 | self.dropout = nn.Dropout(config.attention_probs_dropout_prob) 246 | self.position_embedding_type = position_embedding_type or getattr( 247 | config, "position_embedding_type", "absolute" 248 | ) 249 | if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": 250 | self.max_position_embeddings = config.max_position_embeddings 251 | self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) 252 | 253 | self.is_decoder = config.is_decoder 254 | 255 | def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: 256 | new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) 257 | x = x.view(new_x_shape) 258 | return x.permute(0, 2, 1, 3) 259 | 260 | def forward( 261 | self, 262 | hidden_states: torch.Tensor, 263 | attention_mask: Optional[torch.FloatTensor] = None, 264 | head_mask: Optional[torch.FloatTensor] = None, 265 | encoder_hidden_states: Optional[torch.FloatTensor] = None, 266 | encoder_attention_mask: Optional[torch.FloatTensor] = None, 267 | past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, 268 | output_attentions: Optional[bool] = False, 269 | ) -> Tuple[torch.Tensor]: 270 | mixed_query_layer = self.query(hidden_states) 271 | 272 | # If this is instantiated as a cross-attention module, the keys 273 | # and values come from an encoder; the attention mask needs to be 274 | # such that the encoder's padding tokens are not attended to. 275 | is_cross_attention = encoder_hidden_states is not None 276 | 277 | if is_cross_attention and past_key_value is not None: 278 | # reuse k,v, cross_attentions 279 | key_layer = past_key_value[0] 280 | value_layer = past_key_value[1] 281 | attention_mask = encoder_attention_mask 282 | elif is_cross_attention: 283 | key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) 284 | value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) 285 | attention_mask = encoder_attention_mask 286 | elif past_key_value is not None: 287 | key_layer = self.transpose_for_scores(self.key(hidden_states)) 288 | value_layer = self.transpose_for_scores(self.value(hidden_states)) 289 | key_layer = torch.cat([past_key_value[0], key_layer], dim=2) 290 | value_layer = torch.cat([past_key_value[1], value_layer], dim=2) 291 | else: 292 | key_layer = self.transpose_for_scores(self.key(hidden_states)) 293 | value_layer = self.transpose_for_scores(self.value(hidden_states)) 294 | 295 | query_layer = self.transpose_for_scores(mixed_query_layer) 296 | 297 | use_cache = past_key_value is not None 298 | if self.is_decoder: 299 | # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. 300 | # Further calls to cross_attention layer can then reuse all cross-attention 301 | # key/value_states (first "if" case) 302 | # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of 303 | # all previous decoder key/value_states. Further calls to uni-directional self-attention 304 | # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) 305 | # if encoder bi-directional self-attention `past_key_value` is always `None` 306 | past_key_value = (key_layer, value_layer) 307 | 308 | # Take the dot product between "query" and "key" to get the raw attention scores. 309 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) 310 | 311 | if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": 312 | query_length, key_length = query_layer.shape[2], key_layer.shape[2] 313 | if use_cache: 314 | position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view( 315 | -1, 1 316 | ) 317 | else: 318 | position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) 319 | position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1) 320 | distance = position_ids_l - position_ids_r 321 | 322 | positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) 323 | positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility 324 | 325 | if self.position_embedding_type == "relative_key": 326 | relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) 327 | attention_scores = attention_scores + relative_position_scores 328 | elif self.position_embedding_type == "relative_key_query": 329 | relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) 330 | relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding) 331 | attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key 332 | 333 | attention_scores = attention_scores / math.sqrt(self.attention_head_size) 334 | if attention_mask is not None: 335 | # Apply the attention mask is (precomputed for all layers in BertModel forward() function) 336 | attention_scores = attention_scores + attention_mask 337 | 338 | # Normalize the attention scores to probabilities. 339 | attention_probs = nn.functional.softmax(attention_scores, dim=-1) 340 | 341 | # This is actually dropping out entire tokens to attend to, which might 342 | # seem a bit unusual, but is taken from the original Transformer paper. 343 | attention_probs = self.dropout(attention_probs) 344 | 345 | # Mask heads if we want to 346 | if head_mask is not None: 347 | attention_probs = attention_probs * head_mask 348 | 349 | context_layer = torch.matmul(attention_probs, value_layer) 350 | 351 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous() 352 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) 353 | context_layer = context_layer.view(new_context_layer_shape) 354 | 355 | outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) 356 | 357 | if self.is_decoder: 358 | outputs = outputs + (past_key_value,) 359 | return outputs 360 | 361 | 362 | class BertSelfOutput(nn.Module): 363 | def __init__(self, config): 364 | super().__init__() 365 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 366 | self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) 367 | self.adapter = AdapterModule(config.hidden_size) 368 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 369 | 370 | def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: 371 | hidden_states = self.dense(hidden_states) 372 | hidden_states = self.dropout(hidden_states) 373 | hidden_states = self.adapter(hidden_states) 374 | hidden_states = self.LayerNorm(hidden_states + input_tensor) 375 | return hidden_states 376 | 377 | 378 | class BertAttention(nn.Module): 379 | def __init__(self, config, position_embedding_type=None): 380 | super().__init__() 381 | self.self = BertSelfAttention(config, position_embedding_type=position_embedding_type) 382 | self.output = BertSelfOutput(config) 383 | self.pruned_heads = set() 384 | 385 | def prune_heads(self, heads): 386 | if len(heads) == 0: 387 | return 388 | heads, index = find_pruneable_heads_and_indices( 389 | heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads 390 | ) 391 | 392 | # Prune linear layers 393 | self.self.query = prune_linear_layer(self.self.query, index) 394 | self.self.key = prune_linear_layer(self.self.key, index) 395 | self.self.value = prune_linear_layer(self.self.value, index) 396 | self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) 397 | 398 | # Update hyper params and store pruned heads 399 | self.self.num_attention_heads = self.self.num_attention_heads - len(heads) 400 | self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads 401 | self.pruned_heads = self.pruned_heads.union(heads) 402 | 403 | def forward( 404 | self, 405 | hidden_states: torch.Tensor, 406 | attention_mask: Optional[torch.FloatTensor] = None, 407 | head_mask: Optional[torch.FloatTensor] = None, 408 | encoder_hidden_states: Optional[torch.FloatTensor] = None, 409 | encoder_attention_mask: Optional[torch.FloatTensor] = None, 410 | past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, 411 | output_attentions: Optional[bool] = False, 412 | ) -> Tuple[torch.Tensor]: 413 | self_outputs = self.self( 414 | hidden_states, 415 | attention_mask, 416 | head_mask, 417 | encoder_hidden_states, 418 | encoder_attention_mask, 419 | past_key_value, 420 | output_attentions, 421 | ) 422 | attention_output = self.output(self_outputs[0], hidden_states) 423 | outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them 424 | return outputs 425 | 426 | 427 | class BertIntermediate(nn.Module): 428 | def __init__(self, config): 429 | super().__init__() 430 | self.dense = nn.Linear(config.hidden_size, config.intermediate_size) 431 | if isinstance(config.hidden_act, str): 432 | self.intermediate_act_fn = ACT2FN[config.hidden_act] 433 | else: 434 | self.intermediate_act_fn = config.hidden_act 435 | 436 | def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: 437 | hidden_states = self.dense(hidden_states) 438 | hidden_states = self.intermediate_act_fn(hidden_states) 439 | return hidden_states 440 | 441 | 442 | class BertOutput(nn.Module): 443 | def __init__(self, config): 444 | super().__init__() 445 | self.dense = nn.Linear(config.intermediate_size, config.hidden_size) 446 | self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) 447 | self.adapter = AdapterModule(config.hidden_size) 448 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 449 | 450 | def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: 451 | hidden_states = self.dense(hidden_states) 452 | hidden_states = self.dropout(hidden_states) 453 | hidden_states = self.adapter(hidden_states) 454 | hidden_states = self.LayerNorm(hidden_states + input_tensor) 455 | return hidden_states 456 | 457 | 458 | class BertLayer(nn.Module): 459 | def __init__(self, config): 460 | super().__init__() 461 | self.chunk_size_feed_forward = config.chunk_size_feed_forward 462 | self.seq_len_dim = 1 463 | self.attention = BertAttention(config) 464 | self.is_decoder = config.is_decoder 465 | self.add_cross_attention = config.add_cross_attention 466 | if self.add_cross_attention: 467 | if not self.is_decoder: 468 | raise ValueError(f"{self} should be used as a decoder model if cross attention is added") 469 | self.crossattention = BertAttention(config, position_embedding_type="absolute") 470 | self.intermediate = BertIntermediate(config) 471 | self.output = BertOutput(config) 472 | 473 | def forward( 474 | self, 475 | hidden_states: torch.Tensor, 476 | attention_mask: Optional[torch.FloatTensor] = None, 477 | head_mask: Optional[torch.FloatTensor] = None, 478 | encoder_hidden_states: Optional[torch.FloatTensor] = None, 479 | encoder_attention_mask: Optional[torch.FloatTensor] = None, 480 | past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, 481 | output_attentions: Optional[bool] = False, 482 | ) -> Tuple[torch.Tensor]: 483 | # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 484 | self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None 485 | self_attention_outputs = self.attention( 486 | hidden_states, 487 | attention_mask, 488 | head_mask, 489 | output_attentions=output_attentions, 490 | past_key_value=self_attn_past_key_value, 491 | ) 492 | attention_output = self_attention_outputs[0] 493 | 494 | # if decoder, the last output is tuple of self-attn cache 495 | if self.is_decoder: 496 | outputs = self_attention_outputs[1:-1] 497 | present_key_value = self_attention_outputs[-1] 498 | else: 499 | outputs = self_attention_outputs[1:] # add self attentions if we output attention weights 500 | 501 | cross_attn_present_key_value = None 502 | if self.is_decoder and encoder_hidden_states is not None: 503 | if not hasattr(self, "crossattention"): 504 | raise ValueError( 505 | f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers" 506 | " by setting `config.add_cross_attention=True`" 507 | ) 508 | 509 | # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple 510 | cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None 511 | cross_attention_outputs = self.crossattention( 512 | attention_output, 513 | attention_mask, 514 | head_mask, 515 | encoder_hidden_states, 516 | encoder_attention_mask, 517 | cross_attn_past_key_value, 518 | output_attentions, 519 | ) 520 | attention_output = cross_attention_outputs[0] 521 | outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights 522 | 523 | # add cross-attn cache to positions 3,4 of present_key_value tuple 524 | cross_attn_present_key_value = cross_attention_outputs[-1] 525 | present_key_value = present_key_value + cross_attn_present_key_value 526 | 527 | layer_output = apply_chunking_to_forward( 528 | self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output 529 | ) 530 | outputs = (layer_output,) + outputs 531 | 532 | # if decoder, return the attn key/values as the last output 533 | if self.is_decoder: 534 | outputs = outputs + (present_key_value,) 535 | 536 | return outputs 537 | 538 | def feed_forward_chunk(self, attention_output): 539 | intermediate_output = self.intermediate(attention_output) 540 | layer_output = self.output(intermediate_output, attention_output) 541 | return layer_output 542 | 543 | 544 | class BertEncoder(nn.Module): 545 | def __init__(self, config): 546 | super().__init__() 547 | self.config = config 548 | self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)]) 549 | self.gradient_checkpointing = False 550 | 551 | def forward( 552 | self, 553 | hidden_states: torch.Tensor, 554 | attention_mask: Optional[torch.FloatTensor] = None, 555 | head_mask: Optional[torch.FloatTensor] = None, 556 | encoder_hidden_states: Optional[torch.FloatTensor] = None, 557 | encoder_attention_mask: Optional[torch.FloatTensor] = None, 558 | past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, 559 | use_cache: Optional[bool] = None, 560 | output_attentions: Optional[bool] = False, 561 | output_hidden_states: Optional[bool] = False, 562 | return_dict: Optional[bool] = True, 563 | ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]: 564 | all_hidden_states = () if output_hidden_states else None 565 | all_self_attentions = () if output_attentions else None 566 | all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None 567 | 568 | if self.gradient_checkpointing and self.training: 569 | if use_cache: 570 | logger.warning_once( 571 | "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." 572 | ) 573 | use_cache = False 574 | 575 | next_decoder_cache = () if use_cache else None 576 | for i, layer_module in enumerate(self.layer): 577 | if output_hidden_states: 578 | all_hidden_states = all_hidden_states + (hidden_states,) 579 | 580 | layer_head_mask = head_mask[i] if head_mask is not None else None 581 | past_key_value = past_key_values[i] if past_key_values is not None else None 582 | 583 | if self.gradient_checkpointing and self.training: 584 | 585 | def create_custom_forward(module): 586 | def custom_forward(*inputs): 587 | return module(*inputs, past_key_value, output_attentions) 588 | 589 | return custom_forward 590 | 591 | layer_outputs = torch.utils.checkpoint.checkpoint( 592 | create_custom_forward(layer_module), 593 | hidden_states, 594 | attention_mask, 595 | layer_head_mask, 596 | encoder_hidden_states, 597 | encoder_attention_mask, 598 | ) 599 | else: 600 | layer_outputs = layer_module( 601 | hidden_states, 602 | attention_mask, 603 | layer_head_mask, 604 | encoder_hidden_states, 605 | encoder_attention_mask, 606 | past_key_value, 607 | output_attentions, 608 | ) 609 | 610 | hidden_states = layer_outputs[0] 611 | if use_cache: 612 | next_decoder_cache += (layer_outputs[-1],) 613 | if output_attentions: 614 | all_self_attentions = all_self_attentions + (layer_outputs[1],) 615 | if self.config.add_cross_attention: 616 | all_cross_attentions = all_cross_attentions + (layer_outputs[2],) 617 | 618 | if output_hidden_states: 619 | all_hidden_states = all_hidden_states + (hidden_states,) 620 | 621 | if not return_dict: 622 | return tuple( 623 | v 624 | for v in [ 625 | hidden_states, 626 | next_decoder_cache, 627 | all_hidden_states, 628 | all_self_attentions, 629 | all_cross_attentions, 630 | ] 631 | if v is not None 632 | ) 633 | return BaseModelOutputWithPastAndCrossAttentions( 634 | last_hidden_state=hidden_states, 635 | past_key_values=next_decoder_cache, 636 | hidden_states=all_hidden_states, 637 | attentions=all_self_attentions, 638 | cross_attentions=all_cross_attentions, 639 | ) 640 | 641 | 642 | class BertPooler(nn.Module): 643 | def __init__(self, config): 644 | super().__init__() 645 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 646 | self.activation = nn.Tanh() 647 | 648 | def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: 649 | # We "pool" the model by simply taking the hidden state corresponding 650 | # to the first token. 651 | first_token_tensor = hidden_states[:, 0] 652 | pooled_output = self.dense(first_token_tensor) 653 | pooled_output = self.activation(pooled_output) 654 | return pooled_output 655 | 656 | 657 | class BertPredictionHeadTransform(nn.Module): 658 | def __init__(self, config): 659 | super().__init__() 660 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 661 | if isinstance(config.hidden_act, str): 662 | self.transform_act_fn = ACT2FN[config.hidden_act] 663 | else: 664 | self.transform_act_fn = config.hidden_act 665 | self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) 666 | 667 | def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: 668 | hidden_states = self.dense(hidden_states) 669 | hidden_states = self.transform_act_fn(hidden_states) 670 | hidden_states = self.LayerNorm(hidden_states) 671 | return hidden_states 672 | 673 | 674 | class BertLMPredictionHead(nn.Module): 675 | def __init__(self, config): 676 | super().__init__() 677 | self.transform = BertPredictionHeadTransform(config) 678 | 679 | # The output weights are the same as the input embeddings, but there is 680 | # an output-only bias for each token. 681 | self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 682 | 683 | self.bias = nn.Parameter(torch.zeros(config.vocab_size)) 684 | 685 | # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` 686 | self.decoder.bias = self.bias 687 | 688 | def forward(self, hidden_states): 689 | hidden_states = self.transform(hidden_states) 690 | hidden_states = self.decoder(hidden_states) 691 | return hidden_states 692 | 693 | 694 | class BertOnlyMLMHead(nn.Module): 695 | def __init__(self, config): 696 | super().__init__() 697 | self.predictions = BertLMPredictionHead(config) 698 | 699 | def forward(self, sequence_output: torch.Tensor) -> torch.Tensor: 700 | prediction_scores = self.predictions(sequence_output) 701 | return prediction_scores 702 | 703 | 704 | class BertOnlyNSPHead(nn.Module): 705 | def __init__(self, config): 706 | super().__init__() 707 | self.seq_relationship = nn.Linear(config.hidden_size, 2) 708 | 709 | def forward(self, pooled_output): 710 | seq_relationship_score = self.seq_relationship(pooled_output) 711 | return seq_relationship_score 712 | 713 | 714 | class BertPreTrainingHeads(nn.Module): 715 | def __init__(self, config): 716 | super().__init__() 717 | self.predictions = BertLMPredictionHead(config) 718 | self.seq_relationship = nn.Linear(config.hidden_size, 2) 719 | 720 | def forward(self, sequence_output, pooled_output): 721 | prediction_scores = self.predictions(sequence_output) 722 | seq_relationship_score = self.seq_relationship(pooled_output) 723 | return prediction_scores, seq_relationship_score 724 | 725 | 726 | class BertPreTrainedModel(PreTrainedModel): 727 | """ 728 | An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained 729 | models. 730 | """ 731 | 732 | config_class = BertConfig 733 | load_tf_weights = load_tf_weights_in_bert 734 | base_model_prefix = "bert" 735 | supports_gradient_checkpointing = True 736 | _keys_to_ignore_on_load_missing = [r"position_ids"] 737 | 738 | def _init_weights(self, module): 739 | """Initialize the weights""" 740 | if isinstance(module, nn.Linear): 741 | # Slightly different from the TF version which uses truncated_normal for initialization 742 | # cf https://github.com/pytorch/pytorch/pull/5617 743 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) 744 | if module.bias is not None: 745 | module.bias.data.zero_() 746 | elif isinstance(module, nn.Embedding): 747 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) 748 | if module.padding_idx is not None: 749 | module.weight.data[module.padding_idx].zero_() 750 | elif isinstance(module, nn.LayerNorm): 751 | module.bias.data.zero_() 752 | module.weight.data.fill_(1.0) 753 | elif isinstance(module, AdapterModule): 754 | # correct place to do near zero initialization 755 | module.proj_down.weight.data.normal_(mean=0.0, std=1e-7) 756 | module.proj_down.bias.data.zero_() 757 | module.proj_up.weight.data.normal_(mean=0.0, std=1e-7) 758 | module.proj_up.bias.data.zero_() 759 | 760 | def _set_gradient_checkpointing(self, module, value=False): 761 | if isinstance(module, BertEncoder): 762 | module.gradient_checkpointing = value 763 | 764 | 765 | @dataclass 766 | class BertForPreTrainingOutput(ModelOutput): 767 | """ 768 | Output type of [`BertForPreTraining`]. 769 | 770 | Args: 771 | loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`): 772 | Total loss as the sum of the masked language modeling loss and the next sequence prediction 773 | (classification) loss. 774 | prediction_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): 775 | Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). 776 | seq_relationship_logits (`torch.FloatTensor` of shape `(batch_size, 2)`): 777 | Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation 778 | before SoftMax). 779 | hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): 780 | Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of 781 | shape `(batch_size, sequence_length, hidden_size)`. 782 | 783 | Hidden-states of the model at the output of each layer plus the initial embedding outputs. 784 | attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): 785 | Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, 786 | sequence_length)`. 787 | 788 | Attentions weights after the attention softmax, used to compute the weighted average in the self-attention 789 | heads. 790 | """ 791 | 792 | loss: Optional[torch.FloatTensor] = None 793 | prediction_logits: torch.FloatTensor = None 794 | seq_relationship_logits: torch.FloatTensor = None 795 | hidden_states: Optional[Tuple[torch.FloatTensor]] = None 796 | attentions: Optional[Tuple[torch.FloatTensor]] = None 797 | 798 | 799 | BERT_START_DOCSTRING = r""" 800 | 801 | This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the 802 | library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads 803 | etc.) 804 | 805 | This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. 806 | Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage 807 | and behavior. 808 | 809 | Parameters: 810 | config ([`BertConfig`]): Model configuration class with all the parameters of the model. 811 | Initializing with a config file does not load the weights associated with the model, only the 812 | configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. 813 | """ 814 | 815 | BERT_INPUTS_DOCSTRING = r""" 816 | Args: 817 | input_ids (`torch.LongTensor` of shape `({0})`): 818 | Indices of input sequence tokens in the vocabulary. 819 | 820 | Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and 821 | [`PreTrainedTokenizer.__call__`] for details. 822 | 823 | [What are input IDs?](../glossary#input-ids) 824 | attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*): 825 | Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: 826 | 827 | - 1 for tokens that are **not masked**, 828 | - 0 for tokens that are **masked**. 829 | 830 | [What are attention masks?](../glossary#attention-mask) 831 | token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*): 832 | Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, 833 | 1]`: 834 | 835 | - 0 corresponds to a *sentence A* token, 836 | - 1 corresponds to a *sentence B* token. 837 | 838 | [What are token type IDs?](../glossary#token-type-ids) 839 | position_ids (`torch.LongTensor` of shape `({0})`, *optional*): 840 | Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, 841 | config.max_position_embeddings - 1]`. 842 | 843 | [What are position IDs?](../glossary#position-ids) 844 | head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): 845 | Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: 846 | 847 | - 1 indicates the head is **not masked**, 848 | - 0 indicates the head is **masked**. 849 | 850 | inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*): 851 | Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This 852 | is useful if you want more control over how to convert `input_ids` indices into associated vectors than the 853 | model's internal embedding lookup matrix. 854 | output_attentions (`bool`, *optional*): 855 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned 856 | tensors for more detail. 857 | output_hidden_states (`bool`, *optional*): 858 | Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for 859 | more detail. 860 | return_dict (`bool`, *optional*): 861 | Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. 862 | """ 863 | 864 | 865 | @add_start_docstrings( 866 | "The bare Bert Model transformer outputting raw hidden-states without any specific head on top.", 867 | BERT_START_DOCSTRING, 868 | ) 869 | class BertModel(BertPreTrainedModel): 870 | """ 871 | 872 | The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of 873 | cross-attention is added between the self-attention layers, following the architecture described in [Attention is 874 | all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, 875 | Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. 876 | 877 | To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set 878 | to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and 879 | `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass. 880 | """ 881 | 882 | def __init__(self, config, add_pooling_layer=True): 883 | super().__init__(config) 884 | self.config = config 885 | 886 | self.embeddings = BertEmbeddings(config) 887 | self.encoder = BertEncoder(config) 888 | 889 | self.pooler = BertPooler(config) if add_pooling_layer else None 890 | 891 | # Initialize weights and apply final processing 892 | self.post_init() 893 | 894 | def get_input_embeddings(self): 895 | return self.embeddings.word_embeddings 896 | 897 | def set_input_embeddings(self, value): 898 | self.embeddings.word_embeddings = value 899 | 900 | def _prune_heads(self, heads_to_prune): 901 | """ 902 | Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base 903 | class PreTrainedModel 904 | """ 905 | for layer, heads in heads_to_prune.items(): 906 | self.encoder.layer[layer].attention.prune_heads(heads) 907 | 908 | @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) 909 | @add_code_sample_docstrings( 910 | checkpoint=_CHECKPOINT_FOR_DOC, 911 | output_type=BaseModelOutputWithPoolingAndCrossAttentions, 912 | config_class=_CONFIG_FOR_DOC, 913 | ) 914 | def forward( 915 | self, 916 | input_ids: Optional[torch.Tensor] = None, 917 | attention_mask: Optional[torch.Tensor] = None, 918 | token_type_ids: Optional[torch.Tensor] = None, 919 | position_ids: Optional[torch.Tensor] = None, 920 | head_mask: Optional[torch.Tensor] = None, 921 | inputs_embeds: Optional[torch.Tensor] = None, 922 | encoder_hidden_states: Optional[torch.Tensor] = None, 923 | encoder_attention_mask: Optional[torch.Tensor] = None, 924 | past_key_values: Optional[List[torch.FloatTensor]] = None, 925 | use_cache: Optional[bool] = None, 926 | output_attentions: Optional[bool] = None, 927 | output_hidden_states: Optional[bool] = None, 928 | return_dict: Optional[bool] = None, 929 | ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]: 930 | r""" 931 | encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): 932 | Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if 933 | the model is configured as a decoder. 934 | encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): 935 | Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in 936 | the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: 937 | 938 | - 1 for tokens that are **not masked**, 939 | - 0 for tokens that are **masked**. 940 | past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): 941 | Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. 942 | 943 | If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that 944 | don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all 945 | `decoder_input_ids` of shape `(batch_size, sequence_length)`. 946 | use_cache (`bool`, *optional*): 947 | If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see 948 | `past_key_values`). 949 | """ 950 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 951 | output_hidden_states = ( 952 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 953 | ) 954 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 955 | 956 | if self.config.is_decoder: 957 | use_cache = use_cache if use_cache is not None else self.config.use_cache 958 | else: 959 | use_cache = False 960 | 961 | if input_ids is not None and inputs_embeds is not None: 962 | raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") 963 | elif input_ids is not None: 964 | input_shape = input_ids.size() 965 | elif inputs_embeds is not None: 966 | input_shape = inputs_embeds.size()[:-1] 967 | else: 968 | raise ValueError("You have to specify either input_ids or inputs_embeds") 969 | 970 | batch_size, seq_length = input_shape 971 | device = input_ids.device if input_ids is not None else inputs_embeds.device 972 | 973 | # past_key_values_length 974 | past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 975 | 976 | if attention_mask is None: 977 | attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) 978 | 979 | if token_type_ids is None: 980 | if hasattr(self.embeddings, "token_type_ids"): 981 | buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length] 982 | buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length) 983 | token_type_ids = buffered_token_type_ids_expanded 984 | else: 985 | token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) 986 | 987 | # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] 988 | # ourselves in which case we just need to make it broadcastable to all heads. 989 | extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape) 990 | 991 | # If a 2D or 3D attention mask is provided for the cross-attention 992 | # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] 993 | if self.config.is_decoder and encoder_hidden_states is not None: 994 | encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() 995 | encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) 996 | if encoder_attention_mask is None: 997 | encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) 998 | encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) 999 | else: 1000 | encoder_extended_attention_mask = None 1001 | 1002 | # Prepare head mask if needed 1003 | # 1.0 in head_mask indicate we keep the head 1004 | # attention_probs has shape bsz x n_heads x N x N 1005 | # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] 1006 | # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] 1007 | head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) 1008 | 1009 | embedding_output = self.embeddings( 1010 | input_ids=input_ids, 1011 | position_ids=position_ids, 1012 | token_type_ids=token_type_ids, 1013 | inputs_embeds=inputs_embeds, 1014 | past_key_values_length=past_key_values_length, 1015 | ) 1016 | encoder_outputs = self.encoder( 1017 | embedding_output, 1018 | attention_mask=extended_attention_mask, 1019 | head_mask=head_mask, 1020 | encoder_hidden_states=encoder_hidden_states, 1021 | encoder_attention_mask=encoder_extended_attention_mask, 1022 | past_key_values=past_key_values, 1023 | use_cache=use_cache, 1024 | output_attentions=output_attentions, 1025 | output_hidden_states=output_hidden_states, 1026 | return_dict=return_dict, 1027 | ) 1028 | sequence_output = encoder_outputs[0] 1029 | pooled_output = self.pooler(sequence_output) if self.pooler is not None else None 1030 | 1031 | if not return_dict: 1032 | return (sequence_output, pooled_output) + encoder_outputs[1:] 1033 | 1034 | return BaseModelOutputWithPoolingAndCrossAttentions( 1035 | last_hidden_state=sequence_output, 1036 | pooler_output=pooled_output, 1037 | past_key_values=encoder_outputs.past_key_values, 1038 | hidden_states=encoder_outputs.hidden_states, 1039 | attentions=encoder_outputs.attentions, 1040 | cross_attentions=encoder_outputs.cross_attentions, 1041 | ) 1042 | 1043 | 1044 | @add_start_docstrings( 1045 | """ 1046 | Bert Model with two heads on top as done during the pretraining: a `masked language modeling` head and a `next 1047 | sentence prediction (classification)` head. 1048 | """, 1049 | BERT_START_DOCSTRING, 1050 | ) 1051 | class BertForPreTraining(BertPreTrainedModel): 1052 | _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias", r"cls.predictions.decoder.weight"] 1053 | 1054 | def __init__(self, config): 1055 | super().__init__(config) 1056 | 1057 | self.bert = BertModel(config) 1058 | self.cls = BertPreTrainingHeads(config) 1059 | 1060 | # Initialize weights and apply final processing 1061 | self.post_init() 1062 | 1063 | def get_output_embeddings(self): 1064 | return self.cls.predictions.decoder 1065 | 1066 | def set_output_embeddings(self, new_embeddings): 1067 | self.cls.predictions.decoder = new_embeddings 1068 | 1069 | @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) 1070 | @replace_return_docstrings(output_type=BertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC) 1071 | def forward( 1072 | self, 1073 | input_ids: Optional[torch.Tensor] = None, 1074 | attention_mask: Optional[torch.Tensor] = None, 1075 | token_type_ids: Optional[torch.Tensor] = None, 1076 | position_ids: Optional[torch.Tensor] = None, 1077 | head_mask: Optional[torch.Tensor] = None, 1078 | inputs_embeds: Optional[torch.Tensor] = None, 1079 | labels: Optional[torch.Tensor] = None, 1080 | next_sentence_label: Optional[torch.Tensor] = None, 1081 | output_attentions: Optional[bool] = None, 1082 | output_hidden_states: Optional[bool] = None, 1083 | return_dict: Optional[bool] = None, 1084 | ) -> Union[Tuple[torch.Tensor], BertForPreTrainingOutput]: 1085 | r""" 1086 | labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): 1087 | Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., 1088 | config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), 1089 | the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` 1090 | next_sentence_label (`torch.LongTensor` of shape `(batch_size,)`, *optional*): 1091 | Labels for computing the next sequence prediction (classification) loss. Input should be a sequence 1092 | pair (see `input_ids` docstring) Indices should be in `[0, 1]`: 1093 | 1094 | - 0 indicates sequence B is a continuation of sequence A, 1095 | - 1 indicates sequence B is a random sequence. 1096 | kwargs (`Dict[str, any]`, optional, defaults to *{}*): 1097 | Used to hide legacy arguments that have been deprecated. 1098 | 1099 | Returns: 1100 | 1101 | Example: 1102 | 1103 | ```python 1104 | >>> from transformers import AutoTokenizer, BertForPreTraining 1105 | >>> import torch 1106 | 1107 | >>> tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") 1108 | >>> model = BertForPreTraining.from_pretrained("bert-base-uncased") 1109 | 1110 | >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") 1111 | >>> outputs = model(**inputs) 1112 | 1113 | >>> prediction_logits = outputs.prediction_logits 1114 | >>> seq_relationship_logits = outputs.seq_relationship_logits 1115 | ``` 1116 | """ 1117 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1118 | 1119 | outputs = self.bert( 1120 | input_ids, 1121 | attention_mask=attention_mask, 1122 | token_type_ids=token_type_ids, 1123 | position_ids=position_ids, 1124 | head_mask=head_mask, 1125 | inputs_embeds=inputs_embeds, 1126 | output_attentions=output_attentions, 1127 | output_hidden_states=output_hidden_states, 1128 | return_dict=return_dict, 1129 | ) 1130 | 1131 | sequence_output, pooled_output = outputs[:2] 1132 | prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output) 1133 | 1134 | total_loss = None 1135 | if labels is not None and next_sentence_label is not None: 1136 | loss_fct = CrossEntropyLoss() 1137 | masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) 1138 | next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1)) 1139 | total_loss = masked_lm_loss + next_sentence_loss 1140 | 1141 | if not return_dict: 1142 | output = (prediction_scores, seq_relationship_score) + outputs[2:] 1143 | return ((total_loss,) + output) if total_loss is not None else output 1144 | 1145 | return BertForPreTrainingOutput( 1146 | loss=total_loss, 1147 | prediction_logits=prediction_scores, 1148 | seq_relationship_logits=seq_relationship_score, 1149 | hidden_states=outputs.hidden_states, 1150 | attentions=outputs.attentions, 1151 | ) 1152 | 1153 | 1154 | @add_start_docstrings( 1155 | """Bert Model with a `language modeling` head on top for CLM fine-tuning.""", BERT_START_DOCSTRING 1156 | ) 1157 | class BertLMHeadModel(BertPreTrainedModel): 1158 | _keys_to_ignore_on_load_unexpected = [r"pooler"] 1159 | _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias", r"cls.predictions.decoder.weight"] 1160 | 1161 | def __init__(self, config): 1162 | super().__init__(config) 1163 | 1164 | if not config.is_decoder: 1165 | logger.warning("If you want to use `BertLMHeadModel` as a standalone, add `is_decoder=True.`") 1166 | 1167 | self.bert = BertModel(config, add_pooling_layer=False) 1168 | self.cls = BertOnlyMLMHead(config) 1169 | 1170 | # Initialize weights and apply final processing 1171 | self.post_init() 1172 | 1173 | def get_output_embeddings(self): 1174 | return self.cls.predictions.decoder 1175 | 1176 | def set_output_embeddings(self, new_embeddings): 1177 | self.cls.predictions.decoder = new_embeddings 1178 | 1179 | @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) 1180 | @add_code_sample_docstrings( 1181 | checkpoint=_CHECKPOINT_FOR_DOC, 1182 | output_type=CausalLMOutputWithCrossAttentions, 1183 | config_class=_CONFIG_FOR_DOC, 1184 | ) 1185 | def forward( 1186 | self, 1187 | input_ids: Optional[torch.Tensor] = None, 1188 | attention_mask: Optional[torch.Tensor] = None, 1189 | token_type_ids: Optional[torch.Tensor] = None, 1190 | position_ids: Optional[torch.Tensor] = None, 1191 | head_mask: Optional[torch.Tensor] = None, 1192 | inputs_embeds: Optional[torch.Tensor] = None, 1193 | encoder_hidden_states: Optional[torch.Tensor] = None, 1194 | encoder_attention_mask: Optional[torch.Tensor] = None, 1195 | labels: Optional[torch.Tensor] = None, 1196 | past_key_values: Optional[List[torch.Tensor]] = None, 1197 | use_cache: Optional[bool] = None, 1198 | output_attentions: Optional[bool] = None, 1199 | output_hidden_states: Optional[bool] = None, 1200 | return_dict: Optional[bool] = None, 1201 | ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]: 1202 | r""" 1203 | encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): 1204 | Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if 1205 | the model is configured as a decoder. 1206 | encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): 1207 | Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in 1208 | the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: 1209 | 1210 | - 1 for tokens that are **not masked**, 1211 | - 0 for tokens that are **masked**. 1212 | labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): 1213 | Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in 1214 | `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are 1215 | ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., config.vocab_size]` 1216 | past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): 1217 | Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. 1218 | 1219 | If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that 1220 | don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all 1221 | `decoder_input_ids` of shape `(batch_size, sequence_length)`. 1222 | use_cache (`bool`, *optional*): 1223 | If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see 1224 | `past_key_values`). 1225 | """ 1226 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1227 | if labels is not None: 1228 | use_cache = False 1229 | 1230 | outputs = self.bert( 1231 | input_ids, 1232 | attention_mask=attention_mask, 1233 | token_type_ids=token_type_ids, 1234 | position_ids=position_ids, 1235 | head_mask=head_mask, 1236 | inputs_embeds=inputs_embeds, 1237 | encoder_hidden_states=encoder_hidden_states, 1238 | encoder_attention_mask=encoder_attention_mask, 1239 | past_key_values=past_key_values, 1240 | use_cache=use_cache, 1241 | output_attentions=output_attentions, 1242 | output_hidden_states=output_hidden_states, 1243 | return_dict=return_dict, 1244 | ) 1245 | 1246 | sequence_output = outputs[0] 1247 | prediction_scores = self.cls(sequence_output) 1248 | 1249 | lm_loss = None 1250 | if labels is not None: 1251 | # we are doing next-token prediction; shift prediction scores and input ids by one 1252 | shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous() 1253 | labels = labels[:, 1:].contiguous() 1254 | loss_fct = CrossEntropyLoss() 1255 | lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) 1256 | 1257 | if not return_dict: 1258 | output = (prediction_scores,) + outputs[2:] 1259 | return ((lm_loss,) + output) if lm_loss is not None else output 1260 | 1261 | return CausalLMOutputWithCrossAttentions( 1262 | loss=lm_loss, 1263 | logits=prediction_scores, 1264 | past_key_values=outputs.past_key_values, 1265 | hidden_states=outputs.hidden_states, 1266 | attentions=outputs.attentions, 1267 | cross_attentions=outputs.cross_attentions, 1268 | ) 1269 | 1270 | def prepare_inputs_for_generation( 1271 | self, input_ids, past_key_values=None, attention_mask=None, use_cache=True, **model_kwargs 1272 | ): 1273 | input_shape = input_ids.shape 1274 | # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly 1275 | if attention_mask is None: 1276 | attention_mask = input_ids.new_ones(input_shape) 1277 | 1278 | # cut decoder_input_ids if past_key_values is used 1279 | if past_key_values is not None: 1280 | input_ids = input_ids[:, -1:] 1281 | 1282 | return { 1283 | "input_ids": input_ids, 1284 | "attention_mask": attention_mask, 1285 | "past_key_values": past_key_values, 1286 | "use_cache": use_cache, 1287 | } 1288 | 1289 | def _reorder_cache(self, past_key_values, beam_idx): 1290 | reordered_past = () 1291 | for layer_past in past_key_values: 1292 | reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) 1293 | return reordered_past 1294 | 1295 | 1296 | @add_start_docstrings("""Bert Model with a `language modeling` head on top.""", BERT_START_DOCSTRING) 1297 | class BertForMaskedLM(BertPreTrainedModel): 1298 | _keys_to_ignore_on_load_unexpected = [r"pooler"] 1299 | _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias", r"cls.predictions.decoder.weight"] 1300 | 1301 | def __init__(self, config): 1302 | super().__init__(config) 1303 | 1304 | if config.is_decoder: 1305 | logger.warning( 1306 | "If you want to use `BertForMaskedLM` make sure `config.is_decoder=False` for " 1307 | "bi-directional self-attention." 1308 | ) 1309 | 1310 | self.bert = BertModel(config, add_pooling_layer=False) 1311 | self.cls = BertOnlyMLMHead(config) 1312 | 1313 | # Initialize weights and apply final processing 1314 | self.post_init() 1315 | 1316 | def get_output_embeddings(self): 1317 | return self.cls.predictions.decoder 1318 | 1319 | def set_output_embeddings(self, new_embeddings): 1320 | self.cls.predictions.decoder = new_embeddings 1321 | 1322 | @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) 1323 | @add_code_sample_docstrings( 1324 | checkpoint=_CHECKPOINT_FOR_DOC, 1325 | output_type=MaskedLMOutput, 1326 | config_class=_CONFIG_FOR_DOC, 1327 | expected_output="'paris'", 1328 | expected_loss=0.88, 1329 | ) 1330 | def forward( 1331 | self, 1332 | input_ids: Optional[torch.Tensor] = None, 1333 | attention_mask: Optional[torch.Tensor] = None, 1334 | token_type_ids: Optional[torch.Tensor] = None, 1335 | position_ids: Optional[torch.Tensor] = None, 1336 | head_mask: Optional[torch.Tensor] = None, 1337 | inputs_embeds: Optional[torch.Tensor] = None, 1338 | encoder_hidden_states: Optional[torch.Tensor] = None, 1339 | encoder_attention_mask: Optional[torch.Tensor] = None, 1340 | labels: Optional[torch.Tensor] = None, 1341 | output_attentions: Optional[bool] = None, 1342 | output_hidden_states: Optional[bool] = None, 1343 | return_dict: Optional[bool] = None, 1344 | ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]: 1345 | r""" 1346 | labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): 1347 | Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., 1348 | config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the 1349 | loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` 1350 | """ 1351 | 1352 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1353 | 1354 | outputs = self.bert( 1355 | input_ids, 1356 | attention_mask=attention_mask, 1357 | token_type_ids=token_type_ids, 1358 | position_ids=position_ids, 1359 | head_mask=head_mask, 1360 | inputs_embeds=inputs_embeds, 1361 | encoder_hidden_states=encoder_hidden_states, 1362 | encoder_attention_mask=encoder_attention_mask, 1363 | output_attentions=output_attentions, 1364 | output_hidden_states=output_hidden_states, 1365 | return_dict=return_dict, 1366 | ) 1367 | 1368 | sequence_output = outputs[0] 1369 | prediction_scores = self.cls(sequence_output) 1370 | 1371 | masked_lm_loss = None 1372 | if labels is not None: 1373 | loss_fct = CrossEntropyLoss() # -100 index = padding token 1374 | masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) 1375 | 1376 | if not return_dict: 1377 | output = (prediction_scores,) + outputs[2:] 1378 | return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output 1379 | 1380 | return MaskedLMOutput( 1381 | loss=masked_lm_loss, 1382 | logits=prediction_scores, 1383 | hidden_states=outputs.hidden_states, 1384 | attentions=outputs.attentions, 1385 | ) 1386 | 1387 | def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs): 1388 | input_shape = input_ids.shape 1389 | effective_batch_size = input_shape[0] 1390 | 1391 | # add a dummy token 1392 | if self.config.pad_token_id is None: 1393 | raise ValueError("The PAD token should be defined for generation") 1394 | 1395 | attention_mask = torch.cat([attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], dim=-1) 1396 | dummy_token = torch.full( 1397 | (effective_batch_size, 1), self.config.pad_token_id, dtype=torch.long, device=input_ids.device 1398 | ) 1399 | input_ids = torch.cat([input_ids, dummy_token], dim=1) 1400 | 1401 | return {"input_ids": input_ids, "attention_mask": attention_mask} 1402 | 1403 | 1404 | @add_start_docstrings( 1405 | """Bert Model with a `next sentence prediction (classification)` head on top.""", 1406 | BERT_START_DOCSTRING, 1407 | ) 1408 | class BertForNextSentencePrediction(BertPreTrainedModel): 1409 | def __init__(self, config): 1410 | super().__init__(config) 1411 | 1412 | self.bert = BertModel(config) 1413 | self.cls = BertOnlyNSPHead(config) 1414 | 1415 | # Initialize weights and apply final processing 1416 | self.post_init() 1417 | 1418 | @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) 1419 | @replace_return_docstrings(output_type=NextSentencePredictorOutput, config_class=_CONFIG_FOR_DOC) 1420 | def forward( 1421 | self, 1422 | input_ids: Optional[torch.Tensor] = None, 1423 | attention_mask: Optional[torch.Tensor] = None, 1424 | token_type_ids: Optional[torch.Tensor] = None, 1425 | position_ids: Optional[torch.Tensor] = None, 1426 | head_mask: Optional[torch.Tensor] = None, 1427 | inputs_embeds: Optional[torch.Tensor] = None, 1428 | labels: Optional[torch.Tensor] = None, 1429 | output_attentions: Optional[bool] = None, 1430 | output_hidden_states: Optional[bool] = None, 1431 | return_dict: Optional[bool] = None, 1432 | **kwargs, 1433 | ) -> Union[Tuple[torch.Tensor], NextSentencePredictorOutput]: 1434 | r""" 1435 | labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): 1436 | Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair 1437 | (see `input_ids` docstring). Indices should be in `[0, 1]`: 1438 | 1439 | - 0 indicates sequence B is a continuation of sequence A, 1440 | - 1 indicates sequence B is a random sequence. 1441 | 1442 | Returns: 1443 | 1444 | Example: 1445 | 1446 | ```python 1447 | >>> from transformers import AutoTokenizer, BertForNextSentencePrediction 1448 | >>> import torch 1449 | 1450 | >>> tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") 1451 | >>> model = BertForNextSentencePrediction.from_pretrained("bert-base-uncased") 1452 | 1453 | >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced." 1454 | >>> next_sentence = "The sky is blue due to the shorter wavelength of blue light." 1455 | >>> encoding = tokenizer(prompt, next_sentence, return_tensors="pt") 1456 | 1457 | >>> outputs = model(**encoding, labels=torch.LongTensor([1])) 1458 | >>> logits = outputs.logits 1459 | >>> assert logits[0, 0] < logits[0, 1] # next sentence was random 1460 | ``` 1461 | """ 1462 | 1463 | if "next_sentence_label" in kwargs: 1464 | warnings.warn( 1465 | "The `next_sentence_label` argument is deprecated and will be removed in a future version, use" 1466 | " `labels` instead.", 1467 | FutureWarning, 1468 | ) 1469 | labels = kwargs.pop("next_sentence_label") 1470 | 1471 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1472 | 1473 | outputs = self.bert( 1474 | input_ids, 1475 | attention_mask=attention_mask, 1476 | token_type_ids=token_type_ids, 1477 | position_ids=position_ids, 1478 | head_mask=head_mask, 1479 | inputs_embeds=inputs_embeds, 1480 | output_attentions=output_attentions, 1481 | output_hidden_states=output_hidden_states, 1482 | return_dict=return_dict, 1483 | ) 1484 | 1485 | pooled_output = outputs[1] 1486 | 1487 | seq_relationship_scores = self.cls(pooled_output) 1488 | 1489 | next_sentence_loss = None 1490 | if labels is not None: 1491 | loss_fct = CrossEntropyLoss() 1492 | next_sentence_loss = loss_fct(seq_relationship_scores.view(-1, 2), labels.view(-1)) 1493 | 1494 | if not return_dict: 1495 | output = (seq_relationship_scores,) + outputs[2:] 1496 | return ((next_sentence_loss,) + output) if next_sentence_loss is not None else output 1497 | 1498 | return NextSentencePredictorOutput( 1499 | loss=next_sentence_loss, 1500 | logits=seq_relationship_scores, 1501 | hidden_states=outputs.hidden_states, 1502 | attentions=outputs.attentions, 1503 | ) 1504 | 1505 | 1506 | @add_start_docstrings( 1507 | """ 1508 | Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled 1509 | output) e.g. for GLUE tasks. 1510 | """, 1511 | BERT_START_DOCSTRING, 1512 | ) 1513 | class BertForSequenceClassification(BertPreTrainedModel): 1514 | def __init__(self, config): 1515 | super().__init__(config) 1516 | self.num_labels = config.num_labels 1517 | self.config = config 1518 | 1519 | self.bert = BertModel(config) 1520 | classifier_dropout = ( 1521 | config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob 1522 | ) 1523 | self.dropout = nn.Dropout(classifier_dropout) 1524 | self.classifier = nn.Linear(config.hidden_size, config.num_labels) 1525 | 1526 | # Initialize weights and apply final processing 1527 | self.post_init() 1528 | 1529 | @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) 1530 | @add_code_sample_docstrings( 1531 | checkpoint=_CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION, 1532 | output_type=SequenceClassifierOutput, 1533 | config_class=_CONFIG_FOR_DOC, 1534 | expected_output=_SEQ_CLASS_EXPECTED_OUTPUT, 1535 | expected_loss=_SEQ_CLASS_EXPECTED_LOSS, 1536 | ) 1537 | def forward( 1538 | self, 1539 | input_ids: Optional[torch.Tensor] = None, 1540 | attention_mask: Optional[torch.Tensor] = None, 1541 | token_type_ids: Optional[torch.Tensor] = None, 1542 | position_ids: Optional[torch.Tensor] = None, 1543 | head_mask: Optional[torch.Tensor] = None, 1544 | inputs_embeds: Optional[torch.Tensor] = None, 1545 | labels: Optional[torch.Tensor] = None, 1546 | output_attentions: Optional[bool] = None, 1547 | output_hidden_states: Optional[bool] = None, 1548 | return_dict: Optional[bool] = None, 1549 | ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]: 1550 | r""" 1551 | labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): 1552 | Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., 1553 | config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If 1554 | `config.num_labels > 1` a classification loss is computed (Cross-Entropy). 1555 | """ 1556 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1557 | 1558 | outputs = self.bert( 1559 | input_ids, 1560 | attention_mask=attention_mask, 1561 | token_type_ids=token_type_ids, 1562 | position_ids=position_ids, 1563 | head_mask=head_mask, 1564 | inputs_embeds=inputs_embeds, 1565 | output_attentions=output_attentions, 1566 | output_hidden_states=output_hidden_states, 1567 | return_dict=return_dict, 1568 | ) 1569 | 1570 | pooled_output = outputs[1] 1571 | 1572 | pooled_output = self.dropout(pooled_output) 1573 | logits = self.classifier(pooled_output) 1574 | 1575 | loss = None 1576 | if labels is not None: 1577 | if self.config.problem_type is None: 1578 | if self.num_labels == 1: 1579 | self.config.problem_type = "regression" 1580 | elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): 1581 | self.config.problem_type = "single_label_classification" 1582 | else: 1583 | self.config.problem_type = "multi_label_classification" 1584 | 1585 | if self.config.problem_type == "regression": 1586 | loss_fct = MSELoss() 1587 | if self.num_labels == 1: 1588 | loss = loss_fct(logits.squeeze(), labels.squeeze()) 1589 | else: 1590 | loss = loss_fct(logits, labels) 1591 | elif self.config.problem_type == "single_label_classification": 1592 | loss_fct = CrossEntropyLoss() 1593 | loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) 1594 | elif self.config.problem_type == "multi_label_classification": 1595 | loss_fct = BCEWithLogitsLoss() 1596 | loss = loss_fct(logits, labels) 1597 | if not return_dict: 1598 | output = (logits,) + outputs[2:] 1599 | return ((loss,) + output) if loss is not None else output 1600 | 1601 | return SequenceClassifierOutput( 1602 | loss=loss, 1603 | logits=logits, 1604 | hidden_states=outputs.hidden_states, 1605 | attentions=outputs.attentions, 1606 | ) 1607 | 1608 | 1609 | @add_start_docstrings( 1610 | """ 1611 | Bert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a 1612 | softmax) e.g. for RocStories/SWAG tasks. 1613 | """, 1614 | BERT_START_DOCSTRING, 1615 | ) 1616 | class BertForMultipleChoice(BertPreTrainedModel): 1617 | def __init__(self, config): 1618 | super().__init__(config) 1619 | 1620 | self.bert = BertModel(config) 1621 | classifier_dropout = ( 1622 | config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob 1623 | ) 1624 | self.dropout = nn.Dropout(classifier_dropout) 1625 | self.classifier = nn.Linear(config.hidden_size, 1) 1626 | 1627 | # Initialize weights and apply final processing 1628 | self.post_init() 1629 | 1630 | @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")) 1631 | @add_code_sample_docstrings( 1632 | checkpoint=_CHECKPOINT_FOR_DOC, 1633 | output_type=MultipleChoiceModelOutput, 1634 | config_class=_CONFIG_FOR_DOC, 1635 | ) 1636 | def forward( 1637 | self, 1638 | input_ids: Optional[torch.Tensor] = None, 1639 | attention_mask: Optional[torch.Tensor] = None, 1640 | token_type_ids: Optional[torch.Tensor] = None, 1641 | position_ids: Optional[torch.Tensor] = None, 1642 | head_mask: Optional[torch.Tensor] = None, 1643 | inputs_embeds: Optional[torch.Tensor] = None, 1644 | labels: Optional[torch.Tensor] = None, 1645 | output_attentions: Optional[bool] = None, 1646 | output_hidden_states: Optional[bool] = None, 1647 | return_dict: Optional[bool] = None, 1648 | ) -> Union[Tuple[torch.Tensor], MultipleChoiceModelOutput]: 1649 | r""" 1650 | labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): 1651 | Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., 1652 | num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See 1653 | `input_ids` above) 1654 | """ 1655 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1656 | num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1] 1657 | 1658 | input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None 1659 | attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None 1660 | token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None 1661 | position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None 1662 | inputs_embeds = ( 1663 | inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) 1664 | if inputs_embeds is not None 1665 | else None 1666 | ) 1667 | 1668 | outputs = self.bert( 1669 | input_ids, 1670 | attention_mask=attention_mask, 1671 | token_type_ids=token_type_ids, 1672 | position_ids=position_ids, 1673 | head_mask=head_mask, 1674 | inputs_embeds=inputs_embeds, 1675 | output_attentions=output_attentions, 1676 | output_hidden_states=output_hidden_states, 1677 | return_dict=return_dict, 1678 | ) 1679 | 1680 | pooled_output = outputs[1] 1681 | 1682 | pooled_output = self.dropout(pooled_output) 1683 | logits = self.classifier(pooled_output) 1684 | reshaped_logits = logits.view(-1, num_choices) 1685 | 1686 | loss = None 1687 | if labels is not None: 1688 | loss_fct = CrossEntropyLoss() 1689 | loss = loss_fct(reshaped_logits, labels) 1690 | 1691 | if not return_dict: 1692 | output = (reshaped_logits,) + outputs[2:] 1693 | return ((loss,) + output) if loss is not None else output 1694 | 1695 | return MultipleChoiceModelOutput( 1696 | loss=loss, 1697 | logits=reshaped_logits, 1698 | hidden_states=outputs.hidden_states, 1699 | attentions=outputs.attentions, 1700 | ) 1701 | 1702 | 1703 | @add_start_docstrings( 1704 | """ 1705 | Bert Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for 1706 | Named-Entity-Recognition (NER) tasks. 1707 | """, 1708 | BERT_START_DOCSTRING, 1709 | ) 1710 | class BertForTokenClassification(BertPreTrainedModel): 1711 | _keys_to_ignore_on_load_unexpected = [r"pooler"] 1712 | 1713 | def __init__(self, config): 1714 | super().__init__(config) 1715 | self.num_labels = config.num_labels 1716 | 1717 | self.bert = BertModel(config, add_pooling_layer=False) 1718 | classifier_dropout = ( 1719 | config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob 1720 | ) 1721 | self.dropout = nn.Dropout(classifier_dropout) 1722 | self.classifier = nn.Linear(config.hidden_size, config.num_labels) 1723 | 1724 | # Initialize weights and apply final processing 1725 | self.post_init() 1726 | 1727 | @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) 1728 | @add_code_sample_docstrings( 1729 | checkpoint=_CHECKPOINT_FOR_TOKEN_CLASSIFICATION, 1730 | output_type=TokenClassifierOutput, 1731 | config_class=_CONFIG_FOR_DOC, 1732 | expected_output=_TOKEN_CLASS_EXPECTED_OUTPUT, 1733 | expected_loss=_TOKEN_CLASS_EXPECTED_LOSS, 1734 | ) 1735 | def forward( 1736 | self, 1737 | input_ids: Optional[torch.Tensor] = None, 1738 | attention_mask: Optional[torch.Tensor] = None, 1739 | token_type_ids: Optional[torch.Tensor] = None, 1740 | position_ids: Optional[torch.Tensor] = None, 1741 | head_mask: Optional[torch.Tensor] = None, 1742 | inputs_embeds: Optional[torch.Tensor] = None, 1743 | labels: Optional[torch.Tensor] = None, 1744 | output_attentions: Optional[bool] = None, 1745 | output_hidden_states: Optional[bool] = None, 1746 | return_dict: Optional[bool] = None, 1747 | ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]: 1748 | r""" 1749 | labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): 1750 | Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. 1751 | """ 1752 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1753 | 1754 | outputs = self.bert( 1755 | input_ids, 1756 | attention_mask=attention_mask, 1757 | token_type_ids=token_type_ids, 1758 | position_ids=position_ids, 1759 | head_mask=head_mask, 1760 | inputs_embeds=inputs_embeds, 1761 | output_attentions=output_attentions, 1762 | output_hidden_states=output_hidden_states, 1763 | return_dict=return_dict, 1764 | ) 1765 | 1766 | sequence_output = outputs[0] 1767 | 1768 | sequence_output = self.dropout(sequence_output) 1769 | logits = self.classifier(sequence_output) 1770 | 1771 | loss = None 1772 | if labels is not None: 1773 | loss_fct = CrossEntropyLoss() 1774 | loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) 1775 | 1776 | if not return_dict: 1777 | output = (logits,) + outputs[2:] 1778 | return ((loss,) + output) if loss is not None else output 1779 | 1780 | return TokenClassifierOutput( 1781 | loss=loss, 1782 | logits=logits, 1783 | hidden_states=outputs.hidden_states, 1784 | attentions=outputs.attentions, 1785 | ) 1786 | 1787 | 1788 | @add_start_docstrings( 1789 | """ 1790 | Bert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear 1791 | layers on top of the hidden-states output to compute `span start logits` and `span end logits`). 1792 | """, 1793 | BERT_START_DOCSTRING, 1794 | ) 1795 | class BertForQuestionAnswering(BertPreTrainedModel): 1796 | _keys_to_ignore_on_load_unexpected = [r"pooler"] 1797 | 1798 | def __init__(self, config): 1799 | super().__init__(config) 1800 | self.num_labels = config.num_labels 1801 | 1802 | self.bert = BertModel(config, add_pooling_layer=False) 1803 | self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) 1804 | 1805 | # Initialize weights and apply final processing 1806 | self.post_init() 1807 | 1808 | @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) 1809 | @add_code_sample_docstrings( 1810 | checkpoint=_CHECKPOINT_FOR_QA, 1811 | output_type=QuestionAnsweringModelOutput, 1812 | config_class=_CONFIG_FOR_DOC, 1813 | qa_target_start_index=_QA_TARGET_START_INDEX, 1814 | qa_target_end_index=_QA_TARGET_END_INDEX, 1815 | expected_output=_QA_EXPECTED_OUTPUT, 1816 | expected_loss=_QA_EXPECTED_LOSS, 1817 | ) 1818 | def forward( 1819 | self, 1820 | input_ids: Optional[torch.Tensor] = None, 1821 | attention_mask: Optional[torch.Tensor] = None, 1822 | token_type_ids: Optional[torch.Tensor] = None, 1823 | position_ids: Optional[torch.Tensor] = None, 1824 | head_mask: Optional[torch.Tensor] = None, 1825 | inputs_embeds: Optional[torch.Tensor] = None, 1826 | start_positions: Optional[torch.Tensor] = None, 1827 | end_positions: Optional[torch.Tensor] = None, 1828 | output_attentions: Optional[bool] = None, 1829 | output_hidden_states: Optional[bool] = None, 1830 | return_dict: Optional[bool] = None, 1831 | ) -> Union[Tuple[torch.Tensor], QuestionAnsweringModelOutput]: 1832 | r""" 1833 | start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): 1834 | Labels for position (index) of the start of the labelled span for computing the token classification loss. 1835 | Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence 1836 | are not taken into account for computing the loss. 1837 | end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): 1838 | Labels for position (index) of the end of the labelled span for computing the token classification loss. 1839 | Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence 1840 | are not taken into account for computing the loss. 1841 | """ 1842 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1843 | 1844 | outputs = self.bert( 1845 | input_ids, 1846 | attention_mask=attention_mask, 1847 | token_type_ids=token_type_ids, 1848 | position_ids=position_ids, 1849 | head_mask=head_mask, 1850 | inputs_embeds=inputs_embeds, 1851 | output_attentions=output_attentions, 1852 | output_hidden_states=output_hidden_states, 1853 | return_dict=return_dict, 1854 | ) 1855 | 1856 | sequence_output = outputs[0] 1857 | 1858 | logits = self.qa_outputs(sequence_output) 1859 | start_logits, end_logits = logits.split(1, dim=-1) 1860 | start_logits = start_logits.squeeze(-1).contiguous() 1861 | end_logits = end_logits.squeeze(-1).contiguous() 1862 | 1863 | total_loss = None 1864 | if start_positions is not None and end_positions is not None: 1865 | # If we are on multi-GPU, split add a dimension 1866 | if len(start_positions.size()) > 1: 1867 | start_positions = start_positions.squeeze(-1) 1868 | if len(end_positions.size()) > 1: 1869 | end_positions = end_positions.squeeze(-1) 1870 | # sometimes the start/end positions are outside our model inputs, we ignore these terms 1871 | ignored_index = start_logits.size(1) 1872 | start_positions = start_positions.clamp(0, ignored_index) 1873 | end_positions = end_positions.clamp(0, ignored_index) 1874 | 1875 | loss_fct = CrossEntropyLoss(ignore_index=ignored_index) 1876 | start_loss = loss_fct(start_logits, start_positions) 1877 | end_loss = loss_fct(end_logits, end_positions) 1878 | total_loss = (start_loss + end_loss) / 2 1879 | 1880 | if not return_dict: 1881 | output = (start_logits, end_logits) + outputs[2:] 1882 | return ((total_loss,) + output) if total_loss is not None else output 1883 | 1884 | return QuestionAnsweringModelOutput( 1885 | loss=total_loss, 1886 | start_logits=start_logits, 1887 | end_logits=end_logits, 1888 | hidden_states=outputs.hidden_states, 1889 | attentions=outputs.attentions, 1890 | ) --------------------------------------------------------------------------------