├── NOTICE
├── .gitignore
├── requirements.txt
├── src
├── sentence_transformers_ext
│ ├── __init__.py
│ ├── cross_encoder_eval
│ │ ├── __init__.py
│ │ ├── CECorrelationEvaluatorAUCEnsemble.py
│ │ ├── CECorrelationEvaluatorAUC.py
│ │ └── CECorrelationEvaluatorEnsemble.py
│ ├── bi_encoder_eval
│ │ ├── __init__.py
│ │ ├── EmbeddingSimilarityEvaluatorAUC.py
│ │ ├── EmbeddingSimilarityEvaluatorAUCEnsemble.py
│ │ ├── EmbeddingSimilarityEvaluator.py
│ │ └── EmbeddingSimilarityEvaluatorEnsemble.py
│ └── utils.py
├── eval.py
├── self_distill.py
├── data.py
└── mutual_distill_parallel.py
├── CODE_OF_CONDUCT.md
├── train_self_distill.sh
├── train_mutual_distill.sh
├── CONTRIBUTING.md
├── README.md
├── LICENSE
└── THIRD-PARTY-LICENSES
/NOTICE:
--------------------------------------------------------------------------------
1 | Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | output/*
2 | *.zip
3 | *__pycache__/
4 | data/*
5 | .ipynb_checkpoints
6 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==2.4.0
2 | transformers==4.48.0
3 | sentence-transformers==2.0.0
4 |
--------------------------------------------------------------------------------
/src/sentence_transformers_ext/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 |
4 | from .utils import * # custom
--------------------------------------------------------------------------------
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | ## Code of Conduct
2 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct).
3 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact
4 | opensource-codeofconduct@amazon.com with any additional questions or comments.
5 |
--------------------------------------------------------------------------------
/train_self_distill.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=$1 python src/self_distill.py \
2 | --model_name_or_path "princeton-nlp/unsup-simcse-roberta-base" \
3 | --batch_size_bi_encoder 128 \
4 | --batch_size_cross_encoder 32 \
5 | --num_epochs_bi_encoder 10 \
6 | --num_epochs_cross_encoder 1 \
7 | --cycle 3 \
8 | --bi_encoder_pooling_mode cls \
9 | --init_with_new_models \
10 | --task sts_sickr \
11 | --random_seed 2021
--------------------------------------------------------------------------------
/src/sentence_transformers_ext/cross_encoder_eval/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 |
4 | from .CECorrelationEvaluatorEnsemble import CECorrelationEvaluatorEnsemble # custom
5 | from .CECorrelationEvaluatorAUC import CECorrelationEvaluatorAUC # custom
6 | from .CECorrelationEvaluatorAUCEnsemble import CECorrelationEvaluatorAUCEnsemble # custom
7 |
--------------------------------------------------------------------------------
/train_mutual_distill.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=$1 python src/mutual_distill_parallel.py \
2 | --device1 0 \
3 | --device2 1 \
4 | --batch_size_bi_encoder 128 \
5 | --batch_size_cross_encoder 32 \
6 | --num_epochs_bi_encoder 10 \
7 | --num_epochs_cross_encoder 1 \
8 | --cycle 3 \
9 | --bi_encoder1_pooling_mode cls \
10 | --bi_encoder2_pooling_mode cls \
11 | --init_with_new_models \
12 | --task sts_sickr \
13 | --random_seed 2021
--------------------------------------------------------------------------------
/src/sentence_transformers_ext/bi_encoder_eval/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 |
4 | from .EmbeddingSimilarityEvaluator import EmbeddingSimilarityEvaluator
5 | from .EmbeddingSimilarityEvaluatorEnsemble import EmbeddingSimilarityEvaluatorEnsemble # custom
6 | from .EmbeddingSimilarityEvaluatorAUC import EmbeddingSimilarityEvaluatorAUC # custom
7 | from .EmbeddingSimilarityEvaluatorAUCEnsemble import EmbeddingSimilarityEvaluatorAUCEnsemble # custom
8 |
--------------------------------------------------------------------------------
/src/sentence_transformers_ext/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 |
4 | import os
5 | import csv
6 |
7 | def write_csv_log(output_path, csv_file, csv_headers, things_to_write):
8 | """
9 | Write logs to a csv file.
10 | Parameters
11 | ----------
12 | output_path: a string specifying the write path
13 | csv_file: a string specifying the csv file name
14 | things_to_write: a list of numbers to be written
15 | Returns
16 | ----------
17 | None
18 | """
19 | csv_path = os.path.join(output_path, csv_file)
20 | output_file_exists = os.path.isfile(csv_path)
21 | with open(csv_path, newline='', mode="a" if output_file_exists else 'w', encoding="utf-8") as f:
22 | writer = csv.writer(f)
23 | if not output_file_exists:
24 | writer.writerow(csv_headers)
25 |
26 | writer.writerow(things_to_write)
--------------------------------------------------------------------------------
/src/sentence_transformers_ext/cross_encoder_eval/CECorrelationEvaluatorAUCEnsemble.py:
--------------------------------------------------------------------------------
1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 |
4 | import logging
5 | from scipy.stats import pearsonr, spearmanr
6 | from sklearn.metrics import roc_auc_score
7 | from typing import List
8 | import numpy as np
9 | import os
10 | import csv
11 | from sentence_transformers.readers import InputExample
12 |
13 | from .CECorrelationEvaluatorEnsemble import CECorrelationEvaluatorEnsemble
14 | from ..utils import write_csv_log
15 |
16 | logger = logging.getLogger(__name__)
17 |
18 | class CECorrelationEvaluatorAUCEnsemble(CECorrelationEvaluatorEnsemble):
19 | """
20 | This evaluator can be used with the CrossEncoder class. Given sentence pairs and continuous scores,
21 | it compute the pearson & spearman correlation between the predicted score for the sentence pair
22 | and the gold score.
23 | """
24 | def __init__(self, sentence_pairs: List[List[str]], scores: List[float], name: str='', write_csv: bool = True):
25 | CECorrelationEvaluatorEnsemble.__init__(self, sentence_pairs, scores, name, write_csv)
26 | self.csv_headers = ["epoch", "steps", "roc_auc_score"] # overwrite parent's csv_headers
27 |
28 | def __call__(self, models, output_path: str = None, epoch: int = -1, steps: int = -1) -> float:
29 | if epoch != -1:
30 | if steps == -1:
31 | out_txt = " after epoch {}:".format(epoch)
32 | else:
33 | out_txt = " in epoch {} after {} steps:".format(epoch, steps)
34 | else:
35 | out_txt = ":"
36 |
37 | logger.info(self.__class__.__name__+": Evaluating the model on " + self.name + " dataset" + out_txt)
38 |
39 | all_scores = []
40 | for model in models:
41 | pred_scores = model.predict(self.sentence_pairs, convert_to_numpy=True, show_progress_bar=False)
42 | all_scores.append(pred_scores)
43 |
44 | pred_scores = np.array(all_scores).mean(0)
45 |
46 | eval_auc = roc_auc_score(self.scores, pred_scores)
47 |
48 | logger.info("roc_auc_score: {:.4f}".format(eval_auc))
49 |
50 | if output_path is not None and self.write_csv:
51 | things_to_write = [epoch, steps, eval_auc]
52 | write_csv_log(output_path=output_path, csv_file=self.csv_file, csv_headers=self.csv_headers, things_to_write=things_to_write)
53 |
54 | return eval_auc
55 |
--------------------------------------------------------------------------------
/src/sentence_transformers_ext/cross_encoder_eval/CECorrelationEvaluatorAUC.py:
--------------------------------------------------------------------------------
1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 |
4 | import logging
5 | from scipy.stats import pearsonr, spearmanr
6 | from sklearn.metrics import roc_auc_score
7 | from typing import List
8 | import os
9 | import csv
10 | from sentence_transformers.readers import InputExample
11 |
12 | from ..utils import write_csv_log
13 |
14 | logger = logging.getLogger(__name__)
15 |
16 | class CECorrelationEvaluatorAUC:
17 | """
18 | This evaluator can be used with the CrossEncoder class. Given sentence pairs and continuous scores,
19 | it compute the pearson & spearman correlation between the predicted score for the sentence pair
20 | and the gold score.
21 | """
22 | def __init__(self, sentence_pairs: List[List[str]], scores: List[float], name: str='', write_csv: bool = True):
23 | self.sentence_pairs = sentence_pairs
24 | self.scores = scores
25 | self.name = name
26 |
27 | self.csv_file = self.__class__.__name__ + ("_" + name if name else '') + "_results.csv"
28 | self.csv_headers = ["epoch", "steps", "roc_auc_score"]
29 | self.write_csv = write_csv
30 |
31 | @classmethod
32 | def from_input_examples(cls, examples: List[InputExample], **kwargs):
33 | sentence_pairs = []
34 | scores = []
35 |
36 | for example in examples:
37 | sentence_pairs.append(example.texts)
38 | scores.append(example.label)
39 | return cls(sentence_pairs, scores, **kwargs)
40 |
41 | def __call__(self, model, output_path: str = None, epoch: int = -1, steps: int = -1) -> float:
42 | if epoch != -1:
43 | if steps == -1:
44 | out_txt = " after epoch {}:".format(epoch)
45 | else:
46 | out_txt = " in epoch {} after {} steps:".format(epoch, steps)
47 | else:
48 | out_txt = ":"
49 |
50 | logger.info(self.__class__.__name__+": Evaluating the model on " + self.name + " dataset" + out_txt)
51 | pred_scores = model.predict(self.sentence_pairs, convert_to_numpy=True, show_progress_bar=False)
52 |
53 |
54 | eval_auc = roc_auc_score(self.scores, pred_scores)
55 |
56 | logger.info("roc_auc_score: {:.4f}".format(eval_auc))
57 |
58 | if output_path is not None and self.write_csv:
59 | things_to_write = [epoch, steps, eval_auc]
60 | write_csv_log(output_path=output_path, csv_file=self.csv_file, csv_headers=self.csv_headers, things_to_write=things_to_write)
61 |
62 | return eval_auc
63 |
--------------------------------------------------------------------------------
/src/sentence_transformers_ext/cross_encoder_eval/CECorrelationEvaluatorEnsemble.py:
--------------------------------------------------------------------------------
1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 |
4 | import logging
5 | from scipy.stats import pearsonr, spearmanr
6 | from typing import List
7 | import os
8 | import csv
9 | import numpy as np
10 | from sentence_transformers.readers import InputExample
11 |
12 | from ..utils import write_csv_log
13 |
14 | logger = logging.getLogger(__name__)
15 |
16 | class CECorrelationEvaluatorEnsemble:
17 | """
18 | This evaluator can be used with the CrossEncoder class. Given sentence pairs and continuous scores,
19 | it compute the pearson & spearman correlation between the predicted score for the sentence pair
20 | and the gold score.
21 | """
22 | def __init__(self, sentence_pairs: List[List[str]], scores: List[float], name: str='', write_csv: bool = True):
23 | self.sentence_pairs = sentence_pairs
24 | self.scores = scores
25 | self.name = name
26 |
27 | self.csv_file = self.__class__.__name__ + ("_" + name if name else '') + "_results.csv"
28 | self.csv_headers = ["epoch", "steps", "Pearson_Correlation", "Spearman_Correlation"]
29 | self.write_csv = write_csv
30 |
31 | @classmethod
32 | def from_input_examples(cls, examples: List[InputExample], **kwargs):
33 | sentence_pairs = []
34 | scores = []
35 |
36 | for example in examples:
37 | sentence_pairs.append(example.texts)
38 | scores.append(example.label)
39 | return cls(sentence_pairs, scores, **kwargs)
40 |
41 | def __call__(self, models, output_path: str = None, epoch: int = -1, steps: int = -1) -> float:
42 | if epoch != -1:
43 | if steps == -1:
44 | out_txt = " after epoch {}:".format(epoch)
45 | else:
46 | out_txt = " in epoch {} after {} steps:".format(epoch, steps)
47 | else:
48 | out_txt = ":"
49 |
50 | logger.info(self.__class__.__name__+": Evaluating the model on " + self.name + " dataset" + out_txt)
51 |
52 | all_scores = []
53 | for model in models:
54 | pred_scores = model.predict(self.sentence_pairs, convert_to_numpy=True, show_progress_bar=False)
55 | all_scores.append(pred_scores)
56 |
57 | pred_scores = np.array(all_scores).mean(0)
58 |
59 | eval_pearson, _ = pearsonr(self.scores, pred_scores)
60 | eval_spearman, _ = spearmanr(self.scores, pred_scores)
61 |
62 | logger.info("Correlation:\tPearson: {:.4f}\tSpearman: {:.4f}".format(eval_pearson, eval_spearman))
63 |
64 | if output_path is not None and self.write_csv:
65 | things_to_write = [epoch, steps, eval_pearson, eval_spearman]
66 | write_csv_log(output_path=output_path, csv_file=self.csv_file, csv_headers=self.csv_headers, things_to_write=things_to_write)
67 |
68 | return eval_spearman
69 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contributing Guidelines
2 |
3 | Thank you for your interest in contributing to our project. Whether it's a bug report, new feature, correction, or additional
4 | documentation, we greatly value feedback and contributions from our community.
5 |
6 | Please read through this document before submitting any issues or pull requests to ensure we have all the necessary
7 | information to effectively respond to your bug report or contribution.
8 |
9 |
10 | ## Reporting Bugs/Feature Requests
11 |
12 | We welcome you to use the GitHub issue tracker to report bugs or suggest features.
13 |
14 | When filing an issue, please check existing open, or recently closed, issues to make sure somebody else hasn't already
15 | reported the issue. Please try to include as much information as you can. Details like these are incredibly useful:
16 |
17 | * A reproducible test case or series of steps
18 | * The version of our code being used
19 | * Any modifications you've made relevant to the bug
20 | * Anything unusual about your environment or deployment
21 |
22 |
23 | ## Contributing via Pull Requests
24 | Contributions via pull requests are much appreciated. Before sending us a pull request, please ensure that:
25 |
26 | 1. You are working against the latest source on the *main* branch.
27 | 2. You check existing open, and recently merged, pull requests to make sure someone else hasn't addressed the problem already.
28 | 3. You open an issue to discuss any significant work - we would hate for your time to be wasted.
29 |
30 | To send us a pull request, please:
31 |
32 | 1. Fork the repository.
33 | 2. Modify the source; please focus on the specific change you are contributing. If you also reformat all the code, it will be hard for us to focus on your change.
34 | 3. Ensure local tests pass.
35 | 4. Commit to your fork using clear commit messages.
36 | 5. Send us a pull request, answering any default questions in the pull request interface.
37 | 6. Pay attention to any automated CI failures reported in the pull request, and stay involved in the conversation.
38 |
39 | GitHub provides additional document on [forking a repository](https://help.github.com/articles/fork-a-repo/) and
40 | [creating a pull request](https://help.github.com/articles/creating-a-pull-request/).
41 |
42 |
43 | ## Finding contributions to work on
44 | Looking at the existing issues is a great way to find something to contribute on. As our projects, by default, use the default GitHub issue labels (enhancement/bug/duplicate/help wanted/invalid/question/wontfix), looking at any 'help wanted' issues is a great place to start.
45 |
46 |
47 | ## Code of Conduct
48 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct).
49 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact
50 | opensource-codeofconduct@amazon.com with any additional questions or comments.
51 |
52 |
53 | ## Security issue notifications
54 | If you discover a potential security issue in this project we ask that you notify AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/). Please do **not** create a public github issue.
55 |
56 |
57 | ## Licensing
58 |
59 | See the [LICENSE](LICENSE) file for our project's licensing. We will ask you to confirm the licensing of your contribution.
60 |
--------------------------------------------------------------------------------
/src/sentence_transformers_ext/bi_encoder_eval/EmbeddingSimilarityEvaluatorAUC.py:
--------------------------------------------------------------------------------
1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 |
4 | from sentence_transformers.evaluation import (
5 | SentenceEvaluator,
6 | SimilarityFunction
7 | )
8 | import logging
9 | import os
10 | import csv
11 | from sklearn.metrics.pairwise import paired_cosine_distances, paired_euclidean_distances, paired_manhattan_distances
12 | from scipy.stats import pearsonr, spearmanr
13 | import numpy as np
14 | from typing import List
15 | from sklearn.metrics import roc_auc_score
16 | from sentence_transformers.readers import InputExample
17 |
18 | from ..utils import write_csv_log
19 |
20 | logger = logging.getLogger(__name__)
21 |
22 | class EmbeddingSimilarityEvaluatorAUC(SentenceEvaluator):
23 | """
24 | Evaluate a model based on the similarity of the embeddings by calculating the Spearman and Pearson rank correlation
25 | in comparison to the gold standard labels.
26 | The metrics are the cosine similarity as well as euclidean and Manhattan distance
27 | The returned score is the Spearman correlation with a specified metric.
28 |
29 | The results are written in a CSV. If a CSV already exists, then values are appended.
30 | """
31 | def __init__(self, sentences1: List[str], sentences2: List[str], scores: List[float], batch_size: int = 16, main_similarity: SimilarityFunction = SimilarityFunction.COSINE, name: str = '', show_progress_bar: bool = False, write_csv: bool = True):
32 | """
33 | Constructs an evaluator based for the dataset
34 |
35 | The labels need to indicate the similarity between the sentences.
36 |
37 | :param sentences1: List with the first sentence in a pair
38 | :param sentences2: List with the second sentence in a pair
39 | :param scores: Similarity score between sentences1[i] and sentences2[i]
40 | :param write_csv: Write results to a CSV file
41 | """
42 | self.sentences1 = sentences1
43 | self.sentences2 = sentences2
44 | self.scores = scores
45 | self.write_csv = write_csv
46 |
47 | assert len(self.sentences1) == len(self.sentences2)
48 | assert len(self.sentences1) == len(self.scores)
49 |
50 | self.main_similarity = main_similarity
51 | self.name = name
52 |
53 | self.batch_size = batch_size
54 | if show_progress_bar is None:
55 | show_progress_bar = (logger.getEffectiveLevel() == logging.INFO or logger.getEffectiveLevel() == logging.DEBUG)
56 | self.show_progress_bar = show_progress_bar
57 |
58 | self.csv_file = self.__class__.__name__ + ("_"+name if name else '')+"_results.csv"
59 | self.csv_headers = ["epoch", "steps", "cosine_auc", "euclidean_auc", "manhattan_auc", "dot_auc"]
60 |
61 | @classmethod
62 | def from_input_examples(cls, examples: List[InputExample], **kwargs):
63 | sentences1 = []
64 | sentences2 = []
65 | scores = []
66 |
67 | for example in examples:
68 | sentences1.append(example.texts[0])
69 | sentences2.append(example.texts[1])
70 | scores.append(example.label)
71 | return cls(sentences1, sentences2, scores, **kwargs)
72 |
73 |
74 | def __call__(self, model, output_path: str = None, epoch: int = -1, steps: int = -1) -> float:
75 | if epoch != -1:
76 | if steps == -1:
77 | out_txt = " after epoch {}:".format(epoch)
78 | else:
79 | out_txt = " in epoch {} after {} steps:".format(epoch, steps)
80 | else:
81 | out_txt = ":"
82 |
83 | logger.info(self.__class__.__name__+": Evaluating the model on " + self.name + " dataset" + out_txt)
84 |
85 | embeddings1 = model.encode(self.sentences1, batch_size=self.batch_size, show_progress_bar=self.show_progress_bar, convert_to_numpy=True)
86 | embeddings2 = model.encode(self.sentences2, batch_size=self.batch_size, show_progress_bar=self.show_progress_bar, convert_to_numpy=True)
87 | labels = self.scores
88 |
89 | cosine_scores = 1 - (paired_cosine_distances(embeddings1, embeddings2))
90 | manhattan_distances = -paired_manhattan_distances(embeddings1, embeddings2)
91 | euclidean_distances = -paired_euclidean_distances(embeddings1, embeddings2)
92 | dot_products = [np.dot(emb1, emb2) for emb1, emb2 in zip(embeddings1, embeddings2)]
93 |
94 |
95 | eval_auc_cosine = roc_auc_score(labels, cosine_scores)
96 |
97 | eval_auc_manhattan = roc_auc_score(labels, manhattan_distances)
98 |
99 | eval_auc_euclidean = roc_auc_score(labels, euclidean_distances)
100 |
101 | eval_auc_dot = roc_auc_score(labels, dot_products)
102 |
103 | logger.info("Cosine-Similarity AUC: {:.4f}".format(eval_auc_cosine))
104 | logger.info("Manhattan-Distance AUC: {:.4f}".format(eval_auc_manhattan))
105 | logger.info("Euclidean-Distance AUC: {:.4f}".format(eval_auc_euclidean))
106 | logger.info("Dot-Product-Similarity AUC: {:.4f}".format(eval_auc_dot))
107 |
108 | if output_path is not None and self.write_csv:
109 | things_to_write = [epoch, steps, eval_auc_cosine, eval_auc_euclidean, eval_auc_manhattan, eval_auc_dot]
110 | write_csv_log(output_path=output_path, csv_file=self.csv_file, csv_headers=self.csv_headers, things_to_write=things_to_write)
111 |
112 | if self.main_similarity == SimilarityFunction.COSINE:
113 | return eval_auc_cosine
114 | elif self.main_similarity == SimilarityFunction.EUCLIDEAN:
115 | return eval_auc_euclidean
116 | elif self.main_similarity == SimilarityFunction.MANHATTAN:
117 | return eval_auc_manhattan
118 | elif self.main_similarity == SimilarityFunction.DOT_PRODUCT:
119 | return eval_spearman_dot
120 | elif self.main_similarity is None:
121 | return max(eval_auc_cosine, eval_auc_manhattan, eval_auc_euclidean, eval_auc_dot)
122 | else:
123 | raise ValueError("Unknown main_similarity value")
124 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
Trans-Encoder
5 |
6 |
7 |
8 | [arxiv]
9 | ·
10 | [amazon.science blog]
11 | ·
12 | [5min-video]
13 | ·
14 | [talk@RIKEN]
15 | ·
16 | [openreview]
17 |
18 |
19 |
20 |
21 |
22 | Code repo for **ICLR 2022** paper **_[Trans-Encoder: Unsupervised sentence-pair modelling through self- and mutual-distillations](https://arxiv.org/abs/2109.13059)_**
23 | by [Fangyu Liu](http://fangyuliu.me/about.html), [Yunlong Jiao](https://yunlongjiao.github.io/), [Jordan Massiah](https://www.linkedin.com/in/jordan-massiah-562862136/?originalSubdomain=uk), [Emine Yilmaz](https://sites.google.com/site/emineyilmaz/), [Serhii Havrylov](https://serhii-havrylov.github.io/).
24 |
25 | Trans-Encoder is a state-of-the-art unsupervised sentence similarity model. It conducts self-knowledge-distillation on top of pretrained language models by alternating between their bi- and cross-encoder forms.
26 |
27 | ## Huggingface pretrained models for STS
28 |
29 |
30 | | base models | large models |
31 | |
32 |
33 | |model | STS avg. |
34 | |--------|--------|
35 | |baseline: [unsup-simcse-bert-base](https://huggingface.co/princeton-nlp/unsup-simcse-bert-base-uncased) | 76.21 |
36 | | [trans-encoder-bi-simcse-bert-base](https://huggingface.co/cambridgeltl/trans-encoder-bi-simcse-bert-base) | 80.41 |
37 | | [trans-encoder-cross-simcse-bert-base](https://huggingface.co/cambridgeltl/trans-encoder-cross-simcse-bert-base) | 79.90 |
38 | |baseline: [unsup-simcse-roberta-base](https://huggingface.co/princeton-nlp/unsup-simcse-roberta-base) | 76.10 |
39 | | [trans-encoder-bi-simcse-roberta-base](https://huggingface.co/cambridgeltl/trans-encoder-bi-simcse-roberta-base) | 80.47 |
40 | | [trans-encoder-cross-simcse-roberta-base](https://huggingface.co/cambridgeltl/trans-encoder-cross-simcse-roberta-base) | **81.15** |
41 | |
42 |
43 | |model | STS avg. |
44 | |--------|--------|
45 | |baseline: [unsup-simcse-bert-large](https://huggingface.co/princeton-nlp/unsup-simcse-bert-large-uncased) | 78.42 |
46 | | [trans-encoder-bi-simcse-bert-large](https://huggingface.co/cambridgeltl/trans-encoder-bi-simcse-bert-large) | 82.65 |
47 | | [trans-encoder-cross-simcse-bert-large](https://huggingface.co/cambridgeltl/trans-encoder-cross-simcse-bert-large) | 82.52 |
48 | |baseline: [unsup-simcse-roberta-large](https://huggingface.co/princeton-nlp/unsup-simcse-roberta-large) | 78.92 |
49 | | [trans-encoder-bi-simcse-roberta-large](https://huggingface.co/cambridgeltl/trans-encoder-bi-simcse-roberta-large) | **82.93** |
50 | | [trans-encoder-cross-simcse-roberta-large](https://huggingface.co/cambridgeltl/trans-encoder-cross-simcse-roberta-large) | **82.93** |
51 |
52 |
53 | |
54 |
55 |
56 | ## Dependencies
57 |
58 | ```
59 | torch==1.8.1
60 | transformers==4.9.0
61 | sentence-transformers==2.0.0
62 | ```
63 | Please view [requirements.txt](https://github.com/amzn/trans-encoder/blob/main/requirements.txt) for more details.
64 |
65 | ## Data
66 | All training and evaluation data will be automatically downloaded when running the scripts. See [src/data.py](https://github.com/amzn/trans-encoder/blob/main/src/data.py) for details.
67 |
68 | ## Train
69 |
70 | `--task` options: `sts` (STS2012-2016 and STS-b), `sickr`, `sts_sickr` (STS2012-2016, STS-b, and SICK-R), `qqp`, `qnli`, `mrpc`, `snli`, `custom`. See [src/data.py](https://github.com/amzn/trans-encoder/blob/main/src/data.py) for task data details. By default using all STS data (`sts_sickr`).
71 |
72 | #### Self-distillation
73 | ```bash
74 | >> bash train_self_distill.sh 0
75 | ```
76 | `0` denotes GPU device index.
77 |
78 | #### Mutual-distillation
79 | ```bash
80 | >> bash train_mutual_distill.sh 0,1
81 | ```
82 | Two GPUs needed; by default using SimCSE BERT & RoBERTa base models for ensembling. Add `--use_large` for switching to large models.
83 |
84 | #### Train with your custom corpus
85 | ```bash
86 | >> CUDA_VISIBLE_DEVICES=0,1 python src/mutual_distill_parallel.py \
87 | --batch_size_bi_encoder 128 \
88 | --batch_size_cross_encoder 64 \
89 | --num_epochs_bi_encoder 10 \
90 | --num_epochs_cross_encoder 1 \
91 | --cycle 3 \
92 | --bi_encoder1_pooling_mode cls \
93 | --bi_encoder2_pooling_mode cls \
94 | --init_with_new_models \
95 | --task custom \
96 | --random_seed 2021 \
97 | --custom_corpus_path CORPUS_PATH
98 | ```
99 | `CORPUS_PATH` should point to your custom corpus in which every line should be a sentence pair in the form of `sent1||sent2`.
100 |
101 | ## Evaluate
102 | #### Evaluate a single model
103 |
104 | Bi-encoder:
105 | ```bash
106 | >> python src/eval.py \
107 | --model_name_or_path "cambridgeltl/trans-encoder-bi-simcse-roberta-large" \
108 | --mode bi \
109 | --task sts_sickr
110 | ```
111 | Cross-encoder:
112 | ```bash
113 | >> python src/eval.py \
114 | --model_name_or_path "cambridgeltl/trans-encoder-cross-simcse-roberta-large" \
115 | --mode cross \
116 | --task sts_sickr
117 | ```
118 | #### Evaluate ensemble
119 |
120 | Bi-encoder:
121 | ```bash
122 | >> python src/eval.py \
123 | --model_name_or_path1 "cambridgeltl/trans-encoder-bi-simcse-bert-large" \
124 | --model_name_or_path2 "cambridgeltl/trans-encoder-bi-simcse-roberta-large" \
125 | --mode bi \
126 | --ensemble \
127 | --task sts_sickr
128 | ```
129 |
130 | Cross-encoder:
131 | ```bash
132 | >> python src/eval.py \
133 | --model_name_or_path1 "cambridgeltl/trans-encoder-cross-simcse-bert-large" \
134 | --model_name_or_path2 "cambridgeltl/trans-encoder-cross-simcse-roberta-large" \
135 | --mode cross \
136 | --ensemble \
137 | --task sts_sickr
138 | ```
139 |
140 | ## Authors
141 |
142 | - [**Fangyu Liu**](http://fangyuliu.me/about.html): Main contributor
143 |
144 | ## Security
145 |
146 | See [CONTRIBUTING](CONTRIBUTING.md#security-issue-notifications) for more information.
147 |
148 | ## License
149 |
150 | This project is licensed under the Apache-2.0 License.
151 |
152 |
--------------------------------------------------------------------------------
/src/sentence_transformers_ext/bi_encoder_eval/EmbeddingSimilarityEvaluatorAUCEnsemble.py:
--------------------------------------------------------------------------------
1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 |
4 | from sentence_transformers.evaluation import (
5 | SentenceEvaluator,
6 | SimilarityFunction
7 | )
8 | import logging
9 | import os
10 | import csv
11 | from sklearn.metrics.pairwise import paired_cosine_distances, paired_euclidean_distances, paired_manhattan_distances
12 | from scipy.stats import pearsonr, spearmanr
13 | import numpy as np
14 | from typing import List
15 | from sklearn.metrics import roc_auc_score
16 | from sentence_transformers.readers import InputExample
17 |
18 | from ..utils import write_csv_log
19 |
20 | logger = logging.getLogger(__name__)
21 |
22 | class EmbeddingSimilarityEvaluatorAUCEnsemble(SentenceEvaluator):
23 | """
24 | Evaluate a model based on the similarity of the embeddings by calculating the Spearman and Pearson rank correlation
25 | in comparison to the gold standard labels.
26 | The metrics are the cosine similarity as well as euclidean and Manhattan distance
27 | The returned score is the Spearman correlation with a specified metric.
28 |
29 | The results are written in a CSV. If a CSV already exists, then values are appended.
30 | """
31 | def __init__(self, sentences1: List[str], sentences2: List[str], scores: List[float], batch_size: int = 16, main_similarity: SimilarityFunction = SimilarityFunction.COSINE, name: str = '', show_progress_bar: bool = False, write_csv: bool = True):
32 | """
33 | Constructs an evaluator based for the dataset
34 |
35 | The labels need to indicate the similarity between the sentences.
36 |
37 | :param sentences1: List with the first sentence in a pair
38 | :param sentences2: List with the second sentence in a pair
39 | :param scores: Similarity score between sentences1[i] and sentences2[i]
40 | :param write_csv: Write results to a CSV file
41 | """
42 | self.sentences1 = sentences1
43 | self.sentences2 = sentences2
44 | self.scores = scores
45 | self.write_csv = write_csv
46 |
47 | assert len(self.sentences1) == len(self.sentences2)
48 | assert len(self.sentences1) == len(self.scores)
49 |
50 | self.main_similarity = main_similarity
51 | self.name = name
52 |
53 | self.batch_size = batch_size
54 | if show_progress_bar is None:
55 | show_progress_bar = (logger.getEffectiveLevel() == logging.INFO or logger.getEffectiveLevel() == logging.DEBUG)
56 | self.show_progress_bar = show_progress_bar
57 |
58 | self.csv_file = self.__class__.__name__ + ("_"+name if name else '')+"_results.csv"
59 | self.csv_headers = ["epoch", "steps", "cosine_auc", "euclidean_auc", "manhattan_auc", "dot_auc"]
60 |
61 | @classmethod
62 | def from_input_examples(cls, examples: List[InputExample], **kwargs):
63 | sentences1 = []
64 | sentences2 = []
65 | scores = []
66 |
67 | for example in examples:
68 | sentences1.append(example.texts[0])
69 | sentences2.append(example.texts[1])
70 | scores.append(example.label)
71 | return cls(sentences1, sentences2, scores, **kwargs)
72 |
73 |
74 | def __call__(self, models, output_path: str = None, epoch: int = -1, steps: int = -1) -> float:
75 | if epoch != -1:
76 | if steps == -1:
77 | out_txt = " after epoch {}:".format(epoch)
78 | else:
79 | out_txt = " in epoch {} after {} steps:".format(epoch, steps)
80 | else:
81 | out_txt = ":"
82 |
83 | logger.info(self.__class__.__name__+": Evaluating the model on " + self.name + " dataset" + out_txt)
84 |
85 | labels = self.scores
86 |
87 | auc_cosine_scores_all_models = []
88 | auc_manhattan_distances_all_models = []
89 | auc_euclidean_distances_all_models = []
90 | auc_dot_products_all_models = []
91 |
92 | # compute average predictions of all models
93 | for model in models:
94 | embeddings1 = model.encode(self.sentences1, batch_size=self.batch_size, show_progress_bar=self.show_progress_bar, convert_to_numpy=True)
95 | embeddings2 = model.encode(self.sentences2, batch_size=self.batch_size, show_progress_bar=self.show_progress_bar, convert_to_numpy=True)
96 |
97 | cosine_scores = 1 - (paired_cosine_distances(embeddings1, embeddings2))
98 | manhattan_distances = -paired_manhattan_distances(embeddings1, embeddings2)
99 | euclidean_distances = -paired_euclidean_distances(embeddings1, embeddings2)
100 | dot_products = [np.dot(emb1, emb2) for emb1, emb2 in zip(embeddings1, embeddings2)]
101 | auc_cosine_scores_all_models.append(roc_auc_score(labels, cosine_scores))
102 | auc_manhattan_distances_all_models.append(roc_auc_score(labels, manhattan_distances))
103 | auc_euclidean_distances_all_models.append(roc_auc_score(labels, euclidean_distances))
104 | auc_dot_products_all_models.append(roc_auc_score(labels, dot_products))
105 |
106 | eval_auc_cosine = np.array(auc_cosine_scores_all_models).mean(0)
107 | eval_auc_manhattan = np.array(auc_manhattan_distances_all_models).mean(0)
108 | eval_auc_euclidean = np.array(auc_euclidean_distances_all_models).mean(0)
109 | eval_auc_dot = np.array(auc_dot_products_all_models).mean(0)
110 |
111 | logger.info("Cosine-Similarity AUC: {:.4f}".format(eval_auc_cosine))
112 | logger.info("Manhattan-Distance AUC: {:.4f}".format(eval_auc_manhattan))
113 | logger.info("Euclidean-Distance AUC: {:.4f}".format(eval_auc_euclidean))
114 | logger.info("Dot-Product-Similarity AUC: {:.4f}".format(eval_auc_dot))
115 |
116 | if output_path is not None and self.write_csv:
117 | things_to_write = [epoch, steps, eval_auc_cosine, eval_auc_euclidean,
118 | eval_auc_manhattan, eval_auc_dot]
119 | write_csv_log(output_path=output_path, csv_file=self.csv_file, csv_headers=self.csv_headers, things_to_write=things_to_write)
120 |
121 | if self.main_similarity == SimilarityFunction.COSINE:
122 | return eval_auc_cosine
123 | elif self.main_similarity == SimilarityFunction.EUCLIDEAN:
124 | return eval_auc_euclidean
125 | elif self.main_similarity == SimilarityFunction.MANHATTAN:
126 | return eval_auc_manhattan
127 | elif self.main_similarity == SimilarityFunction.DOT_PRODUCT:
128 | return eval_spearman_dot
129 | elif self.main_similarity is None:
130 | return max(eval_auc_cosine, eval_auc_manhattan, eval_auc_euclidean, eval_auc_dot)
131 | else:
132 | raise ValueError("Unknown main_similarity value")
133 |
--------------------------------------------------------------------------------
/src/sentence_transformers_ext/bi_encoder_eval/EmbeddingSimilarityEvaluator.py:
--------------------------------------------------------------------------------
1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 |
4 | from sentence_transformers.evaluation import (
5 | SentenceEvaluator,
6 | SimilarityFunction
7 | )
8 | import logging
9 | import os
10 | import csv
11 | from sklearn.metrics.pairwise import paired_cosine_distances, paired_euclidean_distances, paired_manhattan_distances
12 | from scipy.stats import pearsonr, spearmanr
13 | import numpy as np
14 | from typing import List
15 | from sentence_transformers.readers import InputExample
16 |
17 | from ..utils import write_csv_log
18 |
19 | logger = logging.getLogger(__name__)
20 |
21 | class EmbeddingSimilarityEvaluator(SentenceEvaluator):
22 | """
23 | Evaluate a model based on the similarity of the embeddings by calculating the Spearman and Pearson rank correlation
24 | in comparison to the gold standard labels.
25 | The metrics are the cosine similarity as well as euclidean and Manhattan distance
26 | The returned score is the Spearman correlation with a specified metric.
27 |
28 | The results are written in a CSV. If a CSV already exists, then values are appended.
29 | """
30 | def __init__(self, sentences1: List[str], sentences2: List[str], scores: List[float], batch_size: int = 16, main_similarity: SimilarityFunction = SimilarityFunction.COSINE, name: str = '', show_progress_bar: bool = False, write_csv: bool = True):
31 | """
32 | Constructs an evaluator based for the dataset
33 |
34 | The labels need to indicate the similarity between the sentences.
35 |
36 | :param sentences1: List with the first sentence in a pair
37 | :param sentences2: List with the second sentence in a pair
38 | :param scores: Similarity score between sentences1[i] and sentences2[i]
39 | :param write_csv: Write results to a CSV file
40 | """
41 | self.sentences1 = sentences1
42 | self.sentences2 = sentences2
43 | self.scores = scores
44 | self.write_csv = write_csv
45 |
46 | assert len(self.sentences1) == len(self.sentences2)
47 | assert len(self.sentences1) == len(self.scores)
48 |
49 | self.main_similarity = main_similarity
50 | self.name = name
51 |
52 | self.batch_size = batch_size
53 | if show_progress_bar is None:
54 | show_progress_bar = (logger.getEffectiveLevel() == logging.INFO or logger.getEffectiveLevel() == logging.DEBUG)
55 | self.show_progress_bar = show_progress_bar
56 |
57 | self.csv_file = self.__class__.__name__ + ("_"+name if name else '')+"_results.csv"
58 | self.csv_headers = ["epoch", "steps", "cosine_pearson", "cosine_spearman", "euclidean_pearson", "euclidean_spearman", "manhattan_pearson", "manhattan_spearman", "dot_pearson", "dot_spearman"]
59 |
60 | @classmethod
61 | def from_input_examples(cls, examples: List[InputExample], **kwargs):
62 | sentences1 = []
63 | sentences2 = []
64 | scores = []
65 |
66 | for example in examples:
67 | sentences1.append(example.texts[0])
68 | sentences2.append(example.texts[1])
69 | scores.append(example.label)
70 | return cls(sentences1, sentences2, scores, **kwargs)
71 |
72 |
73 | def __call__(self, model, output_path: str = None, epoch: int = -1, steps: int = -1) -> float:
74 | if epoch != -1:
75 | if steps == -1:
76 | out_txt = " after epoch {}:".format(epoch)
77 | else:
78 | out_txt = " in epoch {} after {} steps:".format(epoch, steps)
79 | else:
80 | out_txt = ":"
81 |
82 | logger.info(self.__class__.__name__+": Evaluating the model on " + self.name + " dataset" + out_txt)
83 |
84 | embeddings1 = model.encode(self.sentences1, batch_size=self.batch_size, show_progress_bar=self.show_progress_bar, convert_to_numpy=True)
85 | embeddings2 = model.encode(self.sentences2, batch_size=self.batch_size, show_progress_bar=self.show_progress_bar, convert_to_numpy=True)
86 | labels = self.scores
87 |
88 | cosine_scores = 1 - (paired_cosine_distances(embeddings1, embeddings2))
89 |
90 | manhattan_distances = -paired_manhattan_distances(embeddings1, embeddings2)
91 | euclidean_distances = -paired_euclidean_distances(embeddings1, embeddings2)
92 | dot_products = [np.dot(emb1, emb2) for emb1, emb2 in zip(embeddings1, embeddings2)]
93 |
94 | eval_pearson_cosine, _ = pearsonr(labels, cosine_scores)
95 | eval_spearman_cosine, _ = spearmanr(labels, cosine_scores)
96 |
97 | eval_pearson_manhattan, _ = pearsonr(labels, manhattan_distances)
98 | eval_spearman_manhattan, _ = spearmanr(labels, manhattan_distances)
99 |
100 | eval_pearson_euclidean, _ = pearsonr(labels, euclidean_distances)
101 | eval_spearman_euclidean, _ = spearmanr(labels, euclidean_distances)
102 |
103 | eval_pearson_dot, _ = pearsonr(labels, dot_products)
104 | eval_spearman_dot, _ = spearmanr(labels, dot_products)
105 |
106 | logger.info("Cosine-Similarity :\tPearson: {:.4f}\tSpearman: {:.4f}".format(
107 | eval_pearson_cosine, eval_spearman_cosine))
108 | logger.info("Manhattan-Distance:\tPearson: {:.4f}\tSpearman: {:.4f}".format(
109 | eval_pearson_manhattan, eval_spearman_manhattan))
110 | logger.info("Euclidean-Distance:\tPearson: {:.4f}\tSpearman: {:.4f}".format(
111 | eval_pearson_euclidean, eval_spearman_euclidean))
112 | logger.info("Dot-Product-Similarity:\tPearson: {:.4f}\tSpearman: {:.4f}".format(
113 | eval_pearson_dot, eval_spearman_dot))
114 |
115 | if output_path is not None and self.write_csv:
116 | things_to_write = [epoch, steps, eval_pearson_cosine, eval_spearman_cosine, eval_pearson_euclidean,
117 | eval_spearman_euclidean, eval_pearson_manhattan, eval_spearman_manhattan, eval_pearson_dot, eval_spearman_dot]
118 | write_csv_log(output_path=output_path, csv_file=self.csv_file, csv_headers=self.csv_headers, things_to_write=things_to_write)
119 |
120 | if self.main_similarity == SimilarityFunction.COSINE:
121 | return eval_spearman_cosine
122 | elif self.main_similarity == SimilarityFunction.EUCLIDEAN:
123 | return eval_spearman_euclidean
124 | elif self.main_similarity == SimilarityFunction.MANHATTAN:
125 | return eval_spearman_manhattan
126 | elif self.main_similarity == SimilarityFunction.DOT_PRODUCT:
127 | return eval_spearman_dot
128 | elif self.main_similarity is None:
129 | return max(eval_spearman_cosine, eval_spearman_manhattan, eval_spearman_euclidean, eval_spearman_dot)
130 | else:
131 | raise ValueError("Unknown main_similarity value")
132 |
--------------------------------------------------------------------------------
/src/sentence_transformers_ext/bi_encoder_eval/EmbeddingSimilarityEvaluatorEnsemble.py:
--------------------------------------------------------------------------------
1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 |
4 | from sentence_transformers.evaluation import (
5 | SentenceEvaluator,
6 | SimilarityFunction
7 | )
8 | import logging
9 | import os
10 | import csv
11 | from sklearn.metrics.pairwise import paired_cosine_distances, paired_euclidean_distances, paired_manhattan_distances
12 | from scipy.stats import pearsonr, spearmanr
13 | import numpy as np
14 | from typing import List
15 |
16 | from sentence_transformers.readers import InputExample
17 |
18 | from ..utils import write_csv_log
19 |
20 | logger = logging.getLogger(__name__)
21 |
22 | class EmbeddingSimilarityEvaluatorEnsemble(SentenceEvaluator):
23 | """
24 | Evaluate a model based on the similarity of the embeddings by calculating the Spearman and Pearson rank correlation
25 | in comparison to the gold standard labels.
26 | The metrics are the cosine similarity as well as euclidean and Manhattan distance
27 | The returned score is the Spearman correlation with a specified metric.
28 |
29 | The results are written in a CSV. If a CSV already exists, then values are appended.
30 | """
31 | def __init__(self, sentences1: List[str], sentences2: List[str], scores: List[float], batch_size: int = 16, main_similarity: SimilarityFunction = None, name: str = '', show_progress_bar: bool = False, write_csv: bool = True):
32 | """
33 | Constructs an evaluator based for the dataset
34 |
35 | The labels need to indicate the similarity between the sentences.
36 |
37 | :param sentences1: List with the first sentence in a pair
38 | :param sentences2: List with the second sentence in a pair
39 | :param scores: Similarity score between sentences1[i] and sentences2[i]
40 | :param write_csv: Write results to a CSV file
41 | """
42 | self.sentences1 = sentences1
43 | self.sentences2 = sentences2
44 | self.scores = scores
45 | self.write_csv = write_csv
46 |
47 | assert len(self.sentences1) == len(self.sentences2)
48 | assert len(self.sentences1) == len(self.scores)
49 |
50 | self.main_similarity = main_similarity
51 | self.name = name
52 |
53 | self.batch_size = batch_size
54 | if show_progress_bar is None:
55 | show_progress_bar = (logger.getEffectiveLevel() == logging.INFO or logger.getEffectiveLevel() == logging.DEBUG)
56 | self.show_progress_bar = show_progress_bar
57 |
58 | self.csv_file = self.__class__.__name__ + ("_"+name if name else '')+"_results.csv"
59 | self.csv_headers = ["epoch", "steps", "cosine_pearson", "cosine_spearman", "euclidean_pearson", "euclidean_spearman", "manhattan_pearson", "manhattan_spearman", "dot_pearson", "dot_spearman"]
60 |
61 | @classmethod
62 | def from_input_examples(cls, examples: List[InputExample], **kwargs):
63 | sentences1 = []
64 | sentences2 = []
65 | scores = []
66 |
67 | for example in examples:
68 | sentences1.append(example.texts[0])
69 | sentences2.append(example.texts[1])
70 | scores.append(example.label)
71 | return cls(sentences1, sentences2, scores, **kwargs)
72 |
73 |
74 | def __call__(self, models, output_path: str = None, epoch: int = -1, steps: int = -1) -> float:
75 | if epoch != -1:
76 | if steps == -1:
77 | out_txt = " after epoch {}:".format(epoch)
78 | else:
79 | out_txt = " in epoch {} after {} steps:".format(epoch, steps)
80 | else:
81 | out_txt = ":"
82 |
83 | logger.info(self.__class__.__name__+": Evaluating the model on " + self.name + " dataset" + out_txt)
84 |
85 | labels = self.scores
86 |
87 | cosine_scores_all_models = []
88 | manhattan_distances_all_models = []
89 | euclidean_distances_all_models = []
90 | dot_products_all_models = []
91 |
92 | # compute average predictions of all models
93 | for model in models:
94 | embeddings1 = model.encode(self.sentences1, batch_size=self.batch_size, show_progress_bar=self.show_progress_bar, convert_to_numpy=True)
95 | embeddings2 = model.encode(self.sentences2, batch_size=self.batch_size, show_progress_bar=self.show_progress_bar, convert_to_numpy=True)
96 |
97 | cosine_scores = 1 - (paired_cosine_distances(embeddings1, embeddings2))
98 | manhattan_distances = -paired_manhattan_distances(embeddings1, embeddings2)
99 | euclidean_distances = -paired_euclidean_distances(embeddings1, embeddings2)
100 | dot_products = [np.dot(emb1, emb2) for emb1, emb2 in zip(embeddings1, embeddings2)]
101 | cosine_scores_all_models.append(cosine_scores)
102 | manhattan_distances_all_models.append(manhattan_distances)
103 | euclidean_distances_all_models.append(euclidean_distances)
104 | dot_products_all_models.append(dot_products)
105 |
106 | cosine_scores = np.array(cosine_scores_all_models).mean(0)
107 | manhattan_distances = np.array(manhattan_distances_all_models).mean(0)
108 | euclidean_distances = np.array(euclidean_distances_all_models).mean(0)
109 | dot_products = np.array(dot_products_all_models).mean(0)
110 |
111 | eval_pearson_cosine, _ = pearsonr(labels, cosine_scores)
112 | eval_spearman_cosine, _ = spearmanr(labels, cosine_scores)
113 |
114 | eval_pearson_manhattan, _ = pearsonr(labels, manhattan_distances)
115 | eval_spearman_manhattan, _ = spearmanr(labels, manhattan_distances)
116 |
117 | eval_pearson_euclidean, _ = pearsonr(labels, euclidean_distances)
118 | eval_spearman_euclidean, _ = spearmanr(labels, euclidean_distances)
119 |
120 | eval_pearson_dot, _ = pearsonr(labels, dot_products)
121 | eval_spearman_dot, _ = spearmanr(labels, dot_products)
122 |
123 | logger.info("Cosine-Similarity :\tPearson: {:.4f}\tSpearman: {:.4f}".format(
124 | eval_pearson_cosine, eval_spearman_cosine))
125 | logger.info("Manhattan-Distance:\tPearson: {:.4f}\tSpearman: {:.4f}".format(
126 | eval_pearson_manhattan, eval_spearman_manhattan))
127 | logger.info("Euclidean-Distance:\tPearson: {:.4f}\tSpearman: {:.4f}".format(
128 | eval_pearson_euclidean, eval_spearman_euclidean))
129 | logger.info("Dot-Product-Similarity:\tPearson: {:.4f}\tSpearman: {:.4f}".format(
130 | eval_pearson_dot, eval_spearman_dot))
131 |
132 | if output_path is not None and self.write_csv:
133 | things_to_write = [epoch, steps, eval_pearson_cosine, eval_spearman_cosine, eval_pearson_euclidean,
134 | eval_spearman_euclidean, eval_pearson_manhattan, eval_spearman_manhattan, eval_pearson_dot, eval_spearman_dot]
135 | write_csv_log(output_path=output_path, csv_file=self.csv_file, csv_headers=self.csv_headers, things_to_write=things_to_write)
136 |
137 | if self.main_similarity == SimilarityFunction.COSINE:
138 | return eval_spearman_cosine
139 | elif self.main_similarity == SimilarityFunction.EUCLIDEAN:
140 | return eval_spearman_euclidean
141 | elif self.main_similarity == SimilarityFunction.MANHATTAN:
142 | return eval_spearman_manhattan
143 | elif self.main_similarity == SimilarityFunction.DOT_PRODUCT:
144 | return eval_spearman_dot
145 | elif self.main_similarity is None:
146 | return max(eval_spearman_cosine, eval_spearman_manhattan, eval_spearman_euclidean, eval_spearman_dot)
147 | else:
148 | raise ValueError("Unknown main_similarity value")
149 |
--------------------------------------------------------------------------------
/src/eval.py:
--------------------------------------------------------------------------------
1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 |
4 | from torch.utils.data import DataLoader
5 | import torch.nn.functional as F
6 | from sentence_transformers import models, losses, util, SentenceTransformer
7 | from sentence_transformers.cross_encoder import CrossEncoder
8 | from sentence_transformers import LoggingHandler, SentenceTransformer
9 | from sentence_transformers.cross_encoder.evaluation import CECorrelationEvaluator
10 | from sentence_transformers.readers import InputExample
11 | from datetime import datetime
12 | import argparse
13 | import logging
14 | import csv
15 | import torch
16 | import sys
17 | import random
18 | import tqdm
19 | import gzip
20 | import os
21 | import pandas as pd
22 | import numpy as np
23 |
24 | from data import load_data
25 |
26 | from sentence_transformers_ext.bi_encoder_eval import (
27 | EmbeddingSimilarityEvaluator,
28 | EmbeddingSimilarityEvaluatorEnsemble,
29 | EmbeddingSimilarityEvaluatorAUC,
30 | EmbeddingSimilarityEvaluatorAUCEnsemble
31 | )
32 | from sentence_transformers_ext.cross_encoder_eval import (
33 | CECorrelationEvaluatorEnsemble,
34 | CECorrelationEvaluatorAUC,
35 | CECorrelationEvaluatorAUCEnsemble
36 | )
37 |
38 |
39 | def eval_encoder(all_test, encoder, task="sts", enc_type="bi", ensemble=False):
40 | """
41 | Evaluate bi- or cross-encoders.
42 | Parameters
43 | ----------
44 | all_test: a dict of all test sets
45 | encoder: a bi- or cross-enocder
46 | enc_type: a string specifying whether the encoder is a bi- or cross-encoder
47 | ensemble: a bool value indicating whether multiple encoders are used in input
48 | Returns
49 | ----------
50 | None
51 | """
52 | scores = []
53 | for name, data in all_test.items():
54 | if task in ["sts", "sickr", "sts_sickr", "custom"]:
55 | if enc_type == "bi":
56 | if ensemble:
57 | test_evaluator = EmbeddingSimilarityEvaluatorEnsemble.from_input_examples(data, name=name)
58 | else:
59 | test_evaluator = EmbeddingSimilarityEvaluator.from_input_examples(data, name=name)
60 | elif enc_type == "cross":
61 | if ensemble:
62 | test_evaluator = CECorrelationEvaluatorEnsemble.from_input_examples(data, name=name)
63 | else:
64 | test_evaluator = CECorrelationEvaluator.from_input_examples(data, name=name)
65 | else:
66 | raise NotImplementedError()
67 | else:
68 | if enc_type == "bi":
69 | if ensemble:
70 | test_evaluator = EmbeddingSimilarityEvaluatorAUCEnsemble.from_input_examples(data, name=name)
71 | else:
72 | test_evaluator = EmbeddingSimilarityEvaluatorAUC.from_input_examples(data, name=name)
73 | elif enc_type == "cross":
74 | if ensemble:
75 | test_evaluator = CECorrelationEvaluatorAUCEnsemble.from_input_examples(data, name=name)
76 | else:
77 | test_evaluator = CECorrelationEvaluatorAUC.from_input_examples(data, name=name)
78 | else:
79 | raise NotImplementedError()
80 | scores += [test_evaluator(encoder)]
81 | logging.info (" & ".join(["%.2f" % (s*100) for s in scores]))
82 | logging.info (f"***** test's avg spearman's rho: {sum(scores)/len(scores):.4f} ****")
83 |
84 |
85 | def main():
86 |
87 | #### Just some code to print debug information to stdout
88 | logging.basicConfig(format='%(asctime)s - %(message)s',
89 | datefmt='%Y-%m-%d %H:%M:%S',
90 | level=logging.INFO,
91 | handlers=[LoggingHandler()])
92 |
93 |
94 | parser = argparse.ArgumentParser()
95 | parser.add_argument("--model_name_or_path", type=str,
96 | default="princeton-nlp/unsup-simcse-bert-base-uncased",
97 | help="Transformers' model name or path")
98 | parser.add_argument("--task", type=str, default="sts",
99 | help='{sts|sickr|sts_sickr|qqp|qnli|mrpc|snli|custom}')
100 | parser.add_argument("--mode", type=str, default='bi', help="{cross|bi}")
101 | parser.add_argument("--device", type=int, default=0)
102 | parser.add_argument("--bi_encoder_pooling_mode", type=str, default='cls',
103 | help="{cls|mean}")
104 | parser.add_argument("--ensemble", action="store_true")
105 | parser.add_argument("--model_name_or_path1", type=str,
106 | default="princeton-nlp/unsup-simcse-bert-base-uncased")
107 | parser.add_argument("--model_name_or_path2", type=str,
108 | default="princeton-nlp/unsup-simcse-roberta-base")
109 | parser.add_argument("--bi_encoder_pooling_mode1", type=str, default="cls")
110 | parser.add_argument("--bi_encoder_pooling_mode2", type=str, default="cls")
111 | parser.add_argument("--quick_test", action="store_true")
112 |
113 | args = parser.parse_args()
114 | print (args)
115 |
116 | ### read datasets
117 | all_pairs, all_test, dev_samples = load_data(args.task)
118 |
119 | if args.quick_test:
120 | all_pairs = all_pairs[:5000] # for quick testing
121 |
122 | print ("|raw sentence pairs|:", len(all_pairs))
123 | print ("|dev set|:", len(dev_samples))
124 | for key in all_test:
125 | print ("|test set: %s|" % key, len(all_test[key]))
126 |
127 | model_name = args.model_name_or_path
128 | model_name1 = args.model_name_or_path1
129 | model_name2 = args.model_name_or_path2
130 |
131 | max_seq_length = 32
132 | device=args.device
133 |
134 | if not args.ensemble:
135 | logging.info ("########## load model and evaluate ##########")
136 |
137 | if args.mode == "bi":
138 |
139 | ###### Bi-encoder (sentence-transformers) ######
140 | logging.info(f"Loading bi-encoder model: {model_name}")
141 | # Use Huggingface/transformers model (like BERT, RoBERTa, XLNet, XLM-R) for mapping tokens to embeddings
142 | word_embedding_model = models.Transformer(model_name, max_seq_length=max_seq_length)
143 |
144 | # Apply mean pooling to get one fixed sized sentence vector
145 | pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(), pooling_mode=args.bi_encoder_pooling_mode)
146 |
147 | bi_encoder = SentenceTransformer(modules=[word_embedding_model, pooling_model], device=device)
148 |
149 | # eval bi-encoder
150 | logging.info ("Evaluate bi-encoder...")
151 | eval_encoder(all_test, bi_encoder, task=args.task, enc_type="bi")
152 |
153 |
154 | elif args.mode == "cross":
155 |
156 | ###### cross-encoder (sentence-transformers) ######
157 | logging.info(f"Loading cross-encoder model: {model_name}")
158 |
159 | cross_encoder = CrossEncoder(model_name, device=device)
160 |
161 | # eval cross-encoder
162 | logging.info ("Evaluate cross-encoder...")
163 | eval_encoder(all_test, cross_encoder, task=args.task, enc_type="cross")
164 |
165 | else:
166 | raise NotImplementedError()
167 |
168 | else:
169 | logging.info ("########## load models and evaluate ##########")
170 |
171 | if args.mode == "bi":
172 | ###### Bi-encoder (sentence-transformers) ######
173 | logging.info(f"Loading bi-encoder1 model: {model_name1}")
174 | logging.info(f"Loading bi-encoder2 model: {model_name2}")
175 |
176 | # Use Huggingface/transformers model (like BERT, RoBERTa, XLNet, XLM-R) for mapping tokens to embeddings
177 | word_embedding_model1 = models.Transformer(model_name1, max_seq_length=max_seq_length)
178 | word_embedding_model2 = models.Transformer(model_name2, max_seq_length=max_seq_length)
179 |
180 | # Apply mean pooling to get one fixed sized sentence vector
181 | pooling_model1 = models.Pooling(word_embedding_model1.get_word_embedding_dimension(), pooling_mode=args.bi_encoder_pooling_mode1)
182 | pooling_model2 = models.Pooling(word_embedding_model2.get_word_embedding_dimension(), pooling_mode=args.bi_encoder_pooling_mode2)
183 |
184 | bi_encoder1 = SentenceTransformer(modules=[word_embedding_model1, pooling_model1], device=device)
185 | bi_encoder2 = SentenceTransformer(modules=[word_embedding_model2, pooling_model2], device=device)
186 |
187 | # eval bi-encoder
188 | logging.info ("Evaluate bi-encoder (ensembled)...")
189 | eval_encoder(all_test, [bi_encoder1, bi_encoder2], task=args.task, enc_type="bi", ensemble=True)
190 |
191 |
192 | elif args.mode == "cross":
193 |
194 | ###### cross-encoder (sentence-transformers) ######
195 | logging.info(f"Loading cross-encoder1 model: {model_name1}")
196 | logging.info(f"Loading cross-encoder2 model: {model_name2}")
197 |
198 | cross_encoder1 = CrossEncoder(model_name1, device=device)
199 | cross_encoder2 = CrossEncoder(model_name2, device=device)
200 |
201 | # eval cross-encoder
202 | logging.info ("Evaluate cross-encoder (ensembled)...")
203 | eval_encoder(all_test, [cross_encoder1, cross_encoder2], task=args.task, enc_type="cross", ensemble=True)
204 |
205 | else:
206 | raise NotImplementedError()
207 |
208 | if __name__ == "__main__":
209 | main()
210 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 |
2 | Apache License
3 | Version 2.0, January 2004
4 | http://www.apache.org/licenses/
5 |
6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7 |
8 | 1. Definitions.
9 |
10 | "License" shall mean the terms and conditions for use, reproduction,
11 | and distribution as defined by Sections 1 through 9 of this document.
12 |
13 | "Licensor" shall mean the copyright owner or entity authorized by
14 | the copyright owner that is granting the License.
15 |
16 | "Legal Entity" shall mean the union of the acting entity and all
17 | other entities that control, are controlled by, or are under common
18 | control with that entity. For the purposes of this definition,
19 | "control" means (i) the power, direct or indirect, to cause the
20 | direction or management of such entity, whether by contract or
21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
22 | outstanding shares, or (iii) beneficial ownership of such entity.
23 |
24 | "You" (or "Your") shall mean an individual or Legal Entity
25 | exercising permissions granted by this License.
26 |
27 | "Source" form shall mean the preferred form for making modifications,
28 | including but not limited to software source code, documentation
29 | source, and configuration files.
30 |
31 | "Object" form shall mean any form resulting from mechanical
32 | transformation or translation of a Source form, including but
33 | not limited to compiled object code, generated documentation,
34 | and conversions to other media types.
35 |
36 | "Work" shall mean the work of authorship, whether in Source or
37 | Object form, made available under the License, as indicated by a
38 | copyright notice that is included in or attached to the work
39 | (an example is provided in the Appendix below).
40 |
41 | "Derivative Works" shall mean any work, whether in Source or Object
42 | form, that is based on (or derived from) the Work and for which the
43 | editorial revisions, annotations, elaborations, or other modifications
44 | represent, as a whole, an original work of authorship. For the purposes
45 | of this License, Derivative Works shall not include works that remain
46 | separable from, or merely link (or bind by name) to the interfaces of,
47 | the Work and Derivative Works thereof.
48 |
49 | "Contribution" shall mean any work of authorship, including
50 | the original version of the Work and any modifications or additions
51 | to that Work or Derivative Works thereof, that is intentionally
52 | submitted to Licensor for inclusion in the Work by the copyright owner
53 | or by an individual or Legal Entity authorized to submit on behalf of
54 | the copyright owner. For the purposes of this definition, "submitted"
55 | means any form of electronic, verbal, or written communication sent
56 | to the Licensor or its representatives, including but not limited to
57 | communication on electronic mailing lists, source code control systems,
58 | and issue tracking systems that are managed by, or on behalf of, the
59 | Licensor for the purpose of discussing and improving the Work, but
60 | excluding communication that is conspicuously marked or otherwise
61 | designated in writing by the copyright owner as "Not a Contribution."
62 |
63 | "Contributor" shall mean Licensor and any individual or Legal Entity
64 | on behalf of whom a Contribution has been received by Licensor and
65 | subsequently incorporated within the Work.
66 |
67 | 2. Grant of Copyright License. Subject to the terms and conditions of
68 | this License, each Contributor hereby grants to You a perpetual,
69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70 | copyright license to reproduce, prepare Derivative Works of,
71 | publicly display, publicly perform, sublicense, and distribute the
72 | Work and such Derivative Works in Source or Object form.
73 |
74 | 3. Grant of Patent License. Subject to the terms and conditions of
75 | this License, each Contributor hereby grants to You a perpetual,
76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77 | (except as stated in this section) patent license to make, have made,
78 | use, offer to sell, sell, import, and otherwise transfer the Work,
79 | where such license applies only to those patent claims licensable
80 | by such Contributor that are necessarily infringed by their
81 | Contribution(s) alone or by combination of their Contribution(s)
82 | with the Work to which such Contribution(s) was submitted. If You
83 | institute patent litigation against any entity (including a
84 | cross-claim or counterclaim in a lawsuit) alleging that the Work
85 | or a Contribution incorporated within the Work constitutes direct
86 | or contributory patent infringement, then any patent licenses
87 | granted to You under this License for that Work shall terminate
88 | as of the date such litigation is filed.
89 |
90 | 4. Redistribution. You may reproduce and distribute copies of the
91 | Work or Derivative Works thereof in any medium, with or without
92 | modifications, and in Source or Object form, provided that You
93 | meet the following conditions:
94 |
95 | (a) You must give any other recipients of the Work or
96 | Derivative Works a copy of this License; and
97 |
98 | (b) You must cause any modified files to carry prominent notices
99 | stating that You changed the files; and
100 |
101 | (c) You must retain, in the Source form of any Derivative Works
102 | that You distribute, all copyright, patent, trademark, and
103 | attribution notices from the Source form of the Work,
104 | excluding those notices that do not pertain to any part of
105 | the Derivative Works; and
106 |
107 | (d) If the Work includes a "NOTICE" text file as part of its
108 | distribution, then any Derivative Works that You distribute must
109 | include a readable copy of the attribution notices contained
110 | within such NOTICE file, excluding those notices that do not
111 | pertain to any part of the Derivative Works, in at least one
112 | of the following places: within a NOTICE text file distributed
113 | as part of the Derivative Works; within the Source form or
114 | documentation, if provided along with the Derivative Works; or,
115 | within a display generated by the Derivative Works, if and
116 | wherever such third-party notices normally appear. The contents
117 | of the NOTICE file are for informational purposes only and
118 | do not modify the License. You may add Your own attribution
119 | notices within Derivative Works that You distribute, alongside
120 | or as an addendum to the NOTICE text from the Work, provided
121 | that such additional attribution notices cannot be construed
122 | as modifying the License.
123 |
124 | You may add Your own copyright statement to Your modifications and
125 | may provide additional or different license terms and conditions
126 | for use, reproduction, or distribution of Your modifications, or
127 | for any such Derivative Works as a whole, provided Your use,
128 | reproduction, and distribution of the Work otherwise complies with
129 | the conditions stated in this License.
130 |
131 | 5. Submission of Contributions. Unless You explicitly state otherwise,
132 | any Contribution intentionally submitted for inclusion in the Work
133 | by You to the Licensor shall be under the terms and conditions of
134 | this License, without any additional terms or conditions.
135 | Notwithstanding the above, nothing herein shall supersede or modify
136 | the terms of any separate license agreement you may have executed
137 | with Licensor regarding such Contributions.
138 |
139 | 6. Trademarks. This License does not grant permission to use the trade
140 | names, trademarks, service marks, or product names of the Licensor,
141 | except as required for reasonable and customary use in describing the
142 | origin of the Work and reproducing the content of the NOTICE file.
143 |
144 | 7. Disclaimer of Warranty. Unless required by applicable law or
145 | agreed to in writing, Licensor provides the Work (and each
146 | Contributor provides its Contributions) on an "AS IS" BASIS,
147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148 | implied, including, without limitation, any warranties or conditions
149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150 | PARTICULAR PURPOSE. You are solely responsible for determining the
151 | appropriateness of using or redistributing the Work and assume any
152 | risks associated with Your exercise of permissions under this License.
153 |
154 | 8. Limitation of Liability. In no event and under no legal theory,
155 | whether in tort (including negligence), contract, or otherwise,
156 | unless required by applicable law (such as deliberate and grossly
157 | negligent acts) or agreed to in writing, shall any Contributor be
158 | liable to You for damages, including any direct, indirect, special,
159 | incidental, or consequential damages of any character arising as a
160 | result of this License or out of the use or inability to use the
161 | Work (including but not limited to damages for loss of goodwill,
162 | work stoppage, computer failure or malfunction, or any and all
163 | other commercial damages or losses), even if such Contributor
164 | has been advised of the possibility of such damages.
165 |
166 | 9. Accepting Warranty or Additional Liability. While redistributing
167 | the Work or Derivative Works thereof, You may choose to offer,
168 | and charge a fee for, acceptance of support, warranty, indemnity,
169 | or other liability obligations and/or rights consistent with this
170 | License. However, in accepting such obligations, You may act only
171 | on Your own behalf and on Your sole responsibility, not on behalf
172 | of any other Contributor, and only if You agree to indemnify,
173 | defend, and hold each Contributor harmless for any liability
174 | incurred by, or claims asserted against, such Contributor by reason
175 | of your accepting any such warranty or additional liability.
176 |
--------------------------------------------------------------------------------
/THIRD-PARTY-LICENSES:
--------------------------------------------------------------------------------
1 | The Amazon Open-Source Code - "trans-encoder" - includes the following third-party software/licensing:
2 |
3 | ** sentence-transformers - https://github.com/UKPLab/sentence-transformers
4 |
5 | Apache License
6 | Version 2.0, January 2004
7 | http://www.apache.org/licenses/
8 |
9 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
10 |
11 | 1. Definitions.
12 |
13 | "License" shall mean the terms and conditions for use, reproduction,
14 | and distribution as defined by Sections 1 through 9 of this document.
15 |
16 | "Licensor" shall mean the copyright owner or entity authorized by
17 | the copyright owner that is granting the License.
18 |
19 | "Legal Entity" shall mean the union of the acting entity and all
20 | other entities that control, are controlled by, or are under common
21 | control with that entity. For the purposes of this definition,
22 | "control" means (i) the power, direct or indirect, to cause the
23 | direction or management of such entity, whether by contract or
24 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
25 | outstanding shares, or (iii) beneficial ownership of such entity.
26 |
27 | "You" (or "Your") shall mean an individual or Legal Entity
28 | exercising permissions granted by this License.
29 |
30 | "Source" form shall mean the preferred form for making modifications,
31 | including but not limited to software source code, documentation
32 | source, and configuration files.
33 |
34 | "Object" form shall mean any form resulting from mechanical
35 | transformation or translation of a Source form, including but
36 | not limited to compiled object code, generated documentation,
37 | and conversions to other media types.
38 |
39 | "Work" shall mean the work of authorship, whether in Source or
40 | Object form, made available under the License, as indicated by a
41 | copyright notice that is included in or attached to the work
42 | (an example is provided in the Appendix below).
43 |
44 | "Derivative Works" shall mean any work, whether in Source or Object
45 | form, that is based on (or derived from) the Work and for which the
46 | editorial revisions, annotations, elaborations, or other modifications
47 | represent, as a whole, an original work of authorship. For the purposes
48 | of this License, Derivative Works shall not include works that remain
49 | separable from, or merely link (or bind by name) to the interfaces of,
50 | the Work and Derivative Works thereof.
51 |
52 | "Contribution" shall mean any work of authorship, including
53 | the original version of the Work and any modifications or additions
54 | to that Work or Derivative Works thereof, that is intentionally
55 | submitted to Licensor for inclusion in the Work by the copyright owner
56 | or by an individual or Legal Entity authorized to submit on behalf of
57 | the copyright owner. For the purposes of this definition, "submitted"
58 | means any form of electronic, verbal, or written communication sent
59 | to the Licensor or its representatives, including but not limited to
60 | communication on electronic mailing lists, source code control systems,
61 | and issue tracking systems that are managed by, or on behalf of, the
62 | Licensor for the purpose of discussing and improving the Work, but
63 | excluding communication that is conspicuously marked or otherwise
64 | designated in writing by the copyright owner as "Not a Contribution."
65 |
66 | "Contributor" shall mean Licensor and any individual or Legal Entity
67 | on behalf of whom a Contribution has been received by Licensor and
68 | subsequently incorporated within the Work.
69 |
70 | 2. Grant of Copyright License. Subject to the terms and conditions of
71 | this License, each Contributor hereby grants to You a perpetual,
72 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
73 | copyright license to reproduce, prepare Derivative Works of,
74 | publicly display, publicly perform, sublicense, and distribute the
75 | Work and such Derivative Works in Source or Object form.
76 |
77 | 3. Grant of Patent License. Subject to the terms and conditions of
78 | this License, each Contributor hereby grants to You a perpetual,
79 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
80 | (except as stated in this section) patent license to make, have made,
81 | use, offer to sell, sell, import, and otherwise transfer the Work,
82 | where such license applies only to those patent claims licensable
83 | by such Contributor that are necessarily infringed by their
84 | Contribution(s) alone or by combination of their Contribution(s)
85 | with the Work to which such Contribution(s) was submitted. If You
86 | institute patent litigation against any entity (including a
87 | cross-claim or counterclaim in a lawsuit) alleging that the Work
88 | or a Contribution incorporated within the Work constitutes direct
89 | or contributory patent infringement, then any patent licenses
90 | granted to You under this License for that Work shall terminate
91 | as of the date such litigation is filed.
92 |
93 | 4. Redistribution. You may reproduce and distribute copies of the
94 | Work or Derivative Works thereof in any medium, with or without
95 | modifications, and in Source or Object form, provided that You
96 | meet the following conditions:
97 |
98 | (a) You must give any other recipients of the Work or
99 | Derivative Works a copy of this License; and
100 |
101 | (b) You must cause any modified files to carry prominent notices
102 | stating that You changed the files; and
103 |
104 | (c) You must retain, in the Source form of any Derivative Works
105 | that You distribute, all copyright, patent, trademark, and
106 | attribution notices from the Source form of the Work,
107 | excluding those notices that do not pertain to any part of
108 | the Derivative Works; and
109 |
110 | (d) If the Work includes a "NOTICE" text file as part of its
111 | distribution, then any Derivative Works that You distribute must
112 | include a readable copy of the attribution notices contained
113 | within such NOTICE file, excluding those notices that do not
114 | pertain to any part of the Derivative Works, in at least one
115 | of the following places: within a NOTICE text file distributed
116 | as part of the Derivative Works; within the Source form or
117 | documentation, if provided along with the Derivative Works; or,
118 | within a display generated by the Derivative Works, if and
119 | wherever such third-party notices normally appear. The contents
120 | of the NOTICE file are for informational purposes only and
121 | do not modify the License. You may add Your own attribution
122 | notices within Derivative Works that You distribute, alongside
123 | or as an addendum to the NOTICE text from the Work, provided
124 | that such additional attribution notices cannot be construed
125 | as modifying the License.
126 |
127 | You may add Your own copyright statement to Your modifications and
128 | may provide additional or different license terms and conditions
129 | for use, reproduction, or distribution of Your modifications, or
130 | for any such Derivative Works as a whole, provided Your use,
131 | reproduction, and distribution of the Work otherwise complies with
132 | the conditions stated in this License.
133 |
134 | 5. Submission of Contributions. Unless You explicitly state otherwise,
135 | any Contribution intentionally submitted for inclusion in the Work
136 | by You to the Licensor shall be under the terms and conditions of
137 | this License, without any additional terms or conditions.
138 | Notwithstanding the above, nothing herein shall supersede or modify
139 | the terms of any separate license agreement you may have executed
140 | with Licensor regarding such Contributions.
141 |
142 | 6. Trademarks. This License does not grant permission to use the trade
143 | names, trademarks, service marks, or product names of the Licensor,
144 | except as required for reasonable and customary use in describing the
145 | origin of the Work and reproducing the content of the NOTICE file.
146 |
147 | 7. Disclaimer of Warranty. Unless required by applicable law or
148 | agreed to in writing, Licensor provides the Work (and each
149 | Contributor provides its Contributions) on an "AS IS" BASIS,
150 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
151 | implied, including, without limitation, any warranties or conditions
152 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
153 | PARTICULAR PURPOSE. You are solely responsible for determining the
154 | appropriateness of using or redistributing the Work and assume any
155 | risks associated with Your exercise of permissions under this License.
156 |
157 | 8. Limitation of Liability. In no event and under no legal theory,
158 | whether in tort (including negligence), contract, or otherwise,
159 | unless required by applicable law (such as deliberate and grossly
160 | negligent acts) or agreed to in writing, shall any Contributor be
161 | liable to You for damages, including any direct, indirect, special,
162 | incidental, or consequential damages of any character arising as a
163 | result of this License or out of the use or inability to use the
164 | Work (including but not limited to damages for loss of goodwill,
165 | work stoppage, computer failure or malfunction, or any and all
166 | other commercial damages or losses), even if such Contributor
167 | has been advised of the possibility of such damages.
168 |
169 | 9. Accepting Warranty or Additional Liability. While redistributing
170 | the Work or Derivative Works thereof, You may choose to offer,
171 | and charge a fee for, acceptance of support, warranty, indemnity,
172 | or other liability obligations and/or rights consistent with this
173 | License. However, in accepting such obligations, You may act only
174 | on Your own behalf and on Your sole responsibility, not on behalf
175 | of any other Contributor, and only if You agree to indemnify,
176 | defend, and hold each Contributor harmless for any liability
177 | incurred by, or claims asserted against, such Contributor by reason
178 | of your accepting any such warranty or additional liability.
179 |
180 | END OF TERMS AND CONDITIONS
181 |
182 | APPENDIX: How to apply the Apache License to your work.
183 |
184 | To apply the Apache License to your work, attach the following
185 | boilerplate notice, with the fields enclosed by brackets "{}"
186 | replaced with your own identifying information. (Don't include
187 | the brackets!) The text should be enclosed in the appropriate
188 | comment syntax for the file format. We also recommend that a
189 | file or class name and description of purpose be included on the
190 | same "printed page" as the copyright notice for easier
191 | identification within third-party archives.
192 |
193 | Copyright {yyyy} {name of copyright owner}
194 |
195 | Licensed under the Apache License, Version 2.0 (the "License");
196 | you may not use this file except in compliance with the License.
197 | You may obtain a copy of the License at
198 |
199 | http://www.apache.org/licenses/LICENSE-2.0
200 |
201 | Unless required by applicable law or agreed to in writing, software
202 | distributed under the License is distributed on an "AS IS" BASIS,
203 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
204 | See the License for the specific language governing permissions and
205 | limitations under the License.
206 |
--------------------------------------------------------------------------------
/src/self_distill.py:
--------------------------------------------------------------------------------
1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 |
4 | import torch
5 | from torch.utils.data import DataLoader
6 | import torch.nn.functional as F
7 | from sentence_transformers import models, losses, util, SentenceTransformer, LoggingHandler
8 | from sentence_transformers.cross_encoder import CrossEncoder
9 | from sentence_transformers.cross_encoder.evaluation import CECorrelationEvaluator
10 | from sentence_transformers.readers import InputExample
11 | from datetime import datetime
12 | import argparse
13 | import logging
14 | import sys
15 | import random
16 | import tqdm
17 | import math
18 | import os
19 | import numpy as np
20 | import pandas as pd
21 |
22 | # import from local codes
23 | from data import load_data
24 | from eval import eval_encoder
25 |
26 | from sentence_transformers_ext.bi_encoder_eval import (
27 | EmbeddingSimilarityEvaluator,
28 | EmbeddingSimilarityEvaluatorEnsemble,
29 | EmbeddingSimilarityEvaluatorAUC,
30 | EmbeddingSimilarityEvaluatorAUCEnsemble
31 | )
32 | from sentence_transformers_ext.cross_encoder_eval import (
33 | CECorrelationEvaluatorEnsemble,
34 | CECorrelationEvaluatorAUC,
35 | CECorrelationEvaluatorAUCEnsemble
36 | )
37 |
38 | parser = argparse.ArgumentParser()
39 | parser.add_argument("--model_name_or_path", type=str,
40 | default="princeton-nlp/unsup-simcse-bert-base-uncased",
41 | help="Transformers' model name or path")
42 | parser.add_argument("--task", type=str, default='sts',
43 | help='{sts|sickr|sts_sickr|qqp|qnli|mrpc|snli|custom}')
44 | parser.add_argument("--device", type=int, default=0)
45 | parser.add_argument("--cycle", type=int, default=3)
46 | parser.add_argument("--bi_encoder_pooling_mode", type=str, default='cls',
47 | help="{cls|mean}")
48 | parser.add_argument("--num_epochs_cross_encoder", type=int, default=1)
49 | parser.add_argument("--num_epochs_bi_encoder", type=int, default=10)
50 | parser.add_argument("--batch_size_cross_encoder", type=int, default=32)
51 | parser.add_argument("--batch_size_bi_encoder", type=int, default=128)
52 | parser.add_argument("--init_with_new_models", action="store_true")
53 | parser.add_argument("--random_seed", type=int, default=2021)
54 | #parser.add_argument("--use_raw_data_from_all_tasks", action="store_true")
55 | parser.add_argument("--add_snli_data", type=int, default=0)
56 | parser.add_argument("--custom_corpus_path", type=str, default=None)
57 | parser.add_argument("--num_training_pairs", type=int, default=None)
58 | parser.add_argument("--save_all_predictions", action="store_true")
59 | parser.add_argument("--quick_test", action="store_true")
60 |
61 |
62 | args = parser.parse_args()
63 | print (args)
64 |
65 | torch.manual_seed(args.random_seed)
66 |
67 | #### Just some code to print debug information to stdout
68 | logging.basicConfig(format="%(asctime)s - %(message)s",
69 | datefmt="%Y-%m-%d %H:%M:%S",
70 | level=logging.INFO,
71 | handlers=[LoggingHandler()])
72 |
73 |
74 | ### read datasets
75 | all_pairs, all_test, dev_samples = load_data(args.task, fpath=args.custom_corpus_path)
76 |
77 | # load_pairs from other tasks
78 | """
79 | if args.use_raw_data_from_all_tasks:
80 | all_pairs_qqp, _, _ = load_data("qqp")
81 | all_pairs_qnli, _, _ = load_data("qnli")
82 | all_pairs_mrpc, _, _ = load_data("mrpc")
83 | all_pairs = all_pairs + all_pairs_qqp + all_pairs_qnli + all_pairs_mrpc
84 | """
85 |
86 | if args.add_snli_data != 0:
87 | random.seed(args.random_seed)
88 | all_pairs_snli, _, _ = load_data("snli")
89 | all_pairs_snli_sampled = random.sample(all_pairs_snli, args.add_snli_data)
90 | all_pairs = all_pairs + all_pairs_snli_sampled
91 |
92 | if args.quick_test:
93 | all_pairs = all_pairs[:1000] # for quick testing
94 |
95 | # randomly select training pairs for control study
96 | if args.num_training_pairs is not None:
97 | print ("before sampling |raw sentence pairs|:", len(all_pairs))
98 | if args.num_training_pairs == -1:
99 | # use all
100 | pass
101 | else:
102 | random.seed(args.random_seed)
103 | all_pairs = random.sample(all_pairs, args.num_training_pairs)
104 |
105 | print ("|raw sentence pairs|:", len(all_pairs))
106 | print ("|dev set|:", len(dev_samples))
107 | for key in all_test:
108 | print ("|test set: %s|" % key, len(all_test[key]))
109 |
110 | model_name = args.model_name_or_path #"princeton-nlp/unsup-simcse-bert-base-uncased"
111 | simcse2base = {
112 | "princeton-nlp/unsup-simcse-roberta-base": "roberta-base",
113 | "princeton-nlp/unsup-simcse-roberta-large": "roberta-large",
114 | "princeton-nlp/unsup-simcse-bert-base-uncased": "bert-base-uncased",
115 | "princeton-nlp/unsup-simcse-bert-large-uncased": "bert-large-uncased"}
116 | batch_size_cross_encoder = args.batch_size_cross_encoder
117 | batch_size_bi_encoder = args.batch_size_bi_encoder
118 | num_epochs_cross_encoder = args.num_epochs_cross_encoder
119 | num_epochs_bi_encoder = args.num_epochs_bi_encoder
120 | max_seq_length = 32
121 | total_cycle = args.cycle
122 | device=args.device
123 |
124 | logging.info ("########## load base model and evaluate ##########")
125 |
126 | ###### Bi-encoder (sentence-transformers) ######
127 | logging.info(f"Loading bi-encoder model: {model_name}")
128 | # Use Huggingface/transformers model (like BERT, RoBERTa, XLNet, XLM-R) for mapping tokens to embeddings
129 | word_embedding_model = models.Transformer(model_name, max_seq_length=max_seq_length)
130 |
131 | # Apply mean pooling to get one fixed sized sentence vector
132 | pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(), pooling_mode=args.bi_encoder_pooling_mode)
133 |
134 | bi_encoder = SentenceTransformer(modules=[word_embedding_model, pooling_model], device=device)
135 |
136 | # eval bi-encoder
137 | logging.info ("Evaluate bi-encoder...")
138 | eval_encoder(all_test, bi_encoder, task=args.task, enc_type="bi")
139 |
140 | start_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
141 |
142 | bi_encoder_dev_scores = []
143 | cross_encoder_dev_scores = []
144 | bi_encoder_path = model_name
145 | all_predictions_bi_encoder = {}
146 | all_predictions_cross_encoder = {}
147 |
148 |
149 | for cycle in range(total_cycle):
150 | cycle += 1
151 | logging.info (f"########## cycle {cycle:.0f} starts ##########")
152 |
153 | ###### label data with bi-encoder ######
154 | # label sentence pairs with bi-encoder
155 | logging.info ("Label sentence pairs with bi-encoder...")
156 |
157 | # Two lists of sentences
158 | sents1 = [p[0] for p in all_pairs]
159 | sents2 = [p[1] for p in all_pairs]
160 |
161 | #Compute embedding for both lists
162 | embeddings1 = bi_encoder.encode(sents1, convert_to_tensor=True)
163 | embeddings2 = bi_encoder.encode(sents2, convert_to_tensor=True)
164 |
165 | #Compute cosine-similarits
166 | cosine_scores = F.cosine_similarity(embeddings1, embeddings2)
167 |
168 | # save the predictions
169 | all_predictions_bi_encoder["bi_encoder_cycle_"+str(cycle-1)] = cosine_scores.cpu().numpy()
170 |
171 | # form (self-labelled) train set
172 | train_samples = []
173 |
174 | for i in range(len(sents1)):
175 | if args.task in ["qnli"]:
176 | train_samples.append(InputExample(texts=[sents1[i], sents2[i]], label=cosine_scores[i]))
177 | else:
178 | train_samples.append(InputExample(texts=[sents1[i], sents2[i]], label=cosine_scores[i]))
179 | train_samples.append(InputExample(texts=[sents2[i], sents1[i]], label=cosine_scores[i]))
180 |
181 | del bi_encoder, embeddings1, embeddings2, cosine_scores
182 | torch.cuda.empty_cache()
183 |
184 | ###### Cross-encoder learning ######
185 | if args.init_with_new_models:
186 | bi_encoder_path = simcse2base[model_name] #model_name # always use new model (PLM)
187 | logging.info(f"Loading cross-encoder model: {bi_encoder_path}")
188 | # Use Huggingface/transformers model (like BERT, RoBERTa, XLNet, XLM-R) for cross-encoder model
189 | cross_encoder = CrossEncoder(bi_encoder_path, num_labels=1, device=device, max_length=64)
190 |
191 | # We wrap gold_samples (which is a List[InputExample]) into a pytorch DataLoader
192 | train_dataloader = DataLoader(train_samples, shuffle=True, batch_size=batch_size_cross_encoder)
193 |
194 | # We add an evaluator, which evaluates the performance during training
195 | if args.task in ["sts", "sickr", "sts_sickr", "custom"]:
196 | evaluator = CECorrelationEvaluator.from_input_examples(dev_samples, name='dev')
197 | else:
198 | evaluator = CECorrelationEvaluatorAUC.from_input_examples(dev_samples, name='dev')
199 |
200 | # Configure the training
201 | warmup_steps = math.ceil(len(train_dataloader) * num_epochs_cross_encoder * 0.1) #10% of train data for warm-up
202 | logging.info(f"Warmup-steps: {warmup_steps}")
203 |
204 | cross_encoder_path = f"output/cross-encoder/" \
205 | f"{args.task}_cycle{cycle}_{model_name.replace('/', '-')}-{start_time}"
206 |
207 | # Train the cross-encoder model
208 | cross_encoder.fit(
209 | train_dataloader=train_dataloader,
210 | evaluator=evaluator,
211 | evaluation_steps=200,
212 | use_amp=True,
213 | epochs=num_epochs_cross_encoder,
214 | warmup_steps=warmup_steps,
215 | output_path=cross_encoder_path)
216 |
217 |
218 |
219 | cross_encoder = CrossEncoder(cross_encoder_path, max_length=64, device=device)
220 | #cross_encoder = CrossEncoder(cross_encoder_path, device=device)
221 |
222 | # eval cross-encoder
223 | dev_score = evaluator(cross_encoder)
224 | cross_encoder_dev_scores.append(dev_score)
225 | logging.info (f"***** dev's spearman's rho: {dev_score:.4f} *****")
226 |
227 | ###### label data with cross-encoder ######
228 | # label sentence pairs with cross-encoder
229 | logging.info ("Label sentence pairs with cross-encoder...")
230 | silver_scores = cross_encoder.predict(all_pairs)
231 | silver_samples = list(InputExample(texts=[data[0], data[1]], label=score) for \
232 | data, score in zip(all_pairs, silver_scores))
233 |
234 | del cross_encoder
235 | torch.cuda.empty_cache()
236 |
237 | # save the predictions
238 | all_predictions_cross_encoder["cross_encoder_cycle_"+str(cycle)] = silver_scores
239 |
240 | ###### Bi-encoder learning ######
241 | if args.init_with_new_models:
242 | cross_encoder_path = model_name # always use new model (SimCSE)
243 | logging.info(f"Loading bi-encoder model: {cross_encoder_path}")
244 | word_embedding_model = models.Transformer(cross_encoder_path, max_seq_length=32)
245 | pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(), pooling_mode=args.bi_encoder_pooling_mode)
246 | bi_encoder = SentenceTransformer(modules=[word_embedding_model, pooling_model], device=device)
247 |
248 | train_dataloader = DataLoader(silver_samples, shuffle=True, batch_size=batch_size_bi_encoder)
249 | train_loss = losses.CosineSimilarityLoss(model=bi_encoder)
250 |
251 | if args.task in ["sts", "sickr", "sts_sickr", "custom"]:
252 | evaluator = EmbeddingSimilarityEvaluator.from_input_examples(dev_samples, name='dev')
253 | else:
254 | evaluator = EmbeddingSimilarityEvaluatorAUC.from_input_examples(dev_samples, name='dev')
255 |
256 | # Configure the training.
257 | #warmup_steps = math.ceil(len(train_dataloader) * num_epochs * 0.1) #10% of train data for warm-up
258 | #logging.info(f"Warmup-steps: {warmup_steps}")
259 |
260 | bi_encoder_path = f"output/bi-encoder/" \
261 | f"{args.task}_cycle{cycle}_{model_name.replace('/', '-')}-{start_time}"
262 |
263 | # Train the bi-encoder model
264 | bi_encoder.fit(
265 | train_objectives=[(train_dataloader, train_loss)],
266 | evaluator=evaluator,
267 | epochs=num_epochs_bi_encoder,
268 | evaluation_steps=200,
269 | warmup_steps=0,
270 | output_path=bi_encoder_path,
271 | optimizer_params= {"lr": 5e-5},
272 | use_amp=True,
273 | )
274 |
275 | bi_encoder = SentenceTransformer(bi_encoder_path, device=device)
276 |
277 | # eval bi-encoder
278 | dev_score = evaluator(bi_encoder)
279 | bi_encoder_dev_scores.append(dev_score)
280 | logging.info (f"***** dev's spearman's rho: {dev_score:.4f} *****")
281 |
282 | logging.info (bi_encoder_dev_scores)
283 | logging.info (cross_encoder_dev_scores)
284 |
285 | # best bi-encoder
286 | best_cycle_bi_encoder = np.argmax(bi_encoder_dev_scores)+1
287 | best_cycle_bi_encoder_path = f"output/bi-encoder/"\
288 | f"{args.task}_cycle{best_cycle_bi_encoder}_{model_name.replace('/', '-')}-{start_time}"
289 | # eval bi-encoder
290 | logging.info (f"Evaluate best bi-encoder (from cycle {best_cycle_bi_encoder})...")
291 | bi_encoder = SentenceTransformer(best_cycle_bi_encoder_path, device=device)
292 | logging.info (best_cycle_bi_encoder_path)
293 | eval_encoder(all_test, bi_encoder, task=args.task, enc_type="bi")
294 |
295 |
296 | # best cross-encoder
297 | best_cycle_cross_encoder = np.argmax(cross_encoder_dev_scores)+1
298 | best_cycle_cross_encoder_path = f"output/cross-encoder/"\
299 | f"{args.task}_cycle{best_cycle_cross_encoder}_{model_name.replace('/', '-')}-{start_time}"
300 | # eval cross-encoder
301 | logging.info (f"Evaluate best cross-encoder (from cycle {best_cycle_cross_encoder})...")
302 | logging.info (best_cycle_cross_encoder_path)
303 | #cross_encoder = CrossEncoder(best_cycle_cross_encoder_path, max_length=64, device=device)
304 | cross_encoder = CrossEncoder(best_cycle_cross_encoder_path, device=device)
305 | eval_encoder(all_test, cross_encoder, task=args.task, enc_type="cross")
306 |
307 | # save all predictions
308 | best_cycle_bi_encoder_csv_path = f"output/bi-encoder/" \
309 | f"{args.task}_cycle{best_cycle_bi_encoder}_{model_name.replace('/', '-')}-{start_time}_all_preds.csv"
310 | best_cycle_cross_encoder_csv_path = f"output/cross-encoder/" \
311 | f"{args.task}_cycle{best_cycle_cross_encoder}_{model_name.replace('/', '-')}-{start_time}_all_preds.csv"
312 |
313 | if args.save_all_predictions:
314 | bi_pred_df = pd.DataFrame(np.array(list(all_predictions_bi_encoder.values())).T, columns=list(all_predictions_bi_encoder.keys()))
315 | cross_pred_df = pd.DataFrame(np.array(list(all_predictions_cross_encoder.values())).T, columns=list(all_predictions_cross_encoder.keys()))
316 | bi_pred_df.to_csv(best_cycle_bi_encoder_csv_path, index=False)
317 | cross_pred_df.to_csv(best_cycle_cross_encoder_csv_path, index=False)
318 |
319 | logging.info ("\n")
320 | print (args)
321 | logging.info ("\n")
322 | logging.info ("***** END *****")
323 |
--------------------------------------------------------------------------------
/src/data.py:
--------------------------------------------------------------------------------
1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 |
4 | from datasets import load_dataset
5 | from sentence_transformers.readers import InputExample
6 | from sentence_transformers import util
7 | import pandas as pd
8 | import os
9 | import gzip
10 | import csv
11 | import logging
12 | import subprocess
13 | from zipfile import ZipFile
14 |
15 | SCORE = "score"
16 | SPLIT = "split"
17 | SENTENCE = "sentence"
18 | SENTENCE1 = "sentence1"
19 | SENTENCE2 = "sentence2"
20 | QUESTION = "question"
21 | QUESTION1 = "question1"
22 | QUESTION2 = "question2"
23 |
24 | def load_snli():
25 | """
26 | Load the SNLI dataset (https://nlp.stanford.edu/projects/snli/) from huggingface dataset portal.
27 | Parameters
28 | ----------
29 | None
30 | Returns
31 | ----------
32 | all_pairs: a list of sentence pairs from the SNLI dataset
33 | """
34 |
35 | all_pairs = []
36 |
37 | dataset = load_dataset("snli")
38 | all_pairs += [(row["premise"], row["hypothesis"]) for row in dataset["train"]]
39 | all_pairs += [(row["premise"], row["hypothesis"]) for row in dataset["validation"]]
40 | all_pairs += [(row["premise"], row["hypothesis"]) for row in dataset["test"]]
41 |
42 | return all_pairs, None, None
43 |
44 | def load_sts():
45 | """
46 | Load the STS datasets:
47 | STS 2012: https://www.cs.york.ac.uk/semeval-2012/task6/
48 | STS 2013: http://ixa2.si.ehu.eus/sts/
49 | STS 2014: https://alt.qcri.org/semeval2014/task10/
50 | STS 2015: https://alt.qcri.org/semeval2015/task2/
51 | STS 2016: https://alt.qcri.org/semeval2016/task1/
52 | STS-Benchmark: http://ixa2.si.ehu.eus/stswiki/index.php/STSbenchmark
53 | Parameters
54 | ----------
55 | None
56 | Returns
57 | ----------
58 | all_pairs: a list of sentence pairs from the STS datasets
59 | all_test: a dict of all test sets of the STS datasets
60 | dev_samples: a list of InputExample instances as the dev set
61 | """
62 |
63 | # Check if STS datasets exsist. If not, download and extract it
64 | sts_dataset_path = "data/"
65 | if not os.path.exists(os.path.join(sts_dataset_path, "STS_data")):
66 | logging.info("Dataset not found. Download")
67 | zip_save_path = "data/STS_data.zip"
68 | #os.system("wget https://fangyuliu.me/data/STS_data.zip -P data/")
69 | subprocess.run(["wget", "--no-check-certificate", "https://fangyuliu.me/data/STS_data.zip", "-P", "data/"])
70 | with ZipFile(zip_save_path, "r") as zipIn:
71 | zipIn.extractall(sts_dataset_path)
72 |
73 | all_pairs = []
74 | all_test = {}
75 | dedup = set()
76 |
77 | # read sts 2012-2016
78 | for year in ["2012","2013","2014","2015","2016"]:
79 | all_test[year] = []
80 | for year in ["2012","2013","2014","2015","2016"]:
81 | df = pd.read_csv(f"data/STS_data/en/{year}.test.tsv", delimiter="\t",
82 | quoting=csv.QUOTE_NONE, encoding="utf-8", names=[SCORE, SENTENCE1, SENTENCE2])
83 | for row in df.iterrows():
84 | if str(row[1][SCORE]) == "nan": continue
85 | all_test[year].append(InputExample(texts=[row[1][SENTENCE1], row[1][SENTENCE2]], label=row[1][SCORE]))
86 |
87 | df = pd.read_csv("data/STS_data/en/2012_to_2016.test.tsv", delimiter="\t",
88 | quoting=csv.QUOTE_NONE, encoding="utf-8", names=[SCORE, SENTENCE1, SENTENCE2])
89 |
90 | for row in df.iterrows():
91 | concat = row[1][SENTENCE1]+row[1][SENTENCE2]
92 | if concat in dedup:
93 | continue
94 | all_pairs.append([row[1][SENTENCE1], row[1][SENTENCE2]])
95 | dedup.add(concat)
96 |
97 | # sts-b
98 | # Check if STS-B exsists. If not, download and extract it
99 | sts_dataset_path = "data/stsbenchmark.tsv.gz"
100 | if not os.path.exists(sts_dataset_path):
101 | util.http_get("https://sbert.net/datasets/stsbenchmark.tsv.gz", sts_dataset_path)
102 | # read sts-b
103 | dev_samples_stsb = []
104 | test_samples_stsb = []
105 | with gzip.open(sts_dataset_path, "rt", encoding="utf8") as fIn:
106 | reader = csv.DictReader(fIn, delimiter="\t", quoting=csv.QUOTE_NONE)
107 | for row in reader:
108 | score = float(row[SCORE]) / 5.0 # Normalize score to range 0 ... 1
109 |
110 | if row[SPLIT] == "dev":
111 | dev_samples_stsb.append(InputExample(texts=[row[SENTENCE1], row[SENTENCE2]], label=score))
112 | elif row[SPLIT] == "test":
113 | test_samples_stsb.append(InputExample(texts=[row[SENTENCE1], row[SENTENCE2]], label=score))
114 |
115 | # add (non-duplicated) sentence pair to all_pairs
116 | concat = row[SENTENCE1]+row[SENTENCE2]
117 | if concat in dedup:
118 | continue
119 | all_pairs.append([row[SENTENCE1], row[SENTENCE2]])
120 | dedup.add(concat)
121 |
122 | all_test["stsb"] = test_samples_stsb
123 | dev_samples = dev_samples_stsb
124 | return all_pairs, all_test, dev_samples
125 |
126 | def load_sickr():
127 | """
128 | Load the SICK-R dataset: http://clic.cimec.unitn.it/composes/sick.html
129 | Parameters
130 | ----------
131 | None
132 | Returns
133 | ----------
134 | all_pairs: a list of sentence pairs from the SICK-R dataset
135 | all_test: a dict of all test sets
136 | dev_samples: a list of InputExample instances as the dev set
137 | """
138 |
139 |
140 | sts_dataset_path = "data/"
141 | if not os.path.exists(os.path.join(sts_dataset_path, "STS_data")):
142 | logging.info("Dataset not found. Download")
143 | zip_save_path = "data/STS_data.zip"
144 | subprocess.run(["wget", "--no-check-certificate", "https://fangyuliu.me/data/STS_data.zip", "-P", "data/"])
145 | with ZipFile(zip_save_path, "r") as zipIn:
146 | zipIn.extractall(sts_dataset_path)
147 |
148 | all_pairs = []
149 | all_test = {}
150 | dedup = set()
151 |
152 | # read sickr
153 | test_samples_sickr = []
154 | dev_samples_sickr = []
155 |
156 | df = pd.read_csv("data/STS_data/en/SICK_annotated.txt", delimiter="\t",
157 | quoting=csv.QUOTE_NONE, encoding="utf-8")
158 |
159 | for row in df.iterrows():
160 | row = row[1]
161 | score = row["relatedness_score"] / 5.0
162 | if row["SemEval_set"] == "TEST":
163 | test_samples_sickr.append(InputExample(texts=[row["sentence_A"], row["sentence_B"]], label=score))
164 | elif row["SemEval_set"] == "TRIAL":
165 | dev_samples_sickr.append(InputExample(texts=[row["sentence_A"], row["sentence_B"]], label=score))
166 |
167 | concat = row["sentence_A"]+row["sentence_B"]
168 | if concat in dedup:
169 | continue
170 | all_pairs.append([row["sentence_A"], row["sentence_B"]])
171 | dedup.add(concat)
172 |
173 | all_test["sickr"] = test_samples_sickr
174 | dev_samples = dev_samples_sickr
175 | return all_pairs, all_test, dev_samples
176 |
177 | def load_qqp():
178 | """
179 | Load the QQP dataset (https://www.kaggle.com/c/quora-question-pairs) from huggingface dataset portal.
180 | Parameters
181 | ----------
182 | None
183 | Returns
184 | ----------
185 | all_pairs: a list of sentence pairs from the QQP dataset
186 | all_test: a dict of all test sets
187 | dev_samples: a list of InputExample instances as the dev set
188 | """
189 |
190 | all_pairs = []
191 | all_test = {}
192 |
193 | dev_samples_qqp = []
194 | test_samples_qqp = []
195 |
196 | # Check if the QQP dataset exists. If not, download and extract
197 | qqp_dataset_path = "data/quora-IR-dataset"
198 | if not os.path.exists(qqp_dataset_path):
199 | logging.info("Dataset not found. Download")
200 | zip_save_path = 'data/quora-IR-dataset.zip'
201 | util.http_get(url='https://sbert.net/datasets/quora-IR-dataset.zip', path=zip_save_path)
202 | with ZipFile(zip_save_path, 'r') as zipIn:
203 | zipIn.extractall(qqp_dataset_path)
204 |
205 | qqp_datapoints_cut_train = 10000
206 | qqp_datapoints_cut_val = 1000
207 | qqp_datapoints_cut_test = 10000
208 |
209 | with open(os.path.join(qqp_dataset_path, "classification/train_pairs.tsv"), encoding="utf8") as fIn:
210 | reader = csv.DictReader(fIn, delimiter="\t", quoting=csv.QUOTE_NONE)
211 | for i, row in enumerate(reader):
212 | if i == qqp_datapoints_cut_train: break
213 | all_pairs.append([row[QUESTION1], row[QUESTION2]])
214 |
215 | with open(os.path.join(qqp_dataset_path, "classification/dev_pairs.tsv"), encoding="utf8") as fIn:
216 | reader = csv.DictReader(fIn, delimiter="\t", quoting=csv.QUOTE_NONE)
217 | for i, row in enumerate(reader):
218 | if i == qqp_datapoints_cut_val: break
219 | dev_samples_qqp.append(InputExample(texts=[row[QUESTION1], row[QUESTION2]], label=int(row['is_duplicate'])))
220 | all_pairs.append([row[QUESTION1], row[QUESTION2]])
221 |
222 | with open(os.path.join(qqp_dataset_path, "classification/test_pairs.tsv"), encoding="utf8") as fIn:
223 | reader = csv.DictReader(fIn, delimiter="\t", quoting=csv.QUOTE_NONE)
224 | for i, row in enumerate(reader):
225 | if i == qqp_datapoints_cut_test: break
226 | test_samples_qqp.append(InputExample(texts=[row[QUESTION1], row[QUESTION2]], label=int(row['is_duplicate'])))
227 | all_pairs.append([row[QUESTION1], row[QUESTION2]])
228 |
229 | all_test["qqp"] = test_samples_qqp
230 | dev_samples = dev_samples_qqp
231 |
232 | return all_pairs, all_test, dev_samples
233 |
234 | def load_qnli():
235 | """
236 | Load the QNLI dataset (part of GLUE: https://gluebenchmark.com/) from huggingface dataset portal.
237 | Parameters
238 | ----------
239 | None
240 | Returns
241 | ----------
242 | all_pairs: a list of sentence pairs from the QNLI dataset
243 | all_test: a dict of all test sets
244 | dev_samples: a list of InputExample instances as the dev set
245 | """
246 |
247 | all_pairs = []
248 | all_test = {}
249 |
250 | dev_samples_qnli = []
251 | test_samples_qnli = []
252 |
253 | dataset = load_dataset("glue", "qnli")
254 |
255 | qnli_datapoints_cut_train = 10000
256 |
257 | for i, row in enumerate(dataset["train"]):
258 | if i == qnli_datapoints_cut_train: break
259 | all_pairs.append([row[QUESTION], row[SENTENCE]])
260 |
261 | for row in dataset["validation"]:
262 | label = 0 if row["label"]==1 else 1
263 | dev_samples_qnli.append(
264 | InputExample(texts=[row[QUESTION], row[SENTENCE]], label=label))
265 | all_pairs.append([row[QUESTION], row[SENTENCE]])
266 |
267 | # test labels of qnli are not given, use the first 1k in dev set as test
268 | all_test["qnli"] = dev_samples_qnli[1000:]
269 | dev_samples = dev_samples_qnli[:1000]
270 |
271 | return all_pairs, all_test, dev_samples
272 |
273 | def load_mrpc():
274 | """
275 | Load the MRPC dataset (https://www.microsoft.com/en-us/download/details.aspx?id=52398) from huggingface dataset portal.
276 | Parameters
277 | ----------
278 | None
279 | Returns
280 | ----------
281 | all_pairs: a list of sentence pairs from the MRPC dataset
282 | all_test: a dict of all test sets
283 | dev_samples: a list of InputExample instances as the dev set
284 | """
285 | all_pairs = []
286 | all_test = {}
287 |
288 | dev_samples_mrpc = []
289 | test_samples_mrpc = []
290 |
291 | dataset = load_dataset("glue", "mrpc")
292 |
293 | for row in dataset["train"]:
294 | all_pairs.append([row[SENTENCE1], row[SENTENCE2]])
295 |
296 | for row in dataset["validation"]:
297 | dev_samples_mrpc.append(
298 | InputExample(texts=[row[SENTENCE1], row[SENTENCE2]], label=int(row["label"])))
299 | all_pairs.append([row[SENTENCE1], row[SENTENCE2]])
300 |
301 | for row in dataset["test"]:
302 | test_samples_mrpc.append(
303 | InputExample(texts=[row[SENTENCE1], row[SENTENCE2]], label=int(row["label"])))
304 | all_pairs.append([row[SENTENCE1], row[SENTENCE2]])
305 |
306 | all_test["mrpc"] = test_samples_mrpc
307 | dev_samples = dev_samples_mrpc
308 |
309 | return all_pairs, all_test, dev_samples
310 |
311 | def load_sts_and_sickr():
312 | """
313 | Load both STS and SICK-R datasets. Use STS-B's dev set for dev.
314 | Parameters
315 | ----------
316 | None
317 | Returns
318 | ----------
319 | all_pairs: a list of sentence pairs from the STS+SICK-R dataset
320 | all_test: a dict of all test sets
321 | dev_samples: a list of InputExample instances as the dev set
322 | """
323 |
324 | all_pairs_sts, all_test_sts, dev_samples_sts = load_sts()
325 | all_pairs_sickr, all_test_sickr, dev_samples_sickr = load_sickr()
326 | all_pairs = all_pairs_sts+all_pairs_sickr
327 | all_test = {**all_test_sts, **all_test_sickr}
328 | return all_pairs, all_test, dev_samples_sts # sts-b's dev is used
329 |
330 | def load_custom(fpath):
331 | """
332 | Load custom sentence-pair corpus. Use STS-B's dev set for dev.
333 | Parameters
334 | ----------
335 | fpath: path to the training file, where sentence pairs are formatted as 'sent1||sent2'
336 | Returns
337 | ----------
338 | all_pairs: a list of sentence pairs from the STS+SICK-R dataset
339 | all_test: a dict of all test sets
340 | dev_samples: a list of InputExample instances as the dev set
341 | """
342 | all_pairs = []
343 | all_test = {}
344 |
345 | # load custom training corpus
346 | with open(fpath, "r") as f:
347 | lines = f.readlines()
348 | for line in lines:
349 | line = line.strip()
350 | if len(line.split("||")) != 2: continue # skip
351 | sent1, sent2 = line.split("||")
352 | all_pairs.append([sent1, sent2])
353 |
354 | # load STS-b dev/test set
355 | # Check if STS-B exsists. If not, download and extract it
356 | sts_dataset_path = "data/stsbenchmark.tsv.gz"
357 | if not os.path.exists(sts_dataset_path):
358 | util.http_get("https://sbert.net/datasets/stsbenchmark.tsv.gz", sts_dataset_path)
359 | # read sts-b
360 | dev_samples_stsb = []
361 | test_samples_stsb = []
362 | with gzip.open(sts_dataset_path, "rt", encoding="utf8") as fIn:
363 | reader = csv.DictReader(fIn, delimiter="\t", quoting=csv.QUOTE_NONE)
364 | for row in reader:
365 | score = float(row[SCORE]) / 5.0 # Normalize score to range 0 ... 1
366 |
367 | if row[SPLIT] == "dev":
368 | dev_samples_stsb.append(InputExample(texts=[row[SENTENCE1], row[SENTENCE2]], label=score))
369 | elif row[SPLIT] == "test":
370 | test_samples_stsb.append(InputExample(texts=[row[SENTENCE1], row[SENTENCE2]], label=score))
371 |
372 | # add entence pair to all_pairs
373 | #all_pairs.append([row[SENTENCE1], row[SENTENCE2]])
374 |
375 | all_test["stsb"] = test_samples_stsb
376 | dev_samples = dev_samples_stsb
377 |
378 | return all_pairs, all_test, dev_samples
379 |
380 |
381 | task_loader_dict = {
382 | "sts": load_sts,
383 | "sickr": load_sickr,
384 | "sts_sickr": load_sts_and_sickr,
385 | "qqp": load_qqp,
386 | "qnli": load_qnli,
387 | "mrpc": load_mrpc,
388 | "snli": load_snli,
389 | "custom": load_custom
390 | }
391 |
392 | def load_data(task, fpath=None):
393 | """
394 | A unified dataset loader for all tasks.
395 | Parameters
396 | ----------
397 | task: a string specifying dataset/task to be loaded (for possible options see 'task_loader_dict')
398 | Returns
399 | ----------
400 | all_pairs: a list of sentence pairs from the specified dataset
401 | all_test: a dict of all test sets
402 | dev_samples: a list of InputExample instances as the dev set
403 | """
404 | if task not in task_loader_dict.keys():
405 | raise NotImplementedError()
406 | if task == "custom":
407 | return task_loader_dict[task](fpath)
408 | else:
409 | return task_loader_dict[task]()
410 |
411 |
412 | if __name__ == "__main__":
413 | # test if all datasets can be properly loaded
414 | for task in task_loader_dict:
415 | print (f"loading {task}...")
416 | load_data(task)
417 | print ("done.")
--------------------------------------------------------------------------------
/src/mutual_distill_parallel.py:
--------------------------------------------------------------------------------
1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 |
4 | import torch
5 | from torch.utils.data import DataLoader
6 | import torch.nn.functional as F
7 | from sentence_transformers import models, losses, util, SentenceTransformer, LoggingHandler
8 | from sentence_transformers.cross_encoder import CrossEncoder
9 | from sentence_transformers.cross_encoder.evaluation import CECorrelationEvaluator
10 | from sentence_transformers.readers import InputExample
11 | from datetime import datetime
12 | import argparse
13 | import logging
14 | import sys
15 | import random
16 | import tqdm
17 | import math
18 | import os
19 | import numpy as np
20 |
21 | # import from local codes
22 | from data import load_data
23 | from eval import eval_encoder
24 |
25 | from sentence_transformers_ext.bi_encoder_eval import (
26 | EmbeddingSimilarityEvaluator,
27 | EmbeddingSimilarityEvaluatorEnsemble,
28 | EmbeddingSimilarityEvaluatorAUC,
29 | EmbeddingSimilarityEvaluatorAUCEnsemble
30 | )
31 | from sentence_transformers_ext.cross_encoder_eval import (
32 | CECorrelationEvaluatorEnsemble,
33 | CECorrelationEvaluatorAUC,
34 | CECorrelationEvaluatorAUCEnsemble
35 | )
36 |
37 | parser = argparse.ArgumentParser()
38 |
39 | parser.add_argument("--task", type=str, default='sts',
40 | help='{sts|sickr|sts_sickr|qqp|qnli|mrpc|snli|custom}')
41 | parser.add_argument("--device1", type=int, default=0)
42 | parser.add_argument("--device2", type=int, default=0)
43 | parser.add_argument("--cycle", type=int, default=3)
44 | parser.add_argument("--num_epochs_cross_encoder", type=int, default=1)
45 | parser.add_argument("--num_epochs_bi_encoder", type=int, default=10)
46 | parser.add_argument("--batch_size_cross_encoder", type=int, default=32)
47 | parser.add_argument("--batch_size_bi_encoder", type=int, default=128)
48 | parser.add_argument("--init_with_new_models", action="store_true")
49 | parser.add_argument("--use_large", action="store_true")
50 | parser.add_argument("--bi_encoder1_pooling_mode", type=str,
51 | default='cls', help="{cls|mean}")
52 | parser.add_argument("--bi_encoder2_pooling_mode", type=str,
53 | default='cls', help="{cls|mean}")
54 | parser.add_argument("--random_seed", type=int, default=2021)
55 | parser.add_argument("--custom_corpus_path", type=str, default=None)
56 | parser.add_argument("--quick_test", action="store_true")
57 |
58 |
59 |
60 | args = parser.parse_args()
61 | print (args)
62 |
63 | torch.manual_seed(args.random_seed)
64 |
65 | #### Just some code to print debug information to stdout
66 | logging.basicConfig(format="%(asctime)s - %(message)s",
67 | datefmt="%Y-%m-%d %H:%M:%S",
68 | level=logging.INFO,
69 | handlers=[LoggingHandler()])
70 |
71 | ### read datasets
72 | all_pairs, all_test, dev_samples = load_data(args.task, fpath=args.custom_corpus_path)
73 | if args.quick_test:
74 | all_pairs = all_pairs[:1000] # for quick test
75 |
76 | print ("|raw sentence pairs|:", len(all_pairs))
77 | print ("|dev set|:", len(dev_samples))
78 | for key in all_test:
79 | print ("|test set: %s|" % key, len(all_test[key]))
80 |
81 |
82 | if not args.use_large:
83 | model_name1 = "princeton-nlp/unsup-simcse-bert-base-uncased"
84 | model_name2 = "princeton-nlp/unsup-simcse-roberta-base"
85 | else:
86 | model_name1 = "princeton-nlp/unsup-simcse-bert-large-uncased"
87 | model_name2 = "princeton-nlp/unsup-simcse-roberta-large"
88 |
89 |
90 | simcse2base = {
91 | "princeton-nlp/unsup-simcse-roberta-base": "roberta-base",
92 | "princeton-nlp/unsup-simcse-roberta-large": "roberta-large",
93 | "princeton-nlp/unsup-simcse-bert-base-uncased": "bert-base-uncased",
94 | "princeton-nlp/unsup-simcse-bert-large-uncased": "bert-large-uncased"
95 | }
96 |
97 | batch_size_cross_encoder = args.batch_size_cross_encoder
98 | batch_size_bi_encoder = args.batch_size_bi_encoder
99 | num_epochs_cross_encoder = args.num_epochs_cross_encoder
100 | num_epochs_bi_encoder = args.num_epochs_bi_encoder
101 | max_seq_length = 32
102 | total_cycle = args.cycle
103 | device1=args.device1
104 | device2=args.device2
105 |
106 | logging.info ("########## load base models and evaluate ##########")
107 |
108 | ###### Bi-encoder (sentence-transformers) ######
109 | logging.info(f"Loading bi-encoder model1: {model_name1}")
110 | logging.info(f"Loading bi-encoder model2: {model_name2}")
111 | # Use Huggingface/transformers model (like BERT, RoBERTa, XLNet, XLM-R) for mapping tokens to embeddings
112 | word_embedding_model1 = models.Transformer(model_name1, max_seq_length=max_seq_length)
113 | word_embedding_model2 = models.Transformer(model_name2, max_seq_length=max_seq_length)
114 |
115 | # Apply mean pooling to get one fixed sized sentence vector
116 | pooling_model1 = models.Pooling(word_embedding_model1.get_word_embedding_dimension(),
117 | pooling_mode=args.bi_encoder1_pooling_mode) # bert
118 | pooling_model2 = models.Pooling(word_embedding_model1.get_word_embedding_dimension(),
119 | pooling_mode=args.bi_encoder2_pooling_mode) # roberta
120 |
121 | bi_encoder1 = SentenceTransformer(modules=[word_embedding_model1, pooling_model1], device=device1)
122 | bi_encoder2 = SentenceTransformer(modules=[word_embedding_model2, pooling_model2], device=device2)
123 |
124 | # eval bi-encoder
125 | logging.info ("Evaluate bi-encoder (ensembled)...")
126 | scores = []
127 | for name, data in all_test.items():
128 | if args.task in ["sts", "sickr", "sts_sickr", "custom"]:
129 | test_evaluator = EmbeddingSimilarityEvaluatorEnsemble.from_input_examples(data, name=name)
130 | else:
131 | test_evaluator = EmbeddingSimilarityEvaluatorAUCEnsemble.from_input_examples(data, name=name)
132 | scores += [test_evaluator([bi_encoder1, bi_encoder2])]
133 | logging.info (f"***** test's avg spearman's rho: {sum(scores)/len(scores):.4f} ****")
134 |
135 | start_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
136 |
137 | bi_encoder_dev_scores = []
138 | cross_encoder_dev_scores = []
139 |
140 | for cycle in range(1, total_cycle+1):
141 | logging.info (f"########## cycle {cycle:.0f} starts ##########")
142 |
143 | ###### label data with bi-encoder ######
144 | # label sentence pairs with bi-encoder
145 | logging.info ("Label sentence pairs...")
146 |
147 | # Two lists of sentences
148 | sents1 = [p[0] for p in all_pairs]
149 | sents2 = [p[1] for p in all_pairs]
150 |
151 | #Compute embedding for both lists
152 | embeddings1 = bi_encoder1.encode(sents1, convert_to_tensor=True)
153 | embeddings2 = bi_encoder1.encode(sents2, convert_to_tensor=True)
154 |
155 | #Compute cosine-similarits
156 | cosine_scores1 = F.cosine_similarity(embeddings1, embeddings2)
157 |
158 | embeddings1 = bi_encoder2.encode(sents1, convert_to_tensor=True)
159 | embeddings2 = bi_encoder2.encode(sents2, convert_to_tensor=True)
160 | cosine_scores2 = F.cosine_similarity(embeddings1, embeddings2)
161 |
162 | cosine_scores = torch.stack([cosine_scores1.cpu(), cosine_scores2.cpu()]).mean(0)
163 |
164 | # form (self-labelled) train set
165 | train_samples = []
166 |
167 | for i in range(len(sents1)):
168 | if args.task in ["qnli"]:
169 | train_samples.append(InputExample(texts=[sents1[i], sents2[i]], label=cosine_scores[i]))
170 | else:
171 | train_samples.append(InputExample(texts=[sents1[i], sents2[i]], label=cosine_scores[i]))
172 | train_samples.append(InputExample(texts=[sents2[i], sents1[i]], label=cosine_scores[i]))
173 |
174 | del bi_encoder1, bi_encoder2, embeddings1, embeddings2, cosine_scores1, cosine_scores2
175 | torch.cuda.empty_cache()
176 |
177 | ###### Cross-encoder learning ######
178 | logging.info(f"Loading cross-encoder1 model: {simcse2base[model_name1]}")
179 | logging.info(f"Loading cross-encoder2 model: {simcse2base[model_name2]}")
180 | # Use Huggingface/transformers model (like BERT, RoBERTa, XLNet, XLM-R) for cross-encoder model
181 | cross_encoder1 = CrossEncoder(simcse2base[model_name1], num_labels=1, device=device1, max_length=64)
182 | cross_encoder2 = CrossEncoder(simcse2base[model_name2], num_labels=1, device=device2, max_length=64)
183 |
184 | # We wrap gold_samples (which is a List[InputExample]) into a pytorch DataLoader
185 | train_dataloader = DataLoader(train_samples, shuffle=True, batch_size=batch_size_cross_encoder)
186 |
187 | # We add an evaluator, which evaluates the performance during training
188 | if args.task in ["sts", "sickr", "sts_sickr", "custom"]:
189 | evaluator = CECorrelationEvaluator.from_input_examples(dev_samples, name='dev')
190 | else:
191 | evaluator = CECorrelationEvaluatorAUC.from_input_examples(dev_samples, name='dev')
192 |
193 | # Configure the training
194 | warmup_steps = math.ceil(len(train_dataloader) * num_epochs_cross_encoder * 0.1) #10% of train data for warm-up
195 | logging.info(f"Warmup-steps: {warmup_steps}")
196 |
197 | cross_encoder_path1 = f"output/cross-encoder/" \
198 | f"{args.task}_cycle{cycle}_mutual_parallel_{model_name1.replace('/', '-')}-{start_time}"
199 | cross_encoder_path2 = f"output/cross-encoder/" \
200 | f"{args.task}_cycle{cycle}_mutual_parallel_{model_name2.replace('/', '-')}-{start_time}"
201 |
202 |
203 | # Train the cross-encoder model
204 | cross_encoder1.fit(
205 | train_dataloader=train_dataloader,
206 | evaluator=evaluator,
207 | evaluation_steps=200,
208 | use_amp=True,
209 | epochs=num_epochs_cross_encoder,
210 | warmup_steps=warmup_steps,
211 | output_path=cross_encoder_path1)
212 |
213 | # Train the cross-encoder model
214 | cross_encoder2.fit(
215 | train_dataloader=train_dataloader,
216 | evaluator=evaluator,
217 | evaluation_steps=200,
218 | use_amp=True,
219 | epochs=num_epochs_cross_encoder,
220 | warmup_steps=warmup_steps,
221 | output_path=cross_encoder_path2)
222 |
223 |
224 | cross_encoder1 = CrossEncoder(cross_encoder_path1, max_length=64, device=device1)
225 | cross_encoder2 = CrossEncoder(cross_encoder_path2, max_length=64, device=device2)
226 |
227 | dev_score1 = evaluator(cross_encoder1)
228 | dev_score2 = evaluator(cross_encoder2)
229 | cross_encoder_dev_scores.append([dev_score1, dev_score2])
230 | logging.info (f"***** dev's spearman's rho: cross-encoder1 {dev_score1:.4f}, cross-encoder2 {dev_score2:.4f} *****")
231 |
232 | ###### label data with cross-encoder ######
233 | # label sentence pairs with cross-encoder
234 | logging.info ("Label sentence pairs...")
235 | silver_scores1 = cross_encoder1.predict(all_pairs)
236 | silver_scores2 = cross_encoder2.predict(all_pairs)
237 | silver_scores = np.array([silver_scores1,silver_scores2]).mean(0)
238 | silver_samples = list(InputExample(texts=[data[0], data[1]], label=score) for \
239 | data, score in zip(all_pairs, silver_scores))
240 |
241 | del cross_encoder1, cross_encoder2
242 | torch.cuda.empty_cache()
243 |
244 | ###### Bi-encoder learning ######
245 |
246 | logging.info(f"Loading bi-encoder1 model: {model_name1}")
247 | logging.info(f"Loading bi-encoder2 model: {model_name2}")
248 | # Use Huggingface/transformers model (like BERT, RoBERTa, XLNet, XLM-R) for mapping tokens to embeddings
249 | word_embedding_model1 = models.Transformer(model_name1, max_seq_length=max_seq_length)
250 | word_embedding_model2 = models.Transformer(model_name2, max_seq_length=max_seq_length)
251 |
252 | # Apply mean pooling to get one fixed sized sentence vector
253 | pooling_model1 = models.Pooling(word_embedding_model1.get_word_embedding_dimension(),
254 | pooling_mode=args.bi_encoder1_pooling_mode) # bert
255 | pooling_model2 = models.Pooling(word_embedding_model1.get_word_embedding_dimension(),
256 | pooling_mode=args.bi_encoder2_pooling_mode) # roberta
257 |
258 | bi_encoder1 = SentenceTransformer(modules=[word_embedding_model1, pooling_model1], device=device1)
259 | bi_encoder2 = SentenceTransformer(modules=[word_embedding_model2, pooling_model2], device=device2)
260 |
261 | train_dataloader = DataLoader(silver_samples, shuffle=True, batch_size=batch_size_bi_encoder)
262 | train_loss1 = losses.CosineSimilarityLoss(model=bi_encoder1)
263 | train_loss2 = losses.CosineSimilarityLoss(model=bi_encoder2)
264 |
265 | if args.task in ["sts", "sickr", "sts_sickr", "custom"]:
266 | evaluator = EmbeddingSimilarityEvaluator.from_input_examples(dev_samples, name="dev")
267 | else:
268 | evaluator = EmbeddingSimilarityEvaluatorAUC.from_input_examples(dev_samples, name="dev")
269 |
270 | # Configure the training.
271 | #warmup_steps = math.ceil(len(train_dataloader) * num_epochs * 0.1) #10% of train data for warm-up
272 | #logging.info(f"Warmup-steps: {warmup_steps}")
273 |
274 | bi_encoder_path1 = f"output/bi-encoder/" \
275 | f"{args.task}_cycle{cycle}_mutual_parallel_{model_name1.replace('/', '-')}-{start_time}"
276 | bi_encoder_path2 = f"output/bi-encoder/" \
277 | f"{args.task}_cycle{cycle}_mutual_parallel_{model_name2.replace('/', '-')}-{start_time}"
278 |
279 | bi_encoder1.fit(
280 | train_objectives=[(train_dataloader, train_loss1)],
281 | evaluator=evaluator,
282 | epochs=num_epochs_bi_encoder,
283 | evaluation_steps=200,
284 | warmup_steps=0,
285 | output_path=bi_encoder_path1,
286 | optimizer_params= {"lr": 5e-5},
287 | use_amp=True,
288 | )
289 |
290 | bi_encoder2.fit(
291 | train_objectives=[(train_dataloader, train_loss2)],
292 | evaluator=evaluator,
293 | epochs=num_epochs_bi_encoder,
294 | evaluation_steps=200,
295 | warmup_steps=0,
296 | output_path=bi_encoder_path2,
297 | optimizer_params= {"lr": 5e-5},
298 | use_amp=True,
299 | )
300 |
301 | bi_encoder1 = SentenceTransformer(bi_encoder_path1, device=device1)
302 | bi_encoder2 = SentenceTransformer(bi_encoder_path2, device=device2)
303 |
304 | dev_score1 = evaluator(bi_encoder1)
305 | dev_score2 = evaluator(bi_encoder2)
306 | bi_encoder_dev_scores.append([dev_score1, dev_score2])
307 | logging.info (f"***** dev's spearman's rho: bi-encoder1 {dev_score1:.4f}, bi-encoder2 {dev_score2:.4f} *****")
308 |
309 | del bi_encoder1, bi_encoder2
310 | torch.cuda.empty_cache()
311 |
312 | print (cross_encoder_dev_scores)
313 | print (bi_encoder_dev_scores)
314 |
315 | # best bi-encoder
316 | logging.info ("$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$")
317 | bi_encoder_dev_scores1 = [p[0] for p in bi_encoder_dev_scores]
318 | bi_encoder_dev_scores2 = [p[1] for p in bi_encoder_dev_scores]
319 | best_cycle_bi_encoder1 = np.argmax(bi_encoder_dev_scores1)+1
320 | best_cycle_bi_encoder2 = np.argmax(bi_encoder_dev_scores2)+1
321 |
322 | best_cycle_bi_encoder_path1 = f"output/bi-encoder/" \
323 | f"{args.task}_cycle{best_cycle_bi_encoder1}_mutual_parallel_{model_name1.replace('/', '-')}-{start_time}"
324 | best_cycle_bi_encoder_path2 = f"output/bi-encoder/" \
325 | f"{args.task}_cycle{best_cycle_bi_encoder2}_mutual_parallel_{model_name2.replace('/', '-')}-{start_time}"
326 |
327 | # eval bi-encoder
328 | logging.info ("£££££ Evaluate best bi-encoder1...")
329 | bi_encoder1 = SentenceTransformer(best_cycle_bi_encoder_path1, device=device1)
330 | logging.info (best_cycle_bi_encoder_path1)
331 | eval_encoder(all_test, bi_encoder1, task=args.task, enc_type="bi")
332 |
333 | logging.info ("£££££ Evaluate best bi-encoder2...")
334 | bi_encoder2 = SentenceTransformer(best_cycle_bi_encoder_path2, device=device2)
335 | logging.info (best_cycle_bi_encoder_path2)
336 | eval_encoder(all_test, bi_encoder2, task=args.task, enc_type="bi")
337 |
338 | # eval bi-encoder (ensembled)
339 | logging.info ("£££££ Evaluate best bi-encoders (ensembled)...")
340 | eval_encoder(all_test, [bi_encoder1, bi_encoder2], task=args.task, enc_type="bi", ensemble=True)
341 |
342 | # best cross-encoder
343 | logging.info ("$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$")
344 | cross_encoder_dev_scores1 = [p[0] for p in cross_encoder_dev_scores]
345 | cross_encoder_dev_scores2 = [p[1] for p in cross_encoder_dev_scores]
346 | best_cycle_cross_encoder1 = np.argmax(cross_encoder_dev_scores1)+1
347 | best_cycle_cross_encoder2 = np.argmax(cross_encoder_dev_scores2)+1
348 |
349 | best_cycle_cross_encoder_path1 = f"output/cross-encoder/" \
350 | f"{args.task}_cycle{best_cycle_cross_encoder1}_mutual_parallel_{model_name1.replace('/', '-')}-{start_time}"
351 | best_cycle_cross_encoder_path2 = f"output/cross-encoder/" \
352 | f"{args.task}_cycle{best_cycle_cross_encoder2}_mutual_parallel_{model_name2.replace('/', '-')}-{start_time}"
353 |
354 |
355 | # eval cross-encoder
356 | logging.info ("£££££ Evaluate best cross-encoder1...")
357 | logging.info (best_cycle_cross_encoder_path1)
358 | #cross_encoder1 = CrossEncoder(best_cycle_cross_encoder_path1, max_length=64, device=device1)
359 | cross_encoder1 = CrossEncoder(best_cycle_cross_encoder_path1, device=device1)
360 | eval_encoder(all_test, cross_encoder1, task=args.task, enc_type="cross")
361 |
362 | logging.info ("£££££ Evaluate best cross-encoder2...")
363 | logging.info (best_cycle_cross_encoder_path2)
364 | #cross_encoder2 = CrossEncoder(best_cycle_cross_encoder_path2, max_length=64, device=device2)
365 | cross_encoder2 = CrossEncoder(best_cycle_cross_encoder_path2, device=device2)
366 | eval_encoder(all_test, cross_encoder2, task=args.task, enc_type="cross")
367 |
368 | # eval cross-encoder (ensembled)
369 | logging.info ("£££££ Evaluate best cross-encoders (ensembled)...")
370 | eval_encoder(all_test, [cross_encoder1, cross_encoder2], task=args.task, enc_type="cross", ensemble=True)
371 |
372 |
373 | logging.info ("\n")
374 | print (args)
375 | logging.info ("\n")
376 | logging.info ("***** END *****")
377 |
--------------------------------------------------------------------------------