├── NOTICE ├── .gitignore ├── requirements.txt ├── src ├── sentence_transformers_ext │ ├── __init__.py │ ├── cross_encoder_eval │ │ ├── __init__.py │ │ ├── CECorrelationEvaluatorAUCEnsemble.py │ │ ├── CECorrelationEvaluatorAUC.py │ │ └── CECorrelationEvaluatorEnsemble.py │ ├── bi_encoder_eval │ │ ├── __init__.py │ │ ├── EmbeddingSimilarityEvaluatorAUC.py │ │ ├── EmbeddingSimilarityEvaluatorAUCEnsemble.py │ │ ├── EmbeddingSimilarityEvaluator.py │ │ └── EmbeddingSimilarityEvaluatorEnsemble.py │ └── utils.py ├── eval.py ├── self_distill.py ├── data.py └── mutual_distill_parallel.py ├── CODE_OF_CONDUCT.md ├── train_self_distill.sh ├── train_mutual_distill.sh ├── CONTRIBUTING.md ├── README.md ├── LICENSE └── THIRD-PARTY-LICENSES /NOTICE: -------------------------------------------------------------------------------- 1 | Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | output/* 2 | *.zip 3 | *__pycache__/ 4 | data/* 5 | .ipynb_checkpoints 6 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==2.4.0 2 | transformers==4.48.0 3 | sentence-transformers==2.0.0 4 | -------------------------------------------------------------------------------- /src/sentence_transformers_ext/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | from .utils import * # custom -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | ## Code of Conduct 2 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 3 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 4 | opensource-codeofconduct@amazon.com with any additional questions or comments. 5 | -------------------------------------------------------------------------------- /train_self_distill.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=$1 python src/self_distill.py \ 2 | --model_name_or_path "princeton-nlp/unsup-simcse-roberta-base" \ 3 | --batch_size_bi_encoder 128 \ 4 | --batch_size_cross_encoder 32 \ 5 | --num_epochs_bi_encoder 10 \ 6 | --num_epochs_cross_encoder 1 \ 7 | --cycle 3 \ 8 | --bi_encoder_pooling_mode cls \ 9 | --init_with_new_models \ 10 | --task sts_sickr \ 11 | --random_seed 2021 -------------------------------------------------------------------------------- /src/sentence_transformers_ext/cross_encoder_eval/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | from .CECorrelationEvaluatorEnsemble import CECorrelationEvaluatorEnsemble # custom 5 | from .CECorrelationEvaluatorAUC import CECorrelationEvaluatorAUC # custom 6 | from .CECorrelationEvaluatorAUCEnsemble import CECorrelationEvaluatorAUCEnsemble # custom 7 | -------------------------------------------------------------------------------- /train_mutual_distill.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=$1 python src/mutual_distill_parallel.py \ 2 | --device1 0 \ 3 | --device2 1 \ 4 | --batch_size_bi_encoder 128 \ 5 | --batch_size_cross_encoder 32 \ 6 | --num_epochs_bi_encoder 10 \ 7 | --num_epochs_cross_encoder 1 \ 8 | --cycle 3 \ 9 | --bi_encoder1_pooling_mode cls \ 10 | --bi_encoder2_pooling_mode cls \ 11 | --init_with_new_models \ 12 | --task sts_sickr \ 13 | --random_seed 2021 -------------------------------------------------------------------------------- /src/sentence_transformers_ext/bi_encoder_eval/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | from .EmbeddingSimilarityEvaluator import EmbeddingSimilarityEvaluator 5 | from .EmbeddingSimilarityEvaluatorEnsemble import EmbeddingSimilarityEvaluatorEnsemble # custom 6 | from .EmbeddingSimilarityEvaluatorAUC import EmbeddingSimilarityEvaluatorAUC # custom 7 | from .EmbeddingSimilarityEvaluatorAUCEnsemble import EmbeddingSimilarityEvaluatorAUCEnsemble # custom 8 | -------------------------------------------------------------------------------- /src/sentence_transformers_ext/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import os 5 | import csv 6 | 7 | def write_csv_log(output_path, csv_file, csv_headers, things_to_write): 8 | """ 9 | Write logs to a csv file. 10 | Parameters 11 | ---------- 12 | output_path: a string specifying the write path 13 | csv_file: a string specifying the csv file name 14 | things_to_write: a list of numbers to be written 15 | Returns 16 | ---------- 17 | None 18 | """ 19 | csv_path = os.path.join(output_path, csv_file) 20 | output_file_exists = os.path.isfile(csv_path) 21 | with open(csv_path, newline='', mode="a" if output_file_exists else 'w', encoding="utf-8") as f: 22 | writer = csv.writer(f) 23 | if not output_file_exists: 24 | writer.writerow(csv_headers) 25 | 26 | writer.writerow(things_to_write) -------------------------------------------------------------------------------- /src/sentence_transformers_ext/cross_encoder_eval/CECorrelationEvaluatorAUCEnsemble.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import logging 5 | from scipy.stats import pearsonr, spearmanr 6 | from sklearn.metrics import roc_auc_score 7 | from typing import List 8 | import numpy as np 9 | import os 10 | import csv 11 | from sentence_transformers.readers import InputExample 12 | 13 | from .CECorrelationEvaluatorEnsemble import CECorrelationEvaluatorEnsemble 14 | from ..utils import write_csv_log 15 | 16 | logger = logging.getLogger(__name__) 17 | 18 | class CECorrelationEvaluatorAUCEnsemble(CECorrelationEvaluatorEnsemble): 19 | """ 20 | This evaluator can be used with the CrossEncoder class. Given sentence pairs and continuous scores, 21 | it compute the pearson & spearman correlation between the predicted score for the sentence pair 22 | and the gold score. 23 | """ 24 | def __init__(self, sentence_pairs: List[List[str]], scores: List[float], name: str='', write_csv: bool = True): 25 | CECorrelationEvaluatorEnsemble.__init__(self, sentence_pairs, scores, name, write_csv) 26 | self.csv_headers = ["epoch", "steps", "roc_auc_score"] # overwrite parent's csv_headers 27 | 28 | def __call__(self, models, output_path: str = None, epoch: int = -1, steps: int = -1) -> float: 29 | if epoch != -1: 30 | if steps == -1: 31 | out_txt = " after epoch {}:".format(epoch) 32 | else: 33 | out_txt = " in epoch {} after {} steps:".format(epoch, steps) 34 | else: 35 | out_txt = ":" 36 | 37 | logger.info(self.__class__.__name__+": Evaluating the model on " + self.name + " dataset" + out_txt) 38 | 39 | all_scores = [] 40 | for model in models: 41 | pred_scores = model.predict(self.sentence_pairs, convert_to_numpy=True, show_progress_bar=False) 42 | all_scores.append(pred_scores) 43 | 44 | pred_scores = np.array(all_scores).mean(0) 45 | 46 | eval_auc = roc_auc_score(self.scores, pred_scores) 47 | 48 | logger.info("roc_auc_score: {:.4f}".format(eval_auc)) 49 | 50 | if output_path is not None and self.write_csv: 51 | things_to_write = [epoch, steps, eval_auc] 52 | write_csv_log(output_path=output_path, csv_file=self.csv_file, csv_headers=self.csv_headers, things_to_write=things_to_write) 53 | 54 | return eval_auc 55 | -------------------------------------------------------------------------------- /src/sentence_transformers_ext/cross_encoder_eval/CECorrelationEvaluatorAUC.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import logging 5 | from scipy.stats import pearsonr, spearmanr 6 | from sklearn.metrics import roc_auc_score 7 | from typing import List 8 | import os 9 | import csv 10 | from sentence_transformers.readers import InputExample 11 | 12 | from ..utils import write_csv_log 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | class CECorrelationEvaluatorAUC: 17 | """ 18 | This evaluator can be used with the CrossEncoder class. Given sentence pairs and continuous scores, 19 | it compute the pearson & spearman correlation between the predicted score for the sentence pair 20 | and the gold score. 21 | """ 22 | def __init__(self, sentence_pairs: List[List[str]], scores: List[float], name: str='', write_csv: bool = True): 23 | self.sentence_pairs = sentence_pairs 24 | self.scores = scores 25 | self.name = name 26 | 27 | self.csv_file = self.__class__.__name__ + ("_" + name if name else '') + "_results.csv" 28 | self.csv_headers = ["epoch", "steps", "roc_auc_score"] 29 | self.write_csv = write_csv 30 | 31 | @classmethod 32 | def from_input_examples(cls, examples: List[InputExample], **kwargs): 33 | sentence_pairs = [] 34 | scores = [] 35 | 36 | for example in examples: 37 | sentence_pairs.append(example.texts) 38 | scores.append(example.label) 39 | return cls(sentence_pairs, scores, **kwargs) 40 | 41 | def __call__(self, model, output_path: str = None, epoch: int = -1, steps: int = -1) -> float: 42 | if epoch != -1: 43 | if steps == -1: 44 | out_txt = " after epoch {}:".format(epoch) 45 | else: 46 | out_txt = " in epoch {} after {} steps:".format(epoch, steps) 47 | else: 48 | out_txt = ":" 49 | 50 | logger.info(self.__class__.__name__+": Evaluating the model on " + self.name + " dataset" + out_txt) 51 | pred_scores = model.predict(self.sentence_pairs, convert_to_numpy=True, show_progress_bar=False) 52 | 53 | 54 | eval_auc = roc_auc_score(self.scores, pred_scores) 55 | 56 | logger.info("roc_auc_score: {:.4f}".format(eval_auc)) 57 | 58 | if output_path is not None and self.write_csv: 59 | things_to_write = [epoch, steps, eval_auc] 60 | write_csv_log(output_path=output_path, csv_file=self.csv_file, csv_headers=self.csv_headers, things_to_write=things_to_write) 61 | 62 | return eval_auc 63 | -------------------------------------------------------------------------------- /src/sentence_transformers_ext/cross_encoder_eval/CECorrelationEvaluatorEnsemble.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import logging 5 | from scipy.stats import pearsonr, spearmanr 6 | from typing import List 7 | import os 8 | import csv 9 | import numpy as np 10 | from sentence_transformers.readers import InputExample 11 | 12 | from ..utils import write_csv_log 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | class CECorrelationEvaluatorEnsemble: 17 | """ 18 | This evaluator can be used with the CrossEncoder class. Given sentence pairs and continuous scores, 19 | it compute the pearson & spearman correlation between the predicted score for the sentence pair 20 | and the gold score. 21 | """ 22 | def __init__(self, sentence_pairs: List[List[str]], scores: List[float], name: str='', write_csv: bool = True): 23 | self.sentence_pairs = sentence_pairs 24 | self.scores = scores 25 | self.name = name 26 | 27 | self.csv_file = self.__class__.__name__ + ("_" + name if name else '') + "_results.csv" 28 | self.csv_headers = ["epoch", "steps", "Pearson_Correlation", "Spearman_Correlation"] 29 | self.write_csv = write_csv 30 | 31 | @classmethod 32 | def from_input_examples(cls, examples: List[InputExample], **kwargs): 33 | sentence_pairs = [] 34 | scores = [] 35 | 36 | for example in examples: 37 | sentence_pairs.append(example.texts) 38 | scores.append(example.label) 39 | return cls(sentence_pairs, scores, **kwargs) 40 | 41 | def __call__(self, models, output_path: str = None, epoch: int = -1, steps: int = -1) -> float: 42 | if epoch != -1: 43 | if steps == -1: 44 | out_txt = " after epoch {}:".format(epoch) 45 | else: 46 | out_txt = " in epoch {} after {} steps:".format(epoch, steps) 47 | else: 48 | out_txt = ":" 49 | 50 | logger.info(self.__class__.__name__+": Evaluating the model on " + self.name + " dataset" + out_txt) 51 | 52 | all_scores = [] 53 | for model in models: 54 | pred_scores = model.predict(self.sentence_pairs, convert_to_numpy=True, show_progress_bar=False) 55 | all_scores.append(pred_scores) 56 | 57 | pred_scores = np.array(all_scores).mean(0) 58 | 59 | eval_pearson, _ = pearsonr(self.scores, pred_scores) 60 | eval_spearman, _ = spearmanr(self.scores, pred_scores) 61 | 62 | logger.info("Correlation:\tPearson: {:.4f}\tSpearman: {:.4f}".format(eval_pearson, eval_spearman)) 63 | 64 | if output_path is not None and self.write_csv: 65 | things_to_write = [epoch, steps, eval_pearson, eval_spearman] 66 | write_csv_log(output_path=output_path, csv_file=self.csv_file, csv_headers=self.csv_headers, things_to_write=things_to_write) 67 | 68 | return eval_spearman 69 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing Guidelines 2 | 3 | Thank you for your interest in contributing to our project. Whether it's a bug report, new feature, correction, or additional 4 | documentation, we greatly value feedback and contributions from our community. 5 | 6 | Please read through this document before submitting any issues or pull requests to ensure we have all the necessary 7 | information to effectively respond to your bug report or contribution. 8 | 9 | 10 | ## Reporting Bugs/Feature Requests 11 | 12 | We welcome you to use the GitHub issue tracker to report bugs or suggest features. 13 | 14 | When filing an issue, please check existing open, or recently closed, issues to make sure somebody else hasn't already 15 | reported the issue. Please try to include as much information as you can. Details like these are incredibly useful: 16 | 17 | * A reproducible test case or series of steps 18 | * The version of our code being used 19 | * Any modifications you've made relevant to the bug 20 | * Anything unusual about your environment or deployment 21 | 22 | 23 | ## Contributing via Pull Requests 24 | Contributions via pull requests are much appreciated. Before sending us a pull request, please ensure that: 25 | 26 | 1. You are working against the latest source on the *main* branch. 27 | 2. You check existing open, and recently merged, pull requests to make sure someone else hasn't addressed the problem already. 28 | 3. You open an issue to discuss any significant work - we would hate for your time to be wasted. 29 | 30 | To send us a pull request, please: 31 | 32 | 1. Fork the repository. 33 | 2. Modify the source; please focus on the specific change you are contributing. If you also reformat all the code, it will be hard for us to focus on your change. 34 | 3. Ensure local tests pass. 35 | 4. Commit to your fork using clear commit messages. 36 | 5. Send us a pull request, answering any default questions in the pull request interface. 37 | 6. Pay attention to any automated CI failures reported in the pull request, and stay involved in the conversation. 38 | 39 | GitHub provides additional document on [forking a repository](https://help.github.com/articles/fork-a-repo/) and 40 | [creating a pull request](https://help.github.com/articles/creating-a-pull-request/). 41 | 42 | 43 | ## Finding contributions to work on 44 | Looking at the existing issues is a great way to find something to contribute on. As our projects, by default, use the default GitHub issue labels (enhancement/bug/duplicate/help wanted/invalid/question/wontfix), looking at any 'help wanted' issues is a great place to start. 45 | 46 | 47 | ## Code of Conduct 48 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 49 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 50 | opensource-codeofconduct@amazon.com with any additional questions or comments. 51 | 52 | 53 | ## Security issue notifications 54 | If you discover a potential security issue in this project we ask that you notify AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/). Please do **not** create a public github issue. 55 | 56 | 57 | ## Licensing 58 | 59 | See the [LICENSE](LICENSE) file for our project's licensing. We will ask you to confirm the licensing of your contribution. 60 | -------------------------------------------------------------------------------- /src/sentence_transformers_ext/bi_encoder_eval/EmbeddingSimilarityEvaluatorAUC.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | from sentence_transformers.evaluation import ( 5 | SentenceEvaluator, 6 | SimilarityFunction 7 | ) 8 | import logging 9 | import os 10 | import csv 11 | from sklearn.metrics.pairwise import paired_cosine_distances, paired_euclidean_distances, paired_manhattan_distances 12 | from scipy.stats import pearsonr, spearmanr 13 | import numpy as np 14 | from typing import List 15 | from sklearn.metrics import roc_auc_score 16 | from sentence_transformers.readers import InputExample 17 | 18 | from ..utils import write_csv_log 19 | 20 | logger = logging.getLogger(__name__) 21 | 22 | class EmbeddingSimilarityEvaluatorAUC(SentenceEvaluator): 23 | """ 24 | Evaluate a model based on the similarity of the embeddings by calculating the Spearman and Pearson rank correlation 25 | in comparison to the gold standard labels. 26 | The metrics are the cosine similarity as well as euclidean and Manhattan distance 27 | The returned score is the Spearman correlation with a specified metric. 28 | 29 | The results are written in a CSV. If a CSV already exists, then values are appended. 30 | """ 31 | def __init__(self, sentences1: List[str], sentences2: List[str], scores: List[float], batch_size: int = 16, main_similarity: SimilarityFunction = SimilarityFunction.COSINE, name: str = '', show_progress_bar: bool = False, write_csv: bool = True): 32 | """ 33 | Constructs an evaluator based for the dataset 34 | 35 | The labels need to indicate the similarity between the sentences. 36 | 37 | :param sentences1: List with the first sentence in a pair 38 | :param sentences2: List with the second sentence in a pair 39 | :param scores: Similarity score between sentences1[i] and sentences2[i] 40 | :param write_csv: Write results to a CSV file 41 | """ 42 | self.sentences1 = sentences1 43 | self.sentences2 = sentences2 44 | self.scores = scores 45 | self.write_csv = write_csv 46 | 47 | assert len(self.sentences1) == len(self.sentences2) 48 | assert len(self.sentences1) == len(self.scores) 49 | 50 | self.main_similarity = main_similarity 51 | self.name = name 52 | 53 | self.batch_size = batch_size 54 | if show_progress_bar is None: 55 | show_progress_bar = (logger.getEffectiveLevel() == logging.INFO or logger.getEffectiveLevel() == logging.DEBUG) 56 | self.show_progress_bar = show_progress_bar 57 | 58 | self.csv_file = self.__class__.__name__ + ("_"+name if name else '')+"_results.csv" 59 | self.csv_headers = ["epoch", "steps", "cosine_auc", "euclidean_auc", "manhattan_auc", "dot_auc"] 60 | 61 | @classmethod 62 | def from_input_examples(cls, examples: List[InputExample], **kwargs): 63 | sentences1 = [] 64 | sentences2 = [] 65 | scores = [] 66 | 67 | for example in examples: 68 | sentences1.append(example.texts[0]) 69 | sentences2.append(example.texts[1]) 70 | scores.append(example.label) 71 | return cls(sentences1, sentences2, scores, **kwargs) 72 | 73 | 74 | def __call__(self, model, output_path: str = None, epoch: int = -1, steps: int = -1) -> float: 75 | if epoch != -1: 76 | if steps == -1: 77 | out_txt = " after epoch {}:".format(epoch) 78 | else: 79 | out_txt = " in epoch {} after {} steps:".format(epoch, steps) 80 | else: 81 | out_txt = ":" 82 | 83 | logger.info(self.__class__.__name__+": Evaluating the model on " + self.name + " dataset" + out_txt) 84 | 85 | embeddings1 = model.encode(self.sentences1, batch_size=self.batch_size, show_progress_bar=self.show_progress_bar, convert_to_numpy=True) 86 | embeddings2 = model.encode(self.sentences2, batch_size=self.batch_size, show_progress_bar=self.show_progress_bar, convert_to_numpy=True) 87 | labels = self.scores 88 | 89 | cosine_scores = 1 - (paired_cosine_distances(embeddings1, embeddings2)) 90 | manhattan_distances = -paired_manhattan_distances(embeddings1, embeddings2) 91 | euclidean_distances = -paired_euclidean_distances(embeddings1, embeddings2) 92 | dot_products = [np.dot(emb1, emb2) for emb1, emb2 in zip(embeddings1, embeddings2)] 93 | 94 | 95 | eval_auc_cosine = roc_auc_score(labels, cosine_scores) 96 | 97 | eval_auc_manhattan = roc_auc_score(labels, manhattan_distances) 98 | 99 | eval_auc_euclidean = roc_auc_score(labels, euclidean_distances) 100 | 101 | eval_auc_dot = roc_auc_score(labels, dot_products) 102 | 103 | logger.info("Cosine-Similarity AUC: {:.4f}".format(eval_auc_cosine)) 104 | logger.info("Manhattan-Distance AUC: {:.4f}".format(eval_auc_manhattan)) 105 | logger.info("Euclidean-Distance AUC: {:.4f}".format(eval_auc_euclidean)) 106 | logger.info("Dot-Product-Similarity AUC: {:.4f}".format(eval_auc_dot)) 107 | 108 | if output_path is not None and self.write_csv: 109 | things_to_write = [epoch, steps, eval_auc_cosine, eval_auc_euclidean, eval_auc_manhattan, eval_auc_dot] 110 | write_csv_log(output_path=output_path, csv_file=self.csv_file, csv_headers=self.csv_headers, things_to_write=things_to_write) 111 | 112 | if self.main_similarity == SimilarityFunction.COSINE: 113 | return eval_auc_cosine 114 | elif self.main_similarity == SimilarityFunction.EUCLIDEAN: 115 | return eval_auc_euclidean 116 | elif self.main_similarity == SimilarityFunction.MANHATTAN: 117 | return eval_auc_manhattan 118 | elif self.main_similarity == SimilarityFunction.DOT_PRODUCT: 119 | return eval_spearman_dot 120 | elif self.main_similarity is None: 121 | return max(eval_auc_cosine, eval_auc_manhattan, eval_auc_euclidean, eval_auc_dot) 122 | else: 123 | raise ValueError("Unknown main_similarity value") 124 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 |
3 |

4 |

Trans-Encoder

5 |

6 | 7 |

8 | [arxiv] 9 | · 10 | [amazon.science blog] 11 | · 12 | [5min-video] 13 | · 14 | [talk@RIKEN] 15 | · 16 | [openreview] 17 |

18 |

19 | 20 | 21 | 22 | Code repo for **ICLR 2022** paper **_[Trans-Encoder: Unsupervised sentence-pair modelling through self- and mutual-distillations](https://arxiv.org/abs/2109.13059)_**
23 | by [Fangyu Liu](http://fangyuliu.me/about.html), [Yunlong Jiao](https://yunlongjiao.github.io/), [Jordan Massiah](https://www.linkedin.com/in/jordan-massiah-562862136/?originalSubdomain=uk), [Emine Yilmaz](https://sites.google.com/site/emineyilmaz/), [Serhii Havrylov](https://serhii-havrylov.github.io/). 24 | 25 | Trans-Encoder is a state-of-the-art unsupervised sentence similarity model. It conducts self-knowledge-distillation on top of pretrained language models by alternating between their bi- and cross-encoder forms. 26 | 27 | ## Huggingface pretrained models for STS 28 | 29 | 30 | 31 |
base models large models
32 | 33 | |model | STS avg. | 34 | |--------|--------| 35 | |baseline: [unsup-simcse-bert-base](https://huggingface.co/princeton-nlp/unsup-simcse-bert-base-uncased) | 76.21 | 36 | | [trans-encoder-bi-simcse-bert-base](https://huggingface.co/cambridgeltl/trans-encoder-bi-simcse-bert-base) | 80.41 | 37 | | [trans-encoder-cross-simcse-bert-base](https://huggingface.co/cambridgeltl/trans-encoder-cross-simcse-bert-base) | 79.90 | 38 | |baseline: [unsup-simcse-roberta-base](https://huggingface.co/princeton-nlp/unsup-simcse-roberta-base) | 76.10 | 39 | | [trans-encoder-bi-simcse-roberta-base](https://huggingface.co/cambridgeltl/trans-encoder-bi-simcse-roberta-base) | 80.47 | 40 | | [trans-encoder-cross-simcse-roberta-base](https://huggingface.co/cambridgeltl/trans-encoder-cross-simcse-roberta-base) | **81.15** | 41 | 42 | 43 | |model | STS avg. | 44 | |--------|--------| 45 | |baseline: [unsup-simcse-bert-large](https://huggingface.co/princeton-nlp/unsup-simcse-bert-large-uncased) | 78.42 | 46 | | [trans-encoder-bi-simcse-bert-large](https://huggingface.co/cambridgeltl/trans-encoder-bi-simcse-bert-large) | 82.65 | 47 | | [trans-encoder-cross-simcse-bert-large](https://huggingface.co/cambridgeltl/trans-encoder-cross-simcse-bert-large) | 82.52 | 48 | |baseline: [unsup-simcse-roberta-large](https://huggingface.co/princeton-nlp/unsup-simcse-roberta-large) | 78.92 | 49 | | [trans-encoder-bi-simcse-roberta-large](https://huggingface.co/cambridgeltl/trans-encoder-bi-simcse-roberta-large) | **82.93** | 50 | | [trans-encoder-cross-simcse-roberta-large](https://huggingface.co/cambridgeltl/trans-encoder-cross-simcse-roberta-large) | **82.93** | 51 | 52 | 53 |
54 | 55 | 56 | ## Dependencies 57 | 58 | ``` 59 | torch==1.8.1 60 | transformers==4.9.0 61 | sentence-transformers==2.0.0 62 | ``` 63 | Please view [requirements.txt](https://github.com/amzn/trans-encoder/blob/main/requirements.txt) for more details. 64 | 65 | ## Data 66 | All training and evaluation data will be automatically downloaded when running the scripts. See [src/data.py](https://github.com/amzn/trans-encoder/blob/main/src/data.py) for details. 67 | 68 | ## Train 69 | 70 | `--task` options: `sts` (STS2012-2016 and STS-b), `sickr`, `sts_sickr` (STS2012-2016, STS-b, and SICK-R), `qqp`, `qnli`, `mrpc`, `snli`, `custom`. See [src/data.py](https://github.com/amzn/trans-encoder/blob/main/src/data.py) for task data details. By default using all STS data (`sts_sickr`). 71 | 72 | #### Self-distillation 73 | ```bash 74 | >> bash train_self_distill.sh 0 75 | ``` 76 | `0` denotes GPU device index. 77 | 78 | #### Mutual-distillation 79 | ```bash 80 | >> bash train_mutual_distill.sh 0,1 81 | ``` 82 | Two GPUs needed; by default using SimCSE BERT & RoBERTa base models for ensembling. Add `--use_large` for switching to large models. 83 | 84 | #### Train with your custom corpus 85 | ```bash 86 | >> CUDA_VISIBLE_DEVICES=0,1 python src/mutual_distill_parallel.py \ 87 | --batch_size_bi_encoder 128 \ 88 | --batch_size_cross_encoder 64 \ 89 | --num_epochs_bi_encoder 10 \ 90 | --num_epochs_cross_encoder 1 \ 91 | --cycle 3 \ 92 | --bi_encoder1_pooling_mode cls \ 93 | --bi_encoder2_pooling_mode cls \ 94 | --init_with_new_models \ 95 | --task custom \ 96 | --random_seed 2021 \ 97 | --custom_corpus_path CORPUS_PATH 98 | ``` 99 | `CORPUS_PATH` should point to your custom corpus in which every line should be a sentence pair in the form of `sent1||sent2`. 100 | 101 | ## Evaluate 102 | #### Evaluate a single model 103 | 104 | Bi-encoder: 105 | ```bash 106 | >> python src/eval.py \ 107 | --model_name_or_path "cambridgeltl/trans-encoder-bi-simcse-roberta-large" \ 108 | --mode bi \ 109 | --task sts_sickr 110 | ``` 111 | Cross-encoder: 112 | ```bash 113 | >> python src/eval.py \ 114 | --model_name_or_path "cambridgeltl/trans-encoder-cross-simcse-roberta-large" \ 115 | --mode cross \ 116 | --task sts_sickr 117 | ``` 118 | #### Evaluate ensemble 119 | 120 | Bi-encoder: 121 | ```bash 122 | >> python src/eval.py \ 123 | --model_name_or_path1 "cambridgeltl/trans-encoder-bi-simcse-bert-large" \ 124 | --model_name_or_path2 "cambridgeltl/trans-encoder-bi-simcse-roberta-large" \ 125 | --mode bi \ 126 | --ensemble \ 127 | --task sts_sickr 128 | ``` 129 | 130 | Cross-encoder: 131 | ```bash 132 | >> python src/eval.py \ 133 | --model_name_or_path1 "cambridgeltl/trans-encoder-cross-simcse-bert-large" \ 134 | --model_name_or_path2 "cambridgeltl/trans-encoder-cross-simcse-roberta-large" \ 135 | --mode cross \ 136 | --ensemble \ 137 | --task sts_sickr 138 | ``` 139 | 140 | ## Authors 141 | 142 | - [**Fangyu Liu**](http://fangyuliu.me/about.html): Main contributor 143 | 144 | ## Security 145 | 146 | See [CONTRIBUTING](CONTRIBUTING.md#security-issue-notifications) for more information. 147 | 148 | ## License 149 | 150 | This project is licensed under the Apache-2.0 License. 151 | 152 | -------------------------------------------------------------------------------- /src/sentence_transformers_ext/bi_encoder_eval/EmbeddingSimilarityEvaluatorAUCEnsemble.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | from sentence_transformers.evaluation import ( 5 | SentenceEvaluator, 6 | SimilarityFunction 7 | ) 8 | import logging 9 | import os 10 | import csv 11 | from sklearn.metrics.pairwise import paired_cosine_distances, paired_euclidean_distances, paired_manhattan_distances 12 | from scipy.stats import pearsonr, spearmanr 13 | import numpy as np 14 | from typing import List 15 | from sklearn.metrics import roc_auc_score 16 | from sentence_transformers.readers import InputExample 17 | 18 | from ..utils import write_csv_log 19 | 20 | logger = logging.getLogger(__name__) 21 | 22 | class EmbeddingSimilarityEvaluatorAUCEnsemble(SentenceEvaluator): 23 | """ 24 | Evaluate a model based on the similarity of the embeddings by calculating the Spearman and Pearson rank correlation 25 | in comparison to the gold standard labels. 26 | The metrics are the cosine similarity as well as euclidean and Manhattan distance 27 | The returned score is the Spearman correlation with a specified metric. 28 | 29 | The results are written in a CSV. If a CSV already exists, then values are appended. 30 | """ 31 | def __init__(self, sentences1: List[str], sentences2: List[str], scores: List[float], batch_size: int = 16, main_similarity: SimilarityFunction = SimilarityFunction.COSINE, name: str = '', show_progress_bar: bool = False, write_csv: bool = True): 32 | """ 33 | Constructs an evaluator based for the dataset 34 | 35 | The labels need to indicate the similarity between the sentences. 36 | 37 | :param sentences1: List with the first sentence in a pair 38 | :param sentences2: List with the second sentence in a pair 39 | :param scores: Similarity score between sentences1[i] and sentences2[i] 40 | :param write_csv: Write results to a CSV file 41 | """ 42 | self.sentences1 = sentences1 43 | self.sentences2 = sentences2 44 | self.scores = scores 45 | self.write_csv = write_csv 46 | 47 | assert len(self.sentences1) == len(self.sentences2) 48 | assert len(self.sentences1) == len(self.scores) 49 | 50 | self.main_similarity = main_similarity 51 | self.name = name 52 | 53 | self.batch_size = batch_size 54 | if show_progress_bar is None: 55 | show_progress_bar = (logger.getEffectiveLevel() == logging.INFO or logger.getEffectiveLevel() == logging.DEBUG) 56 | self.show_progress_bar = show_progress_bar 57 | 58 | self.csv_file = self.__class__.__name__ + ("_"+name if name else '')+"_results.csv" 59 | self.csv_headers = ["epoch", "steps", "cosine_auc", "euclidean_auc", "manhattan_auc", "dot_auc"] 60 | 61 | @classmethod 62 | def from_input_examples(cls, examples: List[InputExample], **kwargs): 63 | sentences1 = [] 64 | sentences2 = [] 65 | scores = [] 66 | 67 | for example in examples: 68 | sentences1.append(example.texts[0]) 69 | sentences2.append(example.texts[1]) 70 | scores.append(example.label) 71 | return cls(sentences1, sentences2, scores, **kwargs) 72 | 73 | 74 | def __call__(self, models, output_path: str = None, epoch: int = -1, steps: int = -1) -> float: 75 | if epoch != -1: 76 | if steps == -1: 77 | out_txt = " after epoch {}:".format(epoch) 78 | else: 79 | out_txt = " in epoch {} after {} steps:".format(epoch, steps) 80 | else: 81 | out_txt = ":" 82 | 83 | logger.info(self.__class__.__name__+": Evaluating the model on " + self.name + " dataset" + out_txt) 84 | 85 | labels = self.scores 86 | 87 | auc_cosine_scores_all_models = [] 88 | auc_manhattan_distances_all_models = [] 89 | auc_euclidean_distances_all_models = [] 90 | auc_dot_products_all_models = [] 91 | 92 | # compute average predictions of all models 93 | for model in models: 94 | embeddings1 = model.encode(self.sentences1, batch_size=self.batch_size, show_progress_bar=self.show_progress_bar, convert_to_numpy=True) 95 | embeddings2 = model.encode(self.sentences2, batch_size=self.batch_size, show_progress_bar=self.show_progress_bar, convert_to_numpy=True) 96 | 97 | cosine_scores = 1 - (paired_cosine_distances(embeddings1, embeddings2)) 98 | manhattan_distances = -paired_manhattan_distances(embeddings1, embeddings2) 99 | euclidean_distances = -paired_euclidean_distances(embeddings1, embeddings2) 100 | dot_products = [np.dot(emb1, emb2) for emb1, emb2 in zip(embeddings1, embeddings2)] 101 | auc_cosine_scores_all_models.append(roc_auc_score(labels, cosine_scores)) 102 | auc_manhattan_distances_all_models.append(roc_auc_score(labels, manhattan_distances)) 103 | auc_euclidean_distances_all_models.append(roc_auc_score(labels, euclidean_distances)) 104 | auc_dot_products_all_models.append(roc_auc_score(labels, dot_products)) 105 | 106 | eval_auc_cosine = np.array(auc_cosine_scores_all_models).mean(0) 107 | eval_auc_manhattan = np.array(auc_manhattan_distances_all_models).mean(0) 108 | eval_auc_euclidean = np.array(auc_euclidean_distances_all_models).mean(0) 109 | eval_auc_dot = np.array(auc_dot_products_all_models).mean(0) 110 | 111 | logger.info("Cosine-Similarity AUC: {:.4f}".format(eval_auc_cosine)) 112 | logger.info("Manhattan-Distance AUC: {:.4f}".format(eval_auc_manhattan)) 113 | logger.info("Euclidean-Distance AUC: {:.4f}".format(eval_auc_euclidean)) 114 | logger.info("Dot-Product-Similarity AUC: {:.4f}".format(eval_auc_dot)) 115 | 116 | if output_path is not None and self.write_csv: 117 | things_to_write = [epoch, steps, eval_auc_cosine, eval_auc_euclidean, 118 | eval_auc_manhattan, eval_auc_dot] 119 | write_csv_log(output_path=output_path, csv_file=self.csv_file, csv_headers=self.csv_headers, things_to_write=things_to_write) 120 | 121 | if self.main_similarity == SimilarityFunction.COSINE: 122 | return eval_auc_cosine 123 | elif self.main_similarity == SimilarityFunction.EUCLIDEAN: 124 | return eval_auc_euclidean 125 | elif self.main_similarity == SimilarityFunction.MANHATTAN: 126 | return eval_auc_manhattan 127 | elif self.main_similarity == SimilarityFunction.DOT_PRODUCT: 128 | return eval_spearman_dot 129 | elif self.main_similarity is None: 130 | return max(eval_auc_cosine, eval_auc_manhattan, eval_auc_euclidean, eval_auc_dot) 131 | else: 132 | raise ValueError("Unknown main_similarity value") 133 | -------------------------------------------------------------------------------- /src/sentence_transformers_ext/bi_encoder_eval/EmbeddingSimilarityEvaluator.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | from sentence_transformers.evaluation import ( 5 | SentenceEvaluator, 6 | SimilarityFunction 7 | ) 8 | import logging 9 | import os 10 | import csv 11 | from sklearn.metrics.pairwise import paired_cosine_distances, paired_euclidean_distances, paired_manhattan_distances 12 | from scipy.stats import pearsonr, spearmanr 13 | import numpy as np 14 | from typing import List 15 | from sentence_transformers.readers import InputExample 16 | 17 | from ..utils import write_csv_log 18 | 19 | logger = logging.getLogger(__name__) 20 | 21 | class EmbeddingSimilarityEvaluator(SentenceEvaluator): 22 | """ 23 | Evaluate a model based on the similarity of the embeddings by calculating the Spearman and Pearson rank correlation 24 | in comparison to the gold standard labels. 25 | The metrics are the cosine similarity as well as euclidean and Manhattan distance 26 | The returned score is the Spearman correlation with a specified metric. 27 | 28 | The results are written in a CSV. If a CSV already exists, then values are appended. 29 | """ 30 | def __init__(self, sentences1: List[str], sentences2: List[str], scores: List[float], batch_size: int = 16, main_similarity: SimilarityFunction = SimilarityFunction.COSINE, name: str = '', show_progress_bar: bool = False, write_csv: bool = True): 31 | """ 32 | Constructs an evaluator based for the dataset 33 | 34 | The labels need to indicate the similarity between the sentences. 35 | 36 | :param sentences1: List with the first sentence in a pair 37 | :param sentences2: List with the second sentence in a pair 38 | :param scores: Similarity score between sentences1[i] and sentences2[i] 39 | :param write_csv: Write results to a CSV file 40 | """ 41 | self.sentences1 = sentences1 42 | self.sentences2 = sentences2 43 | self.scores = scores 44 | self.write_csv = write_csv 45 | 46 | assert len(self.sentences1) == len(self.sentences2) 47 | assert len(self.sentences1) == len(self.scores) 48 | 49 | self.main_similarity = main_similarity 50 | self.name = name 51 | 52 | self.batch_size = batch_size 53 | if show_progress_bar is None: 54 | show_progress_bar = (logger.getEffectiveLevel() == logging.INFO or logger.getEffectiveLevel() == logging.DEBUG) 55 | self.show_progress_bar = show_progress_bar 56 | 57 | self.csv_file = self.__class__.__name__ + ("_"+name if name else '')+"_results.csv" 58 | self.csv_headers = ["epoch", "steps", "cosine_pearson", "cosine_spearman", "euclidean_pearson", "euclidean_spearman", "manhattan_pearson", "manhattan_spearman", "dot_pearson", "dot_spearman"] 59 | 60 | @classmethod 61 | def from_input_examples(cls, examples: List[InputExample], **kwargs): 62 | sentences1 = [] 63 | sentences2 = [] 64 | scores = [] 65 | 66 | for example in examples: 67 | sentences1.append(example.texts[0]) 68 | sentences2.append(example.texts[1]) 69 | scores.append(example.label) 70 | return cls(sentences1, sentences2, scores, **kwargs) 71 | 72 | 73 | def __call__(self, model, output_path: str = None, epoch: int = -1, steps: int = -1) -> float: 74 | if epoch != -1: 75 | if steps == -1: 76 | out_txt = " after epoch {}:".format(epoch) 77 | else: 78 | out_txt = " in epoch {} after {} steps:".format(epoch, steps) 79 | else: 80 | out_txt = ":" 81 | 82 | logger.info(self.__class__.__name__+": Evaluating the model on " + self.name + " dataset" + out_txt) 83 | 84 | embeddings1 = model.encode(self.sentences1, batch_size=self.batch_size, show_progress_bar=self.show_progress_bar, convert_to_numpy=True) 85 | embeddings2 = model.encode(self.sentences2, batch_size=self.batch_size, show_progress_bar=self.show_progress_bar, convert_to_numpy=True) 86 | labels = self.scores 87 | 88 | cosine_scores = 1 - (paired_cosine_distances(embeddings1, embeddings2)) 89 | 90 | manhattan_distances = -paired_manhattan_distances(embeddings1, embeddings2) 91 | euclidean_distances = -paired_euclidean_distances(embeddings1, embeddings2) 92 | dot_products = [np.dot(emb1, emb2) for emb1, emb2 in zip(embeddings1, embeddings2)] 93 | 94 | eval_pearson_cosine, _ = pearsonr(labels, cosine_scores) 95 | eval_spearman_cosine, _ = spearmanr(labels, cosine_scores) 96 | 97 | eval_pearson_manhattan, _ = pearsonr(labels, manhattan_distances) 98 | eval_spearman_manhattan, _ = spearmanr(labels, manhattan_distances) 99 | 100 | eval_pearson_euclidean, _ = pearsonr(labels, euclidean_distances) 101 | eval_spearman_euclidean, _ = spearmanr(labels, euclidean_distances) 102 | 103 | eval_pearson_dot, _ = pearsonr(labels, dot_products) 104 | eval_spearman_dot, _ = spearmanr(labels, dot_products) 105 | 106 | logger.info("Cosine-Similarity :\tPearson: {:.4f}\tSpearman: {:.4f}".format( 107 | eval_pearson_cosine, eval_spearman_cosine)) 108 | logger.info("Manhattan-Distance:\tPearson: {:.4f}\tSpearman: {:.4f}".format( 109 | eval_pearson_manhattan, eval_spearman_manhattan)) 110 | logger.info("Euclidean-Distance:\tPearson: {:.4f}\tSpearman: {:.4f}".format( 111 | eval_pearson_euclidean, eval_spearman_euclidean)) 112 | logger.info("Dot-Product-Similarity:\tPearson: {:.4f}\tSpearman: {:.4f}".format( 113 | eval_pearson_dot, eval_spearman_dot)) 114 | 115 | if output_path is not None and self.write_csv: 116 | things_to_write = [epoch, steps, eval_pearson_cosine, eval_spearman_cosine, eval_pearson_euclidean, 117 | eval_spearman_euclidean, eval_pearson_manhattan, eval_spearman_manhattan, eval_pearson_dot, eval_spearman_dot] 118 | write_csv_log(output_path=output_path, csv_file=self.csv_file, csv_headers=self.csv_headers, things_to_write=things_to_write) 119 | 120 | if self.main_similarity == SimilarityFunction.COSINE: 121 | return eval_spearman_cosine 122 | elif self.main_similarity == SimilarityFunction.EUCLIDEAN: 123 | return eval_spearman_euclidean 124 | elif self.main_similarity == SimilarityFunction.MANHATTAN: 125 | return eval_spearman_manhattan 126 | elif self.main_similarity == SimilarityFunction.DOT_PRODUCT: 127 | return eval_spearman_dot 128 | elif self.main_similarity is None: 129 | return max(eval_spearman_cosine, eval_spearman_manhattan, eval_spearman_euclidean, eval_spearman_dot) 130 | else: 131 | raise ValueError("Unknown main_similarity value") 132 | -------------------------------------------------------------------------------- /src/sentence_transformers_ext/bi_encoder_eval/EmbeddingSimilarityEvaluatorEnsemble.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | from sentence_transformers.evaluation import ( 5 | SentenceEvaluator, 6 | SimilarityFunction 7 | ) 8 | import logging 9 | import os 10 | import csv 11 | from sklearn.metrics.pairwise import paired_cosine_distances, paired_euclidean_distances, paired_manhattan_distances 12 | from scipy.stats import pearsonr, spearmanr 13 | import numpy as np 14 | from typing import List 15 | 16 | from sentence_transformers.readers import InputExample 17 | 18 | from ..utils import write_csv_log 19 | 20 | logger = logging.getLogger(__name__) 21 | 22 | class EmbeddingSimilarityEvaluatorEnsemble(SentenceEvaluator): 23 | """ 24 | Evaluate a model based on the similarity of the embeddings by calculating the Spearman and Pearson rank correlation 25 | in comparison to the gold standard labels. 26 | The metrics are the cosine similarity as well as euclidean and Manhattan distance 27 | The returned score is the Spearman correlation with a specified metric. 28 | 29 | The results are written in a CSV. If a CSV already exists, then values are appended. 30 | """ 31 | def __init__(self, sentences1: List[str], sentences2: List[str], scores: List[float], batch_size: int = 16, main_similarity: SimilarityFunction = None, name: str = '', show_progress_bar: bool = False, write_csv: bool = True): 32 | """ 33 | Constructs an evaluator based for the dataset 34 | 35 | The labels need to indicate the similarity between the sentences. 36 | 37 | :param sentences1: List with the first sentence in a pair 38 | :param sentences2: List with the second sentence in a pair 39 | :param scores: Similarity score between sentences1[i] and sentences2[i] 40 | :param write_csv: Write results to a CSV file 41 | """ 42 | self.sentences1 = sentences1 43 | self.sentences2 = sentences2 44 | self.scores = scores 45 | self.write_csv = write_csv 46 | 47 | assert len(self.sentences1) == len(self.sentences2) 48 | assert len(self.sentences1) == len(self.scores) 49 | 50 | self.main_similarity = main_similarity 51 | self.name = name 52 | 53 | self.batch_size = batch_size 54 | if show_progress_bar is None: 55 | show_progress_bar = (logger.getEffectiveLevel() == logging.INFO or logger.getEffectiveLevel() == logging.DEBUG) 56 | self.show_progress_bar = show_progress_bar 57 | 58 | self.csv_file = self.__class__.__name__ + ("_"+name if name else '')+"_results.csv" 59 | self.csv_headers = ["epoch", "steps", "cosine_pearson", "cosine_spearman", "euclidean_pearson", "euclidean_spearman", "manhattan_pearson", "manhattan_spearman", "dot_pearson", "dot_spearman"] 60 | 61 | @classmethod 62 | def from_input_examples(cls, examples: List[InputExample], **kwargs): 63 | sentences1 = [] 64 | sentences2 = [] 65 | scores = [] 66 | 67 | for example in examples: 68 | sentences1.append(example.texts[0]) 69 | sentences2.append(example.texts[1]) 70 | scores.append(example.label) 71 | return cls(sentences1, sentences2, scores, **kwargs) 72 | 73 | 74 | def __call__(self, models, output_path: str = None, epoch: int = -1, steps: int = -1) -> float: 75 | if epoch != -1: 76 | if steps == -1: 77 | out_txt = " after epoch {}:".format(epoch) 78 | else: 79 | out_txt = " in epoch {} after {} steps:".format(epoch, steps) 80 | else: 81 | out_txt = ":" 82 | 83 | logger.info(self.__class__.__name__+": Evaluating the model on " + self.name + " dataset" + out_txt) 84 | 85 | labels = self.scores 86 | 87 | cosine_scores_all_models = [] 88 | manhattan_distances_all_models = [] 89 | euclidean_distances_all_models = [] 90 | dot_products_all_models = [] 91 | 92 | # compute average predictions of all models 93 | for model in models: 94 | embeddings1 = model.encode(self.sentences1, batch_size=self.batch_size, show_progress_bar=self.show_progress_bar, convert_to_numpy=True) 95 | embeddings2 = model.encode(self.sentences2, batch_size=self.batch_size, show_progress_bar=self.show_progress_bar, convert_to_numpy=True) 96 | 97 | cosine_scores = 1 - (paired_cosine_distances(embeddings1, embeddings2)) 98 | manhattan_distances = -paired_manhattan_distances(embeddings1, embeddings2) 99 | euclidean_distances = -paired_euclidean_distances(embeddings1, embeddings2) 100 | dot_products = [np.dot(emb1, emb2) for emb1, emb2 in zip(embeddings1, embeddings2)] 101 | cosine_scores_all_models.append(cosine_scores) 102 | manhattan_distances_all_models.append(manhattan_distances) 103 | euclidean_distances_all_models.append(euclidean_distances) 104 | dot_products_all_models.append(dot_products) 105 | 106 | cosine_scores = np.array(cosine_scores_all_models).mean(0) 107 | manhattan_distances = np.array(manhattan_distances_all_models).mean(0) 108 | euclidean_distances = np.array(euclidean_distances_all_models).mean(0) 109 | dot_products = np.array(dot_products_all_models).mean(0) 110 | 111 | eval_pearson_cosine, _ = pearsonr(labels, cosine_scores) 112 | eval_spearman_cosine, _ = spearmanr(labels, cosine_scores) 113 | 114 | eval_pearson_manhattan, _ = pearsonr(labels, manhattan_distances) 115 | eval_spearman_manhattan, _ = spearmanr(labels, manhattan_distances) 116 | 117 | eval_pearson_euclidean, _ = pearsonr(labels, euclidean_distances) 118 | eval_spearman_euclidean, _ = spearmanr(labels, euclidean_distances) 119 | 120 | eval_pearson_dot, _ = pearsonr(labels, dot_products) 121 | eval_spearman_dot, _ = spearmanr(labels, dot_products) 122 | 123 | logger.info("Cosine-Similarity :\tPearson: {:.4f}\tSpearman: {:.4f}".format( 124 | eval_pearson_cosine, eval_spearman_cosine)) 125 | logger.info("Manhattan-Distance:\tPearson: {:.4f}\tSpearman: {:.4f}".format( 126 | eval_pearson_manhattan, eval_spearman_manhattan)) 127 | logger.info("Euclidean-Distance:\tPearson: {:.4f}\tSpearman: {:.4f}".format( 128 | eval_pearson_euclidean, eval_spearman_euclidean)) 129 | logger.info("Dot-Product-Similarity:\tPearson: {:.4f}\tSpearman: {:.4f}".format( 130 | eval_pearson_dot, eval_spearman_dot)) 131 | 132 | if output_path is not None and self.write_csv: 133 | things_to_write = [epoch, steps, eval_pearson_cosine, eval_spearman_cosine, eval_pearson_euclidean, 134 | eval_spearman_euclidean, eval_pearson_manhattan, eval_spearman_manhattan, eval_pearson_dot, eval_spearman_dot] 135 | write_csv_log(output_path=output_path, csv_file=self.csv_file, csv_headers=self.csv_headers, things_to_write=things_to_write) 136 | 137 | if self.main_similarity == SimilarityFunction.COSINE: 138 | return eval_spearman_cosine 139 | elif self.main_similarity == SimilarityFunction.EUCLIDEAN: 140 | return eval_spearman_euclidean 141 | elif self.main_similarity == SimilarityFunction.MANHATTAN: 142 | return eval_spearman_manhattan 143 | elif self.main_similarity == SimilarityFunction.DOT_PRODUCT: 144 | return eval_spearman_dot 145 | elif self.main_similarity is None: 146 | return max(eval_spearman_cosine, eval_spearman_manhattan, eval_spearman_euclidean, eval_spearman_dot) 147 | else: 148 | raise ValueError("Unknown main_similarity value") 149 | -------------------------------------------------------------------------------- /src/eval.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | from torch.utils.data import DataLoader 5 | import torch.nn.functional as F 6 | from sentence_transformers import models, losses, util, SentenceTransformer 7 | from sentence_transformers.cross_encoder import CrossEncoder 8 | from sentence_transformers import LoggingHandler, SentenceTransformer 9 | from sentence_transformers.cross_encoder.evaluation import CECorrelationEvaluator 10 | from sentence_transformers.readers import InputExample 11 | from datetime import datetime 12 | import argparse 13 | import logging 14 | import csv 15 | import torch 16 | import sys 17 | import random 18 | import tqdm 19 | import gzip 20 | import os 21 | import pandas as pd 22 | import numpy as np 23 | 24 | from data import load_data 25 | 26 | from sentence_transformers_ext.bi_encoder_eval import ( 27 | EmbeddingSimilarityEvaluator, 28 | EmbeddingSimilarityEvaluatorEnsemble, 29 | EmbeddingSimilarityEvaluatorAUC, 30 | EmbeddingSimilarityEvaluatorAUCEnsemble 31 | ) 32 | from sentence_transformers_ext.cross_encoder_eval import ( 33 | CECorrelationEvaluatorEnsemble, 34 | CECorrelationEvaluatorAUC, 35 | CECorrelationEvaluatorAUCEnsemble 36 | ) 37 | 38 | 39 | def eval_encoder(all_test, encoder, task="sts", enc_type="bi", ensemble=False): 40 | """ 41 | Evaluate bi- or cross-encoders. 42 | Parameters 43 | ---------- 44 | all_test: a dict of all test sets 45 | encoder: a bi- or cross-enocder 46 | enc_type: a string specifying whether the encoder is a bi- or cross-encoder 47 | ensemble: a bool value indicating whether multiple encoders are used in input 48 | Returns 49 | ---------- 50 | None 51 | """ 52 | scores = [] 53 | for name, data in all_test.items(): 54 | if task in ["sts", "sickr", "sts_sickr", "custom"]: 55 | if enc_type == "bi": 56 | if ensemble: 57 | test_evaluator = EmbeddingSimilarityEvaluatorEnsemble.from_input_examples(data, name=name) 58 | else: 59 | test_evaluator = EmbeddingSimilarityEvaluator.from_input_examples(data, name=name) 60 | elif enc_type == "cross": 61 | if ensemble: 62 | test_evaluator = CECorrelationEvaluatorEnsemble.from_input_examples(data, name=name) 63 | else: 64 | test_evaluator = CECorrelationEvaluator.from_input_examples(data, name=name) 65 | else: 66 | raise NotImplementedError() 67 | else: 68 | if enc_type == "bi": 69 | if ensemble: 70 | test_evaluator = EmbeddingSimilarityEvaluatorAUCEnsemble.from_input_examples(data, name=name) 71 | else: 72 | test_evaluator = EmbeddingSimilarityEvaluatorAUC.from_input_examples(data, name=name) 73 | elif enc_type == "cross": 74 | if ensemble: 75 | test_evaluator = CECorrelationEvaluatorAUCEnsemble.from_input_examples(data, name=name) 76 | else: 77 | test_evaluator = CECorrelationEvaluatorAUC.from_input_examples(data, name=name) 78 | else: 79 | raise NotImplementedError() 80 | scores += [test_evaluator(encoder)] 81 | logging.info (" & ".join(["%.2f" % (s*100) for s in scores])) 82 | logging.info (f"***** test's avg spearman's rho: {sum(scores)/len(scores):.4f} ****") 83 | 84 | 85 | def main(): 86 | 87 | #### Just some code to print debug information to stdout 88 | logging.basicConfig(format='%(asctime)s - %(message)s', 89 | datefmt='%Y-%m-%d %H:%M:%S', 90 | level=logging.INFO, 91 | handlers=[LoggingHandler()]) 92 | 93 | 94 | parser = argparse.ArgumentParser() 95 | parser.add_argument("--model_name_or_path", type=str, 96 | default="princeton-nlp/unsup-simcse-bert-base-uncased", 97 | help="Transformers' model name or path") 98 | parser.add_argument("--task", type=str, default="sts", 99 | help='{sts|sickr|sts_sickr|qqp|qnli|mrpc|snli|custom}') 100 | parser.add_argument("--mode", type=str, default='bi', help="{cross|bi}") 101 | parser.add_argument("--device", type=int, default=0) 102 | parser.add_argument("--bi_encoder_pooling_mode", type=str, default='cls', 103 | help="{cls|mean}") 104 | parser.add_argument("--ensemble", action="store_true") 105 | parser.add_argument("--model_name_or_path1", type=str, 106 | default="princeton-nlp/unsup-simcse-bert-base-uncased") 107 | parser.add_argument("--model_name_or_path2", type=str, 108 | default="princeton-nlp/unsup-simcse-roberta-base") 109 | parser.add_argument("--bi_encoder_pooling_mode1", type=str, default="cls") 110 | parser.add_argument("--bi_encoder_pooling_mode2", type=str, default="cls") 111 | parser.add_argument("--quick_test", action="store_true") 112 | 113 | args = parser.parse_args() 114 | print (args) 115 | 116 | ### read datasets 117 | all_pairs, all_test, dev_samples = load_data(args.task) 118 | 119 | if args.quick_test: 120 | all_pairs = all_pairs[:5000] # for quick testing 121 | 122 | print ("|raw sentence pairs|:", len(all_pairs)) 123 | print ("|dev set|:", len(dev_samples)) 124 | for key in all_test: 125 | print ("|test set: %s|" % key, len(all_test[key])) 126 | 127 | model_name = args.model_name_or_path 128 | model_name1 = args.model_name_or_path1 129 | model_name2 = args.model_name_or_path2 130 | 131 | max_seq_length = 32 132 | device=args.device 133 | 134 | if not args.ensemble: 135 | logging.info ("########## load model and evaluate ##########") 136 | 137 | if args.mode == "bi": 138 | 139 | ###### Bi-encoder (sentence-transformers) ###### 140 | logging.info(f"Loading bi-encoder model: {model_name}") 141 | # Use Huggingface/transformers model (like BERT, RoBERTa, XLNet, XLM-R) for mapping tokens to embeddings 142 | word_embedding_model = models.Transformer(model_name, max_seq_length=max_seq_length) 143 | 144 | # Apply mean pooling to get one fixed sized sentence vector 145 | pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(), pooling_mode=args.bi_encoder_pooling_mode) 146 | 147 | bi_encoder = SentenceTransformer(modules=[word_embedding_model, pooling_model], device=device) 148 | 149 | # eval bi-encoder 150 | logging.info ("Evaluate bi-encoder...") 151 | eval_encoder(all_test, bi_encoder, task=args.task, enc_type="bi") 152 | 153 | 154 | elif args.mode == "cross": 155 | 156 | ###### cross-encoder (sentence-transformers) ###### 157 | logging.info(f"Loading cross-encoder model: {model_name}") 158 | 159 | cross_encoder = CrossEncoder(model_name, device=device) 160 | 161 | # eval cross-encoder 162 | logging.info ("Evaluate cross-encoder...") 163 | eval_encoder(all_test, cross_encoder, task=args.task, enc_type="cross") 164 | 165 | else: 166 | raise NotImplementedError() 167 | 168 | else: 169 | logging.info ("########## load models and evaluate ##########") 170 | 171 | if args.mode == "bi": 172 | ###### Bi-encoder (sentence-transformers) ###### 173 | logging.info(f"Loading bi-encoder1 model: {model_name1}") 174 | logging.info(f"Loading bi-encoder2 model: {model_name2}") 175 | 176 | # Use Huggingface/transformers model (like BERT, RoBERTa, XLNet, XLM-R) for mapping tokens to embeddings 177 | word_embedding_model1 = models.Transformer(model_name1, max_seq_length=max_seq_length) 178 | word_embedding_model2 = models.Transformer(model_name2, max_seq_length=max_seq_length) 179 | 180 | # Apply mean pooling to get one fixed sized sentence vector 181 | pooling_model1 = models.Pooling(word_embedding_model1.get_word_embedding_dimension(), pooling_mode=args.bi_encoder_pooling_mode1) 182 | pooling_model2 = models.Pooling(word_embedding_model2.get_word_embedding_dimension(), pooling_mode=args.bi_encoder_pooling_mode2) 183 | 184 | bi_encoder1 = SentenceTransformer(modules=[word_embedding_model1, pooling_model1], device=device) 185 | bi_encoder2 = SentenceTransformer(modules=[word_embedding_model2, pooling_model2], device=device) 186 | 187 | # eval bi-encoder 188 | logging.info ("Evaluate bi-encoder (ensembled)...") 189 | eval_encoder(all_test, [bi_encoder1, bi_encoder2], task=args.task, enc_type="bi", ensemble=True) 190 | 191 | 192 | elif args.mode == "cross": 193 | 194 | ###### cross-encoder (sentence-transformers) ###### 195 | logging.info(f"Loading cross-encoder1 model: {model_name1}") 196 | logging.info(f"Loading cross-encoder2 model: {model_name2}") 197 | 198 | cross_encoder1 = CrossEncoder(model_name1, device=device) 199 | cross_encoder2 = CrossEncoder(model_name2, device=device) 200 | 201 | # eval cross-encoder 202 | logging.info ("Evaluate cross-encoder (ensembled)...") 203 | eval_encoder(all_test, [cross_encoder1, cross_encoder2], task=args.task, enc_type="cross", ensemble=True) 204 | 205 | else: 206 | raise NotImplementedError() 207 | 208 | if __name__ == "__main__": 209 | main() 210 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | -------------------------------------------------------------------------------- /THIRD-PARTY-LICENSES: -------------------------------------------------------------------------------- 1 | The Amazon Open-Source Code - "trans-encoder" - includes the following third-party software/licensing: 2 | 3 | ** sentence-transformers - https://github.com/UKPLab/sentence-transformers 4 | 5 | Apache License 6 | Version 2.0, January 2004 7 | http://www.apache.org/licenses/ 8 | 9 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 10 | 11 | 1. Definitions. 12 | 13 | "License" shall mean the terms and conditions for use, reproduction, 14 | and distribution as defined by Sections 1 through 9 of this document. 15 | 16 | "Licensor" shall mean the copyright owner or entity authorized by 17 | the copyright owner that is granting the License. 18 | 19 | "Legal Entity" shall mean the union of the acting entity and all 20 | other entities that control, are controlled by, or are under common 21 | control with that entity. For the purposes of this definition, 22 | "control" means (i) the power, direct or indirect, to cause the 23 | direction or management of such entity, whether by contract or 24 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 25 | outstanding shares, or (iii) beneficial ownership of such entity. 26 | 27 | "You" (or "Your") shall mean an individual or Legal Entity 28 | exercising permissions granted by this License. 29 | 30 | "Source" form shall mean the preferred form for making modifications, 31 | including but not limited to software source code, documentation 32 | source, and configuration files. 33 | 34 | "Object" form shall mean any form resulting from mechanical 35 | transformation or translation of a Source form, including but 36 | not limited to compiled object code, generated documentation, 37 | and conversions to other media types. 38 | 39 | "Work" shall mean the work of authorship, whether in Source or 40 | Object form, made available under the License, as indicated by a 41 | copyright notice that is included in or attached to the work 42 | (an example is provided in the Appendix below). 43 | 44 | "Derivative Works" shall mean any work, whether in Source or Object 45 | form, that is based on (or derived from) the Work and for which the 46 | editorial revisions, annotations, elaborations, or other modifications 47 | represent, as a whole, an original work of authorship. For the purposes 48 | of this License, Derivative Works shall not include works that remain 49 | separable from, or merely link (or bind by name) to the interfaces of, 50 | the Work and Derivative Works thereof. 51 | 52 | "Contribution" shall mean any work of authorship, including 53 | the original version of the Work and any modifications or additions 54 | to that Work or Derivative Works thereof, that is intentionally 55 | submitted to Licensor for inclusion in the Work by the copyright owner 56 | or by an individual or Legal Entity authorized to submit on behalf of 57 | the copyright owner. For the purposes of this definition, "submitted" 58 | means any form of electronic, verbal, or written communication sent 59 | to the Licensor or its representatives, including but not limited to 60 | communication on electronic mailing lists, source code control systems, 61 | and issue tracking systems that are managed by, or on behalf of, the 62 | Licensor for the purpose of discussing and improving the Work, but 63 | excluding communication that is conspicuously marked or otherwise 64 | designated in writing by the copyright owner as "Not a Contribution." 65 | 66 | "Contributor" shall mean Licensor and any individual or Legal Entity 67 | on behalf of whom a Contribution has been received by Licensor and 68 | subsequently incorporated within the Work. 69 | 70 | 2. Grant of Copyright License. Subject to the terms and conditions of 71 | this License, each Contributor hereby grants to You a perpetual, 72 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 73 | copyright license to reproduce, prepare Derivative Works of, 74 | publicly display, publicly perform, sublicense, and distribute the 75 | Work and such Derivative Works in Source or Object form. 76 | 77 | 3. Grant of Patent License. Subject to the terms and conditions of 78 | this License, each Contributor hereby grants to You a perpetual, 79 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 80 | (except as stated in this section) patent license to make, have made, 81 | use, offer to sell, sell, import, and otherwise transfer the Work, 82 | where such license applies only to those patent claims licensable 83 | by such Contributor that are necessarily infringed by their 84 | Contribution(s) alone or by combination of their Contribution(s) 85 | with the Work to which such Contribution(s) was submitted. If You 86 | institute patent litigation against any entity (including a 87 | cross-claim or counterclaim in a lawsuit) alleging that the Work 88 | or a Contribution incorporated within the Work constitutes direct 89 | or contributory patent infringement, then any patent licenses 90 | granted to You under this License for that Work shall terminate 91 | as of the date such litigation is filed. 92 | 93 | 4. Redistribution. You may reproduce and distribute copies of the 94 | Work or Derivative Works thereof in any medium, with or without 95 | modifications, and in Source or Object form, provided that You 96 | meet the following conditions: 97 | 98 | (a) You must give any other recipients of the Work or 99 | Derivative Works a copy of this License; and 100 | 101 | (b) You must cause any modified files to carry prominent notices 102 | stating that You changed the files; and 103 | 104 | (c) You must retain, in the Source form of any Derivative Works 105 | that You distribute, all copyright, patent, trademark, and 106 | attribution notices from the Source form of the Work, 107 | excluding those notices that do not pertain to any part of 108 | the Derivative Works; and 109 | 110 | (d) If the Work includes a "NOTICE" text file as part of its 111 | distribution, then any Derivative Works that You distribute must 112 | include a readable copy of the attribution notices contained 113 | within such NOTICE file, excluding those notices that do not 114 | pertain to any part of the Derivative Works, in at least one 115 | of the following places: within a NOTICE text file distributed 116 | as part of the Derivative Works; within the Source form or 117 | documentation, if provided along with the Derivative Works; or, 118 | within a display generated by the Derivative Works, if and 119 | wherever such third-party notices normally appear. The contents 120 | of the NOTICE file are for informational purposes only and 121 | do not modify the License. You may add Your own attribution 122 | notices within Derivative Works that You distribute, alongside 123 | or as an addendum to the NOTICE text from the Work, provided 124 | that such additional attribution notices cannot be construed 125 | as modifying the License. 126 | 127 | You may add Your own copyright statement to Your modifications and 128 | may provide additional or different license terms and conditions 129 | for use, reproduction, or distribution of Your modifications, or 130 | for any such Derivative Works as a whole, provided Your use, 131 | reproduction, and distribution of the Work otherwise complies with 132 | the conditions stated in this License. 133 | 134 | 5. Submission of Contributions. Unless You explicitly state otherwise, 135 | any Contribution intentionally submitted for inclusion in the Work 136 | by You to the Licensor shall be under the terms and conditions of 137 | this License, without any additional terms or conditions. 138 | Notwithstanding the above, nothing herein shall supersede or modify 139 | the terms of any separate license agreement you may have executed 140 | with Licensor regarding such Contributions. 141 | 142 | 6. Trademarks. This License does not grant permission to use the trade 143 | names, trademarks, service marks, or product names of the Licensor, 144 | except as required for reasonable and customary use in describing the 145 | origin of the Work and reproducing the content of the NOTICE file. 146 | 147 | 7. Disclaimer of Warranty. Unless required by applicable law or 148 | agreed to in writing, Licensor provides the Work (and each 149 | Contributor provides its Contributions) on an "AS IS" BASIS, 150 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 151 | implied, including, without limitation, any warranties or conditions 152 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 153 | PARTICULAR PURPOSE. You are solely responsible for determining the 154 | appropriateness of using or redistributing the Work and assume any 155 | risks associated with Your exercise of permissions under this License. 156 | 157 | 8. Limitation of Liability. In no event and under no legal theory, 158 | whether in tort (including negligence), contract, or otherwise, 159 | unless required by applicable law (such as deliberate and grossly 160 | negligent acts) or agreed to in writing, shall any Contributor be 161 | liable to You for damages, including any direct, indirect, special, 162 | incidental, or consequential damages of any character arising as a 163 | result of this License or out of the use or inability to use the 164 | Work (including but not limited to damages for loss of goodwill, 165 | work stoppage, computer failure or malfunction, or any and all 166 | other commercial damages or losses), even if such Contributor 167 | has been advised of the possibility of such damages. 168 | 169 | 9. Accepting Warranty or Additional Liability. While redistributing 170 | the Work or Derivative Works thereof, You may choose to offer, 171 | and charge a fee for, acceptance of support, warranty, indemnity, 172 | or other liability obligations and/or rights consistent with this 173 | License. However, in accepting such obligations, You may act only 174 | on Your own behalf and on Your sole responsibility, not on behalf 175 | of any other Contributor, and only if You agree to indemnify, 176 | defend, and hold each Contributor harmless for any liability 177 | incurred by, or claims asserted against, such Contributor by reason 178 | of your accepting any such warranty or additional liability. 179 | 180 | END OF TERMS AND CONDITIONS 181 | 182 | APPENDIX: How to apply the Apache License to your work. 183 | 184 | To apply the Apache License to your work, attach the following 185 | boilerplate notice, with the fields enclosed by brackets "{}" 186 | replaced with your own identifying information. (Don't include 187 | the brackets!) The text should be enclosed in the appropriate 188 | comment syntax for the file format. We also recommend that a 189 | file or class name and description of purpose be included on the 190 | same "printed page" as the copyright notice for easier 191 | identification within third-party archives. 192 | 193 | Copyright {yyyy} {name of copyright owner} 194 | 195 | Licensed under the Apache License, Version 2.0 (the "License"); 196 | you may not use this file except in compliance with the License. 197 | You may obtain a copy of the License at 198 | 199 | http://www.apache.org/licenses/LICENSE-2.0 200 | 201 | Unless required by applicable law or agreed to in writing, software 202 | distributed under the License is distributed on an "AS IS" BASIS, 203 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 204 | See the License for the specific language governing permissions and 205 | limitations under the License. 206 | -------------------------------------------------------------------------------- /src/self_distill.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import torch 5 | from torch.utils.data import DataLoader 6 | import torch.nn.functional as F 7 | from sentence_transformers import models, losses, util, SentenceTransformer, LoggingHandler 8 | from sentence_transformers.cross_encoder import CrossEncoder 9 | from sentence_transformers.cross_encoder.evaluation import CECorrelationEvaluator 10 | from sentence_transformers.readers import InputExample 11 | from datetime import datetime 12 | import argparse 13 | import logging 14 | import sys 15 | import random 16 | import tqdm 17 | import math 18 | import os 19 | import numpy as np 20 | import pandas as pd 21 | 22 | # import from local codes 23 | from data import load_data 24 | from eval import eval_encoder 25 | 26 | from sentence_transformers_ext.bi_encoder_eval import ( 27 | EmbeddingSimilarityEvaluator, 28 | EmbeddingSimilarityEvaluatorEnsemble, 29 | EmbeddingSimilarityEvaluatorAUC, 30 | EmbeddingSimilarityEvaluatorAUCEnsemble 31 | ) 32 | from sentence_transformers_ext.cross_encoder_eval import ( 33 | CECorrelationEvaluatorEnsemble, 34 | CECorrelationEvaluatorAUC, 35 | CECorrelationEvaluatorAUCEnsemble 36 | ) 37 | 38 | parser = argparse.ArgumentParser() 39 | parser.add_argument("--model_name_or_path", type=str, 40 | default="princeton-nlp/unsup-simcse-bert-base-uncased", 41 | help="Transformers' model name or path") 42 | parser.add_argument("--task", type=str, default='sts', 43 | help='{sts|sickr|sts_sickr|qqp|qnli|mrpc|snli|custom}') 44 | parser.add_argument("--device", type=int, default=0) 45 | parser.add_argument("--cycle", type=int, default=3) 46 | parser.add_argument("--bi_encoder_pooling_mode", type=str, default='cls', 47 | help="{cls|mean}") 48 | parser.add_argument("--num_epochs_cross_encoder", type=int, default=1) 49 | parser.add_argument("--num_epochs_bi_encoder", type=int, default=10) 50 | parser.add_argument("--batch_size_cross_encoder", type=int, default=32) 51 | parser.add_argument("--batch_size_bi_encoder", type=int, default=128) 52 | parser.add_argument("--init_with_new_models", action="store_true") 53 | parser.add_argument("--random_seed", type=int, default=2021) 54 | #parser.add_argument("--use_raw_data_from_all_tasks", action="store_true") 55 | parser.add_argument("--add_snli_data", type=int, default=0) 56 | parser.add_argument("--custom_corpus_path", type=str, default=None) 57 | parser.add_argument("--num_training_pairs", type=int, default=None) 58 | parser.add_argument("--save_all_predictions", action="store_true") 59 | parser.add_argument("--quick_test", action="store_true") 60 | 61 | 62 | args = parser.parse_args() 63 | print (args) 64 | 65 | torch.manual_seed(args.random_seed) 66 | 67 | #### Just some code to print debug information to stdout 68 | logging.basicConfig(format="%(asctime)s - %(message)s", 69 | datefmt="%Y-%m-%d %H:%M:%S", 70 | level=logging.INFO, 71 | handlers=[LoggingHandler()]) 72 | 73 | 74 | ### read datasets 75 | all_pairs, all_test, dev_samples = load_data(args.task, fpath=args.custom_corpus_path) 76 | 77 | # load_pairs from other tasks 78 | """ 79 | if args.use_raw_data_from_all_tasks: 80 | all_pairs_qqp, _, _ = load_data("qqp") 81 | all_pairs_qnli, _, _ = load_data("qnli") 82 | all_pairs_mrpc, _, _ = load_data("mrpc") 83 | all_pairs = all_pairs + all_pairs_qqp + all_pairs_qnli + all_pairs_mrpc 84 | """ 85 | 86 | if args.add_snli_data != 0: 87 | random.seed(args.random_seed) 88 | all_pairs_snli, _, _ = load_data("snli") 89 | all_pairs_snli_sampled = random.sample(all_pairs_snli, args.add_snli_data) 90 | all_pairs = all_pairs + all_pairs_snli_sampled 91 | 92 | if args.quick_test: 93 | all_pairs = all_pairs[:1000] # for quick testing 94 | 95 | # randomly select training pairs for control study 96 | if args.num_training_pairs is not None: 97 | print ("before sampling |raw sentence pairs|:", len(all_pairs)) 98 | if args.num_training_pairs == -1: 99 | # use all 100 | pass 101 | else: 102 | random.seed(args.random_seed) 103 | all_pairs = random.sample(all_pairs, args.num_training_pairs) 104 | 105 | print ("|raw sentence pairs|:", len(all_pairs)) 106 | print ("|dev set|:", len(dev_samples)) 107 | for key in all_test: 108 | print ("|test set: %s|" % key, len(all_test[key])) 109 | 110 | model_name = args.model_name_or_path #"princeton-nlp/unsup-simcse-bert-base-uncased" 111 | simcse2base = { 112 | "princeton-nlp/unsup-simcse-roberta-base": "roberta-base", 113 | "princeton-nlp/unsup-simcse-roberta-large": "roberta-large", 114 | "princeton-nlp/unsup-simcse-bert-base-uncased": "bert-base-uncased", 115 | "princeton-nlp/unsup-simcse-bert-large-uncased": "bert-large-uncased"} 116 | batch_size_cross_encoder = args.batch_size_cross_encoder 117 | batch_size_bi_encoder = args.batch_size_bi_encoder 118 | num_epochs_cross_encoder = args.num_epochs_cross_encoder 119 | num_epochs_bi_encoder = args.num_epochs_bi_encoder 120 | max_seq_length = 32 121 | total_cycle = args.cycle 122 | device=args.device 123 | 124 | logging.info ("########## load base model and evaluate ##########") 125 | 126 | ###### Bi-encoder (sentence-transformers) ###### 127 | logging.info(f"Loading bi-encoder model: {model_name}") 128 | # Use Huggingface/transformers model (like BERT, RoBERTa, XLNet, XLM-R) for mapping tokens to embeddings 129 | word_embedding_model = models.Transformer(model_name, max_seq_length=max_seq_length) 130 | 131 | # Apply mean pooling to get one fixed sized sentence vector 132 | pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(), pooling_mode=args.bi_encoder_pooling_mode) 133 | 134 | bi_encoder = SentenceTransformer(modules=[word_embedding_model, pooling_model], device=device) 135 | 136 | # eval bi-encoder 137 | logging.info ("Evaluate bi-encoder...") 138 | eval_encoder(all_test, bi_encoder, task=args.task, enc_type="bi") 139 | 140 | start_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") 141 | 142 | bi_encoder_dev_scores = [] 143 | cross_encoder_dev_scores = [] 144 | bi_encoder_path = model_name 145 | all_predictions_bi_encoder = {} 146 | all_predictions_cross_encoder = {} 147 | 148 | 149 | for cycle in range(total_cycle): 150 | cycle += 1 151 | logging.info (f"########## cycle {cycle:.0f} starts ##########") 152 | 153 | ###### label data with bi-encoder ###### 154 | # label sentence pairs with bi-encoder 155 | logging.info ("Label sentence pairs with bi-encoder...") 156 | 157 | # Two lists of sentences 158 | sents1 = [p[0] for p in all_pairs] 159 | sents2 = [p[1] for p in all_pairs] 160 | 161 | #Compute embedding for both lists 162 | embeddings1 = bi_encoder.encode(sents1, convert_to_tensor=True) 163 | embeddings2 = bi_encoder.encode(sents2, convert_to_tensor=True) 164 | 165 | #Compute cosine-similarits 166 | cosine_scores = F.cosine_similarity(embeddings1, embeddings2) 167 | 168 | # save the predictions 169 | all_predictions_bi_encoder["bi_encoder_cycle_"+str(cycle-1)] = cosine_scores.cpu().numpy() 170 | 171 | # form (self-labelled) train set 172 | train_samples = [] 173 | 174 | for i in range(len(sents1)): 175 | if args.task in ["qnli"]: 176 | train_samples.append(InputExample(texts=[sents1[i], sents2[i]], label=cosine_scores[i])) 177 | else: 178 | train_samples.append(InputExample(texts=[sents1[i], sents2[i]], label=cosine_scores[i])) 179 | train_samples.append(InputExample(texts=[sents2[i], sents1[i]], label=cosine_scores[i])) 180 | 181 | del bi_encoder, embeddings1, embeddings2, cosine_scores 182 | torch.cuda.empty_cache() 183 | 184 | ###### Cross-encoder learning ###### 185 | if args.init_with_new_models: 186 | bi_encoder_path = simcse2base[model_name] #model_name # always use new model (PLM) 187 | logging.info(f"Loading cross-encoder model: {bi_encoder_path}") 188 | # Use Huggingface/transformers model (like BERT, RoBERTa, XLNet, XLM-R) for cross-encoder model 189 | cross_encoder = CrossEncoder(bi_encoder_path, num_labels=1, device=device, max_length=64) 190 | 191 | # We wrap gold_samples (which is a List[InputExample]) into a pytorch DataLoader 192 | train_dataloader = DataLoader(train_samples, shuffle=True, batch_size=batch_size_cross_encoder) 193 | 194 | # We add an evaluator, which evaluates the performance during training 195 | if args.task in ["sts", "sickr", "sts_sickr", "custom"]: 196 | evaluator = CECorrelationEvaluator.from_input_examples(dev_samples, name='dev') 197 | else: 198 | evaluator = CECorrelationEvaluatorAUC.from_input_examples(dev_samples, name='dev') 199 | 200 | # Configure the training 201 | warmup_steps = math.ceil(len(train_dataloader) * num_epochs_cross_encoder * 0.1) #10% of train data for warm-up 202 | logging.info(f"Warmup-steps: {warmup_steps}") 203 | 204 | cross_encoder_path = f"output/cross-encoder/" \ 205 | f"{args.task}_cycle{cycle}_{model_name.replace('/', '-')}-{start_time}" 206 | 207 | # Train the cross-encoder model 208 | cross_encoder.fit( 209 | train_dataloader=train_dataloader, 210 | evaluator=evaluator, 211 | evaluation_steps=200, 212 | use_amp=True, 213 | epochs=num_epochs_cross_encoder, 214 | warmup_steps=warmup_steps, 215 | output_path=cross_encoder_path) 216 | 217 | 218 | 219 | cross_encoder = CrossEncoder(cross_encoder_path, max_length=64, device=device) 220 | #cross_encoder = CrossEncoder(cross_encoder_path, device=device) 221 | 222 | # eval cross-encoder 223 | dev_score = evaluator(cross_encoder) 224 | cross_encoder_dev_scores.append(dev_score) 225 | logging.info (f"***** dev's spearman's rho: {dev_score:.4f} *****") 226 | 227 | ###### label data with cross-encoder ###### 228 | # label sentence pairs with cross-encoder 229 | logging.info ("Label sentence pairs with cross-encoder...") 230 | silver_scores = cross_encoder.predict(all_pairs) 231 | silver_samples = list(InputExample(texts=[data[0], data[1]], label=score) for \ 232 | data, score in zip(all_pairs, silver_scores)) 233 | 234 | del cross_encoder 235 | torch.cuda.empty_cache() 236 | 237 | # save the predictions 238 | all_predictions_cross_encoder["cross_encoder_cycle_"+str(cycle)] = silver_scores 239 | 240 | ###### Bi-encoder learning ###### 241 | if args.init_with_new_models: 242 | cross_encoder_path = model_name # always use new model (SimCSE) 243 | logging.info(f"Loading bi-encoder model: {cross_encoder_path}") 244 | word_embedding_model = models.Transformer(cross_encoder_path, max_seq_length=32) 245 | pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(), pooling_mode=args.bi_encoder_pooling_mode) 246 | bi_encoder = SentenceTransformer(modules=[word_embedding_model, pooling_model], device=device) 247 | 248 | train_dataloader = DataLoader(silver_samples, shuffle=True, batch_size=batch_size_bi_encoder) 249 | train_loss = losses.CosineSimilarityLoss(model=bi_encoder) 250 | 251 | if args.task in ["sts", "sickr", "sts_sickr", "custom"]: 252 | evaluator = EmbeddingSimilarityEvaluator.from_input_examples(dev_samples, name='dev') 253 | else: 254 | evaluator = EmbeddingSimilarityEvaluatorAUC.from_input_examples(dev_samples, name='dev') 255 | 256 | # Configure the training. 257 | #warmup_steps = math.ceil(len(train_dataloader) * num_epochs * 0.1) #10% of train data for warm-up 258 | #logging.info(f"Warmup-steps: {warmup_steps}") 259 | 260 | bi_encoder_path = f"output/bi-encoder/" \ 261 | f"{args.task}_cycle{cycle}_{model_name.replace('/', '-')}-{start_time}" 262 | 263 | # Train the bi-encoder model 264 | bi_encoder.fit( 265 | train_objectives=[(train_dataloader, train_loss)], 266 | evaluator=evaluator, 267 | epochs=num_epochs_bi_encoder, 268 | evaluation_steps=200, 269 | warmup_steps=0, 270 | output_path=bi_encoder_path, 271 | optimizer_params= {"lr": 5e-5}, 272 | use_amp=True, 273 | ) 274 | 275 | bi_encoder = SentenceTransformer(bi_encoder_path, device=device) 276 | 277 | # eval bi-encoder 278 | dev_score = evaluator(bi_encoder) 279 | bi_encoder_dev_scores.append(dev_score) 280 | logging.info (f"***** dev's spearman's rho: {dev_score:.4f} *****") 281 | 282 | logging.info (bi_encoder_dev_scores) 283 | logging.info (cross_encoder_dev_scores) 284 | 285 | # best bi-encoder 286 | best_cycle_bi_encoder = np.argmax(bi_encoder_dev_scores)+1 287 | best_cycle_bi_encoder_path = f"output/bi-encoder/"\ 288 | f"{args.task}_cycle{best_cycle_bi_encoder}_{model_name.replace('/', '-')}-{start_time}" 289 | # eval bi-encoder 290 | logging.info (f"Evaluate best bi-encoder (from cycle {best_cycle_bi_encoder})...") 291 | bi_encoder = SentenceTransformer(best_cycle_bi_encoder_path, device=device) 292 | logging.info (best_cycle_bi_encoder_path) 293 | eval_encoder(all_test, bi_encoder, task=args.task, enc_type="bi") 294 | 295 | 296 | # best cross-encoder 297 | best_cycle_cross_encoder = np.argmax(cross_encoder_dev_scores)+1 298 | best_cycle_cross_encoder_path = f"output/cross-encoder/"\ 299 | f"{args.task}_cycle{best_cycle_cross_encoder}_{model_name.replace('/', '-')}-{start_time}" 300 | # eval cross-encoder 301 | logging.info (f"Evaluate best cross-encoder (from cycle {best_cycle_cross_encoder})...") 302 | logging.info (best_cycle_cross_encoder_path) 303 | #cross_encoder = CrossEncoder(best_cycle_cross_encoder_path, max_length=64, device=device) 304 | cross_encoder = CrossEncoder(best_cycle_cross_encoder_path, device=device) 305 | eval_encoder(all_test, cross_encoder, task=args.task, enc_type="cross") 306 | 307 | # save all predictions 308 | best_cycle_bi_encoder_csv_path = f"output/bi-encoder/" \ 309 | f"{args.task}_cycle{best_cycle_bi_encoder}_{model_name.replace('/', '-')}-{start_time}_all_preds.csv" 310 | best_cycle_cross_encoder_csv_path = f"output/cross-encoder/" \ 311 | f"{args.task}_cycle{best_cycle_cross_encoder}_{model_name.replace('/', '-')}-{start_time}_all_preds.csv" 312 | 313 | if args.save_all_predictions: 314 | bi_pred_df = pd.DataFrame(np.array(list(all_predictions_bi_encoder.values())).T, columns=list(all_predictions_bi_encoder.keys())) 315 | cross_pred_df = pd.DataFrame(np.array(list(all_predictions_cross_encoder.values())).T, columns=list(all_predictions_cross_encoder.keys())) 316 | bi_pred_df.to_csv(best_cycle_bi_encoder_csv_path, index=False) 317 | cross_pred_df.to_csv(best_cycle_cross_encoder_csv_path, index=False) 318 | 319 | logging.info ("\n") 320 | print (args) 321 | logging.info ("\n") 322 | logging.info ("***** END *****") 323 | -------------------------------------------------------------------------------- /src/data.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | from datasets import load_dataset 5 | from sentence_transformers.readers import InputExample 6 | from sentence_transformers import util 7 | import pandas as pd 8 | import os 9 | import gzip 10 | import csv 11 | import logging 12 | import subprocess 13 | from zipfile import ZipFile 14 | 15 | SCORE = "score" 16 | SPLIT = "split" 17 | SENTENCE = "sentence" 18 | SENTENCE1 = "sentence1" 19 | SENTENCE2 = "sentence2" 20 | QUESTION = "question" 21 | QUESTION1 = "question1" 22 | QUESTION2 = "question2" 23 | 24 | def load_snli(): 25 | """ 26 | Load the SNLI dataset (https://nlp.stanford.edu/projects/snli/) from huggingface dataset portal. 27 | Parameters 28 | ---------- 29 | None 30 | Returns 31 | ---------- 32 | all_pairs: a list of sentence pairs from the SNLI dataset 33 | """ 34 | 35 | all_pairs = [] 36 | 37 | dataset = load_dataset("snli") 38 | all_pairs += [(row["premise"], row["hypothesis"]) for row in dataset["train"]] 39 | all_pairs += [(row["premise"], row["hypothesis"]) for row in dataset["validation"]] 40 | all_pairs += [(row["premise"], row["hypothesis"]) for row in dataset["test"]] 41 | 42 | return all_pairs, None, None 43 | 44 | def load_sts(): 45 | """ 46 | Load the STS datasets: 47 | STS 2012: https://www.cs.york.ac.uk/semeval-2012/task6/ 48 | STS 2013: http://ixa2.si.ehu.eus/sts/ 49 | STS 2014: https://alt.qcri.org/semeval2014/task10/ 50 | STS 2015: https://alt.qcri.org/semeval2015/task2/ 51 | STS 2016: https://alt.qcri.org/semeval2016/task1/ 52 | STS-Benchmark: http://ixa2.si.ehu.eus/stswiki/index.php/STSbenchmark 53 | Parameters 54 | ---------- 55 | None 56 | Returns 57 | ---------- 58 | all_pairs: a list of sentence pairs from the STS datasets 59 | all_test: a dict of all test sets of the STS datasets 60 | dev_samples: a list of InputExample instances as the dev set 61 | """ 62 | 63 | # Check if STS datasets exsist. If not, download and extract it 64 | sts_dataset_path = "data/" 65 | if not os.path.exists(os.path.join(sts_dataset_path, "STS_data")): 66 | logging.info("Dataset not found. Download") 67 | zip_save_path = "data/STS_data.zip" 68 | #os.system("wget https://fangyuliu.me/data/STS_data.zip -P data/") 69 | subprocess.run(["wget", "--no-check-certificate", "https://fangyuliu.me/data/STS_data.zip", "-P", "data/"]) 70 | with ZipFile(zip_save_path, "r") as zipIn: 71 | zipIn.extractall(sts_dataset_path) 72 | 73 | all_pairs = [] 74 | all_test = {} 75 | dedup = set() 76 | 77 | # read sts 2012-2016 78 | for year in ["2012","2013","2014","2015","2016"]: 79 | all_test[year] = [] 80 | for year in ["2012","2013","2014","2015","2016"]: 81 | df = pd.read_csv(f"data/STS_data/en/{year}.test.tsv", delimiter="\t", 82 | quoting=csv.QUOTE_NONE, encoding="utf-8", names=[SCORE, SENTENCE1, SENTENCE2]) 83 | for row in df.iterrows(): 84 | if str(row[1][SCORE]) == "nan": continue 85 | all_test[year].append(InputExample(texts=[row[1][SENTENCE1], row[1][SENTENCE2]], label=row[1][SCORE])) 86 | 87 | df = pd.read_csv("data/STS_data/en/2012_to_2016.test.tsv", delimiter="\t", 88 | quoting=csv.QUOTE_NONE, encoding="utf-8", names=[SCORE, SENTENCE1, SENTENCE2]) 89 | 90 | for row in df.iterrows(): 91 | concat = row[1][SENTENCE1]+row[1][SENTENCE2] 92 | if concat in dedup: 93 | continue 94 | all_pairs.append([row[1][SENTENCE1], row[1][SENTENCE2]]) 95 | dedup.add(concat) 96 | 97 | # sts-b 98 | # Check if STS-B exsists. If not, download and extract it 99 | sts_dataset_path = "data/stsbenchmark.tsv.gz" 100 | if not os.path.exists(sts_dataset_path): 101 | util.http_get("https://sbert.net/datasets/stsbenchmark.tsv.gz", sts_dataset_path) 102 | # read sts-b 103 | dev_samples_stsb = [] 104 | test_samples_stsb = [] 105 | with gzip.open(sts_dataset_path, "rt", encoding="utf8") as fIn: 106 | reader = csv.DictReader(fIn, delimiter="\t", quoting=csv.QUOTE_NONE) 107 | for row in reader: 108 | score = float(row[SCORE]) / 5.0 # Normalize score to range 0 ... 1 109 | 110 | if row[SPLIT] == "dev": 111 | dev_samples_stsb.append(InputExample(texts=[row[SENTENCE1], row[SENTENCE2]], label=score)) 112 | elif row[SPLIT] == "test": 113 | test_samples_stsb.append(InputExample(texts=[row[SENTENCE1], row[SENTENCE2]], label=score)) 114 | 115 | # add (non-duplicated) sentence pair to all_pairs 116 | concat = row[SENTENCE1]+row[SENTENCE2] 117 | if concat in dedup: 118 | continue 119 | all_pairs.append([row[SENTENCE1], row[SENTENCE2]]) 120 | dedup.add(concat) 121 | 122 | all_test["stsb"] = test_samples_stsb 123 | dev_samples = dev_samples_stsb 124 | return all_pairs, all_test, dev_samples 125 | 126 | def load_sickr(): 127 | """ 128 | Load the SICK-R dataset: http://clic.cimec.unitn.it/composes/sick.html 129 | Parameters 130 | ---------- 131 | None 132 | Returns 133 | ---------- 134 | all_pairs: a list of sentence pairs from the SICK-R dataset 135 | all_test: a dict of all test sets 136 | dev_samples: a list of InputExample instances as the dev set 137 | """ 138 | 139 | 140 | sts_dataset_path = "data/" 141 | if not os.path.exists(os.path.join(sts_dataset_path, "STS_data")): 142 | logging.info("Dataset not found. Download") 143 | zip_save_path = "data/STS_data.zip" 144 | subprocess.run(["wget", "--no-check-certificate", "https://fangyuliu.me/data/STS_data.zip", "-P", "data/"]) 145 | with ZipFile(zip_save_path, "r") as zipIn: 146 | zipIn.extractall(sts_dataset_path) 147 | 148 | all_pairs = [] 149 | all_test = {} 150 | dedup = set() 151 | 152 | # read sickr 153 | test_samples_sickr = [] 154 | dev_samples_sickr = [] 155 | 156 | df = pd.read_csv("data/STS_data/en/SICK_annotated.txt", delimiter="\t", 157 | quoting=csv.QUOTE_NONE, encoding="utf-8") 158 | 159 | for row in df.iterrows(): 160 | row = row[1] 161 | score = row["relatedness_score"] / 5.0 162 | if row["SemEval_set"] == "TEST": 163 | test_samples_sickr.append(InputExample(texts=[row["sentence_A"], row["sentence_B"]], label=score)) 164 | elif row["SemEval_set"] == "TRIAL": 165 | dev_samples_sickr.append(InputExample(texts=[row["sentence_A"], row["sentence_B"]], label=score)) 166 | 167 | concat = row["sentence_A"]+row["sentence_B"] 168 | if concat in dedup: 169 | continue 170 | all_pairs.append([row["sentence_A"], row["sentence_B"]]) 171 | dedup.add(concat) 172 | 173 | all_test["sickr"] = test_samples_sickr 174 | dev_samples = dev_samples_sickr 175 | return all_pairs, all_test, dev_samples 176 | 177 | def load_qqp(): 178 | """ 179 | Load the QQP dataset (https://www.kaggle.com/c/quora-question-pairs) from huggingface dataset portal. 180 | Parameters 181 | ---------- 182 | None 183 | Returns 184 | ---------- 185 | all_pairs: a list of sentence pairs from the QQP dataset 186 | all_test: a dict of all test sets 187 | dev_samples: a list of InputExample instances as the dev set 188 | """ 189 | 190 | all_pairs = [] 191 | all_test = {} 192 | 193 | dev_samples_qqp = [] 194 | test_samples_qqp = [] 195 | 196 | # Check if the QQP dataset exists. If not, download and extract 197 | qqp_dataset_path = "data/quora-IR-dataset" 198 | if not os.path.exists(qqp_dataset_path): 199 | logging.info("Dataset not found. Download") 200 | zip_save_path = 'data/quora-IR-dataset.zip' 201 | util.http_get(url='https://sbert.net/datasets/quora-IR-dataset.zip', path=zip_save_path) 202 | with ZipFile(zip_save_path, 'r') as zipIn: 203 | zipIn.extractall(qqp_dataset_path) 204 | 205 | qqp_datapoints_cut_train = 10000 206 | qqp_datapoints_cut_val = 1000 207 | qqp_datapoints_cut_test = 10000 208 | 209 | with open(os.path.join(qqp_dataset_path, "classification/train_pairs.tsv"), encoding="utf8") as fIn: 210 | reader = csv.DictReader(fIn, delimiter="\t", quoting=csv.QUOTE_NONE) 211 | for i, row in enumerate(reader): 212 | if i == qqp_datapoints_cut_train: break 213 | all_pairs.append([row[QUESTION1], row[QUESTION2]]) 214 | 215 | with open(os.path.join(qqp_dataset_path, "classification/dev_pairs.tsv"), encoding="utf8") as fIn: 216 | reader = csv.DictReader(fIn, delimiter="\t", quoting=csv.QUOTE_NONE) 217 | for i, row in enumerate(reader): 218 | if i == qqp_datapoints_cut_val: break 219 | dev_samples_qqp.append(InputExample(texts=[row[QUESTION1], row[QUESTION2]], label=int(row['is_duplicate']))) 220 | all_pairs.append([row[QUESTION1], row[QUESTION2]]) 221 | 222 | with open(os.path.join(qqp_dataset_path, "classification/test_pairs.tsv"), encoding="utf8") as fIn: 223 | reader = csv.DictReader(fIn, delimiter="\t", quoting=csv.QUOTE_NONE) 224 | for i, row in enumerate(reader): 225 | if i == qqp_datapoints_cut_test: break 226 | test_samples_qqp.append(InputExample(texts=[row[QUESTION1], row[QUESTION2]], label=int(row['is_duplicate']))) 227 | all_pairs.append([row[QUESTION1], row[QUESTION2]]) 228 | 229 | all_test["qqp"] = test_samples_qqp 230 | dev_samples = dev_samples_qqp 231 | 232 | return all_pairs, all_test, dev_samples 233 | 234 | def load_qnli(): 235 | """ 236 | Load the QNLI dataset (part of GLUE: https://gluebenchmark.com/) from huggingface dataset portal. 237 | Parameters 238 | ---------- 239 | None 240 | Returns 241 | ---------- 242 | all_pairs: a list of sentence pairs from the QNLI dataset 243 | all_test: a dict of all test sets 244 | dev_samples: a list of InputExample instances as the dev set 245 | """ 246 | 247 | all_pairs = [] 248 | all_test = {} 249 | 250 | dev_samples_qnli = [] 251 | test_samples_qnli = [] 252 | 253 | dataset = load_dataset("glue", "qnli") 254 | 255 | qnli_datapoints_cut_train = 10000 256 | 257 | for i, row in enumerate(dataset["train"]): 258 | if i == qnli_datapoints_cut_train: break 259 | all_pairs.append([row[QUESTION], row[SENTENCE]]) 260 | 261 | for row in dataset["validation"]: 262 | label = 0 if row["label"]==1 else 1 263 | dev_samples_qnli.append( 264 | InputExample(texts=[row[QUESTION], row[SENTENCE]], label=label)) 265 | all_pairs.append([row[QUESTION], row[SENTENCE]]) 266 | 267 | # test labels of qnli are not given, use the first 1k in dev set as test 268 | all_test["qnli"] = dev_samples_qnli[1000:] 269 | dev_samples = dev_samples_qnli[:1000] 270 | 271 | return all_pairs, all_test, dev_samples 272 | 273 | def load_mrpc(): 274 | """ 275 | Load the MRPC dataset (https://www.microsoft.com/en-us/download/details.aspx?id=52398) from huggingface dataset portal. 276 | Parameters 277 | ---------- 278 | None 279 | Returns 280 | ---------- 281 | all_pairs: a list of sentence pairs from the MRPC dataset 282 | all_test: a dict of all test sets 283 | dev_samples: a list of InputExample instances as the dev set 284 | """ 285 | all_pairs = [] 286 | all_test = {} 287 | 288 | dev_samples_mrpc = [] 289 | test_samples_mrpc = [] 290 | 291 | dataset = load_dataset("glue", "mrpc") 292 | 293 | for row in dataset["train"]: 294 | all_pairs.append([row[SENTENCE1], row[SENTENCE2]]) 295 | 296 | for row in dataset["validation"]: 297 | dev_samples_mrpc.append( 298 | InputExample(texts=[row[SENTENCE1], row[SENTENCE2]], label=int(row["label"]))) 299 | all_pairs.append([row[SENTENCE1], row[SENTENCE2]]) 300 | 301 | for row in dataset["test"]: 302 | test_samples_mrpc.append( 303 | InputExample(texts=[row[SENTENCE1], row[SENTENCE2]], label=int(row["label"]))) 304 | all_pairs.append([row[SENTENCE1], row[SENTENCE2]]) 305 | 306 | all_test["mrpc"] = test_samples_mrpc 307 | dev_samples = dev_samples_mrpc 308 | 309 | return all_pairs, all_test, dev_samples 310 | 311 | def load_sts_and_sickr(): 312 | """ 313 | Load both STS and SICK-R datasets. Use STS-B's dev set for dev. 314 | Parameters 315 | ---------- 316 | None 317 | Returns 318 | ---------- 319 | all_pairs: a list of sentence pairs from the STS+SICK-R dataset 320 | all_test: a dict of all test sets 321 | dev_samples: a list of InputExample instances as the dev set 322 | """ 323 | 324 | all_pairs_sts, all_test_sts, dev_samples_sts = load_sts() 325 | all_pairs_sickr, all_test_sickr, dev_samples_sickr = load_sickr() 326 | all_pairs = all_pairs_sts+all_pairs_sickr 327 | all_test = {**all_test_sts, **all_test_sickr} 328 | return all_pairs, all_test, dev_samples_sts # sts-b's dev is used 329 | 330 | def load_custom(fpath): 331 | """ 332 | Load custom sentence-pair corpus. Use STS-B's dev set for dev. 333 | Parameters 334 | ---------- 335 | fpath: path to the training file, where sentence pairs are formatted as 'sent1||sent2' 336 | Returns 337 | ---------- 338 | all_pairs: a list of sentence pairs from the STS+SICK-R dataset 339 | all_test: a dict of all test sets 340 | dev_samples: a list of InputExample instances as the dev set 341 | """ 342 | all_pairs = [] 343 | all_test = {} 344 | 345 | # load custom training corpus 346 | with open(fpath, "r") as f: 347 | lines = f.readlines() 348 | for line in lines: 349 | line = line.strip() 350 | if len(line.split("||")) != 2: continue # skip 351 | sent1, sent2 = line.split("||") 352 | all_pairs.append([sent1, sent2]) 353 | 354 | # load STS-b dev/test set 355 | # Check if STS-B exsists. If not, download and extract it 356 | sts_dataset_path = "data/stsbenchmark.tsv.gz" 357 | if not os.path.exists(sts_dataset_path): 358 | util.http_get("https://sbert.net/datasets/stsbenchmark.tsv.gz", sts_dataset_path) 359 | # read sts-b 360 | dev_samples_stsb = [] 361 | test_samples_stsb = [] 362 | with gzip.open(sts_dataset_path, "rt", encoding="utf8") as fIn: 363 | reader = csv.DictReader(fIn, delimiter="\t", quoting=csv.QUOTE_NONE) 364 | for row in reader: 365 | score = float(row[SCORE]) / 5.0 # Normalize score to range 0 ... 1 366 | 367 | if row[SPLIT] == "dev": 368 | dev_samples_stsb.append(InputExample(texts=[row[SENTENCE1], row[SENTENCE2]], label=score)) 369 | elif row[SPLIT] == "test": 370 | test_samples_stsb.append(InputExample(texts=[row[SENTENCE1], row[SENTENCE2]], label=score)) 371 | 372 | # add entence pair to all_pairs 373 | #all_pairs.append([row[SENTENCE1], row[SENTENCE2]]) 374 | 375 | all_test["stsb"] = test_samples_stsb 376 | dev_samples = dev_samples_stsb 377 | 378 | return all_pairs, all_test, dev_samples 379 | 380 | 381 | task_loader_dict = { 382 | "sts": load_sts, 383 | "sickr": load_sickr, 384 | "sts_sickr": load_sts_and_sickr, 385 | "qqp": load_qqp, 386 | "qnli": load_qnli, 387 | "mrpc": load_mrpc, 388 | "snli": load_snli, 389 | "custom": load_custom 390 | } 391 | 392 | def load_data(task, fpath=None): 393 | """ 394 | A unified dataset loader for all tasks. 395 | Parameters 396 | ---------- 397 | task: a string specifying dataset/task to be loaded (for possible options see 'task_loader_dict') 398 | Returns 399 | ---------- 400 | all_pairs: a list of sentence pairs from the specified dataset 401 | all_test: a dict of all test sets 402 | dev_samples: a list of InputExample instances as the dev set 403 | """ 404 | if task not in task_loader_dict.keys(): 405 | raise NotImplementedError() 406 | if task == "custom": 407 | return task_loader_dict[task](fpath) 408 | else: 409 | return task_loader_dict[task]() 410 | 411 | 412 | if __name__ == "__main__": 413 | # test if all datasets can be properly loaded 414 | for task in task_loader_dict: 415 | print (f"loading {task}...") 416 | load_data(task) 417 | print ("done.") -------------------------------------------------------------------------------- /src/mutual_distill_parallel.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import torch 5 | from torch.utils.data import DataLoader 6 | import torch.nn.functional as F 7 | from sentence_transformers import models, losses, util, SentenceTransformer, LoggingHandler 8 | from sentence_transformers.cross_encoder import CrossEncoder 9 | from sentence_transformers.cross_encoder.evaluation import CECorrelationEvaluator 10 | from sentence_transformers.readers import InputExample 11 | from datetime import datetime 12 | import argparse 13 | import logging 14 | import sys 15 | import random 16 | import tqdm 17 | import math 18 | import os 19 | import numpy as np 20 | 21 | # import from local codes 22 | from data import load_data 23 | from eval import eval_encoder 24 | 25 | from sentence_transformers_ext.bi_encoder_eval import ( 26 | EmbeddingSimilarityEvaluator, 27 | EmbeddingSimilarityEvaluatorEnsemble, 28 | EmbeddingSimilarityEvaluatorAUC, 29 | EmbeddingSimilarityEvaluatorAUCEnsemble 30 | ) 31 | from sentence_transformers_ext.cross_encoder_eval import ( 32 | CECorrelationEvaluatorEnsemble, 33 | CECorrelationEvaluatorAUC, 34 | CECorrelationEvaluatorAUCEnsemble 35 | ) 36 | 37 | parser = argparse.ArgumentParser() 38 | 39 | parser.add_argument("--task", type=str, default='sts', 40 | help='{sts|sickr|sts_sickr|qqp|qnli|mrpc|snli|custom}') 41 | parser.add_argument("--device1", type=int, default=0) 42 | parser.add_argument("--device2", type=int, default=0) 43 | parser.add_argument("--cycle", type=int, default=3) 44 | parser.add_argument("--num_epochs_cross_encoder", type=int, default=1) 45 | parser.add_argument("--num_epochs_bi_encoder", type=int, default=10) 46 | parser.add_argument("--batch_size_cross_encoder", type=int, default=32) 47 | parser.add_argument("--batch_size_bi_encoder", type=int, default=128) 48 | parser.add_argument("--init_with_new_models", action="store_true") 49 | parser.add_argument("--use_large", action="store_true") 50 | parser.add_argument("--bi_encoder1_pooling_mode", type=str, 51 | default='cls', help="{cls|mean}") 52 | parser.add_argument("--bi_encoder2_pooling_mode", type=str, 53 | default='cls', help="{cls|mean}") 54 | parser.add_argument("--random_seed", type=int, default=2021) 55 | parser.add_argument("--custom_corpus_path", type=str, default=None) 56 | parser.add_argument("--quick_test", action="store_true") 57 | 58 | 59 | 60 | args = parser.parse_args() 61 | print (args) 62 | 63 | torch.manual_seed(args.random_seed) 64 | 65 | #### Just some code to print debug information to stdout 66 | logging.basicConfig(format="%(asctime)s - %(message)s", 67 | datefmt="%Y-%m-%d %H:%M:%S", 68 | level=logging.INFO, 69 | handlers=[LoggingHandler()]) 70 | 71 | ### read datasets 72 | all_pairs, all_test, dev_samples = load_data(args.task, fpath=args.custom_corpus_path) 73 | if args.quick_test: 74 | all_pairs = all_pairs[:1000] # for quick test 75 | 76 | print ("|raw sentence pairs|:", len(all_pairs)) 77 | print ("|dev set|:", len(dev_samples)) 78 | for key in all_test: 79 | print ("|test set: %s|" % key, len(all_test[key])) 80 | 81 | 82 | if not args.use_large: 83 | model_name1 = "princeton-nlp/unsup-simcse-bert-base-uncased" 84 | model_name2 = "princeton-nlp/unsup-simcse-roberta-base" 85 | else: 86 | model_name1 = "princeton-nlp/unsup-simcse-bert-large-uncased" 87 | model_name2 = "princeton-nlp/unsup-simcse-roberta-large" 88 | 89 | 90 | simcse2base = { 91 | "princeton-nlp/unsup-simcse-roberta-base": "roberta-base", 92 | "princeton-nlp/unsup-simcse-roberta-large": "roberta-large", 93 | "princeton-nlp/unsup-simcse-bert-base-uncased": "bert-base-uncased", 94 | "princeton-nlp/unsup-simcse-bert-large-uncased": "bert-large-uncased" 95 | } 96 | 97 | batch_size_cross_encoder = args.batch_size_cross_encoder 98 | batch_size_bi_encoder = args.batch_size_bi_encoder 99 | num_epochs_cross_encoder = args.num_epochs_cross_encoder 100 | num_epochs_bi_encoder = args.num_epochs_bi_encoder 101 | max_seq_length = 32 102 | total_cycle = args.cycle 103 | device1=args.device1 104 | device2=args.device2 105 | 106 | logging.info ("########## load base models and evaluate ##########") 107 | 108 | ###### Bi-encoder (sentence-transformers) ###### 109 | logging.info(f"Loading bi-encoder model1: {model_name1}") 110 | logging.info(f"Loading bi-encoder model2: {model_name2}") 111 | # Use Huggingface/transformers model (like BERT, RoBERTa, XLNet, XLM-R) for mapping tokens to embeddings 112 | word_embedding_model1 = models.Transformer(model_name1, max_seq_length=max_seq_length) 113 | word_embedding_model2 = models.Transformer(model_name2, max_seq_length=max_seq_length) 114 | 115 | # Apply mean pooling to get one fixed sized sentence vector 116 | pooling_model1 = models.Pooling(word_embedding_model1.get_word_embedding_dimension(), 117 | pooling_mode=args.bi_encoder1_pooling_mode) # bert 118 | pooling_model2 = models.Pooling(word_embedding_model1.get_word_embedding_dimension(), 119 | pooling_mode=args.bi_encoder2_pooling_mode) # roberta 120 | 121 | bi_encoder1 = SentenceTransformer(modules=[word_embedding_model1, pooling_model1], device=device1) 122 | bi_encoder2 = SentenceTransformer(modules=[word_embedding_model2, pooling_model2], device=device2) 123 | 124 | # eval bi-encoder 125 | logging.info ("Evaluate bi-encoder (ensembled)...") 126 | scores = [] 127 | for name, data in all_test.items(): 128 | if args.task in ["sts", "sickr", "sts_sickr", "custom"]: 129 | test_evaluator = EmbeddingSimilarityEvaluatorEnsemble.from_input_examples(data, name=name) 130 | else: 131 | test_evaluator = EmbeddingSimilarityEvaluatorAUCEnsemble.from_input_examples(data, name=name) 132 | scores += [test_evaluator([bi_encoder1, bi_encoder2])] 133 | logging.info (f"***** test's avg spearman's rho: {sum(scores)/len(scores):.4f} ****") 134 | 135 | start_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") 136 | 137 | bi_encoder_dev_scores = [] 138 | cross_encoder_dev_scores = [] 139 | 140 | for cycle in range(1, total_cycle+1): 141 | logging.info (f"########## cycle {cycle:.0f} starts ##########") 142 | 143 | ###### label data with bi-encoder ###### 144 | # label sentence pairs with bi-encoder 145 | logging.info ("Label sentence pairs...") 146 | 147 | # Two lists of sentences 148 | sents1 = [p[0] for p in all_pairs] 149 | sents2 = [p[1] for p in all_pairs] 150 | 151 | #Compute embedding for both lists 152 | embeddings1 = bi_encoder1.encode(sents1, convert_to_tensor=True) 153 | embeddings2 = bi_encoder1.encode(sents2, convert_to_tensor=True) 154 | 155 | #Compute cosine-similarits 156 | cosine_scores1 = F.cosine_similarity(embeddings1, embeddings2) 157 | 158 | embeddings1 = bi_encoder2.encode(sents1, convert_to_tensor=True) 159 | embeddings2 = bi_encoder2.encode(sents2, convert_to_tensor=True) 160 | cosine_scores2 = F.cosine_similarity(embeddings1, embeddings2) 161 | 162 | cosine_scores = torch.stack([cosine_scores1.cpu(), cosine_scores2.cpu()]).mean(0) 163 | 164 | # form (self-labelled) train set 165 | train_samples = [] 166 | 167 | for i in range(len(sents1)): 168 | if args.task in ["qnli"]: 169 | train_samples.append(InputExample(texts=[sents1[i], sents2[i]], label=cosine_scores[i])) 170 | else: 171 | train_samples.append(InputExample(texts=[sents1[i], sents2[i]], label=cosine_scores[i])) 172 | train_samples.append(InputExample(texts=[sents2[i], sents1[i]], label=cosine_scores[i])) 173 | 174 | del bi_encoder1, bi_encoder2, embeddings1, embeddings2, cosine_scores1, cosine_scores2 175 | torch.cuda.empty_cache() 176 | 177 | ###### Cross-encoder learning ###### 178 | logging.info(f"Loading cross-encoder1 model: {simcse2base[model_name1]}") 179 | logging.info(f"Loading cross-encoder2 model: {simcse2base[model_name2]}") 180 | # Use Huggingface/transformers model (like BERT, RoBERTa, XLNet, XLM-R) for cross-encoder model 181 | cross_encoder1 = CrossEncoder(simcse2base[model_name1], num_labels=1, device=device1, max_length=64) 182 | cross_encoder2 = CrossEncoder(simcse2base[model_name2], num_labels=1, device=device2, max_length=64) 183 | 184 | # We wrap gold_samples (which is a List[InputExample]) into a pytorch DataLoader 185 | train_dataloader = DataLoader(train_samples, shuffle=True, batch_size=batch_size_cross_encoder) 186 | 187 | # We add an evaluator, which evaluates the performance during training 188 | if args.task in ["sts", "sickr", "sts_sickr", "custom"]: 189 | evaluator = CECorrelationEvaluator.from_input_examples(dev_samples, name='dev') 190 | else: 191 | evaluator = CECorrelationEvaluatorAUC.from_input_examples(dev_samples, name='dev') 192 | 193 | # Configure the training 194 | warmup_steps = math.ceil(len(train_dataloader) * num_epochs_cross_encoder * 0.1) #10% of train data for warm-up 195 | logging.info(f"Warmup-steps: {warmup_steps}") 196 | 197 | cross_encoder_path1 = f"output/cross-encoder/" \ 198 | f"{args.task}_cycle{cycle}_mutual_parallel_{model_name1.replace('/', '-')}-{start_time}" 199 | cross_encoder_path2 = f"output/cross-encoder/" \ 200 | f"{args.task}_cycle{cycle}_mutual_parallel_{model_name2.replace('/', '-')}-{start_time}" 201 | 202 | 203 | # Train the cross-encoder model 204 | cross_encoder1.fit( 205 | train_dataloader=train_dataloader, 206 | evaluator=evaluator, 207 | evaluation_steps=200, 208 | use_amp=True, 209 | epochs=num_epochs_cross_encoder, 210 | warmup_steps=warmup_steps, 211 | output_path=cross_encoder_path1) 212 | 213 | # Train the cross-encoder model 214 | cross_encoder2.fit( 215 | train_dataloader=train_dataloader, 216 | evaluator=evaluator, 217 | evaluation_steps=200, 218 | use_amp=True, 219 | epochs=num_epochs_cross_encoder, 220 | warmup_steps=warmup_steps, 221 | output_path=cross_encoder_path2) 222 | 223 | 224 | cross_encoder1 = CrossEncoder(cross_encoder_path1, max_length=64, device=device1) 225 | cross_encoder2 = CrossEncoder(cross_encoder_path2, max_length=64, device=device2) 226 | 227 | dev_score1 = evaluator(cross_encoder1) 228 | dev_score2 = evaluator(cross_encoder2) 229 | cross_encoder_dev_scores.append([dev_score1, dev_score2]) 230 | logging.info (f"***** dev's spearman's rho: cross-encoder1 {dev_score1:.4f}, cross-encoder2 {dev_score2:.4f} *****") 231 | 232 | ###### label data with cross-encoder ###### 233 | # label sentence pairs with cross-encoder 234 | logging.info ("Label sentence pairs...") 235 | silver_scores1 = cross_encoder1.predict(all_pairs) 236 | silver_scores2 = cross_encoder2.predict(all_pairs) 237 | silver_scores = np.array([silver_scores1,silver_scores2]).mean(0) 238 | silver_samples = list(InputExample(texts=[data[0], data[1]], label=score) for \ 239 | data, score in zip(all_pairs, silver_scores)) 240 | 241 | del cross_encoder1, cross_encoder2 242 | torch.cuda.empty_cache() 243 | 244 | ###### Bi-encoder learning ###### 245 | 246 | logging.info(f"Loading bi-encoder1 model: {model_name1}") 247 | logging.info(f"Loading bi-encoder2 model: {model_name2}") 248 | # Use Huggingface/transformers model (like BERT, RoBERTa, XLNet, XLM-R) for mapping tokens to embeddings 249 | word_embedding_model1 = models.Transformer(model_name1, max_seq_length=max_seq_length) 250 | word_embedding_model2 = models.Transformer(model_name2, max_seq_length=max_seq_length) 251 | 252 | # Apply mean pooling to get one fixed sized sentence vector 253 | pooling_model1 = models.Pooling(word_embedding_model1.get_word_embedding_dimension(), 254 | pooling_mode=args.bi_encoder1_pooling_mode) # bert 255 | pooling_model2 = models.Pooling(word_embedding_model1.get_word_embedding_dimension(), 256 | pooling_mode=args.bi_encoder2_pooling_mode) # roberta 257 | 258 | bi_encoder1 = SentenceTransformer(modules=[word_embedding_model1, pooling_model1], device=device1) 259 | bi_encoder2 = SentenceTransformer(modules=[word_embedding_model2, pooling_model2], device=device2) 260 | 261 | train_dataloader = DataLoader(silver_samples, shuffle=True, batch_size=batch_size_bi_encoder) 262 | train_loss1 = losses.CosineSimilarityLoss(model=bi_encoder1) 263 | train_loss2 = losses.CosineSimilarityLoss(model=bi_encoder2) 264 | 265 | if args.task in ["sts", "sickr", "sts_sickr", "custom"]: 266 | evaluator = EmbeddingSimilarityEvaluator.from_input_examples(dev_samples, name="dev") 267 | else: 268 | evaluator = EmbeddingSimilarityEvaluatorAUC.from_input_examples(dev_samples, name="dev") 269 | 270 | # Configure the training. 271 | #warmup_steps = math.ceil(len(train_dataloader) * num_epochs * 0.1) #10% of train data for warm-up 272 | #logging.info(f"Warmup-steps: {warmup_steps}") 273 | 274 | bi_encoder_path1 = f"output/bi-encoder/" \ 275 | f"{args.task}_cycle{cycle}_mutual_parallel_{model_name1.replace('/', '-')}-{start_time}" 276 | bi_encoder_path2 = f"output/bi-encoder/" \ 277 | f"{args.task}_cycle{cycle}_mutual_parallel_{model_name2.replace('/', '-')}-{start_time}" 278 | 279 | bi_encoder1.fit( 280 | train_objectives=[(train_dataloader, train_loss1)], 281 | evaluator=evaluator, 282 | epochs=num_epochs_bi_encoder, 283 | evaluation_steps=200, 284 | warmup_steps=0, 285 | output_path=bi_encoder_path1, 286 | optimizer_params= {"lr": 5e-5}, 287 | use_amp=True, 288 | ) 289 | 290 | bi_encoder2.fit( 291 | train_objectives=[(train_dataloader, train_loss2)], 292 | evaluator=evaluator, 293 | epochs=num_epochs_bi_encoder, 294 | evaluation_steps=200, 295 | warmup_steps=0, 296 | output_path=bi_encoder_path2, 297 | optimizer_params= {"lr": 5e-5}, 298 | use_amp=True, 299 | ) 300 | 301 | bi_encoder1 = SentenceTransformer(bi_encoder_path1, device=device1) 302 | bi_encoder2 = SentenceTransformer(bi_encoder_path2, device=device2) 303 | 304 | dev_score1 = evaluator(bi_encoder1) 305 | dev_score2 = evaluator(bi_encoder2) 306 | bi_encoder_dev_scores.append([dev_score1, dev_score2]) 307 | logging.info (f"***** dev's spearman's rho: bi-encoder1 {dev_score1:.4f}, bi-encoder2 {dev_score2:.4f} *****") 308 | 309 | del bi_encoder1, bi_encoder2 310 | torch.cuda.empty_cache() 311 | 312 | print (cross_encoder_dev_scores) 313 | print (bi_encoder_dev_scores) 314 | 315 | # best bi-encoder 316 | logging.info ("$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$") 317 | bi_encoder_dev_scores1 = [p[0] for p in bi_encoder_dev_scores] 318 | bi_encoder_dev_scores2 = [p[1] for p in bi_encoder_dev_scores] 319 | best_cycle_bi_encoder1 = np.argmax(bi_encoder_dev_scores1)+1 320 | best_cycle_bi_encoder2 = np.argmax(bi_encoder_dev_scores2)+1 321 | 322 | best_cycle_bi_encoder_path1 = f"output/bi-encoder/" \ 323 | f"{args.task}_cycle{best_cycle_bi_encoder1}_mutual_parallel_{model_name1.replace('/', '-')}-{start_time}" 324 | best_cycle_bi_encoder_path2 = f"output/bi-encoder/" \ 325 | f"{args.task}_cycle{best_cycle_bi_encoder2}_mutual_parallel_{model_name2.replace('/', '-')}-{start_time}" 326 | 327 | # eval bi-encoder 328 | logging.info ("£££££ Evaluate best bi-encoder1...") 329 | bi_encoder1 = SentenceTransformer(best_cycle_bi_encoder_path1, device=device1) 330 | logging.info (best_cycle_bi_encoder_path1) 331 | eval_encoder(all_test, bi_encoder1, task=args.task, enc_type="bi") 332 | 333 | logging.info ("£££££ Evaluate best bi-encoder2...") 334 | bi_encoder2 = SentenceTransformer(best_cycle_bi_encoder_path2, device=device2) 335 | logging.info (best_cycle_bi_encoder_path2) 336 | eval_encoder(all_test, bi_encoder2, task=args.task, enc_type="bi") 337 | 338 | # eval bi-encoder (ensembled) 339 | logging.info ("£££££ Evaluate best bi-encoders (ensembled)...") 340 | eval_encoder(all_test, [bi_encoder1, bi_encoder2], task=args.task, enc_type="bi", ensemble=True) 341 | 342 | # best cross-encoder 343 | logging.info ("$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$") 344 | cross_encoder_dev_scores1 = [p[0] for p in cross_encoder_dev_scores] 345 | cross_encoder_dev_scores2 = [p[1] for p in cross_encoder_dev_scores] 346 | best_cycle_cross_encoder1 = np.argmax(cross_encoder_dev_scores1)+1 347 | best_cycle_cross_encoder2 = np.argmax(cross_encoder_dev_scores2)+1 348 | 349 | best_cycle_cross_encoder_path1 = f"output/cross-encoder/" \ 350 | f"{args.task}_cycle{best_cycle_cross_encoder1}_mutual_parallel_{model_name1.replace('/', '-')}-{start_time}" 351 | best_cycle_cross_encoder_path2 = f"output/cross-encoder/" \ 352 | f"{args.task}_cycle{best_cycle_cross_encoder2}_mutual_parallel_{model_name2.replace('/', '-')}-{start_time}" 353 | 354 | 355 | # eval cross-encoder 356 | logging.info ("£££££ Evaluate best cross-encoder1...") 357 | logging.info (best_cycle_cross_encoder_path1) 358 | #cross_encoder1 = CrossEncoder(best_cycle_cross_encoder_path1, max_length=64, device=device1) 359 | cross_encoder1 = CrossEncoder(best_cycle_cross_encoder_path1, device=device1) 360 | eval_encoder(all_test, cross_encoder1, task=args.task, enc_type="cross") 361 | 362 | logging.info ("£££££ Evaluate best cross-encoder2...") 363 | logging.info (best_cycle_cross_encoder_path2) 364 | #cross_encoder2 = CrossEncoder(best_cycle_cross_encoder_path2, max_length=64, device=device2) 365 | cross_encoder2 = CrossEncoder(best_cycle_cross_encoder_path2, device=device2) 366 | eval_encoder(all_test, cross_encoder2, task=args.task, enc_type="cross") 367 | 368 | # eval cross-encoder (ensembled) 369 | logging.info ("£££££ Evaluate best cross-encoders (ensembled)...") 370 | eval_encoder(all_test, [cross_encoder1, cross_encoder2], task=args.task, enc_type="cross", ensemble=True) 371 | 372 | 373 | logging.info ("\n") 374 | print (args) 375 | logging.info ("\n") 376 | logging.info ("***** END *****") 377 | --------------------------------------------------------------------------------