13 |
14 |
15 |
16 |
17 | ## Overview
18 | This is the official code accompanied with the paper [Semantic Re-Tuning via Contrastive Tension](https://openreview.net/pdf?id=Ov_sMNau-PF).
19 | The paper was accepted at ICLR-2021 and official reviews and responses can be found at [OpenReview](https://openreview.net/forum?id=Ov_sMNau-PF).
20 |
21 | Contrastive Tension(CT) is a fully self-supervised algorithm for re-tuning already pre-trained transformer Language Models, and achieves State-Of-The-Art(SOTA) sentence embeddings for Semantic Textual Similarity(STS). All that is required is hence a pre-trained model and a modestly large text corpus. The results presented in the paper sampled text data from Wikipedia.
22 |
23 | This repository contains:
24 | * Tensorflow 2 implementation of the CT algorithm
25 | * State of the art pre-trained STS models
26 | * Tensorflow 2 inference code
27 | * PyTorch inference code
28 |
29 | ### Requirements
30 | While it is possible that other versions works equally fine, we have worked with the following:
31 |
32 | * Python = 3.6.9
33 | * Transformers = 4.1.1
34 |
35 |
36 | ## Usage
37 | All the models and tokenizers are available via the Huggingface interface, and can be loaded for both Tensorflow and PyTorch:
38 | ```python
39 | import transformers
40 |
41 | tokenizer = transformers.AutoTokenizer.from_pretrained('Contrastive-Tension/RoBerta-Large-CT-STSb')
42 |
43 | TF_model = transformers.TFAutoModel.from_pretrained('Contrastive-Tension/RoBerta-Large-CT-STSb')
44 | PT_model = transformers.AutoModel.from_pretrained('Contrastive-Tension/RoBerta-Large-CT-STSb')
45 | ```
46 |
47 | ### Inference
48 | To perform inference with the pre-trained models (or other Huggigface models) please see the script [ExampleBatchInference.py](ExampleBatchInference.py).
49 | The most important thing to remember when running inference is to apply the attention_masks on the batch output vector before mean pooling, as is done in the example script.
50 |
51 | ### CT Training
52 | To run CT on your own models and text data see [ExampleTraining.py](ExampleTraining.py) for a comprehensive example. This file currently creates a dummy corpus of random text. Simply replace this to whatever corpus you like.
53 |
54 |
55 | ## Pre-trained Models
56 | Note that these models are not trained with the exact hyperparameters as those disclosed in the original CT paper. Rather, the parameters are from a short follow-up paper currently under review, which once again pushes the SOTA.
57 |
58 | All evaluation is done using the [SentEval](https://github.com/facebookresearch/SentEval) framework, and shows the: (Pearson / Spearman) correlations
59 | ### Unsupervised / Zero-Shot
60 | As both the training of BERT, and CT itself is fully self-supervised, the models only tuned with CT require no labeled data whatsoever.
61 | The NLI models however, are first fine-tuned towards a natural language inference task, which requires labeled data.
62 |
63 | | Model| Avg Unsupervised STS |STS-b | #Parameters|
64 | | ----------------------------------|:-----: |:-----: |:-----: |
65 | |**Fully Unsupervised** ||
66 | | [BERT-Distil-CT](https://huggingface.co/Contrastive-Tension/BERT-Distil-CT) | 75.12 / 75.04| 78.63 / 77.91 | 66 M|
67 | | [BERT-Base-CT](https://huggingface.co/Contrastive-Tension/BERT-Base-CT) | 73.55 / 73.36 | 75.49 / 73.31 | 108 M|
68 | | [BERT-Large-CT](https://huggingface.co/Contrastive-Tension/BERT-Large-CT) | 77.12 / 76.93| 80.75 / 79.82 | 334 M|
69 | |**Using NLI Data** ||
70 | | [BERT-Distil-NLI-CT](https://huggingface.co/Contrastive-Tension/BERT-Distil-NLI-CT) | 76.65 / 76.63 | 79.74 / 81.01 | 66 M|
71 | | [BERT-Base-NLI-CT](https://huggingface.co/Contrastive-Tension/BERT-Base-NLI-CT) | 76.05 / 76.28 | 79.98 / 81.47 | 108 M|
72 | | [BERT-Large-NLI-CT](https://huggingface.co/Contrastive-Tension/BERT-Large-NLI-CT) | 77.42 / 77.41 | 80.92 / 81.66 | 334 M|
73 |
74 | ### Supervised
75 | These models are fine-tuned directly with STS data, using a modified version of the supervised training object proposed by [S-BERT](https://arxiv.org/abs/1908.10084).
76 | To our knowledge our RoBerta-Large-STSb is the current SOTA model for STS via sentence embeddings.
77 |
78 | | Model| STS-b | #Parameters|
79 | | ----------------------------------|:-----: |:-----: |
80 | | [BERT-Distil-CT-STSb](https://huggingface.co/Contrastive-Tension/BERT-Distil-CT-STSb) | 84.85 / 85.46 | 66 M|
81 | | [BERT-Base-CT-STSb](https://huggingface.co/Contrastive-Tension/BERT-Base-CT-STSb) | 85.31 / 85.76 | 108 M|
82 | | [BERT-Large-CT-STSb](https://huggingface.co/Contrastive-Tension/BERT-Large-CT-STSb) | 85.86 / 86.47 | 334 M|
83 | | [RoBerta-Large-CT-STSb](https://huggingface.co/Contrastive-Tension/RoBerta-Large-CT-STSb) | 87.56 / 88.42 | 334 M|
84 |
85 | ### Other Languages
86 |
87 | | Model | Language | #Parameters|
88 | | ----------------------------------|:-----: |:-----: |
89 | | [BERT-Base-Swe-CT-STSb](https://huggingface.co/Contrastive-Tension/BERT-Base-Swe-CT-STSb/tree/main) | Swedish | 108 M|
90 |
91 |
92 |
93 |
94 | ## License
95 | Distributed under the MIT License. See `LICENSE` for more information.
96 |
97 |
98 |
99 | ## Contact
100 | If you have questions regarding the paper, please consider creating a comment via the official [OpenReview submission](https://openreview.net/forum?id=Ov_sMNau-PF).
101 | If you have questions regarding the code or otherwise related to this Github page, please open an [issue](https://github.com/FreddeFrallan/Contrastive-Tension/issues).
102 |
103 | For other purposes, feel free to contact me directly at: Fredrik.Carlsson@ri.se
104 |
105 |
106 | ## Acknowledgements
107 | * [SentEval](https://github.com/facebookresearch/SentEval)
108 | * [Huggingface](https://huggingface.co/)
109 | * [Sentence-Transformer](https://github.com/UKPLab/sentence-transformers)
110 | * [Best Readme Template](https://github.com/othneildrew/Best-README-Template)
111 |
112 |
113 |
114 |
115 | [contributors-shield]: https://img.shields.io/github/contributors/othneildrew/Best-README-Template.svg?style=for-the-badge
116 | [contributors-url]: https://github.com/othneildrew/Best-README-Template/graphs/contributors
117 | [forks-shield]: https://img.shields.io/github/forks/othneildrew/Best-README-Template.svg?style=for-the-badge
118 | [forks-url]: https://github.com/othneildrew/Best-README-Template/network/members
119 | [stars-shield]: https://img.shields.io/github/stars/othneildrew/Best-README-Template.svg?style=for-the-badge
120 | [stars-url]: https://github.com/othneildrew/Best-README-Template/stargazers
121 | [issues-shield]: https://img.shields.io/github/issues/othneildrew/Best-README-Template.svg?style=for-the-badge
122 | [issues-url]: https://github.com/othneildrew/Best-README-Template/issues
123 | [license-shield]: https://img.shields.io/github/license/othneildrew/Best-README-Template.svg?style=for-the-badge
124 | [license-url]: https://github.com/othneildrew/Best-README-Template/blob/master/LICENSE.txt
125 | [linkedin-shield]: https://img.shields.io/badge/-LinkedIn-black.svg?style=for-the-badge&logo=linkedin&colorB=555
126 | [linkedin-url]: https://linkedin.com/in/othneildrew
127 | [product-screenshot]: images/screenshot.png
128 |
--------------------------------------------------------------------------------
/STSData/Dataset.py:
--------------------------------------------------------------------------------
1 | def getUniqueCaptions(dataset, sortOnSize=True):
2 | captions = {}
3 | for s1, s2, _ in dataset:
4 | captions[s1] = 1
5 | captions[s2] = 1
6 |
7 | def getUniqueCaptions(dataset, sortOnSize=True):
8 | captions = set()
9 | for s1, s2, _ in dataset:
10 | captions.add(s1)
11 | captions.add(s2)
12 | if sortOnSize:
13 | return sorted(captions, key=len)
14 | else:
15 | return list(captions)
16 |
17 |
18 | def _readAndLoadSTSBData(name):
19 | data = []
20 | with open("STSData/{}".format(name), 'r') as fp:
21 | for line in fp.readlines():
22 | genre, filename, year, ids, score, sentence1, sentence2 = line.strip().split('\t')[:7]
23 | data.append((sentence1, sentence2, float(score)))
24 | return data
25 |
26 |
27 | def loadTestData():
28 | return _readAndLoadSTSBData("sts-test.csv")
29 |
30 |
31 | def loadDevData():
32 | return _readAndLoadSTSBData("sts-dev.csv")
33 |
34 |
35 | def loadTrainData():
36 | return _readAndLoadSTSBData("sts-train.csv")
37 |
--------------------------------------------------------------------------------
/STSData/Evaluation.py:
--------------------------------------------------------------------------------
1 | '''
2 | This is just a simple script to get you up and running.
3 | If you are aiming to publish your own results, please consider relying on SentEval for evaluation.
4 | https://github.com/facebookresearch/SentEval
5 | '''
6 |
7 | from sklearn.metrics.pairwise import cosine_similarity
8 | from scipy.stats import pearsonr, spearmanr
9 | from ContrastiveTension import Inference
10 | from STSData import Dataset
11 | import numpy as np
12 | import tqdm
13 |
14 |
15 | def evalCorrelationScores(sent2Vecs, dataset):
16 | similarityScores, humanScores = [], []
17 | for i, data in enumerate(dataset):
18 | s1, s2, score = data
19 | humanScores.append(score)
20 | similarityScores.append(cosine_similarity([sent2Vecs[s1]], [sent2Vecs[s2]])[0][0])
21 |
22 | x, y = np.array(similarityScores), np.array(humanScores)
23 | pearResults = pearsonr(x, y)
24 | spearResults = spearmanr(x, y)
25 |
26 | return {'Pearson': pearResults[0], 'Spearman': spearResults[0]}
27 |
28 |
29 | def evaluateSTS(model, tokenizer, batch_size=512, use_dev_data=False):
30 | data = Dataset.loadDevData() if use_dev_data else Dataset.loadTestData()
31 | texts = Dataset.getUniqueCaptions(data)
32 |
33 | sent2Vec = {}
34 | for i in tqdm.tqdm(range(0, len(texts), batch_size), "Generating Eval Embeddings"):
35 | batchTexts = texts[i:i + batch_size]
36 | embs = Inference.tensorflowGenerateSentenceEmbeddings(model, tokenizer, batchTexts)
37 |
38 | for txt, emb in zip(batchTexts, embs):
39 | sent2Vec[txt] = emb
40 |
41 | return evalCorrelationScores(sent2Vec, data)
42 |
--------------------------------------------------------------------------------
/STSData/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FreddeFrallan/Contrastive-Tension/75293a883344b389bfe726a3d43d6b48b29b55fb/STSData/__init__.py
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | from setuptools import setup, find_packages
4 |
5 | setup(name='ContrastiveTension',
6 | version='0.0.1',
7 | description='ContrastiveTension',
8 | author='',
9 | author_email='',
10 | packages=find_packages(),
11 | )
12 |
--------------------------------------------------------------------------------