├── ct ├── __init__.py ├── functional │ ├── __init__.py │ └── ct_loss.py └── ct_loss.py ├── requirements.txt ├── project.toml ├── MANIFEST.in ├── LICENSE ├── mydocstring.mustache ├── setup.py ├── README.md └── .gitignore /ct/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ct/functional/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.10.0 2 | -------------------------------------------------------------------------------- /project.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools", "wheel"] 3 | build-backend = "setuptools.build_meta" -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include requirements.txt 2 | include *.md 3 | include *.py 4 | 5 | # Exclude build configs 6 | prune .circleci 7 | prune .github 8 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Shaojie Jiang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /mydocstring.mustache: -------------------------------------------------------------------------------- 1 | {{! Based on Google Docstring Template }} 2 | {{! Author Shaojie Jiang }} 3 | {{! Inspired by Graham Harrison https://towardsdatascience.com/3-easy-steps-to-folding-docstrings-in-vscode-fbb64573611b}} 4 | {{summaryPlaceholder}} 5 | 6 | {{extendedSummaryPlaceholder}} 7 | 8 | {{#parametersExist}} 9 | Args: 10 | {{#args}} 11 | {{var}} ({{typePlaceholder}}): {{descriptionPlaceholder}} 12 | {{/args}} 13 | {{#kwargs}} 14 | {{var}} ({{typePlaceholder}}, optional): {{descriptionPlaceholder}}. Defaults to {{&default}}. 15 | {{/kwargs}} 16 | {{/parametersExist}} 17 | 18 | {{#exceptionsExist}} 19 | Raises: 20 | {{#exceptions}} 21 | {{type}}: {{descriptionPlaceholder}} 22 | {{/exceptions}} 23 | {{/exceptionsExist}} 24 | 25 | {{#returnsExist}} 26 | Returns: 27 | {{#returns}} 28 | {{typePlaceholder}}: {{descriptionPlaceholder}} 29 | {{/returns}} 30 | {{/returnsExist}} 31 | 32 | {{#yieldsExist}} 33 | Yields: 34 | {{#yields}} 35 | {{typePlaceholder}}: {{descriptionPlaceholder}} 36 | {{/yields}} 37 | {{/yieldsExist}} 38 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | with open("README.md", "r", encoding="utf-8") as fh: 4 | long_description = fh.read() 5 | 6 | with open('requirements.txt') as f: 7 | reqs = [] 8 | for line in f: 9 | line = line.strip() 10 | reqs.append(line.split('==')[0]) 11 | 12 | setuptools.setup( 13 | name="ct_loss", 14 | version="0.0.3", 15 | author="Shaojie Jiang", 16 | author_email="shaojiejiang.1991@gmail.com", 17 | description="The contrastive token loss for reducing generative repetition of augoregressive neural language models.", 18 | long_description=long_description, 19 | long_description_content_type="text/markdown", 20 | url="https://github.com/ShaojieJiang/CT-Loss", 21 | project_urls={ 22 | "Bug Tracker": "https://github.com/ShaojieJiang/CT-Loss/issues", 23 | }, 24 | classifiers=[ 25 | "Programming Language :: Python :: 3", 26 | "License :: OSI Approved :: MIT License", 27 | "Operating System :: OS Independent", 28 | ], 29 | install_requires=reqs, 30 | packages=setuptools.find_packages(), 31 | python_requires=">=3.7", 32 | ) 33 | -------------------------------------------------------------------------------- /ct/ct_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | 4 | from ct.functional.ct_loss import contrastive_token_loss 5 | 6 | 7 | class ContrastiveTokenLoss(torch.nn.Module): 8 | """A Pytorch Module wrapper for the contrastive_token_loss function. 9 | 10 | Args: 11 | ignore_index (int, optional): Default padding token id. Defaults to -100. 12 | pad_id (int, optional): Specified padding token id. Used to mask out irrelevant preceding tokens. Defaults to 0. 13 | ct_length (Union[int, float], optional): When it's a float value and in [0, 1], it's a portion to the original sequence length; 14 | when it's larger than 1, it specifies the absolute CT length. Defaults to 0.25. 15 | preced_m_negatives (Union[int, float], optional): When it's a float value and in [0, 1], it's a portion to the CT sequence length; 16 | when it's larger than 1, it specifies the absolute negative window size. Defaults to 0.5. 17 | 18 | Returns: 19 | Tensor: Calculated CT loss. 20 | """ 21 | def __init__( 22 | self, 23 | ignore_index=-100, 24 | pad_id=0, 25 | ct_length=0.25, 26 | preced_m_negatives=0.5, 27 | ): 28 | super().__init__() 29 | self.ignore_index = ignore_index 30 | self.pad_id = pad_id 31 | self.ct_length = ct_length 32 | self.preced_m_negatives = preced_m_negatives 33 | 34 | def forward(self, input: Tensor, target: Tensor) -> Tensor: 35 | return contrastive_token_loss( 36 | input, target, self.ignore_index, 37 | self.pad_id, self.ct_length, 38 | self.preced_m_negatives, 39 | ) 40 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Contrastive Token loss function for PyTorch 2 | 3 | This repo is the clean (PyTorch) implementation of the **contrastive token loss** proposed in our paper: _A Simple Contrastive Learning Objective for Alleviating Neural Text Degeneration._ 4 | The source code for reproducing our results reported in the paper, including data pre-processing scripts, our trained models and interactive Google Colab notebooks, can be found in [this repo](https://github.com/ShaojieJiang/lit-seq). 5 | 6 | ## Installation 7 | 8 | `pip install ct-loss` 9 | 10 | ## Usage 11 | You can use our CT objective when **pretraining** or **finetuning** your augoregressive language models. 12 | With CT, the resulting language models will have significantly less **repetitive** generations, even with deterministic decoding such as greedy and beam search. 13 | It only takes several lines of code to use CT loss, around where you calculate PyTorch's `CrossEntropyLoss`. 14 | Here is an example: 15 | ```python 16 | import torch 17 | 18 | # Suppose we already have the model output logits and labels (sequences of token indices). 19 | # For example when the batch size is 10, sequence length is 50 and vocabulary size is 1000: 20 | logits = torch.rand(10, 50, 1000) 21 | labels = torch.randint(0, 999, (10, 50)) 22 | 23 | # This is how you normally use cross-entropy for a language model: 24 | from torch.nn import CrossEntropyLoss 25 | ce_criterion = CrossEntropyLoss() 26 | ce_loss = ce_criterion(logits.view(-1, 1000), labels.view(-1)) 27 | 28 | # This is how you can use our contrastive token loss: 29 | from ct.ct_loss import ContrastiveTokenLoss 30 | ct_criterion = ContrastiveTokenLoss(pad_id=999) # we need pad tokens for masking out tokens in a sequence that should not be used as negative tokens 31 | ct_loss = ct_criterion(logits, labels) 32 | 33 | # In our paper [1], we use CE and CT together 34 | loss = ce_loss + ct_loss 35 | 36 | print(ce_loss, ct_loss) 37 | 38 | >>> tensor(6.9536) tensor(1.5848) 39 | ``` 40 | 41 | ## Cite our paper 42 | 43 | ``` 44 | @article{jiang2022contrastive, 45 | doi = {10.48550/ARXIV.2205.02517}, 46 | url = {https://arxiv.org/abs/2205.02517}, 47 | author = {Jiang, Shaojie and Zhang, Ruqing and Vakulenko, Svitlana and de Rijke, Maarten}, 48 | title = {A Simple Contrastive Learning Objective for Alleviating Neural Text Degeneration}, 49 | publisher = {arXiv}, 50 | year = {2022}, 51 | } 52 | ``` -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # VSCode 132 | .vscode/ -------------------------------------------------------------------------------- /ct/functional/ct_loss.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | import torch 4 | from torch import Tensor 5 | 6 | 7 | def contrastive_token_loss( 8 | input: Tensor, 9 | target: Tensor, 10 | ignore_index: int = -100, 11 | pad_id: int = 0, 12 | ct_length: Union[int, float] = 0.25, 13 | preced_m_negatives: Union[int, float] = 0.5, 14 | # negative_token_portion: float = 0.125, 15 | # infer_length: bool = True, 16 | ) -> Tensor: 17 | """Contrastive Token loss function 18 | 19 | Args: 20 | input (Tensor): Input logits 21 | target (Tensor): Target token indices 22 | ignore_index (int, optional): Default padding token id. Defaults to -100. 23 | pad_id (int, optional): Specified padding token id. Used to mask out irrelevant preceding tokens. Defaults to 0. 24 | ct_length (Union[int, float], optional): When it's a float value and in [0, 1], it's a portion to the original sequence length; 25 | when it's larger than 1, it specifies the absolute CT length. Defaults to 0.25. 26 | preced_m_negatives (Union[int, float], optional): When it's a float value and in [0, 1], it's a portion to the CT sequence length; 27 | when it's larger than 1, it specifies the absolute negative window size. Defaults to 0.5. 28 | 29 | Returns: 30 | Tensor: Calculated CT loss. 31 | """ 32 | if ct_length <= 0: # no need for calculating CT loss 33 | return 0.0 34 | 35 | if ct_length <= 1: # portion of the total length (i.e., CE length) 36 | ct_length = round(input.size(1) * ct_length) 37 | else: # exact value 38 | ct_length = round(ct_length) 39 | 40 | input = input[..., :ct_length, :] 41 | target = target[..., :ct_length] 42 | 43 | assert preced_m_negatives > 0, "preced_m_negatives must be greater than 0 when using CT loss." 44 | if preced_m_negatives <= 1: # portion of ct_length 45 | preced_m_negatives = round(preced_m_negatives * ct_length) 46 | else: # exact value 47 | preced_m_negatives = round(preced_m_negatives) 48 | 49 | if ignore_index != pad_id: 50 | target_with_pad = target.masked_fill(target.eq(ignore_index), pad_id) 51 | else: 52 | target_with_pad = target 53 | 54 | non_padding = target_with_pad != pad_id 55 | 56 | preced_tokens = preced_negatives(target_with_pad, preced_m_negatives, pad_id) 57 | # if preced_m_negatives: 58 | positive_scores = input.gather(2, target_with_pad.unsqueeze(-1)) # label scores 59 | negative_scores = input.gather(2, preced_tokens) 60 | neg_minus_pos = negative_scores - positive_scores 61 | exp = neg_minus_pos.exp() 62 | 63 | pad_mask = preced_tokens.ne(pad_id).int() 64 | sum_exp = (exp * pad_mask).sum(dim=-1) # don't use pad tokens as negatives 65 | losses = (1 + sum_exp).log() * non_padding.int() 66 | 67 | ct_loss = losses.sum() / non_padding.int().sum() 68 | 69 | return ct_loss 70 | 71 | 72 | def preced_negatives( 73 | labels=None, 74 | preced_m_negatives=0, 75 | pad_id=0, 76 | ): 77 | preced_tokens = None 78 | if preced_m_negatives: # use previous k tokens as negatives 79 | preced_tokens = labels.unsqueeze(1).expand(labels.size(0), labels.size(1), labels.size(1)) 80 | mask = torch.ones_like(preced_tokens).bool() 81 | mask = torch.ones_like(preced_tokens).tril(-1).bool() 82 | if preced_m_negatives > 0: 83 | mask = mask.triu(-preced_m_negatives) 84 | preced_tokens = preced_tokens.masked_fill(~mask, pad_id) 85 | 86 | if preced_tokens is not None: 87 | preced_tokens = preced_tokens.masked_fill(preced_tokens == labels.unsqueeze(-1), pad_id) # exclude same label tokens as negatives 88 | 89 | return preced_tokens 90 | --------------------------------------------------------------------------------