├── overview.png ├── requirements.txt ├── scripts ├── run_cnn.sh └── run_xsum.sh ├── CONTRIBUTING.md ├── KD_loss.py ├── CODE_OF_CONDUCT.md ├── README.md ├── quant ├── configuration_bart_quant.py ├── utils_quant.py └── modeling_bart_quant.py ├── LICENSE └── run_summarization_no_trainer.py /overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/Ternary_Binary_Transformer/HEAD/overview.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | transformers==4.17.0 2 | datasets==1.18.4 3 | sacrebleu==2.0 4 | wandb 5 | nltk 6 | accelerate==0.5.1 7 | tensorboard 8 | setuptools<50 9 | rouge_score 10 | -------------------------------------------------------------------------------- /scripts/run_cnn.sh: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | accelerate launch --num_processes 8 --num_machines 1 --multi_gpu run_summarization_no_trainer.py \ 9 | --model_name_or_path local_path/bart-base-cnn \ 10 | --dataset_name cnn_dailymail \ 11 | --dataset_config_name 3.0.0 \ 12 | --pred_distill \ 13 | --hid_distill \ 14 | --num_train_epochs 20 \ 15 | --weight_bits $1 \ 16 | --input_bits $2 \ 17 | --do_train \ 18 | --do_test \ 19 | --distill_encoder 6 \ 20 | --distill_decoder 6 \ 21 | --learning_rate $3 \ 22 | -------------------------------------------------------------------------------- /scripts/run_xsum.sh: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | accelerate launch --num_processes 8 --num_machines 1 --multi_gpu run_summarization_no_trainer.py \ 9 | --model_name_or_path local_path/bart-base-xsum_from_Aktsvigun \ 10 | --dataset_name xsum \ 11 | --dataset_config_name 3.0.0 \ 12 | --pred_distill \ 13 | --hid_distill \ 14 | --num_train_epochs 20 \ 15 | --weight_bits $1 \ 16 | --input_bits $2 \ 17 | --do_train \ 18 | --do_test \ 19 | --distill_encoder 6 \ 20 | --distill_decoder 6 \ 21 | --learning_rate $3 \ 22 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | #Contributing to DPR 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Our Development Process 6 | TBD 7 | 8 | ## Pull Requests 9 | We actively welcome your pull requests. 10 | 11 | 1. Fork the repo and create your branch from `master`. 12 | 2. If you've added code that should be tested, add tests. 13 | 3. If you've changed APIs, update the documentation. 14 | 4. Ensure the test suite passes. 15 | 5. Make sure your code lints. 16 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 17 | 18 | ## Contributor License Agreement ("CLA") 19 | In order to accept your pull request, we need you to submit a CLA. You only need 20 | to do this once to work on any of Facebook's open source projects. 21 | 22 | Complete your CLA here: 23 | 24 | ## Issues 25 | We use GitHub issues to track public bugs. Please ensure your description is 26 | clear and has sufficient instructions to be able to reproduce the issue. 27 | 28 | ## Coding Style 29 | * 2 spaces for indentation rather than tabs 30 | * 120 character line length 31 | * ... 32 | 33 | ## License 34 | By contributing to Facebook AI Research Ternary Binary Transformer toolkit, you agree that your contributions will be licensed 35 | under the LICENSE file in the root directory of this source tree. 36 | -------------------------------------------------------------------------------- /KD_loss.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import torch 9 | from torch.nn import functional as F 10 | from torch.nn.modules import loss 11 | 12 | 13 | class DistributionLoss(loss._Loss): 14 | """The KL-Divergence loss for the binary student model and real teacher output. 15 | output must be a pair of (model_output, real_output), both NxC tensors. 16 | The rows of real_output must all add up to one (probability scores); 17 | however, model_output must be the pre-softmax output of the network.""" 18 | 19 | def forward(self, model_output, real_output): 20 | 21 | self.size_average = True 22 | 23 | # Target is ignored at training time. Loss is defined as KL divergence 24 | # between the model output and the refined labels. 25 | if real_output.requires_grad: 26 | raise ValueError("real network output should not require gradients.") 27 | 28 | model_output_log_prob = F.log_softmax(model_output, dim=1) 29 | real_output_soft = F.softmax(real_output, dim=1) 30 | del model_output, real_output 31 | 32 | # Loss is -dot(model_output_log_prob, real_output). Prepare tensors 33 | # for batch matrix multiplicatio 34 | real_output_soft = real_output_soft.unsqueeze(1) 35 | model_output_log_prob = model_output_log_prob.unsqueeze(2) 36 | 37 | # Compute the loss, and average/sum for the batch. 38 | cross_entropy_loss = -torch.bmm(real_output_soft, model_output_log_prob) 39 | if self.size_average: 40 | cross_entropy_loss = cross_entropy_loss.mean() 41 | else: 42 | cross_entropy_loss = cross_entropy_loss.sum() 43 | # Return a pair of (loss_output, model_output). Model output will be 44 | # used for top-1 and top-5 evaluation. 45 | # model_output_log_prob = model_output_log_prob.squeeze(2) 46 | return cross_entropy_loss 47 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | ## Enforcement 56 | 57 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 58 | reported by contacting the project team at . All 59 | complaints will be reviewed and investigated and will result in a response that 60 | is deemed necessary and appropriate to the circumstances. The project team is 61 | obligated to maintain confidentiality with regard to the reporter of an incident. 62 | Further details of specific enforcement policies may be posted separately. 63 | 64 | Project maintainers who do not follow or enforce the Code of Conduct in good 65 | faith may face temporary or permanent repercussions as determined by other 66 | members of the project's leadership. 67 | 68 | ## Attribution 69 | 70 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 71 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 72 | 73 | [homepage]: https://www.contributor-covenant.org 74 | 75 | For answers to common questions about this code of conduct, see 76 | https://www.contributor-covenant.org/faq -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Ternary_Binary_Transformer 2 | 3 | 4 | This repository contains the training code of TBT introduced in our work: "[Binary and Ternary Natural Language Generation](https://arxiv.org/abs/2306.01841)", published in ACL 2023. 5 | 6 | We approach the problem with a mix of statistics-based quantization for the weights and elastic quantization of the activations and demonstrate the first ternary and binary transformer models on the downstream tasks of summarization and machine translation. 7 | 8 |
9 | 10 |
11 | 12 | 13 | ## Citation 14 | 15 | If you find our code useful for your research, please consider citing: 16 | 17 | @article{liu2023binary, 18 | title={Binary and Ternary Natural Language Generation}, 19 | author={Liu, Zechun and Oguz, Barlas and Pappu, Aasish and Shi, Yangyang and Krishnamoorthi, Raghuraman}, 20 | booktitle={Proceedings of the 61st Annual Meeting of the Association for Computational Linguistics} 21 | year={2023} 22 | } 23 | 24 | Our previous papers related to binarizing BERT model: 25 | * BiT: Robustly Binarized Multi-distilled Transformer (NeurIPS 2022) \[[code](https://github.com/facebookresearch/bit)\] \[[paper](https://arxiv.org/pdf/2205.13016.pdf)\] 26 | 27 | ## Run 28 | 29 | ### 1. Requirements: 30 | * python 3.9.12, pytorch 1.12.1 31 | 32 | ### 2. Pretrained models: 33 | * Download pretrained models from hugging face model zoo. 34 | | Dataset | Finetuned full-precision model | 35 | | --- | --- | 36 | | XSUM | [bart-base-xsum](https://huggingface.co/Aktsvigun/bart-base_xsum_42) | 37 | | CNN/DailyMail | [bart-base-cnn](https://huggingface.co/ainize/bart-base-cnn) | 38 | 39 | ### 3. Steps to run: 40 | * For XSUM benchmark, `bash scrips/run_xsum.sh $w_bit $a_bit $lr` . 41 | * For CNN/DailyMail benchmark, `bash scrips/run_cnn.sh $w_bit $a_bit $lr` . 42 | * Learning rate for each model: 43 | 44 | | | XSUM | CNN/DailyMail | 45 | | --- | --- | --- | 46 | | W2A8 | 3e-4 | 1e-4 | 47 | | W2A2 | 3.5e-4 | 7e-4 | 48 | | W1A8 | 2.5e-4 | 1.5e-4 | 49 | | W1A1 | 5e-4 | 5e-4 | 50 | 51 | ## Models 52 | 53 | | | | | | | XSUM | | | CNN| | 54 | | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | 55 | | | **#Bits** | **Size (M)** | **FLOPs (G)** | **R1** | **R2** | **RL** | **R1** | **R2** | **RL** | 56 | |BART | 32-32-32 | 532.0 | 1x | 43.84 | 20.79 | 35.71 | 44.90 | 22.25 | 42.09 | 57 | |QuantBart| 8 - 8 - 8 | 138.1 | -- | 40.25 | 17.78 | 32.70 | -- | -- | -- | 58 | |DQ-BART| 8 - 8 - 8 | 138.1 | -- | 42.51 | 19.61 | 34.61 | 44.66 | 21.92 | 41.86 | 59 | |*Ternary* | | | | | | | | | | 60 | |Baseline (TWN) | 2 - 2 - 8 | 39.6 | 0.25x | 39.99 | 17.13 | 31.99 | 42.99 | 20.05 | 40.18| 61 | |QuantBart| 2 - 2 - 8 | 39.6 | 0.25x | 39.15 | 16.72 | 31.72 | -- | -- | -- | 62 | |DQ-BART| 2 - 2 - 8 | 39.6 | 0.25x | 40.06 | 17.34 | 32.46 | 42.94 | 20.07 | 40.13 | 63 | |**TBT** | 2 - 2 - 8 | 39.6 | 0.25x | [**42.40**](https://hkustconnect-my.sharepoint.com/:f:/g/personal/zliubq_connect_ust_hk/Ej9CwkrXVVBDmecbj94IoloBA768OTzQSnQyc7U_2iabzA?e=OLtJt9) | [**19.54**](https://hkustconnect-my.sharepoint.com/:f:/g/personal/zliubq_connect_ust_hk/Ej9CwkrXVVBDmecbj94IoloBA768OTzQSnQyc7U_2iabzA?e=OLtJt9) | [**34.51**](https://hkustconnect-my.sharepoint.com/:f:/g/personal/zliubq_connect_ust_hk/Ej9CwkrXVVBDmecbj94IoloBA768OTzQSnQyc7U_2iabzA?e=OLtJt9) | [**43.46**](https://hkustconnect-my.sharepoint.com/:f:/g/personal/zliubq_connect_ust_hk/Eq0HeMLo0RNIntzXsfHY9gIBZHmVCNL1L-AWvT54IxLn7A?e=Ccr8U5) | [**20.52**](https://hkustconnect-my.sharepoint.com/:f:/g/personal/zliubq_connect_ust_hk/Eq0HeMLo0RNIntzXsfHY9gIBZHmVCNL1L-AWvT54IxLn7A?e=Ccr8U5) | [**40.58**](https://hkustconnect-my.sharepoint.com/:f:/g/personal/zliubq_connect_ust_hk/Eq0HeMLo0RNIntzXsfHY9gIBZHmVCNL1L-AWvT54IxLn7A?e=Ccr8U5) | 64 | |Baseline (TWN) | 2 - 2 - 2 | 39.6 | 0.0625x | 12.80 | 1.21 | 11.4 | 12.92 | 0.32 | 12.42| 65 | |**TBT** | 2 - 2 - 2 | 39.6 | 0.0625x | [**36.21**](https://hkustconnect-my.sharepoint.com/:f:/g/personal/zliubq_connect_ust_hk/Ery8OufgDpRIjFL2P9NBxukBHCJ34Tkth7DfhLu5BiHkXA?e=5KrmKE) | [**14.38**](https://hkustconnect-my.sharepoint.com/:f:/g/personal/zliubq_connect_ust_hk/Ery8OufgDpRIjFL2P9NBxukBHCJ34Tkth7DfhLu5BiHkXA?e=5KrmKE) | [**29.07**](https://hkustconnect-my.sharepoint.com/:f:/g/personal/zliubq_connect_ust_hk/Ery8OufgDpRIjFL2P9NBxukBHCJ34Tkth7DfhLu5BiHkXA?e=5KrmKE) | [**41.03**](https://hkustconnect-my.sharepoint.com/:f:/g/personal/zliubq_connect_ust_hk/EryiIhiloWFAjdDkiqRBeYwBm7l-DQxlXViu8Sm_FAzCSg?e=UEHUvB) | [**18.18**](https://hkustconnect-my.sharepoint.com/:f:/g/personal/zliubq_connect_ust_hk/EryiIhiloWFAjdDkiqRBeYwBm7l-DQxlXViu8Sm_FAzCSg?e=UEHUvB) | [**38.30**](https://hkustconnect-my.sharepoint.com/:f:/g/personal/zliubq_connect_ust_hk/EryiIhiloWFAjdDkiqRBeYwBm7l-DQxlXViu8Sm_FAzCSg?e=UEHUvB) | 66 | |*Binary* | | | | | | | | | | 67 | |Baseline (BWN) | 1 - 1 - 8 | 23.2 | 0.125x | 1.90 | 0.01 | 1.78 | 2.78 | 0.08 | 2.48| 68 | |**TBT** | 1 - 1 - 8 | 23.2 | 0.125x | [**40.96**](https://hkustconnect-my.sharepoint.com/:f:/g/personal/zliubq_connect_ust_hk/ErpKKcUo_RlDpIHqqQXMTU8BamD85JA0Ebtg4J5oFhTYJA?e=wtNL2b) | [**18.37**](https://hkustconnect-my.sharepoint.com/:f:/g/personal/zliubq_connect_ust_hk/ErpKKcUo_RlDpIHqqQXMTU8BamD85JA0Ebtg4J5oFhTYJA?e=wtNL2b) | [**33.30**](https://hkustconnect-my.sharepoint.com/:f:/g/personal/zliubq_connect_ust_hk/ErpKKcUo_RlDpIHqqQXMTU8BamD85JA0Ebtg4J5oFhTYJA?e=wtNL2b) | [**42.66**](https://hkustconnect-my.sharepoint.com/:f:/g/personal/zliubq_connect_ust_hk/EriYUYGE2cZAqgSM0YKY7vcBs0PmVvIyqtdnZKWXlANztQ?e=Q3aOWx) | [**19.72**](https://hkustconnect-my.sharepoint.com/:f:/g/personal/zliubq_connect_ust_hk/EriYUYGE2cZAqgSM0YKY7vcBs0PmVvIyqtdnZKWXlANztQ?e=Q3aOWx) | [**39.80**](https://hkustconnect-my.sharepoint.com/:f:/g/personal/zliubq_connect_ust_hk/EriYUYGE2cZAqgSM0YKY7vcBs0PmVvIyqtdnZKWXlANztQ?e=Q3aOWx) | 69 | |Baseline (BWN)| 1 - 1 - 1 | 23.2 | 0.0156x | 1.90 | 0.01 | 1.78 | 2.78 | 0.08 | 2.48| 70 | |**TBT** | 1 - 1 - 1 | 23.2 | 0.0156x | [**31.68**](https://hkustconnect-my.sharepoint.com/:f:/g/personal/zliubq_connect_ust_hk/ElDbaSZyLx1ItZSp2O6rzocBHsjHf_IRMT9kvWk3QIZOkQ?e=WOUFza) | [**11.19**](https://hkustconnect-my.sharepoint.com/:f:/g/personal/zliubq_connect_ust_hk/ElDbaSZyLx1ItZSp2O6rzocBHsjHf_IRMT9kvWk3QIZOkQ?e=WOUFza) | [**25.29**](https://hkustconnect-my.sharepoint.com/:f:/g/personal/zliubq_connect_ust_hk/ElDbaSZyLx1ItZSp2O6rzocBHsjHf_IRMT9kvWk3QIZOkQ?e=WOUFza) | [**35.56**](https://hkustconnect-my.sharepoint.com/:f:/g/personal/zliubq_connect_ust_hk/Epvrka2zQfJNvqXfevDA3KkBALQr_0571d-iFD8d6ezyug?e=PsgzBM) | [**11.71**](https://hkustconnect-my.sharepoint.com/:f:/g/personal/zliubq_connect_ust_hk/Epvrka2zQfJNvqXfevDA3KkBALQr_0571d-iFD8d6ezyug?e=PsgzBM) | [**33.23**](https://hkustconnect-my.sharepoint.com/:f:/g/personal/zliubq_connect_ust_hk/Epvrka2zQfJNvqXfevDA3KkBALQr_0571d-iFD8d6ezyug?e=PsgzBM) | 71 | 72 | 73 | ## Acknowledgement 74 | 75 | The original code is borrowed from [DQ-BART](https://github.com/amazon-science/dq-bart). 76 | 77 | ## Contact 78 | 79 | Zechun Liu, Reality Labs, Meta Inc (liuzechun0216 at gmail.com) 80 | 81 | ## License 82 | BiT is CC-BY-NC 4.0 licensed as of now. 83 | 84 | -------------------------------------------------------------------------------- /quant/configuration_bart_quant.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | # 8 | # 2022 - add quantization modules Amazon.com, Inc. or its affiliates 9 | # Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved. 10 | # Licensed under the Apache License, Version 2.0 (the "License"); 11 | # you may not use this file except in compliance with the License. 12 | # You may obtain a copy of the License at 13 | # 14 | # http://www.apache.org/licenses/LICENSE-2.0 15 | # 16 | # Unless required by applicable law or agreed to in writing, software 17 | # distributed under the License is distributed on an "AS IS" BASIS, 18 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 19 | # See the License for the specific language governing permissions and 20 | # limitations under the License. 21 | """ BART model configuration """ 22 | import warnings 23 | from collections import OrderedDict 24 | from typing import Mapping 25 | 26 | import transformers 27 | 28 | from transformers.configuration_utils import PretrainedConfig 29 | from transformers.onnx import OnnxConfigWithPast 30 | from transformers.utils import logging 31 | 32 | import json 33 | import sys 34 | import os 35 | import copy 36 | from io import open 37 | 38 | CONFIG_NAME = 'config.json' 39 | logger = logging.get_logger(__name__) 40 | 41 | BART_PRETRAINED_CONFIG_ARCHIVE_MAP = { 42 | "facebook/bart-large": "https://huggingface.co/facebook/bart-large/resolve/main/config.json", 43 | # See all BART models at https://huggingface.co/models?filter=bart 44 | } 45 | 46 | 47 | class BartConfig(PretrainedConfig): 48 | r""" 49 | This is the configuration class to store the configuration of a :class:`~transformers.BartModel`. It is used to 50 | instantiate a BART model according to the specified arguments, defining the model architecture. Instantiating a 51 | configuration with the defaults will yield a similar configuration to that of the BART `facebook/bart-large 52 | `__ architecture. 53 | 54 | Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model 55 | outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information. 56 | 57 | 58 | Args: 59 | vocab_size (:obj:`int`, `optional`, defaults to 50265): 60 | Vocabulary size of the BART model. Defines the number of different tokens that can be represented by the 61 | :obj:`inputs_ids` passed when calling :class:`~transformers.BartModel` or 62 | :class:`~transformers.TFBartModel`. 63 | d_model (:obj:`int`, `optional`, defaults to 1024): 64 | Dimensionality of the layers and the pooler layer. 65 | encoder_layers (:obj:`int`, `optional`, defaults to 12): 66 | Number of encoder layers. 67 | decoder_layers (:obj:`int`, `optional`, defaults to 12): 68 | Number of decoder layers. 69 | encoder_attention_heads (:obj:`int`, `optional`, defaults to 16): 70 | Number of attention heads for each attention layer in the Transformer encoder. 71 | decoder_attention_heads (:obj:`int`, `optional`, defaults to 16): 72 | Number of attention heads for each attention layer in the Transformer decoder. 73 | decoder_ffn_dim (:obj:`int`, `optional`, defaults to 4096): 74 | Dimensionality of the "intermediate" (often named feed-forward) layer in decoder. 75 | encoder_ffn_dim (:obj:`int`, `optional`, defaults to 4096): 76 | Dimensionality of the "intermediate" (often named feed-forward) layer in decoder. 77 | activation_function (:obj:`str` or :obj:`function`, `optional`, defaults to :obj:`"gelu"`): 78 | The non-linear activation function (function or string) in the encoder and pooler. If string, 79 | :obj:`"gelu"`, :obj:`"relu"`, :obj:`"silu"` and :obj:`"gelu_new"` are supported. 80 | dropout (:obj:`float`, `optional`, defaults to 0.1): 81 | The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. 82 | attention_dropout (:obj:`float`, `optional`, defaults to 0.0): 83 | The dropout ratio for the attention probabilities. 84 | activation_dropout (:obj:`float`, `optional`, defaults to 0.0): 85 | The dropout ratio for activations inside the fully connected layer. 86 | classifier_dropout (:obj:`float`, `optional`, defaults to 0.0): 87 | The dropout ratio for classifier. 88 | max_position_embeddings (:obj:`int`, `optional`, defaults to 1024): 89 | The maximum sequence length that this model might ever be used with. Typically set this to something large 90 | just in case (e.g., 512 or 1024 or 2048). 91 | init_std (:obj:`float`, `optional`, defaults to 0.02): 92 | The standard deviation of the truncated_normal_initializer for initializing all weight matrices. 93 | encoder_layerdrop: (:obj:`float`, `optional`, defaults to 0.0): 94 | The LayerDrop probability for the encoder. See the `LayerDrop paper `__ for more details. 96 | decoder_layerdrop: (:obj:`float`, `optional`, defaults to 0.0): 97 | The LayerDrop probability for the decoder. See the `LayerDrop paper `__ for more details. 99 | gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): 100 | If True, use gradient checkpointing to save memory at the expense of slower backward pass. 101 | scale_embedding (:obj:`bool`, `optional`, defaults to :obj:`False`): 102 | Scale embeddings by diving by sqrt(d_model). 103 | use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): 104 | Whether or not the model should return the last key/values attentions (not used by all models). 105 | num_labels: (:obj:`int`, `optional`, defaults to 3): 106 | The number of labels to use in :class:`~transformers.BartForSequenceClassification`. 107 | forced_eos_token_id (:obj:`int`, `optional`, defaults to 2): 108 | The id of the token to force as the last generated token when :obj:`max_length` is reached. Usually set to 109 | :obj:`eos_token_id`. 110 | 111 | quantize_act: (:obj:'bool', 'optional', defaulsts to False): 112 | Whether to do quantization or not 113 | input_bits: (:obj:`int`, `optional`, defaults to 8): 114 | the number of bits we use for quantization 115 | 116 | 117 | Example:: 118 | 119 | >>> from transformers import BartModel, BartConfig 120 | 121 | >>> # Initializing a BART facebook/bart-large style configuration 122 | >>> configuration = BartConfig() 123 | 124 | >>> # Initializing a model from the facebook/bart-large style configuration 125 | >>> model = BartModel(configuration) 126 | 127 | >>> # Accessing the model configuration 128 | >>> configuration = model.config 129 | """ 130 | model_type = "bart" 131 | keys_to_ignore_at_inference = ["past_key_values"] 132 | 133 | def __init__( 134 | self, 135 | vocab_size=50265, 136 | max_position_embeddings=1024, 137 | encoder_layers=12, 138 | encoder_ffn_dim=4096, 139 | encoder_attention_heads=16, 140 | decoder_layers=12, 141 | decoder_ffn_dim=4096, 142 | decoder_attention_heads=16, 143 | encoder_layerdrop=0.0, 144 | decoder_layerdrop=0.0, 145 | activation_function="gelu", 146 | d_model=1024, 147 | dropout=0.1, 148 | attention_dropout=0.0, 149 | activation_dropout=0.0, 150 | init_std=0.02, 151 | classifier_dropout=0.0, 152 | scale_embedding=False, 153 | gradient_checkpointing=False, 154 | use_cache=True, 155 | num_labels=3, 156 | pad_token_id=1, 157 | bos_token_id=0, 158 | eos_token_id=2, 159 | is_encoder_decoder=True, 160 | decoder_start_token_id=2, 161 | forced_eos_token_id=2, 162 | quantize_act=False, 163 | input_bits=8, 164 | weight_bits=2, 165 | clip_val=2.5, 166 | **kwargs 167 | ): 168 | super().__init__( 169 | num_labels=num_labels, 170 | pad_token_id=pad_token_id, 171 | bos_token_id=bos_token_id, 172 | eos_token_id=eos_token_id, 173 | is_encoder_decoder=is_encoder_decoder, 174 | decoder_start_token_id=decoder_start_token_id, 175 | forced_eos_token_id=forced_eos_token_id, 176 | **kwargs, 177 | ) 178 | 179 | self.vocab_size = vocab_size 180 | self.max_position_embeddings = max_position_embeddings 181 | self.d_model = d_model 182 | self.encoder_ffn_dim = encoder_ffn_dim 183 | self.encoder_layers = encoder_layers 184 | self.encoder_attention_heads = encoder_attention_heads 185 | self.decoder_ffn_dim = decoder_ffn_dim 186 | self.decoder_layers = decoder_layers 187 | self.decoder_attention_heads = decoder_attention_heads 188 | self.dropout = dropout 189 | self.attention_dropout = attention_dropout 190 | self.activation_dropout = activation_dropout 191 | self.activation_function = activation_function 192 | self.init_std = init_std 193 | self.encoder_layerdrop = encoder_layerdrop 194 | self.decoder_layerdrop = decoder_layerdrop 195 | self.classifier_dropout = classifier_dropout 196 | self.use_cache = use_cache 197 | self.num_hidden_layers = encoder_layers 198 | self.gradient_checkpointing = gradient_checkpointing 199 | self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True 200 | 201 | self.quantize_act = quantize_act 202 | self.input_bits = input_bits 203 | self.weight_bits = weight_bits 204 | self.clip_val = clip_val 205 | 206 | # ensure backward compatibility for BART CNN models 207 | if self.forced_bos_token_id is None and kwargs.get("force_bos_token_to_be_generated", False): 208 | self.forced_bos_token_id = self.bos_token_id 209 | warnings.warn( 210 | f"Please make sure the config includes `forced_bos_token_id={self.bos_token_id}` in future versions." 211 | "The config can simply be saved and uploaded again to be fixed." 212 | ) 213 | 214 | @property 215 | def num_attention_heads(self) -> int: 216 | return self.encoder_attention_heads 217 | 218 | @property 219 | def hidden_size(self) -> int: 220 | return self.d_model 221 | 222 | @classmethod 223 | def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): 224 | config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME) 225 | logger.info("loading configuration file {}".format(config_file)) 226 | # Load config 227 | config = cls.from_json_file(config_file) 228 | 229 | # Update config with kwargs if needed 230 | to_remove = [] 231 | for key, value in kwargs.items(): 232 | setattr(config, key, value) 233 | to_remove.append(key) 234 | for key in to_remove: 235 | kwargs.pop(key, None) 236 | 237 | logger.info("Model config %s", str(config)) 238 | return config 239 | 240 | @classmethod 241 | def from_dict(cls, json_object): 242 | """Constructs a `Config` from a Python dictionary of parameters.""" 243 | config = cls(vocab_size_or_config_json_file=-1) 244 | for key, value in json_object.items(): 245 | setattr(config, key, value) 246 | return config 247 | 248 | @classmethod 249 | def from_json_file(cls, json_file): 250 | """Constructs a `BartConfig` from a json file of parameters.""" 251 | with open(json_file, "r", encoding='utf-8') as reader: 252 | text = reader.read() 253 | return cls.from_dict(json.loads(text)) 254 | 255 | def __eq__(self, other): 256 | return self.__dict__ == other.__dict__ 257 | 258 | def __repr__(self): 259 | return str(self.to_json_string()) 260 | 261 | def to_dict(self): 262 | """Serializes this instance to a Python dictionary.""" 263 | output = copy.deepcopy(self.__dict__) 264 | return output 265 | 266 | def to_json_string(self): 267 | """Serializes this instance to a JSON string.""" 268 | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" 269 | 270 | def to_json_file(self, json_file_path): 271 | """ Save this instance to a json file.""" 272 | with open(json_file_path, "w", encoding='utf-8') as writer: 273 | writer.write(self.to_json_string()) 274 | 275 | 276 | class BartOnnxConfig(OnnxConfigWithPast): 277 | @property 278 | def inputs(self) -> Mapping[str, Mapping[int, str]]: 279 | return OrderedDict( 280 | [ 281 | ("input_ids", {0: "batch", 1: "sequence"}), 282 | ("attention_mask", {0: "batch", 1: "sequence"}), 283 | ] 284 | ) 285 | 286 | @property 287 | def outputs(self) -> Mapping[str, Mapping[int, str]]: 288 | if self.use_past: 289 | return OrderedDict( 290 | [ 291 | ("last_hidden_state", {0: "batch", 1: "sequence"}), 292 | ("past_keys", {0: "batch", 2: "sequence"}), 293 | ("encoder_last_hidden_state", {0: "batch", 1: "sequence"}), 294 | ] 295 | ) 296 | else: 297 | return OrderedDict( 298 | [ 299 | ("last_hidden_state", {0: "batch", 1: "sequence"}), 300 | ("encoder_last_hidden_state", {0: "batch", 1: "sequence"}), 301 | ] 302 | ) 303 | 304 | -------------------------------------------------------------------------------- /quant/utils_quant.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import torch 9 | import torch.nn as nn 10 | import math 11 | 12 | class ElasticQuantBinarizerSigned(torch.autograd.Function): 13 | """ 14 | Modified from Learned Step-size Quantization. 15 | https://arxiv.org/abs/1902.08153 16 | """ 17 | @staticmethod 18 | def forward(ctx, input, alpha, num_bits, layerwise): 19 | """ 20 | :param input: input to be quantized 21 | :param alpha: the step size 22 | :param num_bits: quantization bits 23 | :param layerwise: rowwise quant 24 | :return: quantized output 25 | """ 26 | if not layerwise: 27 | # TODO 28 | raise NotImplementedError 29 | ctx.num_bits = num_bits 30 | if num_bits == 32: 31 | return input 32 | if num_bits == 1 or num_bits == 2: 33 | Qn = -1 34 | Qp = 1 35 | else: 36 | Qn = -2 ** (num_bits - 1) 37 | Qp = 2 ** (num_bits - 1) - 1 38 | 39 | eps = torch.tensor(0.00001).float().to(alpha.device) 40 | if alpha.item() == 1.0 and (not alpha.initialized): 41 | alpha.initialize_wrapper(input, num_bits, symmetric=True, init_method='default') 42 | alpha = torch.where(alpha > eps, alpha, eps) 43 | assert alpha > 0, 'alpha = {:.6f} becomes non-positive'.format(alpha) 44 | 45 | grad_scale = 1.0 / math.sqrt(input.numel()) if not Qp else 1.0 / math.sqrt(input.numel() * Qp) 46 | ctx.save_for_backward(input, alpha) 47 | ctx.other = grad_scale, Qn, Qp 48 | if num_bits == 1: 49 | q_w = input.sign() 50 | else: 51 | q_w = (input / alpha).round().clamp(Qn, Qp) 52 | w_q = q_w * alpha 53 | return w_q 54 | 55 | @staticmethod 56 | def backward(ctx, grad_output): 57 | if ctx.num_bits == 32: 58 | return grad_output, None, None, None 59 | 60 | input_, alpha = ctx.saved_tensors 61 | grad_scale, Qn, Qp = ctx.other 62 | q_w = input_ / alpha 63 | indicate_small = (q_w < Qn).float() 64 | indicate_big = (q_w > Qp).float() 65 | indicate_middle = 1.0 - indicate_small - indicate_big # this is more cpu-friendly than torch.ones(input_.shape) 66 | if ctx.num_bits == 1: 67 | grad_alpha = ((input_.sign()) * grad_output * grad_scale).sum().unsqueeze(dim=0) 68 | else: 69 | grad_alpha = ((indicate_small * Qn + indicate_big * Qp + indicate_middle * ( 70 | -q_w + q_w.round())) * grad_output * grad_scale).sum().unsqueeze(dim=0) 71 | grad_input = indicate_middle * grad_output 72 | return grad_input, grad_alpha, None, None 73 | 74 | 75 | class ElasticQuantBinarizerUnsigned(torch.autograd.Function): 76 | """ 77 | Modified from Learned Step-size Quantization. 78 | https://arxiv.org/abs/1902.08153 79 | """ 80 | @staticmethod 81 | def forward(ctx, input, alpha, num_bits, layerwise): 82 | """ 83 | :param input: input to be quantized 84 | :param alpha: the step size 85 | :param num_bits: quantization bits 86 | :param layerwise: rowwise quant 87 | :return: quantized output 88 | """ 89 | if not layerwise: 90 | # TODO 91 | raise NotImplementedError 92 | ctx.num_bits = num_bits 93 | if num_bits == 32: 94 | return input 95 | Qn = 0 96 | if num_bits == 2: 97 | Qp = 2 98 | else: 99 | Qp = 2 ** (num_bits) - 1 100 | 101 | if num_bits == 1: 102 | input_ = input 103 | else: 104 | min_val = input.min().item() 105 | input_ = input - min_val 106 | 107 | eps = torch.tensor(0.00001).float().to(alpha.device) 108 | if alpha.item() == 1.0 and (not alpha.initialized): 109 | alpha.initialize_wrapper(input, num_bits, symmetric=False, init_method='default') 110 | alpha = torch.where(alpha > eps, alpha, eps) 111 | assert alpha > 0, 'alpha = {:.6f} becomes non-positive'.format(alpha) 112 | 113 | grad_scale = 1.0 / math.sqrt(input.numel() * Qp) 114 | ctx.save_for_backward(input_, alpha) 115 | ctx.other = grad_scale, Qn, Qp 116 | q_w = (input_ / alpha).round().clamp(Qn, Qp) 117 | w_q = q_w * alpha 118 | if num_bits != 1: 119 | w_q = w_q + min_val 120 | 121 | return w_q 122 | 123 | @staticmethod 124 | def backward(ctx, grad_output): 125 | if ctx.num_bits == 32: 126 | return grad_output, None, None, None 127 | 128 | input_, alpha = ctx.saved_tensors 129 | grad_scale, Qn, Qp = ctx.other 130 | q_w = input_ / alpha 131 | indicate_small = (q_w < Qn).float() 132 | indicate_big = (q_w > Qp).float() 133 | indicate_middle = 1.0 - indicate_small - indicate_big # this is more cpu-friendly than torch.ones(input_.shape) 134 | grad_alpha = ((indicate_small * Qn + indicate_big * Qp + indicate_middle * ( 135 | -q_w + q_w.round())) * grad_output * grad_scale).sum().unsqueeze(dim=0) 136 | grad_input = indicate_middle * grad_output 137 | return grad_input, grad_alpha, None, None 138 | 139 | class AlphaInit(nn.Parameter): 140 | def __init__(self, tensor): 141 | super(AlphaInit, self).__new__(nn.Parameter, data=tensor) 142 | self.initialized = False 143 | 144 | def _initialize(self, init_tensor): 145 | assert not self.initialized, 'already initialized.' 146 | self.data.copy_(init_tensor) 147 | self.initialized = True 148 | 149 | def initialize_wrapper(self, tensor, num_bits, symmetric, init_method='default'): 150 | Qp = 2 ** (num_bits - 1) - 1 if symmetric else 2 ** (num_bits) - 1 151 | if Qp == 0: 152 | Qp = 1.0 153 | if init_method == 'default': 154 | init_val = 2 * tensor.abs().mean() / math.sqrt(Qp) if symmetric \ 155 | else 4 * tensor.abs().mean() / math.sqrt(Qp) 156 | elif init_method == 'uniform': 157 | init_val = 1./(2*Qp+1) if symmetric else 1./Qp 158 | 159 | self._initialize(init_val) 160 | 161 | class SymQuantizer(torch.autograd.Function): 162 | """ 163 | uniform quantization 164 | """ 165 | @staticmethod 166 | def forward(ctx, input, clip_val, num_bits, layerwise): 167 | """ 168 | :param ctx: 169 | :param input: tensor to be quantized 170 | :param clip_val: clip the tensor before quantization 171 | :param quant_bits: number of bits 172 | :return: quantized tensor 173 | """ 174 | ctx.save_for_backward(input, clip_val) 175 | input = torch.clamp(input, clip_val[0], clip_val[1]) 176 | #input = torch.where(input < clip_val[1], input, clip_val[1]) 177 | #input = torch.where(input > clip_val[0], input, clip_val[0]) 178 | # NOTE: dynamic scaling (max_input). 179 | if layerwise: 180 | max_input = torch.max(torch.abs(input)).expand_as(input) 181 | else: 182 | if input.ndimension() <= 3: 183 | # weight & hidden layer 184 | max_input = torch.max(torch.abs(input), dim=-1, keepdim=True)[0].expand_as(input).detach() 185 | elif input.ndimension() == 4: 186 | # TODO: attention score matrix, calculate alpha / beta per head 187 | tmp = input.view(input.shape[0], input.shape[1], -1) 188 | max_input = torch.max(torch.abs(tmp), dim=-1, keepdim=True)[0].unsqueeze(-1).expand_as(input).detach() 189 | else: 190 | raise ValueError 191 | s = (2 ** (num_bits - 1) - 1) / max_input 192 | output = torch.round(input * s).div(s) 193 | 194 | return output 195 | 196 | @staticmethod 197 | def backward(ctx, grad_output): 198 | """ 199 | :param ctx: saved non-clipped full-precision tensor and clip_val 200 | :param grad_output: gradient ert the quantized tensor 201 | :return: estimated gradient wrt the full-precision tensor 202 | """ 203 | input, clip_val = ctx.saved_tensors # unclipped input 204 | grad_input = grad_output.clone() 205 | grad_input[input.ge(clip_val[1])] = 0 206 | grad_input[input.le(clip_val[0])] = 0 207 | return grad_input, None, None, None 208 | 209 | 210 | class AsymQuantizer(torch.autograd.Function): 211 | """ 212 | min-max quantization 213 | """ 214 | @staticmethod 215 | def forward(ctx, input, clip_val, num_bits, layerwise): 216 | """ 217 | :param ctx: 218 | :param input: tensor to be quantized 219 | :param clip_val: clip the tensor before quantization 220 | :param quant_bits: number of bits 221 | :return: quantized tensor 222 | """ 223 | ctx.save_for_backward(input, clip_val) 224 | 225 | input = torch.where(input < clip_val[1], input, clip_val[1]) 226 | input = torch.where(input > clip_val[0], input, clip_val[0]) 227 | # input = torch.clamp(input, clip_val[0], clip_val[1]) 228 | # NOTE: dynamic scaling gives better performance than static 229 | if layerwise: 230 | alpha = (input.max() - input.min()).detach() 231 | beta = input.min().detach() 232 | else: 233 | if input.ndimension() <= 3: 234 | # weight & hidden layer 235 | alpha = (input.max(dim=-1, keepdim=True)[0] - input.min(dim=-1, keepdim=True)[0]).expand_as(input).detach() 236 | beta = input.min(dim=-1, keepdim=True)[0].expand_as(input).detach() 237 | elif input.ndimension() == 4: 238 | # TODO: attention score matrix, calculate alpha / beta per head 239 | tmp = input.view(input.shape[0], input.shape[1], -1) 240 | alpha = (tmp.max(dim=-1, keepdim=True)[0].unsqueeze(-1) - \ 241 | tmp.min(dim=-1, keepdim=True)[0].unsqueeze(-1)).expand_as(input).detach() 242 | beta = tmp.min(dim=-1, keepdim=True)[0].unsqueeze(-1).expand_as(input).detach() 243 | else: 244 | raise ValueError 245 | input_normalized = (input - beta) / (alpha + 1e-8) 246 | s = (2**num_bits - 1) 247 | quant_input = torch.round(input_normalized * s).div(s) 248 | output = quant_input * (alpha + 1e-8) + beta 249 | 250 | 251 | return output 252 | 253 | @staticmethod 254 | def backward(ctx, grad_output): 255 | """ 256 | :param ctx: saved non-clipped full-precision tensor and clip_val 257 | :param grad_output: gradient ert the quantized tensor 258 | :return: estimated gradient wrt the full-precision tensor 259 | """ 260 | input, clip_val = ctx.saved_tensors # unclipped input 261 | grad_input = grad_output.clone() 262 | grad_input[input.ge(clip_val[1])] = 0 263 | grad_input[input.le(clip_val[0])] = 0 264 | return grad_input, None, None, None 265 | 266 | 267 | class TwnQuantizer(torch.autograd.Function): 268 | """Ternary Weight Networks (TWN) 269 | Ref: https://arxiv.org/abs/1605.04711 270 | """ 271 | 272 | @staticmethod 273 | def forward(ctx, input, clip_val, num_bits, layerwise): 274 | """ 275 | :param input: tensor to be ternarized 276 | :return: quantized tensor 277 | """ 278 | ctx.save_for_backward(input, clip_val) 279 | input = torch.where(input < clip_val[1], input, clip_val[1]) 280 | input = torch.where(input > clip_val[0], input, clip_val[0]) 281 | if layerwise: 282 | m = input.norm(p=1).div(input.nelement()) 283 | thres = 0.7 * m 284 | pos = (input > thres).float() 285 | neg = (input < -thres).float() 286 | mask = (input.abs() > thres).float() 287 | alpha = (mask * input).abs().sum() / mask.sum() 288 | result = alpha * pos - alpha * neg 289 | else: # row-wise only for embed / weight 290 | n = input[0].nelement() 291 | m = input.data.norm(p=1, dim=1).div(n) 292 | thres = (0.7 * m).view(-1, 1).expand_as(input) 293 | pos = (input > thres).float() 294 | neg = (input < -thres).float() 295 | mask = (input.abs() > thres).float() 296 | alpha = ((mask * input).abs().sum(dim=1) / mask.sum(dim=1)).view(-1, 1) 297 | result = alpha * pos - alpha * neg 298 | 299 | return result 300 | 301 | @staticmethod 302 | def backward(ctx, grad_output): 303 | """ 304 | :param ctx: saved non-clipped full-precision tensor and clip_val 305 | :param grad_output: gradient ert the quantized tensor 306 | :return: estimated gradient wrt the full-precision tensor 307 | """ 308 | input, clip_val = ctx.saved_tensors # unclipped input 309 | grad_input = grad_output.clone() 310 | grad_input[input.ge(clip_val[1])] = 0 311 | grad_input[input.le(clip_val[0])] = 0 312 | return grad_input, None, None, None 313 | 314 | class BwnQuantizer(torch.autograd.Function): 315 | """Binary Weight Network (BWN) 316 | Ref: https://arxiv.org/abs/1603.05279 317 | """ 318 | 319 | @staticmethod 320 | def forward(ctx, input, clip_val, num_bits, layerwise): 321 | """ 322 | :param input: tensor to be binarized 323 | :return: quantized tensor 324 | """ 325 | ctx.save_for_backward(input) 326 | if layerwise: 327 | s = input.size() 328 | m = input.norm(p=1).div(input.nelement()) 329 | e = input.mean() 330 | result = (input-e).sign().mul(m.expand(s)) 331 | else: 332 | n = input[0].nelement() # W of size axb, return a vector of ax1 333 | s = input.size() 334 | m = input.norm(1, 1, keepdim=True).div(n) 335 | e = input.mean() 336 | result = (input-e).sign().mul(m.expand(s)) 337 | 338 | return result 339 | 340 | @staticmethod 341 | def backward(ctx, grad_output): 342 | """ 343 | :param ctx: saved non-clipped full-precision tensor and clip_val 344 | :param grad_output: gradient ert the quantized tensor 345 | :return: estimated gradient wrt the full-precision tensor 346 | """ 347 | grad_input = grad_output.clone() 348 | return grad_input, None, None, None 349 | 350 | class QuantizeLinear(nn.Linear): 351 | 352 | def __init__(self, *kargs, symmetric=True, bias=True, config=None): 353 | super(QuantizeLinear, self).__init__(*kargs,bias=True) 354 | self.weight_bits = config.weight_bits 355 | self.quantize_act = config.quantize_act 356 | #params for weight quant 357 | self.register_buffer('weight_clip_val', torch.tensor([config.clip_val])) 358 | if self.quantize_act: 359 | self.input_bits = config.input_bits 360 | if self.input_bits <= 2 and symmetric: 361 | self.act_clip_val = AlphaInit(torch.tensor(1.0)) 362 | self.act_quantizer = ElasticQuantBinarizerSigned 363 | elif self.input_bits <= 2 and not symmetric: 364 | self.act_clip_val = AlphaInit(torch.tensor(1.0)) 365 | self.act_quantizer = ElasticQuantBinarizerUnsigned 366 | elif self.input_bits == 8 and symmetric: 367 | self.register_buffer('act_clip_val', torch.tensor([-config.clip_val, config.clip_val])) 368 | self.act_quantizer = SymQuantizer 369 | elif self.input_bits == 8 and not symmetric: 370 | self.register_buffer('act_clip_val', torch.tensor([-config.clip_val, config.clip_val])) 371 | self.act_quantizer = AsymQuantizer 372 | else: 373 | raise NotImplementedError 374 | 375 | def forward(self, input): 376 | # quantize weight 377 | assert len(self.weight.size()) == 2 378 | real_weights = self.weight 379 | if self.weight_bits == 1: 380 | scaling_factor = torch.mean(abs(real_weights), dim=1, keepdim=True).detach() 381 | quan_weights_no_grad = scaling_factor * (torch.sign(real_weights/scaling_factor)) 382 | elif self.weight_bits == 2: 383 | scaling_factor = 4/3 * torch.mean(abs(real_weights), dim=1, keepdim=True).detach() 384 | quan_weights_no_grad = scaling_factor * (torch.round(torch.clamp(real_weights/scaling_factor, -1, 1))) 385 | else: 386 | raise NotImplementedError 387 | 388 | weight = quan_weights_no_grad.detach() - real_weights.detach() + real_weights 389 | # quantize input 390 | input = self.act_quantizer.apply(input, self.act_clip_val, self.input_bits, True) 391 | 392 | out = nn.functional.linear(input, weight) 393 | if not self.bias is None: 394 | out += self.bias.view(1, -1).expand_as(out) 395 | 396 | return out 397 | 398 | 399 | class QuantizeEmbedding(nn.Embedding): 400 | 401 | def __init__(self, *kargs,padding_idx=None, config = None): 402 | print('init quantize emb') 403 | super(QuantizeEmbedding, self).__init__(*kargs, padding_idx = padding_idx) 404 | self.weight_bits = config.weight_bits 405 | self.layerwise = False 406 | self.register_buffer('weight_clip_val', torch.tensor([-config.clip_val, config.clip_val])) 407 | 408 | def forward(self, input): 409 | assert len(self.weight.size()) == 2 410 | real_weights = self.weight 411 | if self.weight_bits == 1: 412 | scaling_factor = torch.mean(abs(real_weights), dim=1, keepdim=True).detach() 413 | quan_weights_no_grad = scaling_factor * (torch.sign(real_weights/scaling_factor)) 414 | elif self.weight_bits == 2: 415 | scaling_factor = 4/3 * torch.mean(abs(real_weights), dim=1, keepdim=True).detach() 416 | quan_weights_no_grad = scaling_factor * (torch.round(torch.clamp(real_weights/scaling_factor, -1, 1))) 417 | else: 418 | raise NotImplementedError 419 | 420 | weight = quan_weights_no_grad.detach() - real_weights.detach() + real_weights 421 | 422 | out = nn.functional.embedding( 423 | input, weight, self.padding_idx, self.max_norm, 424 | self.norm_type, self.scale_grad_by_freq, self.sparse) 425 | return out 426 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Attribution-NonCommercial 4.0 International 2 | 3 | ======================================================================= 4 | 5 | Creative Commons Corporation ("Creative Commons") is not a law firm and 6 | does not provide legal services or legal advice. Distribution of 7 | Creative Commons public licenses does not create a lawyer-client or 8 | other relationship. Creative Commons makes its licenses and related 9 | information available on an "as-is" basis. Creative Commons gives no 10 | warranties regarding its licenses, any material licensed under their 11 | terms and conditions, or any related information. Creative Commons 12 | disclaims all liability for damages resulting from their use to the 13 | fullest extent possible. 14 | 15 | Using Creative Commons Public Licenses 16 | 17 | Creative Commons public licenses provide a standard set of terms and 18 | conditions that creators and other rights holders may use to share 19 | original works of authorship and other material subject to copyright 20 | and certain other rights specified in the public license below. The 21 | following considerations are for informational purposes only, are not 22 | exhaustive, and do not form part of our licenses. 23 | 24 | Considerations for licensors: Our public licenses are 25 | intended for use by those authorized to give the public 26 | permission to use material in ways otherwise restricted by 27 | copyright and certain other rights. Our licenses are 28 | irrevocable. Licensors should read and understand the terms 29 | and conditions of the license they choose before applying it. 30 | Licensors should also secure all rights necessary before 31 | applying our licenses so that the public can reuse the 32 | material as expected. Licensors should clearly mark any 33 | material not subject to the license. This includes other CC- 34 | licensed material, or material used under an exception or 35 | limitation to copyright. More considerations for licensors: 36 | wiki.creativecommons.org/Considerations_for_licensors 37 | 38 | Considerations for the public: By using one of our public 39 | licenses, a licensor grants the public permission to use the 40 | licensed material under specified terms and conditions. If 41 | the licensor's permission is not necessary for any reason--for 42 | example, because of any applicable exception or limitation to 43 | copyright--then that use is not regulated by the license. Our 44 | licenses grant only permissions under copyright and certain 45 | other rights that a licensor has authority to grant. Use of 46 | the licensed material may still be restricted for other 47 | reasons, including because others have copyright or other 48 | rights in the material. A licensor may make special requests, 49 | such as asking that all changes be marked or described. 50 | Although not required by our licenses, you are encouraged to 51 | respect those requests where reasonable. More_considerations 52 | for the public: 53 | wiki.creativecommons.org/Considerations_for_licensees 54 | 55 | ======================================================================= 56 | 57 | Creative Commons Attribution-NonCommercial 4.0 International Public 58 | License 59 | 60 | By exercising the Licensed Rights (defined below), You accept and agree 61 | to be bound by the terms and conditions of this Creative Commons 62 | Attribution-NonCommercial 4.0 International Public License ("Public 63 | License"). To the extent this Public License may be interpreted as a 64 | contract, You are granted the Licensed Rights in consideration of Your 65 | acceptance of these terms and conditions, and the Licensor grants You 66 | such rights in consideration of benefits the Licensor receives from 67 | making the Licensed Material available under these terms and 68 | conditions. 69 | 70 | Section 1 -- Definitions. 71 | 72 | a. Adapted Material means material subject to Copyright and Similar 73 | Rights that is derived from or based upon the Licensed Material 74 | and in which the Licensed Material is translated, altered, 75 | arranged, transformed, or otherwise modified in a manner requiring 76 | permission under the Copyright and Similar Rights held by the 77 | Licensor. For purposes of this Public License, where the Licensed 78 | Material is a musical work, performance, or sound recording, 79 | Adapted Material is always produced where the Licensed Material is 80 | synched in timed relation with a moving image. 81 | 82 | b. Adapter's License means the license You apply to Your Copyright 83 | and Similar Rights in Your contributions to Adapted Material in 84 | accordance with the terms and conditions of this Public License. 85 | 86 | c. Copyright and Similar Rights means copyright and/or similar rights 87 | closely related to copyright including, without limitation, 88 | performance, broadcast, sound recording, and Sui Generis Database 89 | Rights, without regard to how the rights are labeled or 90 | categorized. For purposes of this Public License, the rights 91 | specified in Section 2(b)(1)-(2) are not Copyright and Similar 92 | Rights. 93 | d. Effective Technological Measures means those measures that, in the 94 | absence of proper authority, may not be circumvented under laws 95 | fulfilling obligations under Article 11 of the WIPO Copyright 96 | Treaty adopted on December 20, 1996, and/or similar international 97 | agreements. 98 | 99 | e. Exceptions and Limitations means fair use, fair dealing, and/or 100 | any other exception or limitation to Copyright and Similar Rights 101 | that applies to Your use of the Licensed Material. 102 | 103 | f. Licensed Material means the artistic or literary work, database, 104 | or other material to which the Licensor applied this Public 105 | License. 106 | 107 | g. Licensed Rights means the rights granted to You subject to the 108 | terms and conditions of this Public License, which are limited to 109 | all Copyright and Similar Rights that apply to Your use of the 110 | Licensed Material and that the Licensor has authority to license. 111 | 112 | h. Licensor means the individual(s) or entity(ies) granting rights 113 | under this Public License. 114 | 115 | i. NonCommercial means not primarily intended for or directed towards 116 | commercial advantage or monetary compensation. For purposes of 117 | this Public License, the exchange of the Licensed Material for 118 | other material subject to Copyright and Similar Rights by digital 119 | file-sharing or similar means is NonCommercial provided there is 120 | no payment of monetary compensation in connection with the 121 | exchange. 122 | 123 | j. Share means to provide material to the public by any means or 124 | process that requires permission under the Licensed Rights, such 125 | as reproduction, public display, public performance, distribution, 126 | dissemination, communication, or importation, and to make material 127 | available to the public including in ways that members of the 128 | public may access the material from a place and at a time 129 | individually chosen by them. 130 | 131 | k. Sui Generis Database Rights means rights other than copyright 132 | resulting from Directive 96/9/EC of the European Parliament and of 133 | the Council of 11 March 1996 on the legal protection of databases, 134 | as amended and/or succeeded, as well as other essentially 135 | equivalent rights anywhere in the world. 136 | 137 | l. You means the individual or entity exercising the Licensed Rights 138 | under this Public License. Your has a corresponding meaning. 139 | 140 | Section 2 -- Scope. 141 | 142 | a. License grant. 143 | 144 | 1. Subject to the terms and conditions of this Public License, 145 | the Licensor hereby grants You a worldwide, royalty-free, 146 | non-sublicensable, non-exclusive, irrevocable license to 147 | exercise the Licensed Rights in the Licensed Material to: 148 | 149 | a. reproduce and Share the Licensed Material, in whole or 150 | in part, for NonCommercial purposes only; and 151 | 152 | b. produce, reproduce, and Share Adapted Material for 153 | NonCommercial purposes only. 154 | 155 | 2. Exceptions and Limitations. For the avoidance of doubt, where 156 | Exceptions and Limitations apply to Your use, this Public 157 | License does not apply, and You do not need to comply with 158 | its terms and conditions. 159 | 160 | 3. Term. The term of this Public License is specified in Section 161 | 6(a). 162 | 163 | 4. Media and formats; technical modifications allowed. The 164 | Licensor authorizes You to exercise the Licensed Rights in 165 | all media and formats whether now known or hereafter created, 166 | and to make technical modifications necessary to do so. The 167 | Licensor waives and/or agrees not to assert any right or 168 | authority to forbid You from making technical modifications 169 | necessary to exercise the Licensed Rights, including 170 | technical modifications necessary to circumvent Effective 171 | Technological Measures. For purposes of this Public License, 172 | simply making modifications authorized by this Section 2(a) 173 | (4) never produces Adapted Material. 174 | 175 | 5. Downstream recipients. 176 | 177 | a. Offer from the Licensor -- Licensed Material. Every 178 | recipient of the Licensed Material automatically 179 | receives an offer from the Licensor to exercise the 180 | Licensed Rights under the terms and conditions of this 181 | Public License. 182 | 183 | b. No downstream restrictions. You may not offer or impose 184 | any additional or different terms or conditions on, or 185 | apply any Effective Technological Measures to, the 186 | Licensed Material if doing so restricts exercise of the 187 | Licensed Rights by any recipient of the Licensed 188 | Material. 189 | 190 | 6. No endorsement. Nothing in this Public License constitutes or 191 | may be construed as permission to assert or imply that You 192 | are, or that Your use of the Licensed Material is, connected 193 | with, or sponsored, endorsed, or granted official status by, 194 | the Licensor or others designated to receive attribution as 195 | provided in Section 3(a)(1)(A)(i). 196 | 197 | b. Other rights. 198 | 199 | 1. Moral rights, such as the right of integrity, are not 200 | licensed under this Public License, nor are publicity, 201 | privacy, and/or other similar personality rights; however, to 202 | the extent possible, the Licensor waives and/or agrees not to 203 | assert any such rights held by the Licensor to the limited 204 | extent necessary to allow You to exercise the Licensed 205 | Rights, but not otherwise. 206 | 207 | 2. Patent and trademark rights are not licensed under this 208 | Public License. 209 | 210 | 3. To the extent possible, the Licensor waives any right to 211 | collect royalties from You for the exercise of the Licensed 212 | Rights, whether directly or through a collecting society 213 | under any voluntary or waivable statutory or compulsory 214 | licensing scheme. In all other cases the Licensor expressly 215 | reserves any right to collect such royalties, including when 216 | the Licensed Material is used other than for NonCommercial 217 | purposes. 218 | 219 | Section 3 -- License Conditions. 220 | 221 | Your exercise of the Licensed Rights is expressly made subject to the 222 | following conditions. 223 | 224 | a. Attribution. 225 | 226 | 1. If You Share the Licensed Material (including in modified 227 | form), You must: 228 | 229 | a. retain the following if it is supplied by the Licensor 230 | with the Licensed Material: 231 | 232 | i. identification of the creator(s) of the Licensed 233 | Material and any others designated to receive 234 | attribution, in any reasonable manner requested by 235 | the Licensor (including by pseudonym if 236 | designated); 237 | 238 | ii. a copyright notice; 239 | 240 | iii. a notice that refers to this Public License; 241 | 242 | iv. a notice that refers to the disclaimer of 243 | warranties; 244 | 245 | v. a URI or hyperlink to the Licensed Material to the 246 | extent reasonably practicable; 247 | 248 | b. indicate if You modified the Licensed Material and 249 | retain an indication of any previous modifications; and 250 | 251 | c. indicate the Licensed Material is licensed under this 252 | Public License, and include the text of, or the URI or 253 | hyperlink to, this Public License. 254 | 255 | 2. You may satisfy the conditions in Section 3(a)(1) in any 256 | reasonable manner based on the medium, means, and context in 257 | which You Share the Licensed Material. For example, it may be 258 | reasonable to satisfy the conditions by providing a URI or 259 | hyperlink to a resource that includes the required 260 | information. 261 | 262 | 3. If requested by the Licensor, You must remove any of the 263 | information required by Section 3(a)(1)(A) to the extent 264 | reasonably practicable. 265 | 266 | 4. If You Share Adapted Material You produce, the Adapter's 267 | License You apply must not prevent recipients of the Adapted 268 | Material from complying with this Public License. 269 | 270 | Section 4 -- Sui Generis Database Rights. 271 | 272 | Where the Licensed Rights include Sui Generis Database Rights that 273 | apply to Your use of the Licensed Material: 274 | 275 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right 276 | to extract, reuse, reproduce, and Share all or a substantial 277 | portion of the contents of the database for NonCommercial purposes 278 | only; 279 | 280 | b. if You include all or a substantial portion of the database 281 | contents in a database in which You have Sui Generis Database 282 | Rights, then the database in which You have Sui Generis Database 283 | Rights (but not its individual contents) is Adapted Material; and 284 | 285 | c. You must comply with the conditions in Section 3(a) if You Share 286 | all or a substantial portion of the contents of the database. 287 | 288 | For the avoidance of doubt, this Section 4 supplements and does not 289 | replace Your obligations under this Public License where the Licensed 290 | Rights include other Copyright and Similar Rights. 291 | 292 | Section 5 -- Disclaimer of Warranties and Limitation of Liability. 293 | 294 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE 295 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS 296 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF 297 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, 298 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, 299 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR 300 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, 301 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT 302 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT 303 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. 304 | 305 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE 306 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, 307 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, 308 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, 309 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR 310 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN 311 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR 312 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR 313 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. 314 | 315 | c. The disclaimer of warranties and limitation of liability provided 316 | above shall be interpreted in a manner that, to the extent 317 | possible, most closely approximates an absolute disclaimer and 318 | waiver of all liability. 319 | 320 | Section 6 -- Term and Termination. 321 | 322 | a. This Public License applies for the term of the Copyright and 323 | Similar Rights licensed here. However, if You fail to comply with 324 | this Public License, then Your rights under this Public License 325 | terminate automatically. 326 | 327 | b. Where Your right to use the Licensed Material has terminated under 328 | Section 6(a), it reinstates: 329 | 330 | 1. automatically as of the date the violation is cured, provided 331 | it is cured within 30 days of Your discovery of the 332 | violation; or 333 | 334 | 2. upon express reinstatement by the Licensor. 335 | 336 | For the avoidance of doubt, this Section 6(b) does not affect any 337 | right the Licensor may have to seek remedies for Your violations 338 | of this Public License. 339 | 340 | c. For the avoidance of doubt, the Licensor may also offer the 341 | Licensed Material under separate terms or conditions or stop 342 | distributing the Licensed Material at any time; however, doing so 343 | will not terminate this Public License. 344 | 345 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 346 | License. 347 | 348 | Section 7 -- Other Terms and Conditions. 349 | 350 | a. The Licensor shall not be bound by any additional or different 351 | terms or conditions communicated by You unless expressly agreed. 352 | 353 | b. Any arrangements, understandings, or agreements regarding the 354 | Licensed Material not stated herein are separate from and 355 | independent of the terms and conditions of this Public License. 356 | 357 | Section 8 -- Interpretation. 358 | 359 | a. For the avoidance of doubt, this Public License does not, and 360 | shall not be interpreted to, reduce, limit, restrict, or impose 361 | conditions on any use of the Licensed Material that could lawfully 362 | be made without permission under this Public License. 363 | 364 | b. To the extent possible, if any provision of this Public License is 365 | deemed unenforceable, it shall be automatically reformed to the 366 | minimum extent necessary to make it enforceable. If the provision 367 | cannot be reformed, it shall be severed from this Public License 368 | without affecting the enforceability of the remaining terms and 369 | conditions. 370 | 371 | c. No term or condition of this Public License will be waived and no 372 | failure to comply consented to unless expressly agreed to by the 373 | Licensor. 374 | 375 | d. Nothing in this Public License constitutes or may be interpreted 376 | as a limitation upon, or waiver of, any privileges and immunities 377 | that apply to the Licensor or You, including from the legal 378 | processes of any jurisdiction or authority. 379 | 380 | ======================================================================= 381 | 382 | Creative Commons is not a party to its public 383 | licenses. Notwithstanding, Creative Commons may elect to apply one of 384 | its public licenses to material it publishes and in those instances 385 | will be considered the “Licensor.” The text of the Creative Commons 386 | public licenses is dedicated to the public domain under the CC0 Public 387 | Domain Dedication. Except for the limited purpose of indicating that 388 | material is shared under a Creative Commons public license or as 389 | otherwise permitted by the Creative Commons policies published at 390 | creativecommons.org/policies, Creative Commons does not authorize the 391 | use of the trademark "Creative Commons" or any other trademark or logo 392 | of Creative Commons without its prior written consent including, 393 | without limitation, in connection with any unauthorized modifications 394 | to any of its public licenses or any other arrangements, 395 | understandings, or agreements concerning use of licensed material. For 396 | the avoidance of doubt, this paragraph does not form part of the 397 | public licenses. 398 | 399 | Creative Commons may be contacted at creativecommons.org. -------------------------------------------------------------------------------- /run_summarization_no_trainer.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | # 8 | # 2023.02.01 - Add support for two-set quantization and using KD loss for logit distillation 9 | # Meta Platforms, Inc. 10 | # 11 | # Modifications Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. 12 | # Copyright The HuggingFace Team and The HuggingFace Inc. team. All rights reserved. 13 | # 14 | # Licensed under the Apache License, Version 2.0 (the "License"); 15 | # you may not use this file except in compliance with the License. 16 | # You may obtain a copy of the License at 17 | # 18 | # http://www.apache.org/licenses/LICENSE-2.0 19 | # 20 | # Unless required by applicable law or agreed to in writing, software 21 | # distributed under the License is distributed on an "AS IS" BASIS, 22 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 23 | # See the License for the specific language governing permissions and 24 | # limitations under the License. 25 | """ 26 | Fine-tuning a 🤗 Transformers model on summarization. 27 | """ 28 | # You can also adapt this script on your own summarization task. Pointers for this are left as comments. 29 | 30 | import argparse 31 | import logging 32 | import math 33 | import os 34 | import random 35 | import sys 36 | 37 | import datasets 38 | import nltk 39 | import numpy as np 40 | import torch 41 | import transformers 42 | from accelerate import Accelerator 43 | from datasets import load_dataset, load_metric 44 | from filelock import FileLock 45 | from torch.nn import MSELoss 46 | from torch.utils.data.dataloader import DataLoader 47 | from tqdm import tqdm 48 | from tqdm.auto import tqdm 49 | from transformers import ( 50 | CONFIG_MAPPING, 51 | MODEL_MAPPING, 52 | AdamW, 53 | AutoConfig, 54 | AutoModelForSeq2SeqLM, 55 | AutoTokenizer, 56 | DataCollatorForSeq2Seq, 57 | SchedulerType, 58 | get_scheduler, 59 | set_seed, 60 | ) 61 | from transformers.file_utils import is_offline_mode 62 | from transformers.utils.versions import require_version 63 | 64 | from quant.configuration_bart_quant import BartConfig as QBartConfig 65 | from quant.modeling_bart_quant import BartForConditionalGeneration as QBart 66 | import KD_loss 67 | 68 | require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt") 69 | 70 | # You should update this to your particular problem to have better documentation of `model_type` 71 | MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys()) 72 | MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) 73 | 74 | try: 75 | nltk.data.find("tokenizers/punkt") 76 | except (LookupError, OSError): 77 | if is_offline_mode(): 78 | raise LookupError( 79 | "Offline mode: run this script without TRANSFORMERS_OFFLINE first to download nltk data files" 80 | ) 81 | with FileLock(".lock") as lock: 82 | nltk.download("punkt", quiet=True) 83 | 84 | summarization_name_mapping = { 85 | "amazon_reviews_multi": ("review_body", "review_title"), 86 | "big_patent": ("description", "abstract"), 87 | "cnn_dailymail": ("article", "highlights"), 88 | "orange_sum": ("text", "summary"), 89 | "pn_summary": ("article", "summary"), 90 | "psc": ("extract_text", "summary_text"), 91 | "samsum": ("dialogue", "summary"), 92 | "thaisum": ("body", "summary"), 93 | "xglue": ("news_body", "news_title"), 94 | "xsum": ("document", "summary"), 95 | "wiki_summary": ("article", "highlights"), 96 | } 97 | 98 | distill_mappings = {1: {0: 5}, 99 | 2: {0: 0, 1: 5}, 100 | 3: {0: 0, 1: 2, 2: 5}, 101 | 4: {0: 0, 1: 2, 2: 3, 3: 5}, 102 | 5: {0: 0, 1: 1, 2: 3, 3: 4, 4: 5}, 103 | 6: {0: 0, 1: 1, 2: 2, 3: 3, 4: 4, 5: 5} 104 | } 105 | distill_mappings_new = {1: {0: 0}} 106 | NUMS = [str(i) for i in range(6)] 107 | 108 | 109 | def parse_args(): 110 | parser = argparse.ArgumentParser(description="Finetune a transformers model on a text classification task") 111 | parser.add_argument("--dataset_name", 112 | type=str, 113 | default=None, 114 | help="The name of the dataset to use (via the datasets library).", ) 115 | parser.add_argument("--dataset_config_name", 116 | type=str, 117 | default=None, 118 | help="The configuration name of the dataset to use (via the datasets library).", ) 119 | parser.add_argument("--train_file", 120 | type=str, 121 | default=None, 122 | help="A csv or a json file containing the training data.") 123 | parser.add_argument("--validation_file", 124 | type=str, 125 | default=None, 126 | help="A csv or a json file containing the validation data.") 127 | parser.add_argument("--ignore_pad_token_for_loss", 128 | type=bool, 129 | default=True, 130 | help="Whether to ignore the tokens corresponding to " "padded labels in the loss computation or not.", ) 131 | parser.add_argument("--max_source_length", 132 | type=int, 133 | default=1024, 134 | help="The maximum total input sequence length after " 135 | "tokenization.Sequences longer than this will be truncated, sequences shorter will be padded.", ) 136 | parser.add_argument("--source_prefix", 137 | type=str, 138 | default=None, 139 | help="A prefix to add before every source text " "(useful for T5 models).", ) 140 | parser.add_argument("--preprocessing_num_workers", 141 | type=int, 142 | default=None, 143 | help="The number of processes to use for the preprocessing.", ) 144 | parser.add_argument("--overwrite_cache", 145 | type=bool, 146 | default=None, 147 | help="Overwrite the cached training and evaluation sets") 148 | parser.add_argument("--max_target_length", 149 | type=int, 150 | default=128, 151 | help="The maximum total sequence length for target text after " 152 | "tokenization. Sequences longer than this will be truncated, sequences shorter will be padded." 153 | "during ``evaluate`` and ``predict``.", ) 154 | parser.add_argument("--val_max_target_length", 155 | type=int, 156 | default=None, 157 | help="The maximum total sequence length for validation " 158 | "target text after tokenization.Sequences longer than this will be truncated, sequences shorter will be " 159 | "padded. Will default to `max_target_length`.This argument is also used to override the ``max_length`` " 160 | "param of ``model.generate``, which is used during ``evaluate`` and ``predict``.", ) 161 | parser.add_argument("--pad_to_max_length", 162 | action="store_true", 163 | help="If passed, pad all samples to `max_length`. Otherwise, dynamic padding is used.", ) 164 | parser.add_argument("--model_name_or_path", 165 | type=str, 166 | help="Path to pretrained model or model identifier from huggingface.co/models.", 167 | required=True, ) 168 | parser.add_argument("--config_name", 169 | type=str, 170 | default=None, 171 | help="Pretrained config name or path if not the same as model_name", ) 172 | parser.add_argument("--tokenizer_name", 173 | type=str, 174 | default=None, 175 | help="Pretrained tokenizer name or path if not the same as model_name", ) 176 | parser.add_argument("--text_column", 177 | type=str, 178 | default=None, 179 | help="The name of the column in the datasets containing the full texts (for summarization).", ) 180 | parser.add_argument("--summary_column", 181 | type=str, 182 | default=None, 183 | help="The name of the column in the datasets containing the summaries (for summarization).", ) 184 | parser.add_argument("--use_slow_tokenizer", 185 | action="store_true", 186 | help="If passed, will use a slow tokenizer (not backed by the 🤗 Tokenizers library).", ) 187 | parser.add_argument("--per_device_train_batch_size", 188 | type=int, 189 | default=8, 190 | help="Batch size (per device) for the training dataloader.", ) 191 | parser.add_argument("--per_device_eval_batch_size", 192 | type=int, 193 | default=4, 194 | help="Batch size (per device) for the evaluation dataloader.", ) 195 | parser.add_argument("--learning_rate", 196 | type=float, 197 | default=3e-5, 198 | help="Initial learning rate (after the potential warmup period) to use.", ) 199 | parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.") 200 | parser.add_argument("--num_train_epochs", type=int, default=3, help="Total number of training epochs to perform.") 201 | parser.add_argument("--max_train_steps", 202 | type=int, 203 | default=None, 204 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", ) 205 | parser.add_argument("--gradient_accumulation_steps", 206 | type=int, 207 | default=2, 208 | help="Number of updates steps to accumulate before performing a backward/update pass.", ) 209 | parser.add_argument("--lr_scheduler_type", 210 | type=SchedulerType, 211 | default="linear", 212 | help="The scheduler type to use.", 213 | choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", 214 | "constant_with_warmup"], ) 215 | parser.add_argument("--warmup_ratio", 216 | type=float, 217 | default=0.05, 218 | help="warmup in the lr scheduler.") 219 | parser.add_argument("--output_dir", type=str, default=None, help="Where to store the final model.") 220 | parser.add_argument("--seed", type=int, default=28, help="A seed for reproducible training.") 221 | parser.add_argument("--model_type", 222 | type=str, 223 | default=None, 224 | help="Model type to use if training from scratch.", 225 | choices=MODEL_TYPES, ) 226 | 227 | parser.add_argument("--teacher_model", 228 | default=None, 229 | type=str, 230 | help="The models directory.") 231 | parser.add_argument("--student_model", 232 | default=None, 233 | type=str, 234 | help="The models directory.") 235 | parser.add_argument('--pred_distill', 236 | action='store_true', 237 | help="Whether to distil with task layer") 238 | parser.add_argument('--hid_distill', 239 | action='store_true', 240 | help="Whether to distil with intermediate layers") 241 | parser.add_argument('--attn_distill', 242 | action='store_true', 243 | help="Whether to distil with intermediate layers") 244 | 245 | 246 | parser.add_argument("--weight_bits", 247 | default=8, 248 | type=int, 249 | choices=[1, 2], 250 | help="Quantization bits for weight.") 251 | parser.add_argument("--input_bits", 252 | default=8, 253 | type=int, 254 | choices=[1, 2, 8], 255 | help="Quantization bits for activation.") 256 | parser.add_argument("--clip_val", 257 | default=2.5, 258 | type=float, 259 | help="Initial clip value.") 260 | parser.add_argument("--length_penalty", 261 | default=1.0, 262 | type=float, 263 | help="model config param lengthy_penalty.") 264 | parser.add_argument("--max_length", 265 | default=128, 266 | type=int, 267 | help=( 268 | "The maximum total input sequence length after tokenization. Sequences longer than this will be truncated,"" sequences shorter will be padded if `--pad_to_max_lengh` is passed."), 269 | ) 270 | parser.add_argument("--min_length", 271 | default=12, 272 | type=int, 273 | help="model config param min_length.") 274 | parser.add_argument("--num_beams", 275 | default=4, 276 | type=int, 277 | help="Number of beams to use for evaluation. This argument will be " 278 | "passed to ``model.generate``, which is used during ``evaluate`` and ``predict``.", ) 279 | 280 | parser.add_argument('--do_train', 281 | action='store_true', 282 | help="Whether to do train and evaluation") 283 | parser.add_argument('--do_test', 284 | action='store_true', 285 | help="Whether to do test") 286 | parser.add_argument('--test_teacher', 287 | action='store_true', 288 | help="Whether to test teacher") 289 | parser.add_argument('--distill_encoder', 290 | default=6, 291 | type=int, 292 | help="Number of encoder layers after distillation") 293 | parser.add_argument('--distill_decoder', 294 | default=6, 295 | type=int, 296 | help="Number of decoder layers after distillation") 297 | parser.add_argument('--sym_quant_ffn_attn', action='store_true', 298 | help='whether use sym quant for attn score and ffn after act') # default asym 299 | parser.add_argument('--sym_quant_qkvo', action='store_true', default=True, 300 | help='whether use asym quant for Q/K/V and others.') # default sym 301 | 302 | parser.add_argument('--log_steps', default=20) 303 | parser.add_argument('--local_rank', default=0) 304 | parser.add_argument('--weighted', action='store_true') 305 | parser.add_argument('--new_distill_map', action='store_true') 306 | 307 | args = parser.parse_args() 308 | 309 | # Sanity checks 310 | if args.new_distill_map: 311 | assert args.distill_decoder == 1 312 | if args.dataset_name is None and args.train_file is None and args.validation_file is None: 313 | raise ValueError("Need either a dataset name or a training/validation file.") 314 | else: 315 | if args.train_file is not None: 316 | extension = args.train_file.split(".")[-1] 317 | assert extension in ["csv", "json"], "`train_file` should be a csv or a json file." 318 | if args.validation_file is not None: 319 | extension = args.validation_file.split(".")[-1] 320 | assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file." 321 | 322 | args.output_dir = f'./output_{args.dataset_name}/{args.weight_bits}_{args.input_bits}_{args.distill_encoder}_{args.distill_decoder}_{args.num_train_epochs}_{args.learning_rate}_quant' 323 | if args.new_distill_map: 324 | args.output_dir += '_new' 325 | if (not args.pred_distill) and (not args.hid_distill) and (not args.attn_distill): 326 | args.output_dir += '_nodis' 327 | 328 | if args.student_model is None: 329 | args.student_model = args.model_name_or_path 330 | if args.teacher_model is None: 331 | args.teacher_model = args.model_name_or_path 332 | 333 | if args.dataset_name == "xsum": 334 | args.length_penalty = 1.0 335 | args.max_length = 62 336 | args.min_length = 11 337 | args.num_beams = 6 338 | elif args.dataset_name == "cnn_dailymail": 339 | args.length_penalty = 2.0 340 | args.max_length = 142 341 | args.min_length = 56 342 | args.num_beams = 4 343 | else: 344 | assert False, f'args error: dataset name {args.dataset_name}' 345 | if args.weighted: 346 | args.task_weight = 1 347 | args.logits_weight = 0.8 348 | args.hid_weight = 3 349 | args.output_dir += '_weighted' 350 | else: 351 | args.task_weight = 1 352 | args.logits_weight = 1 353 | args.hid_weight = 1 354 | if args.output_dir is not None: 355 | os.makedirs(args.output_dir, exist_ok=True) 356 | return args 357 | 358 | 359 | def main(): 360 | args = parse_args() 361 | logging.basicConfig( 362 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 363 | datefmt="%m/%d/%Y %H:%M:%S", 364 | handlers=[ 365 | logging.StreamHandler(sys.stdout), 366 | logging.FileHandler(os.path.join(args.output_dir, "training.log")), 367 | ], 368 | ) 369 | 370 | logger = logging.getLogger(__name__) 371 | logger.setLevel(logging.INFO) 372 | 373 | if args.source_prefix is None and args.model_name_or_path in [ 374 | "t5-small", 375 | "t5-base", 376 | "t5-large", 377 | "t5-3b", 378 | "t5-11b", 379 | ]: 380 | logger.warning( 381 | "You're running a t5 model but didn't provide a source prefix, which is the expected, e.g. with " 382 | "`--source_prefix 'summarize: ' `" 383 | ) 384 | # Initialize the accelerator. We will let the accelerator handle device placement for us in this example. 385 | accelerator = Accelerator() 386 | # Make one log on every process with the configuration for debugging. 387 | is_master = accelerator.is_local_main_process 388 | if is_master: 389 | logger.info(accelerator.state) 390 | logger.warning(args) 391 | task, run = args.output_dir.split('/')[1:] 392 | 393 | # Setup logging, we only want one process per machine to log things on the screen. 394 | # accelerator.is_local_main_process is only True for one process per machine. 395 | logger.setLevel(logging.INFO if accelerator.is_local_main_process else logging.ERROR) 396 | if accelerator.is_local_main_process: 397 | datasets.utils.logging.set_verbosity_warning() 398 | transformers.utils.logging.set_verbosity_info() 399 | else: 400 | datasets.utils.logging.set_verbosity_error() 401 | transformers.utils.logging.set_verbosity_error() 402 | 403 | # If passed along, set the training seed now. 404 | if args.seed is not None: 405 | set_seed(args.seed) 406 | 407 | # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below) 408 | # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ 409 | # (the dataset will be downloaded automatically from the datasets Hub). 410 | # 411 | # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called 412 | # 'text' is found. You can easily tweak this behavior (see below). 413 | # 414 | # In distributed training, the load_dataset function guarantee that only one local process can concurrently 415 | # download the dataset. 416 | if args.dataset_name is not None: 417 | # Downloading and loading a dataset from the hub. 418 | raw_datasets = load_dataset(args.dataset_name, args.dataset_config_name) 419 | else: 420 | data_files = {} 421 | if args.train_file is not None: 422 | data_files["train"] = args.train_file 423 | if args.validation_file is not None: 424 | data_files["validation"] = args.validation_file 425 | extension = args.train_file.split(".")[-1] 426 | raw_datasets = load_dataset(extension, data_files=data_files) 427 | # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at 428 | # https://huggingface.co/docs/datasets/loading_datasets.html. 429 | 430 | # Load pretrained model and tokenizer 431 | # 432 | # In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently 433 | # download model & vocab. 434 | if args.config_name: 435 | config = AutoConfig.from_pretrained(args.config_name) 436 | elif args.model_name_or_path: 437 | config = AutoConfig.from_pretrained(args.model_name_or_path) 438 | else: 439 | config = CONFIG_MAPPING[args.model_type]() 440 | logger.warning("You are instantiating a new config instance from scratch.") 441 | 442 | if args.tokenizer_name: 443 | tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, use_fast=not args.use_slow_tokenizer) 444 | elif args.model_name_or_path: 445 | tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=not args.use_slow_tokenizer) 446 | else: 447 | raise ValueError( 448 | "You are instantiating a new tokenizer from scratch. This is not supported by this script." 449 | "You can do it from another script, save it, and load it from here, using --tokenizer_name." 450 | ) 451 | 452 | distill_enc_mapping = distill_mappings[args.distill_encoder] 453 | distill_dec_mapping = distill_mappings[args.distill_decoder] if not args.new_distill_map else distill_mappings_new[ 454 | args.distill_decoder] 455 | maps = {'enc': distill_enc_mapping, 'dec': distill_dec_mapping} 456 | 457 | if args.model_name_or_path: 458 | teacher_model = AutoModelForSeq2SeqLM.from_pretrained( 459 | args.model_name_or_path, 460 | from_tf=bool(".ckpt" in args.model_name_or_path), 461 | config=config, 462 | ) 463 | 464 | student_config = QBartConfig.from_pretrained(args.teacher_model, 465 | quantize_act=True, 466 | weight_bits=args.weight_bits, 467 | input_bits=args.input_bits, 468 | clip_val=args.clip_val, 469 | decoder_layers=args.distill_decoder, 470 | encoder_layers=args.distill_encoder, 471 | sym_quant_ffn_attn=args.sym_quant_ffn_attn, 472 | sym_quant_qkvo=args.sym_quant_qkvo) 473 | student_model = QBart(student_config) 474 | 475 | dst_dict = student_model.state_dict() # Initilized student model state dict, needs loading weights 476 | src_dict = teacher_model.state_dict() # Pretrained teacher model state dict, whose weights will be loaded 477 | 478 | for key in dst_dict.keys(): 479 | if ("encoder" in key or "decoder" in key) and key[ 480 | 21] in NUMS: # Determine if the key belongs to a encoder/decoder layer, 481 | # which starts with sth like model.decoder.layers.1 482 | 483 | m = maps[key[6:9]] # Determin if it is an encoder or decoder, and get the layer mapping 484 | old_idx = int(key[21]) # The layer index of the student model that needs loading 485 | new_idx = str(m[old_idx]) # The layer index of the teacher model that should be loaded 486 | mapped_key = key[:21] + new_idx + key[22:] # Get the full teacher layer key 487 | if mapped_key in src_dict.keys(): # Exclude the cases 488 | # which does not exist in the teacher model 489 | dst_dict[key] = src_dict[mapped_key] # Load the weights of the layer 490 | else: 491 | if key in src_dict.keys(): # Load the weights of non-encoder/decoder layers 492 | dst_dict[key] = src_dict[key] 493 | 494 | student_model.load_state_dict(dst_dict, strict=False) # Pass the dict to the student model 495 | 496 | else: 497 | raise ValueError( 498 | "You did not provide a pre-trained teacher_model." 499 | ) 500 | 501 | teacher_model.resize_token_embeddings(len(tokenizer)) 502 | student_model.resize_token_embeddings(len(tokenizer)) 503 | 504 | if teacher_model.config.decoder_start_token_id is None: 505 | raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined") 506 | 507 | prefix = args.source_prefix if args.source_prefix is not None else "" 508 | 509 | # Preprocessing the datasets. 510 | # First we tokenize all the texts. 511 | column_names = raw_datasets["train"].column_names 512 | 513 | # Get the column names for input/target. 514 | dataset_columns = summarization_name_mapping.get(args.dataset_name, None) 515 | if args.text_column is None: 516 | text_column = dataset_columns[0] if dataset_columns is not None else column_names[0] 517 | else: 518 | text_column = args.text_column 519 | if text_column not in column_names: 520 | raise ValueError( 521 | f"--text_column' value '{args.text_column}' needs to be one of: {', '.join(column_names)}" 522 | ) 523 | if args.summary_column is None: 524 | summary_column = dataset_columns[1] if dataset_columns is not None else column_names[1] 525 | else: 526 | summary_column = args.summary_column 527 | if summary_column not in column_names: 528 | raise ValueError( 529 | f"--summary_column' value '{args.summary_column}' needs to be one of: {', '.join(column_names)}" 530 | ) 531 | 532 | # Temporarily set max_target_length for training. 533 | max_target_length = args.max_target_length 534 | padding = "max_length" if args.pad_to_max_length else False 535 | 536 | def preprocess_function(examples): 537 | inputs = examples[text_column] 538 | targets = examples[summary_column] 539 | inputs = [prefix + inp for inp in inputs] 540 | model_inputs = tokenizer(inputs, max_length=args.max_source_length, padding=padding, truncation=True) 541 | 542 | # Setup the tokenizer for targets 543 | with tokenizer.as_target_tokenizer(): 544 | labels = tokenizer(targets, max_length=max_target_length, padding=padding, truncation=True) 545 | 546 | # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore 547 | # padding in the loss. 548 | if padding == "max_length" and args.ignore_pad_token_for_loss: 549 | labels["input_ids"] = [ 550 | [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"] 551 | ] 552 | 553 | model_inputs["labels"] = labels["input_ids"] 554 | return model_inputs 555 | 556 | processed_datasets = raw_datasets.map( 557 | preprocess_function, 558 | batched=True, 559 | remove_columns=column_names, 560 | load_from_cache_file=not args.overwrite_cache, 561 | desc="Running tokenizer on dataset", 562 | num_proc=10 563 | ) 564 | 565 | train_dataset = processed_datasets["train"] 566 | eval_dataset = processed_datasets["validation"] 567 | test_dataset = processed_datasets["test"] 568 | 569 | # Log a few random samples from the training set: 570 | for index in random.sample(range(len(train_dataset)), 1): 571 | logger.info(f"Sample {index} of the training set: {train_dataset[index]}.") 572 | 573 | label_pad_token_id = -100 if args.ignore_pad_token_for_loss else tokenizer.pad_token_id 574 | data_collator = DataCollatorForSeq2Seq( 575 | tokenizer, 576 | model=teacher_model, 577 | label_pad_token_id=label_pad_token_id, 578 | pad_to_multiple_of=8 if accelerator.use_fp16 else None, 579 | ) 580 | 581 | def postprocess_text(preds, labels): 582 | preds = [pred.strip() for pred in preds] 583 | labels = [label.strip() for label in labels] 584 | 585 | # rougeLSum expects newline after each sentence 586 | preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds] 587 | labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels] 588 | 589 | return preds, labels 590 | 591 | train_dataloader = DataLoader( 592 | train_dataset, shuffle=True, collate_fn=data_collator, batch_size=args.per_device_train_batch_size 593 | ) 594 | eval_dataloader = DataLoader(eval_dataset, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size) 595 | test_dataloader = DataLoader(test_dataset, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size) 596 | 597 | # Optimizer 598 | # Split weights in two groups, one with weight decay and the other not. 599 | no_decay = ["bias", "LayerNorm.weight"] 600 | optimizer_grouped_parameters = [ 601 | { 602 | "params": [p for n, p in student_model.named_parameters() if not any(nd in n for nd in no_decay)], 603 | "weight_decay": args.weight_decay, 604 | }, 605 | { 606 | "params": [p for n, p in student_model.named_parameters() if any(nd in n for nd in no_decay)], 607 | "weight_decay": 0.0, 608 | }, 609 | ] 610 | optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate) 611 | 612 | # Prepare everything with our `accelerator`. 613 | student_model, optimizer, train_dataloader, eval_dataloader, test_dataloader = accelerator.prepare( 614 | student_model, optimizer, train_dataloader, eval_dataloader, test_dataloader 615 | ) 616 | teacher_model.to('cuda' if torch.cuda.is_available() else 'cpu') 617 | 618 | # Note -> the training dataloader needs to be prepared before we grab his length below (cause its length will be 619 | # shorter in multiprocess) 620 | 621 | # Scheduler and math around the number of training steps. 622 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) 623 | if args.max_train_steps is None: 624 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 625 | else: 626 | args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) 627 | 628 | lr_scheduler = get_scheduler( 629 | name=args.lr_scheduler_type, 630 | optimizer=optimizer, 631 | num_warmup_steps=int(args.warmup_ratio * args.max_train_steps), 632 | num_training_steps=args.max_train_steps, 633 | ) 634 | 635 | # Metric 636 | metric = load_metric("rouge") 637 | 638 | # Train! 639 | total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps 640 | 641 | logger.info("***** Running training *****") 642 | logger.info(f" Num examples = {len(train_dataset)}") 643 | logger.info(f" Num Epochs = {args.num_train_epochs}") 644 | logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}") 645 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") 646 | logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") 647 | logger.info(f" Total optimization steps = {args.max_train_steps}") 648 | 649 | # print(distill_enc_mapping, distill_dec_mapping) 650 | 651 | logger.info(f" student encoder layers = {student_config.encoder_layers}") 652 | logger.info(f" student decoder layers = {student_config.decoder_layers}") 653 | logger.info( 654 | f" student encoder layers {list(distill_enc_mapping.keys())} is mapped with teacher encoder layers {list(distill_enc_mapping.values())}") 655 | logger.info( 656 | f" student decoder layers {list(distill_dec_mapping.keys())} is mapped with teacher decoder layers {list(distill_dec_mapping.values())}") 657 | # Only show the progress bar once on each machine. 658 | progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process, mininterval=2) 659 | completed_steps = 0 660 | loss_mse = MSELoss() 661 | assert teacher_model.training == False 662 | prev = 0.0 663 | 664 | gen_kwargs = { 665 | "length_penalty": args.length_penalty, 666 | "max_length": args.max_length, 667 | "min_length": args.min_length, 668 | "num_beams": args.num_beams, 669 | } 670 | 671 | if args.do_train: 672 | for epoch in range(args.num_train_epochs): 673 | student_model.train() 674 | 675 | log_total_loss = 0.0 676 | log_task_loss = 0.0 677 | log_logits_loss = 0.0 678 | log_enc_att_loss = 0.0 679 | log_dec_att_loss = 0.0 680 | log_crs_att_loss = 0.0 681 | log_enc_hid_loss = 0.0 682 | log_enc_hid_last_loss = 0.0 683 | log_dec_hid_loss = 0.0 684 | 685 | for step, batch in enumerate(train_dataloader): 686 | task_loss = 0.0 687 | logits_loss = 0.0 688 | enc_att_loss = 0.0 689 | dec_att_loss = 0.0 690 | crs_att_loss = 0.0 691 | enc_hid_loss = 0.0 692 | enc_hid_last_loss = 0.0 693 | dec_hid_loss = 0.0 694 | 695 | student_outputs = student_model(**batch, output_attentions=True, output_hidden_states=True) 696 | task_loss = student_outputs.loss 697 | 698 | with torch.no_grad(): 699 | teacher_outputs = teacher_model(**batch, output_attentions=True, output_hidden_states=True) 700 | if args.pred_distill: 701 | # logits_loss = loss_mse(student_outputs.logits, teacher_outputs.logits) 702 | criterion_kd = KD_loss.DistributionLoss() 703 | logits_shape = student_outputs.logits.shape 704 | logits_loss = criterion_kd(student_outputs.logits.reshape(-1, logits_shape[-1]), teacher_outputs.logits.reshape(-1, logits_shape[-1])) 705 | 706 | if args.hid_distill or args.attn_distill: 707 | for i, student_att in enumerate(student_outputs.encoder_attentions): 708 | mapped_idx = distill_enc_mapping[i] 709 | teacher_att = teacher_outputs.encoder_attentions[mapped_idx] 710 | enc_att_loss += loss_mse(student_att, teacher_att) 711 | 712 | for student_hs, teacher_hs in zip(student_outputs.encoder_last_hidden_state, 713 | teacher_outputs.encoder_last_hidden_state): 714 | enc_hid_last_loss += loss_mse(student_hs, teacher_hs) 715 | 716 | for i, student_hs in enumerate(student_outputs.encoder_hidden_states): 717 | if i == 0: 718 | mapped_idx = 0 719 | else: 720 | mapped_idx = distill_enc_mapping[i - 1] + 1 721 | 722 | teacher_hs = teacher_outputs.encoder_hidden_states[mapped_idx] 723 | enc_hid_loss += loss_mse(student_hs, teacher_hs) 724 | 725 | for i, student_att in enumerate(student_outputs.cross_attentions): 726 | mapped_idx = distill_dec_mapping[i] 727 | teacher_att = teacher_outputs.cross_attentions[mapped_idx] 728 | crs_att_loss += loss_mse(student_att, teacher_att) 729 | 730 | for i, student_att in enumerate(student_outputs.decoder_attentions): 731 | mapped_idx = distill_dec_mapping[i] 732 | teacher_att = teacher_outputs.decoder_attentions[mapped_idx] 733 | dec_att_loss += loss_mse(student_att, teacher_att) 734 | 735 | for i, student_hs in enumerate(student_outputs.decoder_hidden_states): 736 | if i == 0: 737 | mapped_idx = 0 738 | else: 739 | mapped_idx = distill_dec_mapping[i - 1] + 1 740 | 741 | teacher_hs = teacher_outputs.decoder_hidden_states[mapped_idx] 742 | dec_hid_loss += loss_mse(student_hs, teacher_hs) 743 | 744 | total_loss = args.task_weight * task_loss + \ 745 | args.logits_weight * logits_loss + \ 746 | args.hid_weight * ( 747 | enc_att_loss + dec_att_loss + crs_att_loss + enc_hid_loss + enc_hid_last_loss + dec_hid_loss) 748 | 749 | accelerator.backward(total_loss / args.gradient_accumulation_steps) 750 | 751 | log_total_loss += total_loss.item() 752 | log_task_loss += task_loss.item() 753 | if args.pred_distill: 754 | log_logits_loss += logits_loss.item() 755 | if args.attn_distill: 756 | log_enc_att_loss += enc_att_loss.item() 757 | log_dec_att_loss += dec_att_loss.item() 758 | log_crs_att_loss += crs_att_loss.item() 759 | if args.hid_distill: 760 | log_enc_hid_loss += enc_hid_loss.item() 761 | log_enc_hid_last_loss += enc_hid_last_loss.item() 762 | log_dec_hid_loss += dec_hid_loss.item() 763 | 764 | if (step + 1) % args.gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1: 765 | 766 | optimizer.step() 767 | lr_scheduler.step() 768 | 769 | optimizer.zero_grad() 770 | progress_bar.update(1) 771 | completed_steps += 1 772 | 773 | cur_step = (epoch * len(train_dataloader) + step) // args.gradient_accumulation_steps 774 | log_accu_steps = args.log_steps * args.gradient_accumulation_steps 775 | if (step + 1) % log_accu_steps == 0 and is_master: 776 | 777 | log_total_loss = 0.0 778 | log_task_loss = 0.0 779 | log_logits_loss = 0.0 780 | log_enc_att_loss = 0.0 781 | log_dec_att_loss = 0.0 782 | log_crs_att_loss = 0.0 783 | log_enc_hid_loss = 0.0 784 | log_enc_hid_last_loss = 0.0 785 | log_dec_hid_loss = 0.0 786 | 787 | if completed_steps >= args.max_train_steps: 788 | break 789 | 790 | if epoch == 0: 791 | for step, batch in enumerate(tqdm(eval_dataloader)): 792 | with torch.no_grad(): 793 | generated_tokens = accelerator.unwrap_model(student_model).generate( 794 | batch["input_ids"], 795 | attention_mask=batch["attention_mask"], 796 | **gen_kwargs, 797 | ) 798 | 799 | generated_tokens = accelerator.pad_across_processes( 800 | generated_tokens, dim=1, pad_index=tokenizer.pad_token_id 801 | ) 802 | labels = batch["labels"] 803 | if not args.pad_to_max_length: 804 | # If we did not pad to max length, we need to pad the labels too 805 | labels = accelerator.pad_across_processes(batch["labels"], dim=1, 806 | pad_index=tokenizer.pad_token_id) 807 | 808 | generated_tokens = accelerator.gather(generated_tokens).cpu().numpy() 809 | labels = accelerator.gather(labels).cpu().numpy() 810 | 811 | if args.ignore_pad_token_for_loss: 812 | # Replace -100 in the labels as we can't decode them. 813 | labels = np.where(labels != -100, labels, tokenizer.pad_token_id) 814 | if isinstance(generated_tokens, tuple): 815 | generated_tokens = generated_tokens[0] 816 | decoded_preds = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True) 817 | decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) 818 | decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels) 819 | 820 | metric.add_batch(predictions=decoded_preds, references=decoded_labels) 821 | 822 | result = metric.compute(use_stemmer=True) 823 | # Extract a few results from ROUGE 824 | result = {key: value.mid.fmeasure * 100 for key, value in result.items()} 825 | result = {'eval/' + k: round(v, 4) for k, v in result.items()} 826 | res_rougeL = result['eval/rougeLsum'] 827 | 828 | logger.info(f"evaluation result: {result} ") 829 | 830 | if args.output_dir is not None : 831 | accelerator.wait_for_everyone() 832 | unwrapped_model = accelerator.unwrap_model(student_model) 833 | unwrapped_model.save_pretrained(args.output_dir, save_function=accelerator.save) 834 | 835 | # load best model and evaluate on testset 836 | if args.do_test: 837 | try: 838 | student_model.to('cpu') 839 | teacher_model.to('cpu') 840 | del student_model 841 | del teacher_model 842 | del student_outputs 843 | del teacher_outputs 844 | for i in batch: 845 | del i 846 | except Exception as e: 847 | logger.warning(f'Error in deletion: {e}') 848 | if not args.test_teacher: 849 | best_model_config = QBartConfig.from_pretrained(args.output_dir, 850 | quantize_act=True, 851 | weight_bits=args.weight_bits, 852 | input_bits=args.input_bits, 853 | clip_val=args.clip_val, 854 | decoder_layers=args.distill_decoder, 855 | encoder_layers=args.distill_encoder, 856 | sym_quant_ffn_attn=args.sym_quant_ffn_attn, 857 | sym_quant_qkvo=args.sym_quant_qkvo) 858 | best_model = QBart(best_model_config) 859 | best_model.load_state_dict( 860 | torch.load(os.path.join(args.output_dir + "/", "pytorch_model.bin"), map_location='cpu')) 861 | 862 | if args.test_teacher: 863 | best_model = teacher_model 864 | logger.info(f"testing teacher model from {args.teacher_model} ") 865 | 866 | best_model = accelerator.prepare(best_model) 867 | best_model.eval() 868 | 869 | for step, batch in enumerate(tqdm(test_dataloader)): 870 | with torch.no_grad(): 871 | generated_tokens = accelerator.unwrap_model(best_model).generate( 872 | batch["input_ids"], 873 | attention_mask=batch["attention_mask"], 874 | **gen_kwargs, 875 | ) 876 | 877 | generated_tokens = accelerator.pad_across_processes( 878 | generated_tokens, dim=1, pad_index=tokenizer.pad_token_id 879 | ) 880 | labels = batch["labels"] 881 | if not args.pad_to_max_length: 882 | # If we did not pad to max length, we need to pad the labels too 883 | labels = accelerator.pad_across_processes(batch["labels"], dim=1, pad_index=tokenizer.pad_token_id) 884 | 885 | generated_tokens = accelerator.gather(generated_tokens).cpu().numpy() 886 | labels = accelerator.gather(labels).cpu().numpy() 887 | 888 | if args.ignore_pad_token_for_loss: 889 | # Replace -100 in the labels as we can't decode them. 890 | labels = np.where(labels != -100, labels, tokenizer.pad_token_id) 891 | if isinstance(generated_tokens, tuple): 892 | generated_tokens = generated_tokens[0] 893 | decoded_preds = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True) 894 | decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) 895 | decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels) 896 | 897 | metric.add_batch(predictions=decoded_preds, references=decoded_labels) 898 | result = metric.compute(use_stemmer=True) 899 | # Extract a few results from ROUGE 900 | result = {key: value.mid.fmeasure * 100 for key, value in result.items()} 901 | result = {'test/' + k: round(v, 4) for k, v in result.items()} 902 | logger.info(f"test result: {result}") 903 | 904 | 905 | if __name__ == "__main__": 906 | main() 907 | -------------------------------------------------------------------------------- /quant/modeling_bart_quant.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | # 8 | # 2022 - add quantization modules Amazon.com, Inc. or its affiliates 9 | # Modifications Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. 10 | # Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved. 11 | # 12 | # Licensed under the Apache License, Version 2.0 (the "License"); 13 | # you may not use this file except in compliance with the License. 14 | # You may obtain a copy of the License at 15 | # 16 | # http://www.apache.org/licenses/LICENSE-2.0 17 | # 18 | # Unless required by applicable law or agreed to in writing, software 19 | # distributed under the License is distributed on an "AS IS" BASIS, 20 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 21 | # See the License for the specific language governing permissions and 22 | # limitations under the License. 23 | """ BART model configuration """ 24 | import copy 25 | import math 26 | import random 27 | import warnings 28 | from typing import Optional, Tuple 29 | 30 | import torch 31 | import torch.utils.checkpoint 32 | from torch import nn 33 | from torch.nn import CrossEntropyLoss, MSELoss 34 | 35 | from transformers.activations import ACT2FN 36 | from transformers.file_utils import ( 37 | add_code_sample_docstrings, 38 | add_end_docstrings, 39 | add_start_docstrings, 40 | add_start_docstrings_to_model_forward, 41 | replace_return_docstrings, 42 | ) 43 | from transformers.modeling_outputs import ( 44 | BaseModelOutput, 45 | BaseModelOutputWithPastAndCrossAttentions, 46 | CausalLMOutputWithCrossAttentions, 47 | Seq2SeqLMOutput, 48 | Seq2SeqModelOutput, 49 | Seq2SeqQuestionAnsweringModelOutput, 50 | Seq2SeqSequenceClassifierOutput, 51 | ) 52 | from transformers.modeling_utils import PreTrainedModel 53 | from transformers.utils import logging 54 | from .configuration_bart_quant import BartConfig 55 | 56 | from .utils_quant import QuantizeLinear, QuantizeEmbedding, SymQuantizer, AsymQuantizer, ElasticQuantBinarizerSigned, ElasticQuantBinarizerUnsigned, AlphaInit 57 | 58 | 59 | logger = logging.get_logger(__name__) 60 | 61 | _CHECKPOINT_FOR_DOC = "facebook/bart-large" 62 | _CONFIG_FOR_DOC = "BartConfig" 63 | _TOKENIZER_FOR_DOC = "BartTokenizer" 64 | 65 | 66 | BART_PRETRAINED_MODEL_ARCHIVE_LIST = [ 67 | "facebook/bart-large", 68 | # See all BART models at https://huggingface.co/models?filter=bart 69 | ] 70 | 71 | 72 | # Base model docstring 73 | _EXPECTED_OUTPUT_SHAPE = [1, 8, 768] 74 | 75 | # SequenceClassification docstring 76 | _SEQ_CLASS_EXPECTED_OUTPUT_SHAPE = [1, 2] 77 | 78 | # QuestionAsnwering docstring 79 | _QA_EXPECTED_LOSS = 2.98 80 | _QA_EXPECTED_OUTPUT_SHAPE = [1, 17] 81 | 82 | 83 | def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int): 84 | """ 85 | Shift input ids one token to the right. 86 | """ 87 | shifted_input_ids = input_ids.new_zeros(input_ids.shape) 88 | shifted_input_ids[:, 1:] = input_ids[:, :-1].clone() 89 | shifted_input_ids[:, 0] = decoder_start_token_id 90 | 91 | if pad_token_id is None: 92 | raise ValueError("self.model.config.pad_token_id has to be defined.") 93 | # replace possible -100 values in labels by `pad_token_id` 94 | shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) 95 | 96 | return shifted_input_ids 97 | 98 | 99 | def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_values_length: int = 0): 100 | """ 101 | Make causal mask used for bi-directional self-attention. 102 | """ 103 | bsz, tgt_len = input_ids_shape 104 | mask = torch.full((tgt_len, tgt_len), float("-inf")) 105 | mask_cond = torch.arange(mask.size(-1)) 106 | mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) 107 | mask = mask.to(dtype) 108 | 109 | if past_key_values_length > 0: 110 | mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype), mask], dim=-1) 111 | return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) 112 | 113 | 114 | def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): 115 | """ 116 | Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. 117 | """ 118 | bsz, src_len = mask.size() 119 | tgt_len = tgt_len if tgt_len is not None else src_len 120 | 121 | expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) 122 | 123 | inverted_mask = 1.0 - expanded_mask 124 | 125 | return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min) 126 | 127 | 128 | class BartLearnedPositionalEmbedding(nn.Embedding): 129 | """ 130 | This module learns positional embeddings up to a fixed maximum size. 131 | """ 132 | 133 | def __init__(self, num_embeddings: int, embedding_dim: int): 134 | # Bart is set up so that if padding_idx is specified then offset the embedding ids by 2 135 | # and adjust num_embeddings appropriately. Other models don't have this hack 136 | self.offset = 2 137 | super().__init__(num_embeddings + self.offset, embedding_dim) 138 | 139 | def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0): 140 | """`input_ids_shape` is expected to be [bsz x seqlen].""" 141 | bsz, seq_len = input_ids_shape[:2] 142 | positions = torch.arange( 143 | past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device 144 | ) 145 | return super().forward(positions + self.offset) 146 | 147 | 148 | class BartAttention(nn.Module): 149 | """Multi-headed attention from 'Attention Is All You Need' paper""" 150 | 151 | def __init__( 152 | self, 153 | embed_dim: int, 154 | num_heads: int, 155 | config, 156 | dropout: float = 0.0, 157 | is_decoder: bool = False, 158 | bias: bool = True, 159 | ): 160 | super().__init__() 161 | self.embed_dim = embed_dim 162 | self.num_heads = num_heads 163 | self.dropout = dropout 164 | self.head_dim = embed_dim // num_heads 165 | 166 | if (self.head_dim * num_heads) != self.embed_dim: 167 | raise ValueError( 168 | f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" 169 | f" and `num_heads`: {num_heads})." 170 | ) 171 | self.scaling = self.head_dim**-0.5 172 | self.is_decoder = is_decoder 173 | 174 | # self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) 175 | # self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) 176 | # self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) 177 | # self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) 178 | self.k_proj = QuantizeLinear(embed_dim, embed_dim, symmetric=config.sym_quant_qkvo, config=config) 179 | self.v_proj = QuantizeLinear(embed_dim, embed_dim, symmetric=config.sym_quant_qkvo, config=config) 180 | self.q_proj = QuantizeLinear(embed_dim, embed_dim, symmetric=config.sym_quant_qkvo, config=config) 181 | self.out_proj = QuantizeLinear(embed_dim, embed_dim, symmetric=config.sym_quant_qkvo, config=config) 182 | 183 | self.quantize_act = config.quantize_act 184 | if self.quantize_act: 185 | self.input_bits = config.input_bits 186 | if self.input_bits <= 2: 187 | if config.sym_quant_qkvo: 188 | self.act_quantizer_q = ElasticQuantBinarizerSigned 189 | self.act_quantizer_k = ElasticQuantBinarizerSigned 190 | self.act_quantizer_v = ElasticQuantBinarizerSigned 191 | else: 192 | self.act_quantizer_q = ElasticQuantBinarizerUnsigned 193 | self.act_quantizer_k = ElasticQuantBinarizerUnsigned 194 | self.act_quantizer_v = ElasticQuantBinarizerUnsigned 195 | 196 | if config.sym_quant_ffn_attn: 197 | self.act_quantizer_attn = ElasticQuantBinarizerSigned 198 | else: 199 | self.act_quantizer_attn = ElasticQuantBinarizerUnsigned 200 | 201 | self.clip_query = AlphaInit(torch.tensor(1.0)) 202 | self.clip_key = AlphaInit(torch.tensor(1.0)) 203 | self.clip_value = AlphaInit(torch.tensor(1.0)) 204 | self.clip_attn = AlphaInit(torch.tensor(1.0)) 205 | 206 | elif self.input_bits == 8: 207 | if config.sym_quant_qkvo: 208 | self.act_quantizer_q = SymQuantizer 209 | self.act_quantizer_k = SymQuantizer 210 | self.act_quantizer_v = SymQuantizer 211 | else: 212 | self.act_quantizer_q = AsymQuantizer 213 | self.act_quantizer_k = AsymQuantizer 214 | self.act_quantizer_v = AsymQuantizer 215 | if config.sym_quant_ffn_attn: 216 | self.act_quantizer_attn = SymQuantizer 217 | else: 218 | self.act_quantizer_attn = AsymQuantizer 219 | 220 | self.register_buffer('clip_query', torch.Tensor([-config.clip_val, config.clip_val])) 221 | self.register_buffer('clip_key', torch.Tensor([-config.clip_val, config.clip_val])) 222 | self.register_buffer('clip_value', torch.Tensor([-config.clip_val, config.clip_val])) 223 | self.register_buffer('clip_attn', torch.Tensor([-config.clip_val, config.clip_val])) 224 | 225 | else: 226 | raise NotImplementedError 227 | 228 | 229 | def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): 230 | return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() 231 | 232 | def forward( 233 | self, 234 | hidden_states: torch.Tensor, 235 | key_value_states: Optional[torch.Tensor] = None, 236 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 237 | attention_mask: Optional[torch.Tensor] = None, 238 | layer_head_mask: Optional[torch.Tensor] = None, 239 | output_attentions: bool = False, 240 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 241 | """Input shape: Batch x Time x Channel""" 242 | 243 | # if key_value_states are provided this layer is used as a cross-attention layer 244 | # for the decoder 245 | is_cross_attention = key_value_states is not None 246 | 247 | bsz, tgt_len, _ = hidden_states.size() 248 | 249 | # get query proj 250 | query_states = self.q_proj(hidden_states) * self.scaling 251 | # get key, value proj 252 | if is_cross_attention and past_key_value is not None: 253 | # reuse k,v, cross_attentions 254 | key_states = past_key_value[0] 255 | value_states = past_key_value[1] 256 | elif is_cross_attention: 257 | # cross_attentions 258 | key_states = self._shape(self.k_proj(key_value_states), -1, bsz) 259 | value_states = self._shape(self.v_proj(key_value_states), -1, bsz) 260 | elif past_key_value is not None: 261 | # reuse k, v, self_attention 262 | key_states = self._shape(self.k_proj(hidden_states), -1, bsz) 263 | value_states = self._shape(self.v_proj(hidden_states), -1, bsz) 264 | key_states = torch.cat([past_key_value[0], key_states], dim=2) 265 | value_states = torch.cat([past_key_value[1], value_states], dim=2) 266 | else: 267 | # self_attention 268 | key_states = self._shape(self.k_proj(hidden_states), -1, bsz) 269 | value_states = self._shape(self.v_proj(hidden_states), -1, bsz) 270 | 271 | if self.is_decoder: 272 | # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. 273 | # Further calls to cross_attention layer can then reuse all cross-attention 274 | # key/value_states (first "if" case) 275 | # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of 276 | # all previous decoder key/value_states. Further calls to uni-directional self-attention 277 | # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) 278 | # if encoder bi-directional self-attention `past_key_value` is always `None` 279 | past_key_value = (key_states, value_states) 280 | 281 | proj_shape = (bsz * self.num_heads, -1, self.head_dim) 282 | query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) 283 | key_states = key_states.view(*proj_shape) 284 | value_states = value_states.view(*proj_shape) 285 | 286 | src_len = key_states.size(1) 287 | 288 | if self.quantize_act: 289 | query_states = self.act_quantizer_q.apply(query_states, self.clip_query, self.input_bits, True) 290 | key_states = self.act_quantizer_k.apply(key_states, self.clip_key, self.input_bits, True) 291 | value_states = self.act_quantizer_v.apply(value_states, self.clip_value, self.input_bits, True) 292 | 293 | attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) 294 | 295 | if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): 296 | raise ValueError( 297 | f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}" 298 | ) 299 | 300 | if attention_mask is not None: 301 | if attention_mask.size() != (bsz, 1, tgt_len, src_len): 302 | raise ValueError( 303 | f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" 304 | ) 305 | attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask 306 | attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) 307 | 308 | attn_weights = nn.functional.softmax(attn_weights, dim=-1) 309 | 310 | if layer_head_mask is not None: 311 | if layer_head_mask.size() != (self.num_heads,): 312 | raise ValueError( 313 | f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}" 314 | ) 315 | attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) 316 | attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) 317 | 318 | if output_attentions: 319 | # this operation is a bit awkward, but it's required to 320 | # make sure that attn_weights keeps its gradient. 321 | # In order to do so, attn_weights have to be reshaped 322 | # twice and have to be reused in the following 323 | attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) 324 | attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) 325 | else: 326 | attn_weights_reshaped = None 327 | 328 | attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) 329 | 330 | # quantize both attention probs and value states for dot product 331 | if self.quantize_act: 332 | attn_probs = self.act_quantizer_attn.apply(attn_probs, self.clip_attn, self.input_bits, True) 333 | 334 | attn_output = torch.bmm(attn_probs, value_states) 335 | 336 | if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): 337 | raise ValueError( 338 | f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}" 339 | ) 340 | 341 | attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) 342 | attn_output = attn_output.transpose(1, 2) 343 | 344 | # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be 345 | # partitioned aross GPUs when using tensor-parallelism. 346 | attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) 347 | 348 | attn_output = self.out_proj(attn_output) 349 | 350 | return attn_output, attn_weights_reshaped, past_key_value 351 | 352 | 353 | class BartEncoderLayer(nn.Module): 354 | def __init__(self, config: BartConfig): 355 | super().__init__() 356 | self.embed_dim = config.d_model 357 | self.self_attn = BartAttention( 358 | embed_dim=self.embed_dim, 359 | num_heads=config.encoder_attention_heads, 360 | config = config, 361 | dropout=config.attention_dropout, 362 | ) 363 | self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) 364 | self.dropout = config.dropout 365 | self.activation_fn = ACT2FN[config.activation_function] 366 | self.activation_dropout = config.activation_dropout 367 | # self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim) 368 | # self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) 369 | self.fc1 = QuantizeLinear(self.embed_dim, config.encoder_ffn_dim, symmetric=config.sym_quant_qkvo, config=config) 370 | self.fc2 = QuantizeLinear(config.encoder_ffn_dim, self.embed_dim, symmetric=config.sym_quant_ffn_attn, config=config) 371 | self.final_layer_norm = nn.LayerNorm(self.embed_dim) 372 | 373 | def forward( 374 | self, 375 | hidden_states: torch.Tensor, 376 | attention_mask: torch.Tensor, 377 | layer_head_mask: torch.Tensor, 378 | output_attentions: bool = False, 379 | ): 380 | """ 381 | Args: 382 | hidden_states (`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)` 383 | attention_mask (`torch.FloatTensor`): attention mask of size 384 | `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. 385 | layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size 386 | `(encoder_attention_heads,)`. 387 | output_attentions (`bool`, *optional*): 388 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under 389 | returned tensors for more detail. 390 | """ 391 | residual = hidden_states 392 | hidden_states, attn_weights, _ = self.self_attn( 393 | hidden_states=hidden_states, 394 | attention_mask=attention_mask, 395 | layer_head_mask=layer_head_mask, 396 | output_attentions=output_attentions, 397 | ) 398 | hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) 399 | hidden_states = residual + hidden_states 400 | hidden_states = self.self_attn_layer_norm(hidden_states) 401 | 402 | residual = hidden_states 403 | hidden_states = self.activation_fn(self.fc1(hidden_states)) 404 | hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) 405 | hidden_states = self.fc2(hidden_states) 406 | hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) 407 | hidden_states = residual + hidden_states 408 | hidden_states = self.final_layer_norm(hidden_states) 409 | 410 | if hidden_states.dtype == torch.float16 and ( 411 | torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any() 412 | ): 413 | clamp_value = torch.finfo(hidden_states.dtype).max - 1000 414 | hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) 415 | 416 | outputs = (hidden_states,) 417 | 418 | if output_attentions: 419 | outputs += (attn_weights,) 420 | 421 | return outputs 422 | 423 | 424 | class BartDecoderLayer(nn.Module): 425 | def __init__(self, config: BartConfig): 426 | super().__init__() 427 | self.embed_dim = config.d_model 428 | 429 | self.self_attn = BartAttention( 430 | embed_dim=self.embed_dim, 431 | num_heads=config.decoder_attention_heads, 432 | config = config, 433 | dropout=config.attention_dropout, 434 | is_decoder=True, 435 | ) 436 | self.dropout = config.dropout 437 | self.activation_fn = ACT2FN[config.activation_function] 438 | self.activation_dropout = config.activation_dropout 439 | 440 | self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) 441 | self.encoder_attn = BartAttention( 442 | self.embed_dim, 443 | config.decoder_attention_heads, 444 | config = config, 445 | dropout=config.attention_dropout, 446 | is_decoder=True, 447 | ) 448 | self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) 449 | # self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) 450 | # self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim) 451 | self.fc1 = QuantizeLinear(self.embed_dim, config.decoder_ffn_dim, symmetric=config.sym_quant_qkvo, config=config) 452 | self.fc2 = QuantizeLinear(config.decoder_ffn_dim, self.embed_dim, symmetric=config.sym_quant_ffn_attn, config=config) 453 | self.final_layer_norm = nn.LayerNorm(self.embed_dim) 454 | 455 | def forward( 456 | self, 457 | hidden_states: torch.Tensor, 458 | attention_mask: Optional[torch.Tensor] = None, 459 | encoder_hidden_states: Optional[torch.Tensor] = None, 460 | encoder_attention_mask: Optional[torch.Tensor] = None, 461 | layer_head_mask: Optional[torch.Tensor] = None, 462 | cross_attn_layer_head_mask: Optional[torch.Tensor] = None, 463 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 464 | output_attentions: Optional[bool] = False, 465 | use_cache: Optional[bool] = True, 466 | ): 467 | """ 468 | Args: 469 | hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` 470 | attention_mask (`torch.FloatTensor`): attention mask of size 471 | `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. 472 | encoder_hidden_states (`torch.FloatTensor`): 473 | cross attention input to the layer of shape `(batch, seq_len, embed_dim)` 474 | encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size 475 | `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. 476 | layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size 477 | `(encoder_attention_heads,)`. 478 | cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of 479 | size `(decoder_attention_heads,)`. 480 | past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states 481 | output_attentions (`bool`, *optional*): 482 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under 483 | returned tensors for more detail. 484 | """ 485 | residual = hidden_states 486 | 487 | # Self Attention 488 | # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 489 | self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None 490 | # add present self-attn cache to positions 1,2 of present_key_value tuple 491 | hidden_states, self_attn_weights, present_key_value = self.self_attn( 492 | hidden_states=hidden_states, 493 | past_key_value=self_attn_past_key_value, 494 | attention_mask=attention_mask, 495 | layer_head_mask=layer_head_mask, 496 | output_attentions=output_attentions, 497 | ) 498 | hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) 499 | hidden_states = residual + hidden_states 500 | hidden_states = self.self_attn_layer_norm(hidden_states) 501 | 502 | # Cross-Attention Block 503 | cross_attn_present_key_value = None 504 | cross_attn_weights = None 505 | if encoder_hidden_states is not None: 506 | residual = hidden_states 507 | 508 | # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple 509 | cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None 510 | hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( 511 | hidden_states=hidden_states, 512 | key_value_states=encoder_hidden_states, 513 | attention_mask=encoder_attention_mask, 514 | layer_head_mask=cross_attn_layer_head_mask, 515 | past_key_value=cross_attn_past_key_value, 516 | output_attentions=output_attentions, 517 | ) 518 | hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) 519 | hidden_states = residual + hidden_states 520 | hidden_states = self.encoder_attn_layer_norm(hidden_states) 521 | 522 | # add cross-attn to positions 3,4 of present_key_value tuple 523 | present_key_value = present_key_value + cross_attn_present_key_value 524 | 525 | # Fully Connected 526 | residual = hidden_states 527 | hidden_states = self.activation_fn(self.fc1(hidden_states)) 528 | hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) 529 | hidden_states = self.fc2(hidden_states) 530 | hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) 531 | hidden_states = residual + hidden_states 532 | hidden_states = self.final_layer_norm(hidden_states) 533 | 534 | outputs = (hidden_states,) 535 | 536 | if output_attentions: 537 | outputs += (self_attn_weights, cross_attn_weights) 538 | 539 | if use_cache: 540 | outputs += (present_key_value,) 541 | 542 | return outputs 543 | 544 | 545 | class BartClassificationHead(nn.Module): 546 | """Head for sentence-level classification tasks.""" 547 | 548 | def __init__( 549 | self, 550 | input_dim: int, 551 | inner_dim: int, 552 | num_classes: int, 553 | pooler_dropout: float, 554 | ): 555 | super().__init__() 556 | self.dense = nn.Linear(input_dim, inner_dim) 557 | self.dropout = nn.Dropout(p=pooler_dropout) 558 | self.out_proj = nn.Linear(inner_dim, num_classes) 559 | 560 | def forward(self, hidden_states: torch.Tensor): 561 | hidden_states = self.dropout(hidden_states) 562 | hidden_states = self.dense(hidden_states) 563 | hidden_states = torch.tanh(hidden_states) 564 | hidden_states = self.dropout(hidden_states) 565 | hidden_states = self.out_proj(hidden_states) 566 | return hidden_states 567 | 568 | 569 | class BartPretrainedModel(PreTrainedModel): 570 | config_class = BartConfig 571 | base_model_prefix = "model" 572 | supports_gradient_checkpointing = True 573 | _keys_to_ignore_on_load_unexpected = [r"encoder\.version", r"decoder\.version"] 574 | 575 | def _init_weights(self, module): 576 | std = self.config.init_std 577 | if isinstance(module, nn.Linear) or isinstance(module, QuantizeLinear): 578 | module.weight.data.normal_(mean=0.0, std=std) 579 | if module.bias is not None: 580 | module.bias.data.zero_() 581 | elif isinstance(module, nn.Embedding) or isinstance(module, QuantizeEmbedding): 582 | module.weight.data.normal_(mean=0.0, std=std) 583 | if module.padding_idx is not None: 584 | module.weight.data[module.padding_idx].zero_() 585 | 586 | def _set_gradient_checkpointing(self, module, value=False): 587 | if isinstance(module, (BartDecoder, BartEncoder)): 588 | module.gradient_checkpointing = value 589 | 590 | @property 591 | def dummy_inputs(self): 592 | pad_token = self.config.pad_token_id 593 | input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device) 594 | dummy_inputs = { 595 | "attention_mask": input_ids.ne(pad_token), 596 | "input_ids": input_ids, 597 | } 598 | return dummy_inputs 599 | 600 | 601 | class PretrainedBartModel(BartPretrainedModel): 602 | def __init_subclass__(self): 603 | warnings.warn( 604 | "The class `PretrainedBartModel` has been depreciated, please use `BartPretrainedModel` instead.", 605 | FutureWarning, 606 | ) 607 | 608 | 609 | BART_START_DOCSTRING = r""" 610 | This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the 611 | library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads 612 | etc.) 613 | 614 | This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. 615 | Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage 616 | and behavior. 617 | 618 | Parameters: 619 | config ([`BartConfig`]): 620 | Model configuration class with all the parameters of the model. Initializing with a config file does not 621 | load the weights associated with the model, only the configuration. Check out the 622 | [`~PreTrainedModel.from_pretrained`] method to load the model weights. 623 | """ 624 | 625 | BART_GENERATION_EXAMPLE = r""" 626 | Summarization example: 627 | 628 | ```python 629 | >>> from transformers import BartTokenizer, BartForConditionalGeneration 630 | 631 | >>> model = BartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn") 632 | >>> tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn") 633 | 634 | >>> ARTICLE_TO_SUMMARIZE = ( 635 | ... "PG&E stated it scheduled the blackouts in response to forecasts for high winds " 636 | ... "amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were " 637 | ... "scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow." 638 | ... ) 639 | >>> inputs = tokenizer([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors="pt") 640 | 641 | >>> # Generate Summary 642 | >>> summary_ids = model.generate(inputs["input_ids"], num_beams=2, max_length=20) 643 | >>> tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] 644 | 'PG&E scheduled the blackouts in response to forecasts for high winds amid dry conditions' 645 | ``` 646 | 647 | Mask filling example: 648 | 649 | ```python 650 | >>> from transformers import BartTokenizer, BartForConditionalGeneration 651 | 652 | >>> tokenizer = BartTokenizer.from_pretrained("facebook/bart-base") 653 | >>> model = BartForConditionalGeneration.from_pretrained("facebook/bart-base") 654 | 655 | >>> TXT = "My friends are but they eat too many carbs." 656 | >>> input_ids = tokenizer([TXT], return_tensors="pt")["input_ids"] 657 | >>> logits = model(input_ids).logits 658 | 659 | >>> masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item() 660 | >>> probs = logits[0, masked_index].softmax(dim=0) 661 | >>> values, predictions = probs.topk(5) 662 | 663 | >>> tokenizer.decode(predictions).split() 664 | ['not', 'good', 'healthy', 'great', 'very'] 665 | ``` 666 | """ 667 | 668 | BART_INPUTS_DOCSTRING = r""" 669 | Args: 670 | input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): 671 | Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide 672 | it. 673 | 674 | Indices can be obtained using [`BartTokenizer`]. See [`PreTrainedTokenizer.encode`] and 675 | [`PreTrainedTokenizer.__call__`] for details. 676 | 677 | [What are input IDs?](../glossary#input-ids) 678 | attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): 679 | Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: 680 | 681 | - 1 for tokens that are **not masked**, 682 | - 0 for tokens that are **masked**. 683 | 684 | [What are attention masks?](../glossary#attention-mask) 685 | decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): 686 | Indices of decoder input sequence tokens in the vocabulary. 687 | 688 | Indices can be obtained using [`BartTokenizer`]. See [`PreTrainedTokenizer.encode`] and 689 | [`PreTrainedTokenizer.__call__`] for details. 690 | 691 | [What are decoder input IDs?](../glossary#decoder-input-ids) 692 | 693 | Bart uses the `eos_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values` 694 | is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`). 695 | 696 | For translation and summarization training, `decoder_input_ids` should be provided. If no 697 | `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right 698 | for denoising pre-training following the paper. 699 | decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): 700 | Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also 701 | be used by default. 702 | 703 | If you want to change padding behavior, you should read [`modeling_bart._prepare_decoder_inputs`] and 704 | modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more information 705 | on the default strategy. 706 | head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): 707 | Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`: 708 | 709 | - 1 indicates the head is **not masked**, 710 | - 0 indicates the head is **masked**. 711 | 712 | decoder_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): 713 | Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in `[0, 1]`: 714 | 715 | - 1 indicates the head is **not masked**, 716 | - 0 indicates the head is **masked**. 717 | 718 | cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): 719 | Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in `[0, 720 | 1]`: 721 | 722 | - 1 indicates the head is **not masked**, 723 | - 0 indicates the head is **masked**. 724 | 725 | encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*): 726 | Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`) 727 | `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of 728 | hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. 729 | past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): 730 | Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape 731 | `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape 732 | `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. 733 | 734 | Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention 735 | blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. 736 | 737 | If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that 738 | don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all 739 | `decoder_input_ids` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of shape 740 | `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing `input_ids` you 741 | can choose to directly pass an embedded representation. This is useful if you want more control over how to 742 | convert `input_ids` indices into associated vectors than the model's internal embedding lookup matrix. 743 | decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*): 744 | Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded 745 | representation. If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be 746 | input (see `past_key_values`). This is useful if you want more control over how to convert 747 | `decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix. 748 | 749 | If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the value 750 | of `inputs_embeds`. 751 | use_cache (`bool`, *optional*): 752 | If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see 753 | `past_key_values`). 754 | output_attentions (`bool`, *optional*): 755 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned 756 | tensors for more detail. 757 | output_hidden_states (`bool`, *optional*): 758 | Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for 759 | more detail. 760 | return_dict (`bool`, *optional*): 761 | Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple. 762 | """ 763 | 764 | 765 | class BartEncoder(BartPretrainedModel): 766 | """ 767 | Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a 768 | [`BartEncoderLayer`]. 769 | 770 | Args: 771 | config: BartConfig 772 | embed_tokens (nn.Embedding): output embedding 773 | """ 774 | 775 | def __init__(self, config: BartConfig, embed_tokens: Optional[nn.Embedding] = None): 776 | super().__init__(config) 777 | 778 | self.dropout = config.dropout 779 | self.layerdrop = config.encoder_layerdrop 780 | 781 | embed_dim = config.d_model 782 | self.padding_idx = config.pad_token_id 783 | self.max_source_positions = config.max_position_embeddings 784 | self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 785 | 786 | if embed_tokens is not None: 787 | self.embed_tokens = embed_tokens 788 | else: 789 | # self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx) 790 | self.embed_tokens = QuantizeEmbedding(config.vocab_size, embed_dim, padding_idx=self.padding_idx, config=config) 791 | 792 | self.embed_positions = BartLearnedPositionalEmbedding( 793 | config.max_position_embeddings, 794 | embed_dim, 795 | ) 796 | self.layers = nn.ModuleList([BartEncoderLayer(config) for _ in range(config.encoder_layers)]) 797 | self.layernorm_embedding = nn.LayerNorm(embed_dim) 798 | 799 | self.gradient_checkpointing = False 800 | # Initialize weights and apply final processing 801 | self.post_init() 802 | 803 | def get_input_embeddings(self): 804 | return self.embed_tokens 805 | 806 | def set_input_embeddings(self, value): 807 | self.embed_tokens = value 808 | 809 | def forward( 810 | self, 811 | input_ids=None, 812 | attention_mask=None, 813 | head_mask=None, 814 | inputs_embeds=None, 815 | output_attentions=None, 816 | output_hidden_states=None, 817 | return_dict=None, 818 | ): 819 | r""" 820 | Args: 821 | input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): 822 | Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you 823 | provide it. 824 | 825 | Indices can be obtained using [`BartTokenizer`]. See [`PreTrainedTokenizer.encode`] and 826 | [`PreTrainedTokenizer.__call__`] for details. 827 | 828 | [What are input IDs?](../glossary#input-ids) 829 | attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): 830 | Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: 831 | 832 | - 1 for tokens that are **not masked**, 833 | - 0 for tokens that are **masked**. 834 | 835 | [What are attention masks?](../glossary#attention-mask) 836 | head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): 837 | Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: 838 | 839 | - 1 indicates the head is **not masked**, 840 | - 0 indicates the head is **masked**. 841 | 842 | inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): 843 | Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. 844 | This is useful if you want more control over how to convert `input_ids` indices into associated vectors 845 | than the model's internal embedding lookup matrix. 846 | output_attentions (`bool`, *optional*): 847 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under 848 | returned tensors for more detail. 849 | output_hidden_states (`bool`, *optional*): 850 | Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors 851 | for more detail. 852 | return_dict (`bool`, *optional*): 853 | Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple. 854 | """ 855 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 856 | output_hidden_states = ( 857 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 858 | ) 859 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 860 | 861 | # retrieve input_ids and inputs_embeds 862 | if input_ids is not None and inputs_embeds is not None: 863 | raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") 864 | elif input_ids is not None: 865 | input_shape = input_ids.size() 866 | input_ids = input_ids.view(-1, input_shape[-1]) 867 | elif inputs_embeds is not None: 868 | input_shape = inputs_embeds.size()[:-1] 869 | else: 870 | raise ValueError("You have to specify either input_ids or inputs_embeds") 871 | 872 | if inputs_embeds is None: 873 | inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale 874 | 875 | embed_pos = self.embed_positions(input_shape) 876 | 877 | hidden_states = inputs_embeds + embed_pos 878 | hidden_states = self.layernorm_embedding(hidden_states) 879 | hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) 880 | 881 | # expand attention_mask 882 | if attention_mask is not None: 883 | # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] 884 | attention_mask = _expand_mask(attention_mask, inputs_embeds.dtype) 885 | 886 | encoder_states = () if output_hidden_states else None 887 | all_attentions = () if output_attentions else None 888 | 889 | # check if head_mask has a correct number of layers specified if desired 890 | if head_mask is not None: 891 | if head_mask.size()[0] != (len(self.layers)): 892 | raise ValueError( 893 | f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." 894 | ) 895 | 896 | for idx, encoder_layer in enumerate(self.layers): 897 | if output_hidden_states: 898 | encoder_states = encoder_states + (hidden_states,) 899 | # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) 900 | dropout_probability = random.uniform(0, 1) 901 | if self.training and (dropout_probability < self.layerdrop): # skip the layer 902 | layer_outputs = (None, None) 903 | else: 904 | if self.gradient_checkpointing and self.training: 905 | 906 | def create_custom_forward(module): 907 | def custom_forward(*inputs): 908 | return module(*inputs, output_attentions) 909 | 910 | return custom_forward 911 | 912 | layer_outputs = torch.utils.checkpoint.checkpoint( 913 | create_custom_forward(encoder_layer), 914 | hidden_states, 915 | attention_mask, 916 | (head_mask[idx] if head_mask is not None else None), 917 | ) 918 | else: 919 | layer_outputs = encoder_layer( 920 | hidden_states, 921 | attention_mask, 922 | layer_head_mask=(head_mask[idx] if head_mask is not None else None), 923 | output_attentions=output_attentions, 924 | ) 925 | 926 | hidden_states = layer_outputs[0] 927 | 928 | if output_attentions: 929 | all_attentions = all_attentions + (layer_outputs[1],) 930 | 931 | if output_hidden_states: 932 | encoder_states = encoder_states + (hidden_states,) 933 | 934 | if not return_dict: 935 | return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) 936 | return BaseModelOutput( 937 | last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions 938 | ) 939 | 940 | 941 | class BartDecoder(BartPretrainedModel): 942 | """ 943 | Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`BartDecoderLayer`] 944 | 945 | Args: 946 | config: BartConfig 947 | embed_tokens (nn.Embedding): output embedding 948 | """ 949 | 950 | def __init__(self, config: BartConfig, embed_tokens: Optional[nn.Embedding] = None): 951 | super().__init__(config) 952 | self.dropout = config.dropout 953 | self.layerdrop = config.decoder_layerdrop 954 | self.padding_idx = config.pad_token_id 955 | self.max_target_positions = config.max_position_embeddings 956 | self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 957 | 958 | if embed_tokens is not None: 959 | self.embed_tokens = embed_tokens 960 | else: 961 | # self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx) 962 | self.embed_tokens = QuantizeEmbedding(config.vocab_size, embed_dim, padding_idx=self.padding_idx, config=config) 963 | 964 | self.embed_positions = BartLearnedPositionalEmbedding( 965 | config.max_position_embeddings, 966 | config.d_model, 967 | ) 968 | self.layers = nn.ModuleList([BartDecoderLayer(config) for _ in range(config.decoder_layers)]) 969 | self.layernorm_embedding = nn.LayerNorm(config.d_model) 970 | 971 | self.gradient_checkpointing = False 972 | # Initialize weights and apply final processing 973 | self.post_init() 974 | 975 | def get_input_embeddings(self): 976 | return self.embed_tokens 977 | 978 | def set_input_embeddings(self, value): 979 | self.embed_tokens = value 980 | 981 | def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): 982 | # create causal mask 983 | # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] 984 | combined_attention_mask = None 985 | if input_shape[-1] > 1: 986 | combined_attention_mask = _make_causal_mask( 987 | input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length 988 | ).to(self.device) 989 | 990 | if attention_mask is not None: 991 | # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] 992 | expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) 993 | combined_attention_mask = ( 994 | expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask 995 | ) 996 | 997 | return combined_attention_mask 998 | 999 | def forward( 1000 | self, 1001 | input_ids=None, 1002 | attention_mask=None, 1003 | encoder_hidden_states=None, 1004 | encoder_attention_mask=None, 1005 | head_mask=None, 1006 | cross_attn_head_mask=None, 1007 | past_key_values=None, 1008 | inputs_embeds=None, 1009 | use_cache=None, 1010 | output_attentions=None, 1011 | output_hidden_states=None, 1012 | return_dict=None, 1013 | ): 1014 | r""" 1015 | Args: 1016 | input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): 1017 | Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you 1018 | provide it. 1019 | 1020 | Indices can be obtained using [`BartTokenizer`]. See [`PreTrainedTokenizer.encode`] and 1021 | [`PreTrainedTokenizer.__call__`] for details. 1022 | 1023 | [What are input IDs?](../glossary#input-ids) 1024 | attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): 1025 | Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: 1026 | 1027 | - 1 for tokens that are **not masked**, 1028 | - 0 for tokens that are **masked**. 1029 | 1030 | [What are attention masks?](../glossary#attention-mask) 1031 | encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*): 1032 | Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention 1033 | of the decoder. 1034 | encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*): 1035 | Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values 1036 | selected in `[0, 1]`: 1037 | 1038 | - 1 for tokens that are **not masked**, 1039 | - 0 for tokens that are **masked**. 1040 | 1041 | [What are attention masks?](../glossary#attention-mask) 1042 | head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): 1043 | Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: 1044 | 1045 | - 1 indicates the head is **not masked**, 1046 | - 0 indicates the head is **masked**. 1047 | 1048 | cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): 1049 | Mask to nullify selected heads of the cross-attention modules in the decoder to avoid performing 1050 | cross-attention on hidden heads. Mask values selected in `[0, 1]`: 1051 | 1052 | - 1 indicates the head is **not masked**, 1053 | - 0 indicates the head is **masked**. 1054 | 1055 | past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): 1056 | Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of 1057 | shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of 1058 | shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. 1059 | 1060 | Contains pre-computed hidden-states (key and values in the self-attention blocks and in the 1061 | cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. 1062 | 1063 | If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those 1064 | that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of 1065 | all `decoder_input_ids` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of 1066 | shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing 1067 | `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more 1068 | control over how to convert `input_ids` indices into associated vectors than the model's internal 1069 | embedding lookup matrix. 1070 | output_attentions (`bool`, *optional*): 1071 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under 1072 | returned tensors for more detail. 1073 | output_hidden_states (`bool`, *optional*): 1074 | Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors 1075 | for more detail. 1076 | return_dict (`bool`, *optional*): 1077 | Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple. 1078 | """ 1079 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 1080 | output_hidden_states = ( 1081 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 1082 | ) 1083 | use_cache = use_cache if use_cache is not None else self.config.use_cache 1084 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1085 | 1086 | # retrieve input_ids and inputs_embeds 1087 | if input_ids is not None and inputs_embeds is not None: 1088 | raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") 1089 | elif input_ids is not None: 1090 | input_shape = input_ids.size() 1091 | input_ids = input_ids.view(-1, input_shape[-1]) 1092 | elif inputs_embeds is not None: 1093 | input_shape = inputs_embeds.size()[:-1] 1094 | else: 1095 | raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") 1096 | 1097 | # past_key_values_length 1098 | past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 1099 | 1100 | if inputs_embeds is None: 1101 | inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale 1102 | 1103 | attention_mask = self._prepare_decoder_attention_mask( 1104 | attention_mask, input_shape, inputs_embeds, past_key_values_length 1105 | ) 1106 | 1107 | # expand encoder attention mask 1108 | if encoder_hidden_states is not None and encoder_attention_mask is not None: 1109 | # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] 1110 | encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) 1111 | 1112 | # embed positions 1113 | positions = self.embed_positions(input_shape, past_key_values_length) 1114 | 1115 | hidden_states = inputs_embeds + positions 1116 | hidden_states = self.layernorm_embedding(hidden_states) 1117 | 1118 | hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) 1119 | 1120 | # decoder layers 1121 | all_hidden_states = () if output_hidden_states else None 1122 | all_self_attns = () if output_attentions else None 1123 | all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None 1124 | next_decoder_cache = () if use_cache else None 1125 | 1126 | # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired 1127 | for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): 1128 | if attn_mask is not None: 1129 | if attn_mask.size()[0] != (len(self.layers)): 1130 | raise ValueError( 1131 | "The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." 1132 | ) 1133 | 1134 | for idx, decoder_layer in enumerate(self.layers): 1135 | # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) 1136 | if output_hidden_states: 1137 | all_hidden_states += (hidden_states,) 1138 | dropout_probability = random.uniform(0, 1) 1139 | if self.training and (dropout_probability < self.layerdrop): 1140 | continue 1141 | 1142 | past_key_value = past_key_values[idx] if past_key_values is not None else None 1143 | 1144 | if self.gradient_checkpointing and self.training: 1145 | 1146 | if use_cache: 1147 | logger.warning( 1148 | "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." 1149 | ) 1150 | use_cache = False 1151 | 1152 | def create_custom_forward(module): 1153 | def custom_forward(*inputs): 1154 | # None for past_key_value 1155 | return module(*inputs, output_attentions, use_cache) 1156 | 1157 | return custom_forward 1158 | 1159 | layer_outputs = torch.utils.checkpoint.checkpoint( 1160 | create_custom_forward(decoder_layer), 1161 | hidden_states, 1162 | attention_mask, 1163 | encoder_hidden_states, 1164 | encoder_attention_mask, 1165 | head_mask[idx] if head_mask is not None else None, 1166 | cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, 1167 | None, 1168 | ) 1169 | else: 1170 | 1171 | layer_outputs = decoder_layer( 1172 | hidden_states, 1173 | attention_mask=attention_mask, 1174 | encoder_hidden_states=encoder_hidden_states, 1175 | encoder_attention_mask=encoder_attention_mask, 1176 | layer_head_mask=(head_mask[idx] if head_mask is not None else None), 1177 | cross_attn_layer_head_mask=( 1178 | cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None 1179 | ), 1180 | past_key_value=past_key_value, 1181 | output_attentions=output_attentions, 1182 | use_cache=use_cache, 1183 | ) 1184 | hidden_states = layer_outputs[0] 1185 | 1186 | if use_cache: 1187 | next_decoder_cache += (layer_outputs[3 if output_attentions else 1],) 1188 | 1189 | if output_attentions: 1190 | all_self_attns += (layer_outputs[1],) 1191 | 1192 | if encoder_hidden_states is not None: 1193 | all_cross_attentions += (layer_outputs[2],) 1194 | 1195 | # add hidden states from the last decoder layer 1196 | if output_hidden_states: 1197 | all_hidden_states += (hidden_states,) 1198 | 1199 | next_cache = next_decoder_cache if use_cache else None 1200 | if not return_dict: 1201 | return tuple( 1202 | v 1203 | for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] 1204 | if v is not None 1205 | ) 1206 | return BaseModelOutputWithPastAndCrossAttentions( 1207 | last_hidden_state=hidden_states, 1208 | past_key_values=next_cache, 1209 | hidden_states=all_hidden_states, 1210 | attentions=all_self_attns, 1211 | cross_attentions=all_cross_attentions, 1212 | ) 1213 | 1214 | 1215 | @add_start_docstrings( 1216 | "The bare BART Model outputting raw hidden-states without any specific head on top.", 1217 | BART_START_DOCSTRING, 1218 | ) 1219 | class BartModel(BartPretrainedModel): 1220 | def __init__(self, config: BartConfig): 1221 | super().__init__(config) 1222 | 1223 | padding_idx, vocab_size = config.pad_token_id, config.vocab_size 1224 | # self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx) 1225 | self.shared = QuantizeEmbedding(vocab_size, config.d_model, padding_idx=padding_idx, config=config) 1226 | 1227 | self.encoder = BartEncoder(config, self.shared) 1228 | self.decoder = BartDecoder(config, self.shared) 1229 | 1230 | # Initialize weights and apply final processing 1231 | self.post_init() 1232 | 1233 | def get_input_embeddings(self): 1234 | return self.shared 1235 | 1236 | def set_input_embeddings(self, value): 1237 | self.shared = value 1238 | self.encoder.embed_tokens = self.shared 1239 | self.decoder.embed_tokens = self.shared 1240 | 1241 | def get_encoder(self): 1242 | return self.encoder 1243 | 1244 | def get_decoder(self): 1245 | return self.decoder 1246 | 1247 | @add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING) 1248 | @add_code_sample_docstrings( 1249 | processor_class=_TOKENIZER_FOR_DOC, 1250 | checkpoint=_CHECKPOINT_FOR_DOC, 1251 | output_type=Seq2SeqModelOutput, 1252 | config_class=_CONFIG_FOR_DOC, 1253 | expected_output=_EXPECTED_OUTPUT_SHAPE, 1254 | ) 1255 | def forward( 1256 | self, 1257 | input_ids=None, 1258 | attention_mask=None, 1259 | decoder_input_ids=None, 1260 | decoder_attention_mask=None, 1261 | head_mask=None, 1262 | decoder_head_mask=None, 1263 | cross_attn_head_mask=None, 1264 | encoder_outputs=None, 1265 | past_key_values=None, 1266 | inputs_embeds=None, 1267 | decoder_inputs_embeds=None, 1268 | use_cache=None, 1269 | output_attentions=None, 1270 | output_hidden_states=None, 1271 | return_dict=None, 1272 | ): 1273 | 1274 | # different to other models, Bart automatically creates decoder_input_ids from 1275 | # input_ids if no decoder_input_ids are provided 1276 | if decoder_input_ids is None and decoder_inputs_embeds is None: 1277 | if input_ids is None: 1278 | raise ValueError( 1279 | "If no `decoder_input_ids` or `decoder_inputs_embeds` are " 1280 | "passed, `input_ids` cannot be `None`. Please pass either " 1281 | "`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`." 1282 | ) 1283 | 1284 | decoder_input_ids = shift_tokens_right( 1285 | input_ids, self.config.pad_token_id, self.config.decoder_start_token_id 1286 | ) 1287 | 1288 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 1289 | output_hidden_states = ( 1290 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 1291 | ) 1292 | use_cache = use_cache if use_cache is not None else self.config.use_cache 1293 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1294 | 1295 | if encoder_outputs is None: 1296 | encoder_outputs = self.encoder( 1297 | input_ids=input_ids, 1298 | attention_mask=attention_mask, 1299 | head_mask=head_mask, 1300 | inputs_embeds=inputs_embeds, 1301 | output_attentions=output_attentions, 1302 | output_hidden_states=output_hidden_states, 1303 | return_dict=return_dict, 1304 | ) 1305 | # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True 1306 | elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): 1307 | encoder_outputs = BaseModelOutput( 1308 | last_hidden_state=encoder_outputs[0], 1309 | hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, 1310 | attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, 1311 | ) 1312 | 1313 | # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) 1314 | decoder_outputs = self.decoder( 1315 | input_ids=decoder_input_ids, 1316 | attention_mask=decoder_attention_mask, 1317 | encoder_hidden_states=encoder_outputs[0], 1318 | encoder_attention_mask=attention_mask, 1319 | head_mask=decoder_head_mask, 1320 | cross_attn_head_mask=cross_attn_head_mask, 1321 | past_key_values=past_key_values, 1322 | inputs_embeds=decoder_inputs_embeds, 1323 | use_cache=use_cache, 1324 | output_attentions=output_attentions, 1325 | output_hidden_states=output_hidden_states, 1326 | return_dict=return_dict, 1327 | ) 1328 | 1329 | if not return_dict: 1330 | return decoder_outputs + encoder_outputs 1331 | 1332 | return Seq2SeqModelOutput( 1333 | last_hidden_state=decoder_outputs.last_hidden_state, 1334 | past_key_values=decoder_outputs.past_key_values, 1335 | decoder_hidden_states=decoder_outputs.hidden_states, 1336 | decoder_attentions=decoder_outputs.attentions, 1337 | cross_attentions=decoder_outputs.cross_attentions, 1338 | encoder_last_hidden_state=encoder_outputs.last_hidden_state, 1339 | encoder_hidden_states=encoder_outputs.hidden_states, 1340 | encoder_attentions=encoder_outputs.attentions, 1341 | ) 1342 | 1343 | 1344 | @add_start_docstrings( 1345 | "The BART Model with a language modeling head. Can be used for summarization.", BART_START_DOCSTRING 1346 | ) 1347 | class BartForConditionalGeneration(BartPretrainedModel): 1348 | base_model_prefix = "model" 1349 | _keys_to_ignore_on_load_missing = [r"final_logits_bias", r"lm_head\.weight"] 1350 | 1351 | def __init__(self, config: BartConfig): 1352 | super().__init__(config) 1353 | self.model = BartModel(config) 1354 | self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings))) 1355 | self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False) 1356 | 1357 | # Initialize weights and apply final processing 1358 | self.post_init() 1359 | 1360 | def get_encoder(self): 1361 | return self.model.get_encoder() 1362 | 1363 | def get_decoder(self): 1364 | return self.model.get_decoder() 1365 | 1366 | def resize_token_embeddings(self, new_num_tokens: int) -> QuantizeEmbedding: #nn.Embedding: 1367 | new_embeddings = super().resize_token_embeddings(new_num_tokens) 1368 | self._resize_final_logits_bias(new_num_tokens) 1369 | return new_embeddings 1370 | 1371 | def _resize_final_logits_bias(self, new_num_tokens: int) -> None: 1372 | old_num_tokens = self.final_logits_bias.shape[-1] 1373 | if new_num_tokens <= old_num_tokens: 1374 | new_bias = self.final_logits_bias[:, :new_num_tokens] 1375 | else: 1376 | extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device) 1377 | new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1) 1378 | self.register_buffer("final_logits_bias", new_bias) 1379 | 1380 | def get_output_embeddings(self): 1381 | return self.lm_head 1382 | 1383 | def set_output_embeddings(self, new_embeddings): 1384 | self.lm_head = new_embeddings 1385 | 1386 | @add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING) 1387 | @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) 1388 | @add_end_docstrings(BART_GENERATION_EXAMPLE) 1389 | def forward( 1390 | self, 1391 | input_ids=None, 1392 | attention_mask=None, 1393 | decoder_input_ids=None, 1394 | decoder_attention_mask=None, 1395 | head_mask=None, 1396 | decoder_head_mask=None, 1397 | cross_attn_head_mask=None, 1398 | encoder_outputs=None, 1399 | past_key_values=None, 1400 | inputs_embeds=None, 1401 | decoder_inputs_embeds=None, 1402 | labels=None, 1403 | use_cache=None, 1404 | output_attentions=None, 1405 | output_hidden_states=None, 1406 | return_dict=None, 1407 | ): 1408 | r""" 1409 | labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): 1410 | Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., 1411 | config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored 1412 | (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. 1413 | 1414 | Returns: 1415 | """ 1416 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1417 | 1418 | if labels is not None: 1419 | if use_cache: 1420 | logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.") 1421 | use_cache = False 1422 | if decoder_input_ids is None and decoder_inputs_embeds is None: 1423 | decoder_input_ids = shift_tokens_right( 1424 | labels, self.config.pad_token_id, self.config.decoder_start_token_id 1425 | ) 1426 | 1427 | outputs = self.model( 1428 | input_ids, 1429 | attention_mask=attention_mask, 1430 | decoder_input_ids=decoder_input_ids, 1431 | encoder_outputs=encoder_outputs, 1432 | decoder_attention_mask=decoder_attention_mask, 1433 | head_mask=head_mask, 1434 | decoder_head_mask=decoder_head_mask, 1435 | cross_attn_head_mask=cross_attn_head_mask, 1436 | past_key_values=past_key_values, 1437 | inputs_embeds=inputs_embeds, 1438 | decoder_inputs_embeds=decoder_inputs_embeds, 1439 | use_cache=use_cache, 1440 | output_attentions=output_attentions, 1441 | output_hidden_states=output_hidden_states, 1442 | return_dict=return_dict, 1443 | ) 1444 | lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias 1445 | 1446 | masked_lm_loss = None 1447 | if labels is not None: 1448 | loss_fct = CrossEntropyLoss() 1449 | masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1)) 1450 | 1451 | if not return_dict: 1452 | output = (lm_logits,) + outputs[1:] 1453 | return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output 1454 | 1455 | return Seq2SeqLMOutput( 1456 | loss=masked_lm_loss, 1457 | logits=lm_logits, 1458 | past_key_values=outputs.past_key_values, 1459 | decoder_hidden_states=outputs.decoder_hidden_states, 1460 | decoder_attentions=outputs.decoder_attentions, 1461 | cross_attentions=outputs.cross_attentions, 1462 | encoder_last_hidden_state=outputs.encoder_last_hidden_state, 1463 | encoder_hidden_states=outputs.encoder_hidden_states, 1464 | encoder_attentions=outputs.encoder_attentions, 1465 | ) 1466 | 1467 | def prepare_inputs_for_generation( 1468 | self, 1469 | decoder_input_ids, 1470 | past=None, 1471 | attention_mask=None, 1472 | head_mask=None, 1473 | decoder_head_mask=None, 1474 | cross_attn_head_mask=None, 1475 | use_cache=None, 1476 | encoder_outputs=None, 1477 | **kwargs 1478 | ): 1479 | # cut decoder_input_ids if past is used 1480 | if past is not None: 1481 | decoder_input_ids = decoder_input_ids[:, -1:] 1482 | 1483 | return { 1484 | "input_ids": None, # encoder_outputs is defined. input_ids not needed 1485 | "encoder_outputs": encoder_outputs, 1486 | "past_key_values": past, 1487 | "decoder_input_ids": decoder_input_ids, 1488 | "attention_mask": attention_mask, 1489 | "head_mask": head_mask, 1490 | "decoder_head_mask": decoder_head_mask, 1491 | "cross_attn_head_mask": cross_attn_head_mask, 1492 | "use_cache": use_cache, # change this to avoid caching (presumably for debugging) 1493 | } 1494 | 1495 | def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): 1496 | return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) 1497 | 1498 | @staticmethod 1499 | def _reorder_cache(past, beam_idx): 1500 | reordered_past = () 1501 | for layer_past in past: 1502 | # cached cross_attention states don't have to be reordered -> they are always the same 1503 | reordered_past += ( 1504 | tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], 1505 | ) 1506 | return reordered_past 1507 | 1508 | 1509 | @add_start_docstrings( 1510 | """ 1511 | Bart model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g. for GLUE 1512 | tasks. 1513 | """, 1514 | BART_START_DOCSTRING, 1515 | ) 1516 | class BartForSequenceClassification(BartPretrainedModel): 1517 | def __init__(self, config: BartConfig, **kwargs): 1518 | super().__init__(config, **kwargs) 1519 | self.model = BartModel(config) 1520 | self.classification_head = BartClassificationHead( 1521 | config.d_model, 1522 | config.d_model, 1523 | config.num_labels, 1524 | config.classifier_dropout, 1525 | ) 1526 | self.model._init_weights(self.classification_head.dense) 1527 | self.model._init_weights(self.classification_head.out_proj) 1528 | 1529 | @add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING) 1530 | @add_code_sample_docstrings( 1531 | processor_class=_TOKENIZER_FOR_DOC, 1532 | checkpoint=_CHECKPOINT_FOR_DOC, 1533 | output_type=Seq2SeqSequenceClassifierOutput, 1534 | config_class=_CONFIG_FOR_DOC, 1535 | expected_output=_SEQ_CLASS_EXPECTED_OUTPUT_SHAPE, 1536 | ) 1537 | def forward( 1538 | self, 1539 | input_ids=None, 1540 | attention_mask=None, 1541 | decoder_input_ids=None, 1542 | decoder_attention_mask=None, 1543 | head_mask=None, 1544 | decoder_head_mask=None, 1545 | cross_attn_head_mask=None, 1546 | encoder_outputs=None, 1547 | inputs_embeds=None, 1548 | decoder_inputs_embeds=None, 1549 | labels=None, 1550 | use_cache=None, 1551 | output_attentions=None, 1552 | output_hidden_states=None, 1553 | return_dict=None, 1554 | ): 1555 | r""" 1556 | labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): 1557 | Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., 1558 | config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). 1559 | """ 1560 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1561 | if labels is not None: 1562 | use_cache = False 1563 | 1564 | if input_ids is None and inputs_embeds is not None: 1565 | raise NotImplementedError( 1566 | f"Passing input embeddings is currently not supported for {self.__class__.__name__}" 1567 | ) 1568 | 1569 | outputs = self.model( 1570 | input_ids, 1571 | attention_mask=attention_mask, 1572 | decoder_input_ids=decoder_input_ids, 1573 | decoder_attention_mask=decoder_attention_mask, 1574 | head_mask=head_mask, 1575 | decoder_head_mask=decoder_head_mask, 1576 | cross_attn_head_mask=cross_attn_head_mask, 1577 | encoder_outputs=encoder_outputs, 1578 | inputs_embeds=inputs_embeds, 1579 | decoder_inputs_embeds=decoder_inputs_embeds, 1580 | use_cache=use_cache, 1581 | output_attentions=output_attentions, 1582 | output_hidden_states=output_hidden_states, 1583 | return_dict=return_dict, 1584 | ) 1585 | hidden_states = outputs[0] # last hidden state 1586 | 1587 | eos_mask = input_ids.eq(self.config.eos_token_id) 1588 | 1589 | if len(torch.unique_consecutive(eos_mask.sum(1))) > 1: 1590 | raise ValueError("All examples must have the same number of tokens.") 1591 | sentence_representation = hidden_states[eos_mask, :].view(hidden_states.size(0), -1, hidden_states.size(-1))[ 1592 | :, -1, : 1593 | ] 1594 | logits = self.classification_head(sentence_representation) 1595 | 1596 | loss = None 1597 | if labels is not None: 1598 | if self.config.problem_type is None: 1599 | if self.config.num_labels == 1: 1600 | self.config.problem_type = "regression" 1601 | elif self.config.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): 1602 | self.config.problem_type = "single_label_classification" 1603 | else: 1604 | self.config.problem_type = "multi_label_classification" 1605 | 1606 | if self.config.problem_type == "regression": 1607 | loss_fct = MSELoss() 1608 | if self.config.num_labels == 1: 1609 | loss = loss_fct(logits.squeeze(), labels.squeeze()) 1610 | else: 1611 | loss = loss_fct(logits, labels) 1612 | elif self.config.problem_type == "single_label_classification": 1613 | loss_fct = CrossEntropyLoss() 1614 | loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1)) 1615 | elif self.config.problem_type == "multi_label_classification": 1616 | loss_fct = BCEWithLogitsLoss() 1617 | loss = loss_fct(logits, labels) 1618 | if not return_dict: 1619 | output = (logits,) + outputs[1:] 1620 | return ((loss,) + output) if loss is not None else output 1621 | 1622 | return Seq2SeqSequenceClassifierOutput( 1623 | loss=loss, 1624 | logits=logits, 1625 | past_key_values=outputs.past_key_values, 1626 | decoder_hidden_states=outputs.decoder_hidden_states, 1627 | decoder_attentions=outputs.decoder_attentions, 1628 | cross_attentions=outputs.cross_attentions, 1629 | encoder_last_hidden_state=outputs.encoder_last_hidden_state, 1630 | encoder_hidden_states=outputs.encoder_hidden_states, 1631 | encoder_attentions=outputs.encoder_attentions, 1632 | ) 1633 | 1634 | 1635 | @add_start_docstrings( 1636 | """ 1637 | BART Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear 1638 | layer on top of the hidden-states output to compute `span start logits` and `span end logits`). 1639 | """, 1640 | BART_START_DOCSTRING, 1641 | ) 1642 | class BartForQuestionAnswering(BartPretrainedModel): 1643 | def __init__(self, config): 1644 | super().__init__(config) 1645 | 1646 | config.num_labels = 2 1647 | self.num_labels = config.num_labels 1648 | 1649 | self.model = BartModel(config) 1650 | self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) 1651 | 1652 | self.model._init_weights(self.qa_outputs) 1653 | 1654 | @add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING) 1655 | @add_code_sample_docstrings( 1656 | processor_class=_TOKENIZER_FOR_DOC, 1657 | checkpoint=_CHECKPOINT_FOR_DOC, 1658 | output_type=Seq2SeqQuestionAnsweringModelOutput, 1659 | config_class=_CONFIG_FOR_DOC, 1660 | expected_loss=_QA_EXPECTED_LOSS, 1661 | expected_output=_QA_EXPECTED_OUTPUT_SHAPE, 1662 | ) 1663 | def forward( 1664 | self, 1665 | input_ids=None, 1666 | attention_mask=None, 1667 | decoder_input_ids=None, 1668 | decoder_attention_mask=None, 1669 | head_mask=None, 1670 | decoder_head_mask=None, 1671 | cross_attn_head_mask=None, 1672 | encoder_outputs=None, 1673 | start_positions=None, 1674 | end_positions=None, 1675 | inputs_embeds=None, 1676 | decoder_inputs_embeds=None, 1677 | use_cache=None, 1678 | output_attentions=None, 1679 | output_hidden_states=None, 1680 | return_dict=None, 1681 | ): 1682 | r""" 1683 | start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): 1684 | Labels for position (index) of the start of the labelled span for computing the token classification loss. 1685 | Positions are clamped to the length of the sequence (*sequence_length*). Position outside of the sequence 1686 | are not taken into account for computing the loss. 1687 | end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): 1688 | Labels for position (index) of the end of the labelled span for computing the token classification loss. 1689 | Positions are clamped to the length of the sequence (*sequence_length*). Position outside of the sequence 1690 | are not taken into account for computing the loss. 1691 | """ 1692 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1693 | if start_positions is not None and end_positions is not None: 1694 | use_cache = False 1695 | 1696 | outputs = self.model( 1697 | input_ids, 1698 | attention_mask=attention_mask, 1699 | decoder_input_ids=decoder_input_ids, 1700 | decoder_attention_mask=decoder_attention_mask, 1701 | head_mask=head_mask, 1702 | decoder_head_mask=decoder_head_mask, 1703 | cross_attn_head_mask=cross_attn_head_mask, 1704 | encoder_outputs=encoder_outputs, 1705 | inputs_embeds=inputs_embeds, 1706 | decoder_inputs_embeds=decoder_inputs_embeds, 1707 | use_cache=use_cache, 1708 | output_attentions=output_attentions, 1709 | output_hidden_states=output_hidden_states, 1710 | return_dict=return_dict, 1711 | ) 1712 | 1713 | sequence_output = outputs[0] 1714 | 1715 | logits = self.qa_outputs(sequence_output) 1716 | start_logits, end_logits = logits.split(1, dim=-1) 1717 | start_logits = start_logits.squeeze(-1).contiguous() 1718 | end_logits = end_logits.squeeze(-1).contiguous() 1719 | 1720 | total_loss = None 1721 | if start_positions is not None and end_positions is not None: 1722 | # If we are on multi-GPU, split add a dimension 1723 | if len(start_positions.size()) > 1: 1724 | start_positions = start_positions.squeeze(-1) 1725 | if len(end_positions.size()) > 1: 1726 | end_positions = end_positions.squeeze(-1) 1727 | # sometimes the start/end positions are outside our model inputs, we ignore these terms 1728 | ignored_index = start_logits.size(1) 1729 | start_positions = start_positions.clamp(0, ignored_index) 1730 | end_positions = end_positions.clamp(0, ignored_index) 1731 | 1732 | loss_fct = CrossEntropyLoss(ignore_index=ignored_index) 1733 | start_loss = loss_fct(start_logits, start_positions) 1734 | end_loss = loss_fct(end_logits, end_positions) 1735 | total_loss = (start_loss + end_loss) / 2 1736 | 1737 | if not return_dict: 1738 | output = ( 1739 | start_logits, 1740 | end_logits, 1741 | ) + outputs[1:] 1742 | return ((total_loss,) + output) if total_loss is not None else output 1743 | 1744 | return Seq2SeqQuestionAnsweringModelOutput( 1745 | loss=total_loss, 1746 | start_logits=start_logits, 1747 | end_logits=end_logits, 1748 | past_key_values=outputs.past_key_values, 1749 | decoder_hidden_states=outputs.decoder_hidden_states, 1750 | decoder_attentions=outputs.decoder_attentions, 1751 | cross_attentions=outputs.cross_attentions, 1752 | encoder_last_hidden_state=outputs.encoder_last_hidden_state, 1753 | encoder_hidden_states=outputs.encoder_hidden_states, 1754 | encoder_attentions=outputs.encoder_attentions, 1755 | ) 1756 | 1757 | 1758 | class BartDecoderWrapper(BartPretrainedModel): 1759 | """ 1760 | This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is 1761 | used in combination with the [`EncoderDecoderModel`] framework. 1762 | """ 1763 | 1764 | def __init__(self, config): 1765 | super().__init__(config) 1766 | self.decoder = BartDecoder(config) 1767 | 1768 | def forward(self, *args, **kwargs): 1769 | return self.decoder(*args, **kwargs) 1770 | 1771 | 1772 | class BartForCausalLM(BartPretrainedModel): 1773 | def __init__(self, config): 1774 | config = copy.deepcopy(config) 1775 | config.is_decoder = True 1776 | config.is_encoder_decoder = False 1777 | super().__init__(config) 1778 | self.model = BartDecoderWrapper(config) 1779 | 1780 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 1781 | 1782 | # Initialize weights and apply final processing 1783 | self.post_init() 1784 | 1785 | def get_input_embeddings(self): 1786 | return self.model.decoder.embed_tokens 1787 | 1788 | def set_input_embeddings(self, value): 1789 | self.model.decoder.embed_tokens = value 1790 | 1791 | def get_output_embeddings(self): 1792 | return self.lm_head 1793 | 1794 | def set_output_embeddings(self, new_embeddings): 1795 | self.lm_head = new_embeddings 1796 | 1797 | def set_decoder(self, decoder): 1798 | self.model.decoder = decoder 1799 | 1800 | def get_decoder(self): 1801 | return self.model.decoder 1802 | 1803 | @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC) 1804 | def forward( 1805 | self, 1806 | input_ids=None, 1807 | attention_mask=None, 1808 | encoder_hidden_states=None, 1809 | encoder_attention_mask=None, 1810 | head_mask=None, 1811 | cross_attn_head_mask=None, 1812 | past_key_values=None, 1813 | inputs_embeds=None, 1814 | labels=None, 1815 | use_cache=None, 1816 | output_attentions=None, 1817 | output_hidden_states=None, 1818 | return_dict=None, 1819 | ): 1820 | r""" 1821 | Args: 1822 | input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): 1823 | Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you 1824 | provide it. 1825 | 1826 | Indices can be obtained using [`BartTokenizer`]. See [`PreTrainedTokenizer.encode`] and 1827 | [`PreTrainedTokenizer.__call__`] for details. 1828 | 1829 | [What are input IDs?](../glossary#input-ids) 1830 | attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): 1831 | Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: 1832 | 1833 | - 1 for tokens that are **not masked**, 1834 | - 0 for tokens that are **masked**. 1835 | 1836 | [What are attention masks?](../glossary#attention-mask) 1837 | encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): 1838 | Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention 1839 | if the model is configured as a decoder. 1840 | encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): 1841 | Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used 1842 | in the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: 1843 | head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): 1844 | Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: 1845 | 1846 | - 1 indicates the head is **not masked**, 1847 | - 0 indicates the head is **masked**. 1848 | 1849 | cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): 1850 | Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`: 1851 | 1852 | - 1 indicates the head is **not masked**, 1853 | - 0 indicates the head is **masked**. 1854 | 1855 | past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): 1856 | Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of 1857 | shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of 1858 | shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional 1859 | tensors are only required when the model is used as a decoder in a Sequence to Sequence model. 1860 | 1861 | Contains pre-computed hidden-states (key and values in the self-attention blocks and in the 1862 | cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. 1863 | 1864 | If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those 1865 | that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of 1866 | all `decoder_input_ids` of shape `(batch_size, sequence_length)`. 1867 | labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): 1868 | Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., 1869 | config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored 1870 | (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. 1871 | use_cache (`bool`, *optional*): 1872 | If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding 1873 | (see `past_key_values`). 1874 | 1875 | - 1 for tokens that are **not masked**, 1876 | - 0 for tokens that are **masked**. 1877 | output_attentions (`bool`, *optional*): 1878 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under 1879 | returned tensors for more detail. 1880 | output_hidden_states (`bool`, *optional*): 1881 | Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors 1882 | for more detail. 1883 | return_dict (`bool`, *optional*): 1884 | Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple. 1885 | 1886 | Returns: 1887 | 1888 | Example: 1889 | 1890 | ```python 1891 | >>> from transformers import BartTokenizer, BartForCausalLM 1892 | 1893 | >>> tokenizer = BartTokenizer.from_pretrained("facebook/bart-base") 1894 | >>> model = BartForCausalLM.from_pretrained("facebook/bart-base", add_cross_attention=False) 1895 | >>> assert model.config.is_decoder, f"{model.__class__} has to be configured as a decoder." 1896 | >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") 1897 | >>> outputs = model(**inputs) 1898 | 1899 | >>> logits = outputs.logits 1900 | >>> expected_shape = [1, inputs.input_ids.shape[-1], model.config.vocab_size] 1901 | >>> list(logits.shape) == expected_shape 1902 | True 1903 | ```""" 1904 | 1905 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 1906 | output_hidden_states = ( 1907 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 1908 | ) 1909 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1910 | 1911 | # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) 1912 | outputs = self.model.decoder( 1913 | input_ids=input_ids, 1914 | attention_mask=attention_mask, 1915 | encoder_hidden_states=encoder_hidden_states, 1916 | encoder_attention_mask=encoder_attention_mask, 1917 | head_mask=head_mask, 1918 | cross_attn_head_mask=cross_attn_head_mask, 1919 | past_key_values=past_key_values, 1920 | inputs_embeds=inputs_embeds, 1921 | use_cache=use_cache, 1922 | output_attentions=output_attentions, 1923 | output_hidden_states=output_hidden_states, 1924 | return_dict=return_dict, 1925 | ) 1926 | 1927 | logits = self.lm_head(outputs[0]) 1928 | 1929 | loss = None 1930 | if labels is not None: 1931 | loss_fct = CrossEntropyLoss() 1932 | loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) 1933 | 1934 | if not return_dict: 1935 | output = (logits,) + outputs[1:] 1936 | return (loss,) + output if loss is not None else output 1937 | 1938 | return CausalLMOutputWithCrossAttentions( 1939 | loss=loss, 1940 | logits=logits, 1941 | past_key_values=outputs.past_key_values, 1942 | hidden_states=outputs.hidden_states, 1943 | attentions=outputs.attentions, 1944 | cross_attentions=outputs.cross_attentions, 1945 | ) 1946 | 1947 | def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, use_cache=None, **kwargs): 1948 | # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly 1949 | if attention_mask is None: 1950 | attention_mask = input_ids.new_ones(input_ids.shape) 1951 | 1952 | if past: 1953 | input_ids = input_ids[:, -1:] 1954 | # first step, decoder_cached_states are empty 1955 | return { 1956 | "input_ids": input_ids, # encoder_outputs is defined. input_ids not needed 1957 | "attention_mask": attention_mask, 1958 | "past_key_values": past, 1959 | "use_cache": use_cache, 1960 | } 1961 | 1962 | @staticmethod 1963 | def _reorder_cache(past, beam_idx): 1964 | reordered_past = () 1965 | for layer_past in past: 1966 | reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) 1967 | return reordered_past 1968 | --------------------------------------------------------------------------------