├── .gitignore
├── README.md
├── data
├── csts.tar.enc
└── extract.sh
├── make_test_submission.py
├── requirements.txt
├── run_sts.py
├── run_sts.sh
├── run_sts_fewshot.py
└── utils
├── __init__.py
├── fewshot
├── generate_in_context_dataset.py
├── openai_utils.py
└── progress_logger.py
├── progress_logger.py
└── sts
├── __init__.py
├── dataset_preprocessing.py
├── modeling_encoders.py
├── modeling_utils.py
└── triplet_trainer.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | #.idea/
161 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # C-STS
2 |
3 | This repository contains the dataset and code for the paper C-STS: Conditional Semantic Textual Similarity. [[ArXiv]](https://arxiv.org/abs/2305.15093)
4 |
5 |
6 | ## Table of Contents
7 | - [Data](#data)
8 | - [Code](#code)
9 | - [Fine-tuning](#fine-tuning)
10 | - [Few-shot Evaluation](#few-shot-evaluation)
11 | - [Submitting Test Results](#submitting-test-results)
12 | - [Citation](#citation)
13 |
14 | ## Data
15 |
16 | To avoid the intentional/unintentional scraping of the C-STS dataset for pre-training LLMs, which could cause training data contamination and impact their evaluation, we adopt the following approach for our dataset release.
17 |
18 | The dataset for C-STS is stored in an encrypted file named `csts.tar.enc`. To access the dataset, follow these steps:
19 |
20 | 1. Request Access: Submit a request to obtain the decryption password by [clicking here](https://docs.google.com/forms/d/e/1FAIpQLSfoYig6I3qEBUBaNmzugnAKGpX1mSpM5cbGeO-dXq-u_sMPJQ/viewform?usp=sf_link). You will receive an email response with the password immediately.
21 |
22 | 2. Decrypt the Dataset: Once you have received the password via email, you can decrypt the `csts.tar.enc` file using the provided `extract.sh` script. Follow the instructions below:
23 |
24 | - Open a terminal and navigate to the `data` directory.
25 | - Run the following command, replacing `` with the decryption password obtained via email:
26 |
27 | ```bash
28 | bash extract.sh csts.tar.enc
29 | ```
30 |
31 | Provided the correct password, this step will generate three files `csts_train.csv`, `csts_validation.csv`, and `csts_test.csv`, the unencrypted dataset splits.
32 |
33 | You can load the data using [datasets](https://github.com/huggingface/datasets) with the following lines
34 |
35 | ```python
36 | from datasets import load_dataset
37 |
38 | dataset = load_dataset(
39 | 'csv',
40 | data_files=
41 | {
42 | 'train': 'data/csts_train.csv',
43 | 'validation': 'data/csts_validation.csv',
44 | 'test': 'data/csts_test.csv'
45 | }
46 | )
47 | ```
48 |
49 | **Important: By using this dataset, you agree to not publicly share its unencrypted contents or decryption password.**
50 |
51 | ## Code
52 | We provide the basic training scripts and utilities for finetuning and evaluating the models in the paper. The code is adapted from the [HuggingFace Transformers](www.huggingface.co/transformers) library. Refer to the [documentation](https://huggingface.co/transformers/) for more details.
53 |
54 | ### Fine-tuning
55 | The current code supports finetuning any encoder-only model, using the `cross_encoder`, `bi_encoder`, or `tri_encoder` settings described in the paper.
56 | You can finetune the models described in the paper using the `run_sts.sh` script. For example, to finetune the `princeton-nlp/sup-simcse-roberta-base` model on the C-STS dataset, run the following command:
57 |
58 | ```bash
59 | MODEL=princeton-nlp/sup-simcse-roberta-base \
60 | ENCODER_TYPE=bi_encoder \
61 | LR=1e-5 \
62 | WD=0.1 \
63 | TRANSFORM=False \
64 | OBJECTIVE=mse \
65 | OUTPUT_DIR=output \
66 | TRAIN_FILE=data/csts_train.csv \
67 | EVAL_FILE=data/csts_validation.csv \
68 | TEST_FILE=data/csts_test.csv \
69 | bash run_sts.sh
70 | ```
71 |
72 | See `run_sts.sh` for a full description of the available options and default values.
73 |
74 | ### Few-shot Evaluation
75 | The script `run_sts_fewshot.sh` can be used to evaluate large language-models in a few-shot setting with or without instructions. For example, to evaluate the `google/flan-t5-xxl` model on the C-STS dataset, run the following command:
76 |
77 | ```bash
78 | python run_sts_fewshot.py \
79 | --model_name_or_path google/flan-t5-xxl \
80 | --k_shot 2 \
81 | --prompt_name long \
82 | --train_file data/csts_train.csv \
83 | --validation_file data/csts_validation.csv \
84 | --test_file data/csts_test.csv \
85 | --output_dir output/flan-t5-xxl/k2_long \
86 | --dtype tf32 \
87 | --batch_size 4
88 | ```
89 |
90 | To accommodate large model types `run_sts_fewshot.sh` will use all visible GPUs to load the model in model parallel. For smaller models set `CUDA_VISIBLE_DEVICES` to the desired GPU ids.
91 |
92 | Run `python run_sts_fewshot.py --help` for a full description of additional options and default values.
93 |
94 |
95 | ### Submitting Test Results
96 | You can scores for your model on the test set by submitting your predictions using the `make_test_submission.py` script as follows:
97 |
98 | ```bash
99 | python make_test_submission.py your_email@email.com /path/to/your/predictions.json
100 | ```
101 |
102 | This script expects the test predictions file to be in the format generated automatically by the scripts above; i.e.
103 |
104 | ```json
105 | {
106 | "0": 1.0,
107 | "1": 0.0,
108 | "...":
109 | "4731": 0.5
110 | }
111 | ```
112 |
113 | After submission your results will be emailed to the submitted email address with the relevant filename in the subject.
114 |
115 |
116 | ## Citation
117 | ```tex
118 | @misc{deshpande2023csts,
119 | title={CSTS: Conditional Semantic Textual Similarity},
120 | author={Ameet Deshpande and Carlos E. Jimenez and Howard Chen and Vishvak Murahari and Victoria Graf and Tanmay Rajpurohit and Ashwin Kalyan and Danqi Chen and Karthik Narasimhan},
121 | year={2023},
122 | eprint={2305.15093},
123 | archivePrefix={arXiv},
124 | primaryClass={cs.CL}
125 | }
126 | ```
127 |
--------------------------------------------------------------------------------
/data/extract.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | if [ $# -ne 2 ]; then
3 | echo "Error: Please provide the path to the encrypted file and the decryption password."
4 | echo "Usage: ./extract.sh "
5 | exit 1
6 | fi
7 | ENCRYPTED_FILE="$1"
8 | PASSWORD="$2"
9 | openssl aes-256-cbc -a -d -salt -pbkdf2 -in "$ENCRYPTED_FILE" -out csts.tar -pass pass:"$PASSWORD"
10 | EXIT_CODE=$?
11 | if [ $EXIT_CODE -ne 0 ]; then
12 | rm -f csts.tar
13 | echo "Error: Failed to decrypt the file."
14 | exit $EXIT_CODE
15 | fi
16 | tar -xvf csts.tar
17 | EXIT_CODE=$?
18 | if [ $EXIT_CODE -ne 0 ]; then
19 | echo "Error: Failed to extract the decrypted file."
20 | exit $EXIT_CODE
21 | fi
22 | rm -f csts.tar
23 | if [ $? -ne 0 ]; then
24 | echo "Error: Failed to remove the files."
25 | exit $?
26 | fi
27 | echo "Decryption and cleanup completed successfully."
28 | exit 0
29 |
30 |
--------------------------------------------------------------------------------
/make_test_submission.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import re
3 | import json
4 | import requests
5 | from pathlib import Path
6 |
7 |
8 | def send_post_request(email, predictions, filename):
9 | # Prepare the data to be sent
10 | if len(filename) > 200:
11 | raise ValueError('Submission name (%s) longer than 200 characters. Please choose a shorter filename or set the name with --name' % filename)
12 | data = {
13 | 'email': email,
14 | 'predictions': predictions,
15 | 'filename': filename,
16 | }
17 | data_str = json.dumps({'body': json.dumps(data)})
18 | headers = {'content-type': 'application/json'}
19 | # url = 'https://rcxnewlbk5.execute-api.us-east-2.amazonaws.com/test/eval-csts'
20 | url = "https://0sy74d2tog.execute-api.us-east-2.amazonaws.com/dev/c-sts-eval-lambda"
21 | # Create the request object
22 | request_object = {
23 | "url": url,
24 | "headers": headers,
25 | "data": data
26 | }
27 | json.dump(request_object, open('request.json', 'w'), indent=4)
28 | response = requests.post(url, headers=headers, data=data_str)
29 | if response.status_code == 200:
30 | print("Evaluation successful!")
31 | print(response.json()['body'])
32 | print("See email: \"C-STS Evaluation Results for %s\"" % filename)
33 |
34 |
35 | def main(email, predictions_file, name):
36 | predictions_file = Path(predictions_file).resolve(strict=True)
37 | if name is None:
38 | name = predictions_file.as_posix()
39 | if not re.match(r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+(?:[A-Za-z0-9.-]+)*\b", email):
40 | raise ValueError("Email %s is invalid" % email)
41 | with open(predictions_file, 'r') as f:
42 | preds = json.load(f)
43 | keys, preds = zip(*sorted(preds.items(), key=lambda x: int(x[0])))
44 | preds = list(map(float, preds))
45 | if len(keys) != 4732:
46 | raise ValueError("There should be exactly 4732 predictions, but got %d instead" % len(keys))
47 | send_post_request(email, preds, name)
48 |
49 |
50 | if __name__ == "__main__":
51 | parser = argparse.ArgumentParser(description='Send email and predictions to server')
52 | parser.add_argument('email', type=str, help='The email to be sent')
53 | parser.add_argument('predictions_file', type=str, help='The path to the JSON file containing the predictions')
54 | parser.add_argument('--name', type=str, help='The name of the submission. Uses the filename if not specified')
55 | args = parser.parse_args()
56 | main(**vars(args))
57 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | accelerate>=0.20.3
2 | datasets>=2.11.0
3 | huggingface_hub>=0.13.4
4 | openai>=0.27.7
5 | pandas>=2.0.0
6 | scipy>=1.11.1
7 | torch>=2.0.0
8 | transformers>=4.28.1
9 |
--------------------------------------------------------------------------------
/run_sts.py:
--------------------------------------------------------------------------------
1 | """
2 | Adapted code from HuggingFace run_glue.py
3 |
4 | Author: Ameet Deshpande, Carlos E. Jimenez
5 | """
6 | import json
7 | import logging
8 | import os
9 |
10 | os.environ["TOKENIZERS_PARALLELISM"] = "false"
11 | import random
12 | import sys
13 | from dataclasses import dataclass, field
14 | from typing import Optional
15 |
16 | import datasets
17 | import numpy as np
18 | import transformers
19 | from datasets import load_dataset
20 | from scipy.stats import pearsonr, spearmanr
21 | from transformers import (
22 | AutoConfig,
23 | AutoTokenizer,
24 | EvalPrediction,
25 | HfArgumentParser,
26 | PrinterCallback,
27 | Trainer,
28 | )
29 | from transformers import TrainingArguments as HFTrainingArguments
30 | from transformers import default_data_collator, set_seed
31 | from transformers.trainer_utils import get_last_checkpoint
32 |
33 | from utils.progress_logger import LogCallback
34 | from utils.sts.dataset_preprocessing import get_preprocessing_function
35 | from utils.sts.modeling_utils import DataCollatorWithPadding, get_model
36 | from utils.sts.triplet_trainer import TripletTrainer
37 |
38 | logging.basicConfig(
39 | level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s: %(message)s"
40 | )
41 |
42 | logger = logging.getLogger(__name__)
43 |
44 |
45 | @dataclass
46 | class TrainingArguments(HFTrainingArguments):
47 | log_time_interval: int = field(
48 | default=15,
49 | metadata={
50 | "help": (
51 | "Log at each `log_time_interval` seconds. "
52 | "Default will be to log every 15 seconds."
53 | )
54 | },
55 | )
56 |
57 |
58 | @dataclass
59 | class DataTrainingArguments:
60 | """
61 | Arguments pertaining to what data we are going to input our model for training and eval.
62 |
63 | Using `HfArgumentParser` we can turn this class
64 | into argparse arguments to be able to specify them on
65 | the command line.
66 | """
67 |
68 | max_seq_length: int = field(
69 | default=128,
70 | metadata={
71 | "help": "The maximum total input sequence length after tokenization. Sequences longer "
72 | "than this will be truncated, sequences shorter will be padded."
73 | },
74 | )
75 | overwrite_cache: bool = field(
76 | default=False,
77 | metadata={"help": "Overwrite the cached preprocessed datasets or not."},
78 | )
79 | pad_to_max_length: bool = field(
80 | default=False,
81 | metadata={
82 | "help": "Whether to pad all samples to `max_seq_length`. "
83 | "If False, will pad the samples dynamically when batching to the maximum length in the batch."
84 | },
85 | )
86 | max_train_samples: Optional[int] = field(
87 | default=None,
88 | metadata={
89 | "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
90 | "value if set."
91 | },
92 | )
93 | max_eval_samples: Optional[int] = field(
94 | default=None,
95 | metadata={
96 | "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
97 | "value if set."
98 | },
99 | )
100 | max_predict_samples: Optional[int] = field(
101 | default=None,
102 | metadata={
103 | "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
104 | "value if set."
105 | },
106 | )
107 | train_file: Optional[str] = field(
108 | default=None,
109 | metadata={"help": "A csv or a json file containing the training data."},
110 | )
111 | validation_file: Optional[str] = field(
112 | default=None,
113 | metadata={"help": "A csv or a json file containing the validation data."},
114 | )
115 | test_file: Optional[str] = field(
116 | default=None,
117 | metadata={"help": "A csv or a json file containing the test data."},
118 | )
119 | # Dataset specific arguments
120 | max_similarity: Optional[float] = field(
121 | default=None, metadata={"help": "Maximum similarity score."}
122 | )
123 | min_similarity: Optional[float] = field(
124 | default=None, metadata={"help": "Minimum similarity score."}
125 | )
126 | condition_only: Optional[bool] = field(
127 | default=False, metadata={"help": "Only use condition column."}
128 | )
129 | sentences_only: Optional[bool] = field(
130 | default=False, metadata={"help": "Only use sentences column."}
131 | )
132 |
133 | def __post_init__(self):
134 | validation_extension = self.validation_file.split(".")[-1]
135 | if self.train_file is not None:
136 | train_extension = self.train_file.split(".")[-1]
137 | assert train_extension in [
138 | "csv",
139 | "json",
140 | ], "`train_file` should be a csv or a json file."
141 | assert (
142 | train_extension == validation_extension
143 | ), "`train_file` and `validation_file` should have the same extension."
144 | if self.test_file is not None:
145 | test_extension = self.test_file.split(".")[-1]
146 | assert test_extension in [
147 | "csv",
148 | "json",
149 | ], "`test_file` should be a csv or a json file."
150 | assert (
151 | test_extension == validation_extension
152 | ), "`test_file` and `validation_file` should have the same extension."
153 |
154 |
155 | @dataclass
156 | class ModelArguments:
157 | """
158 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
159 | """
160 |
161 | model_name_or_path: str = field(
162 | metadata={
163 | "help": "Path to pretrained model or model identifier from huggingface.co/models"
164 | }
165 | )
166 | config_name: Optional[str] = field(
167 | default=None,
168 | metadata={
169 | "help": "Pretrained config name or path if not the same as model_name"
170 | },
171 | )
172 | tokenizer_name: Optional[str] = field(
173 | default=None,
174 | metadata={
175 | "help": "Pretrained tokenizer name or path if not the same as model_name"
176 | },
177 | )
178 | cache_dir: Optional[str] = field(
179 | default=None,
180 | metadata={
181 | "help": "Where do you want to store the pretrained models downloaded from huggingface.co"
182 | },
183 | )
184 | use_fast_tokenizer: bool = field(
185 | default=True,
186 | metadata={
187 | "help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."
188 | },
189 | )
190 | model_revision: str = field(
191 | default="main",
192 | metadata={
193 | "help": "The specific model version to use (can be a branch name, tag name or commit id)."
194 | },
195 | )
196 | use_auth_token: bool = field(
197 | default=False,
198 | metadata={
199 | "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
200 | "with private models)."
201 | },
202 | )
203 | objective: Optional[str] = field(
204 | default="mse",
205 | metadata={
206 | "help": "Objective function for training. Options:\
207 | 1) regression: Regression task (uses MSELoss).\
208 | 2) classification: Classification task (uses CrossEntropyLoss).\
209 | 3) triplet: Regression task (uses QuadrupletLoss).\
210 | 4) triplet_mse: Regression task uses QuadrupletLoss with MSE loss."
211 | },
212 | )
213 | # What type of modeling
214 | encoding_type: Optional[str] = field(
215 | default="cross_encoder",
216 | metadata={
217 | "help": "What kind of model to choose. Options:\
218 | 1) cross_encoder: Full encoder model.\
219 | 2) bi_encoder: Bi-encoder model.\
220 | 3) tri_encoder: Tri-encoder model."
221 | },
222 | )
223 | # Pooler for bi-encoder
224 | pooler_type: Optional[str] = field(
225 | default="cls",
226 | metadata={
227 | "help": "Pooler type: Options:\
228 | 1) cls: Use [CLS] token.\
229 | 2) avg: Mean pooling."
230 | },
231 | )
232 | freeze_encoder: Optional[bool] = field(
233 | default=False, metadata={"help": "Freeze encoder weights."}
234 | )
235 | transform: Optional[bool] = field(
236 | default=False,
237 | metadata={"help": "Use a linear transformation on the encoder output"},
238 | )
239 | triencoder_head: Optional[str] = field(
240 | default="hadamard",
241 | metadata={
242 | "help": "Tri-encoder head type: Options:\
243 | 1) hadamard: Hadamard product.\
244 | 2) transformer: Transformer."
245 | },
246 | )
247 |
248 |
249 | def main():
250 | parser = HfArgumentParser(
251 | (ModelArguments, DataTrainingArguments, TrainingArguments)
252 | )
253 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
254 | model_args, data_args, training_args = parser.parse_json_file(
255 | json_file=os.path.abspath(sys.argv[1]),
256 | )
257 | else:
258 | model_args, data_args, training_args = parser.parse_args_into_dataclasses()
259 | logging.basicConfig(
260 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
261 | datefmt="%m/%d/%Y %H:%M:%S",
262 | handlers=[logging.StreamHandler(sys.stdout)],
263 | )
264 | training_args.log_level = "info"
265 | log_level = training_args.get_process_log_level()
266 | logger.setLevel(log_level)
267 | datasets.utils.logging.set_verbosity(log_level)
268 | transformers.utils.logging.set_verbosity(log_level)
269 | transformers.utils.logging.enable_default_handler()
270 | transformers.utils.logging.enable_explicit_format()
271 | logger.warning(
272 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
273 | + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
274 | )
275 | if model_args.objective in {"triplet", "triplet_mse"}:
276 | training_args.dataloader_drop_last = True
277 | training_args.per_device_eval_batch_size = 2
278 | logger.info("Training/evaluation parameters %s" % training_args)
279 | last_checkpoint = None
280 | if (
281 | os.path.isdir(training_args.output_dir)
282 | and training_args.do_train
283 | and not training_args.overwrite_output_dir
284 | ):
285 | last_checkpoint = get_last_checkpoint(training_args.output_dir)
286 | if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
287 | raise ValueError(
288 | f"Output directory ({training_args.output_dir}) already exists and is not empty. "
289 | "Use --overwrite_output_dir to overcome."
290 | )
291 | elif (
292 | last_checkpoint is not None and training_args.resume_from_checkpoint is None
293 | ):
294 | logger.warning(
295 | f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
296 | "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
297 | )
298 | set_seed(training_args.seed)
299 | data_files = {"validation": data_args.validation_file}
300 | if training_args.do_train:
301 | data_files["train"] = data_args.train_file
302 | if data_args.test_file is not None:
303 | data_files["test"] = data_args.test_file
304 | elif training_args.do_predict:
305 | raise ValueError("test_file argument is missing. required for do_predict.")
306 | for key, name in data_files.items():
307 | logger.info(f"load a local file for {key}: {name}")
308 | if data_args.validation_file.endswith(".csv") or data_args.validation_file.endswith(
309 | ".tsv"
310 | ):
311 | # Loading a dataset from local csv files
312 | raw_datasets = load_dataset(
313 | "csv",
314 | data_files=data_files,
315 | cache_dir=model_args.cache_dir,
316 | use_auth_token=True if model_args.use_auth_token else None,
317 | )
318 | elif data_args.validation_file.endswith(".json"):
319 | # Loading a dataset from local json files
320 | raw_datasets = load_dataset(
321 | "json",
322 | data_files=data_files,
323 | cache_dir=model_args.cache_dir,
324 | use_auth_token=True if model_args.use_auth_token else None,
325 | )
326 | else:
327 | raise ValueError("validation_file should be a csv or a json file.")
328 | labels = set()
329 | for key in set(raw_datasets.keys()) - {"test"}:
330 | labels.update(raw_datasets[key]["label"])
331 | if data_args.min_similarity is None:
332 | data_args.min_similarity = min(labels)
333 | logger.warning(
334 | f"Setting min_similarity: {data_args.min_similarity}. Override by setting --min_similarity."
335 | )
336 | if data_args.max_similarity is None:
337 | data_args.max_similarity = max(labels)
338 | logger.warning(
339 | f"Setting max_similarity: {data_args.max_similarity}. Override by setting --max_similarity."
340 | )
341 | config = AutoConfig.from_pretrained(
342 | model_args.config_name
343 | if model_args.config_name
344 | else model_args.model_name_or_path,
345 | num_labels=1,
346 | # finetuning_task=None,
347 | cache_dir=model_args.cache_dir,
348 | revision=model_args.model_revision,
349 | use_auth_token=True if model_args.use_auth_token else None,
350 | )
351 | tokenizer = AutoTokenizer.from_pretrained(
352 | model_args.tokenizer_name
353 | if model_args.tokenizer_name
354 | else model_args.model_name_or_path,
355 | cache_dir=model_args.cache_dir,
356 | use_fast=model_args.use_fast_tokenizer,
357 | revision=model_args.model_revision,
358 | use_auth_token=True if model_args.use_auth_token else None,
359 | )
360 | model_cls = get_model(model_args)
361 | config.update(
362 | {
363 | "use_auth_token": model_args.use_auth_token,
364 | "model_revision": model_args.model_revision,
365 | "cache_dir": model_args.cache_dir,
366 | "model_name_or_path": model_args.model_name_or_path,
367 | "objective": model_args.objective,
368 | "pooler_type": model_args.pooler_type,
369 | "transform": model_args.transform,
370 | "triencoder_head": model_args.triencoder_head,
371 | }
372 | )
373 | model = model_cls(config=config)
374 | if model_args.freeze_encoder:
375 | for param in model.backbone.parameters():
376 | param.requires_grad = False
377 | sentence1_key, sentence2_key, condition_key, similarity_key = (
378 | "sentence1",
379 | "sentence2",
380 | "condition",
381 | "label",
382 | )
383 | # Padding strategy
384 | if data_args.pad_to_max_length:
385 | padding = "max_length"
386 | else:
387 | padding = False
388 | if data_args.max_seq_length > tokenizer.model_max_length:
389 | logger.warning(
390 | "The max_seq_length passed (%d) is larger than the maximum length for the "
391 | "model (%d). Using max_seq_length=%d."
392 | % (
393 | data_args.max_seq_length,
394 | tokenizer.model_max_length,
395 | tokenizer.model_max_length,
396 | )
397 | )
398 | max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
399 | preprocess_function = get_preprocessing_function(
400 | tokenizer,
401 | sentence1_key,
402 | sentence2_key,
403 | condition_key,
404 | similarity_key,
405 | padding,
406 | max_seq_length,
407 | model_args,
408 | scale=(data_args.min_similarity, data_args.max_similarity)
409 | if model_args.objective in {"mse", "triplet", "triplet_mse"}
410 | else None,
411 | condition_only=data_args.condition_only,
412 | sentences_only=data_args.sentences_only,
413 | )
414 | with training_args.main_process_first(desc="dataset map pre-processing"):
415 | raw_datasets = raw_datasets.map(
416 | preprocess_function,
417 | batched=True,
418 | load_from_cache_file=not data_args.overwrite_cache,
419 | desc="Running tokenizer on dataset",
420 | remove_columns=raw_datasets["train"].column_names,
421 | )
422 | if training_args.do_train:
423 | if "train" not in raw_datasets:
424 | raise ValueError("--do_train requires a train dataset")
425 | train_dataset = raw_datasets["train"]
426 | if data_args.max_train_samples is not None:
427 | max_train_samples = min(len(train_dataset), data_args.max_train_samples)
428 | train_dataset = train_dataset.select(range(max_train_samples))
429 | if training_args.do_eval:
430 | if (
431 | "validation" not in raw_datasets
432 | and "validation_matched" not in raw_datasets
433 | ):
434 | raise ValueError("--do_eval requires a validation dataset")
435 | eval_dataset = raw_datasets["validation"]
436 | if data_args.max_eval_samples is not None:
437 | max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
438 | eval_dataset = eval_dataset.select(range(max_eval_samples))
439 | if training_args.do_predict or data_args.test_file is not None:
440 | if "test" not in raw_datasets and "test_matched" not in raw_datasets:
441 | raise ValueError("--do_predict requires a test dataset")
442 | predict_dataset = raw_datasets["test"]
443 | if data_args.max_predict_samples is not None:
444 | max_predict_samples = min(
445 | len(predict_dataset), data_args.max_predict_samples
446 | )
447 | predict_dataset = predict_dataset.select(range(max_predict_samples))
448 | # Log a few random samples from the training set:
449 | if training_args.do_train:
450 | for index in random.sample(range(len(train_dataset)), 3):
451 | input_ids = train_dataset[index]["input_ids"]
452 | logger.info(f"tokens: {tokenizer.decode(input_ids)}")
453 | logger.info(f"Sample {index} of the training set: {train_dataset[index]}.")
454 |
455 | def compute_metrics(output: EvalPrediction):
456 | preds = (
457 | output.predictions[0]
458 | if isinstance(output.predictions, tuple)
459 | else output.predictions
460 | )
461 | preds = np.squeeze(preds)
462 | return {
463 | "mse": ((preds - output.label_ids) ** 2).mean().item(),
464 | "pearsonr": pearsonr(preds, output.label_ids)[0],
465 | "spearmanr": spearmanr(preds, output.label_ids)[0],
466 | }
467 |
468 | if data_args.pad_to_max_length:
469 | data_collator = default_data_collator
470 | else:
471 | data_collator = DataCollatorWithPadding(
472 | pad_token_id=tokenizer.pad_token_id,
473 | pad_token_type_id=tokenizer.pad_token_type_id,
474 | pad_to_multiple_of=8 if training_args.fp16 else None,
475 | )
476 | # Initialize our Trainer
477 | trainer_cls = (
478 | TripletTrainer
479 | if model_args.objective in {"triplet", "triplet_mse"}
480 | else Trainer
481 | )
482 | trainer = trainer_cls(
483 | model=model,
484 | args=training_args,
485 | train_dataset=train_dataset if training_args.do_train else None,
486 | eval_dataset=eval_dataset if training_args.do_eval else None,
487 | compute_metrics=compute_metrics,
488 | tokenizer=tokenizer,
489 | data_collator=data_collator,
490 | )
491 | trainer.remove_callback(PrinterCallback)
492 | trainer.add_callback(LogCallback)
493 | if training_args.do_train:
494 | checkpoint = None
495 | if training_args.resume_from_checkpoint is not None:
496 | checkpoint = training_args.resume_from_checkpoint
497 | elif last_checkpoint is not None:
498 | checkpoint = last_checkpoint
499 | train_result = trainer.train(resume_from_checkpoint=checkpoint)
500 | metrics = train_result.metrics
501 | max_train_samples = (
502 | data_args.max_train_samples
503 | if data_args.max_train_samples is not None
504 | else len(train_dataset)
505 | )
506 | metrics["train_samples"] = min(max_train_samples, len(train_dataset))
507 | trainer.save_model() # Saves the tokenizer too for easy upload
508 | trainer.log_metrics("train", metrics)
509 | trainer.save_metrics("train", metrics)
510 | trainer.save_state()
511 | # Evaluation
512 | combined = {}
513 | if training_args.do_eval:
514 | logger.info("*** Evaluate ***")
515 | metrics = trainer.evaluate(eval_dataset=eval_dataset)
516 | max_eval_samples = (
517 | data_args.max_eval_samples
518 | if data_args.max_eval_samples is not None
519 | else len(eval_dataset)
520 | )
521 | metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
522 | combined.update(metrics)
523 | trainer.log_metrics("eval", metrics)
524 | trainer.save_metrics("eval", combined)
525 | if training_args.do_train:
526 | metrics = trainer.evaluate(
527 | eval_dataset=train_dataset, metric_key_prefix="train"
528 | )
529 | max_eval_samples = (
530 | data_args.max_eval_samples
531 | if data_args.max_eval_samples is not None
532 | else len(eval_dataset)
533 | )
534 | metrics["train_samples"] = min(max_eval_samples, len(train_dataset))
535 | trainer.log_metrics("train", metrics)
536 | trainer.save_metrics("train", combined)
537 | if training_args.do_predict:
538 | logger.info("*** Predict ***")
539 | # Removing the `label` columns because it contains -1 and Trainer won't like that.
540 | predict_dataset = predict_dataset.remove_columns("labels")
541 | predictions = trainer.predict(
542 | predict_dataset, metric_key_prefix="predict"
543 | ).predictions
544 | predictions = (
545 | np.squeeze(predictions)
546 | if model_args.objective in {"mse", "triplet", "triplet_mse"}
547 | else np.argmax(predictions, axis=1)
548 | )
549 | predictions = dict(enumerate(predictions.tolist()))
550 | output_predict_file = os.path.join(
551 | training_args.output_dir, f"test_predictions.json"
552 | )
553 | if trainer.is_world_process_zero():
554 | with open(output_predict_file, "w", encoding="utf-8") as outfile:
555 | json.dump(predictions, outfile)
556 | kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "CSTS"}
557 | if training_args.push_to_hub:
558 | trainer.push_to_hub(**kwargs)
559 | else:
560 | trainer.create_model_card(**kwargs)
561 |
562 |
563 | if __name__ == "__main__":
564 | main()
565 |
--------------------------------------------------------------------------------
/run_sts.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | model=${MODEL:-princeton-nlp/sup-simcse-roberta-large} # pre-trained model
3 | encoding=${ENCODER_TYPE:-bi_encoder} # cross_encoder, bi_encoder, tri_encoder
4 | lr=${LR:-1e-5} # learning rate
5 | wd=${WD:-0.1} # weight decay
6 | transform=${TRANSFORM:-False} # whether to use an additional linear layer after the encoder
7 | objective=${OBJECTIVE:-mse} # mse, triplet, triplet_mse
8 | triencoder_head=${TRIENCODER_HEAD:-None} # hadamard, concat (set for tri_encoder)
9 | seed=${SEED:-42}
10 | output_dir=${OUTPUT_DIR:-output}
11 | config=enc_${encoding}__lr_${lr}__wd_${wd}__trans_${transform}__obj_${objective}__tri_${triencoder_head}__s_${seed}
12 | train_file=${TRAIN_FILE:-data/csts_train.csv}
13 | eval_file=${EVAL_FILE:-data/csts_validation.csv}
14 | test_file=${TEST_FILE:-data/csts_test.csv}
15 |
16 | python run_sts.py \
17 | --output_dir "${output_dir}/${model//\//__}/${config}" \
18 | --model_name_or_path ${model} \
19 | --objective ${objective} \
20 | --encoding_type ${encoding} \
21 | --pooler_type cls \
22 | --freeze_encoder False \
23 | --transform ${transform} \
24 | --triencoder_head ${triencoder_head} \
25 | --max_seq_length 512 \
26 | --train_file ${train_file} \
27 | --validation_file ${eval_file} \
28 | --test_file ${test_file} \
29 | --condition_only False \
30 | --sentences_only False \
31 | --do_train \
32 | --do_eval \
33 | --do_predict \
34 | --evaluation_strategy epoch \
35 | --per_device_train_batch_size 8 \
36 | --gradient_accumulation_steps 4 \
37 | --learning_rate ${lr} \
38 | --weight_decay ${wd} \
39 | --max_grad_norm 0.0 \
40 | --num_train_epochs 3 \
41 | --lr_scheduler_type linear \
42 | --warmup_ratio 0.1 \
43 | --log_level info \
44 | --disable_tqdm True \
45 | --save_strategy epoch \
46 | --save_total_limit 1 \
47 | --seed ${seed} \
48 | --data_seed ${seed} \
49 | --fp16 True \
50 | --log_time_interval 15
--------------------------------------------------------------------------------
/run_sts_fewshot.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import logging
4 | import os
5 | import re
6 | import time
7 | from argparse import ArgumentParser
8 | from pathlib import Path
9 |
10 | import numpy as np
11 | import torch
12 | from accelerate import init_empty_weights, load_checkpoint_and_dispatch
13 | from huggingface_hub import snapshot_download
14 | from scipy.stats import pearsonr, spearmanr
15 | from torch.utils.data import DataLoader
16 | from transformers import (AutoConfig, AutoModelForCausalLM,
17 | AutoModelForSeq2SeqLM, AutoTokenizer,
18 | default_data_collator)
19 |
20 | from utils.fewshot.generate_in_context_dataset import make_dataset
21 | from utils.fewshot.openai_utils import (OPENAI_MODELS, authenticate,
22 | get_gpt_prediction)
23 | from utils.fewshot.progress_logger import ProgressLogger
24 |
25 | logging.basicConfig(
26 | level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s: %(message)s"
27 | )
28 | logger = logging.getLogger(__name__)
29 |
30 |
31 | MODEL_CLASSES = {
32 | "gpt": AutoModelForCausalLM,
33 | "t5": AutoModelForSeq2SeqLM,
34 | }
35 |
36 | NO_SKIP_MODULES = {
37 | "t5": ["T5Block"],
38 | "gpt": ["GPTJBlock"],
39 | }
40 |
41 | DTYPES = {
42 | "bf16": torch.bfloat16,
43 | "fp16": torch.float16,
44 | "fp32": torch.float32,
45 | "tf32": torch.float32, # tf32 set by cuda.backend
46 | }
47 |
48 |
49 | def get_tokenizer_type(model):
50 | if (
51 | "t5" in model.lower()
52 | or "t0" in model.lower()
53 | or "tk-" in model.lower()
54 | or "ul2" in model.lower()
55 | ):
56 | return "t5"
57 | elif "gpt" in model.lower():
58 | return "gpt"
59 | else:
60 | raise ValueError(f"Unknown tokenizer type {model}")
61 |
62 |
63 | def extract_float(s):
64 | match = re.search(r"(\d+\.\d+|\d+)", s)
65 | if match:
66 | return float(match.group(1))
67 | return s
68 |
69 |
70 | def eval(
71 | dataset,
72 | model,
73 | tokenizer,
74 | prefix,
75 | tokenizer_type,
76 | min_similarity,
77 | max_similarity,
78 | dataloader_num_workers,
79 | batch_size,
80 | ):
81 | start_time = time.time()
82 | if model in OPENAI_MODELS:
83 | all_preds, all_labels, examples, non_numeric = openai_model_eval(
84 | model,
85 | dataset,
86 | min_similarity,
87 | max_similarity,
88 | )
89 | else:
90 | all_preds, all_labels, examples, non_numeric = non_openai_model_eval(
91 | model,
92 | tokenizer,
93 | tokenizer_type,
94 | dataset,
95 | dataloader_num_workers,
96 | batch_size,
97 | min_similarity,
98 | max_similarity,
99 | )
100 | eval_time = time.time() - start_time
101 | predictions = dict(enumerate(all_preds))
102 | logger.info(f"Example Preds: {all_preds[:3]}")
103 | logger.info(f"Example Labels: {all_labels[:3]}")
104 | results = process_results(
105 | prefix,
106 | eval_time,
107 | len(dataset),
108 | non_numeric,
109 | all_preds,
110 | all_labels,
111 | min_similarity,
112 | max_similarity,
113 | )
114 | return results, predictions, examples
115 |
116 |
117 | def get_tokenizer_func(tokenizer, tokenizer_type):
118 | def tokenizer_func(example):
119 | return tokenizer(
120 | example["text"],
121 | padding="longest",
122 | truncation=True,
123 | return_tensors="pt",
124 | add_special_tokens=tokenizer_type == "t5",
125 | )
126 | return tokenizer_func
127 |
128 |
129 | def openai_model_eval(
130 | model, dataset, min_similarity, max_similarity
131 | ):
132 | all_preds, all_labels, examples = [], [], []
133 | non_numeric = 0
134 | for ix, example in ProgressLogger.wrap_iter(
135 | "eval", dataset, len(dataset), return_ix=True
136 | ):
137 | raw_pred = get_gpt_prediction(model, example["text"])
138 | pred = extract_float(raw_pred)
139 | if type(pred) is not float:
140 | non_numeric += 1
141 | pred = torch.empty(1).uniform_(min_similarity, max_similarity).item()
142 | label = float(example["label"])
143 | all_preds.append(pred)
144 | all_labels.append(label)
145 | examples.append(
146 | {
147 | "id": ix,
148 | "example": example["text"],
149 | "raw_pred": raw_pred,
150 | "pred": pred,
151 | "label": label,
152 | }
153 | )
154 | if ix < 3:
155 | log_example(ix, example["text"], raw_pred, label)
156 | return all_preds, all_labels, examples, non_numeric
157 |
158 |
159 | def non_openai_model_eval(
160 | model,
161 | tokenizer,
162 | tokenizer_type,
163 | dataset,
164 | dataloader_num_workers,
165 | batch_size,
166 | min_similarity,
167 | max_similarity,
168 | ):
169 | preprocess_func = get_tokenizer_func(tokenizer, tokenizer_type)
170 | dataset = dataset.map(
171 | preprocess_func, batched=True, batch_size=batch_size
172 | )
173 | generation_kwargs = {
174 | "gpt": {"max_new_tokens": 20, 'pad_token_id': tokenizer.eos_token_id},
175 | "t5": {"max_new_tokens": 20},
176 | }[tokenizer_type]
177 | dataloader = DataLoader(
178 | dataset,
179 | batch_size=batch_size,
180 | collate_fn=default_data_collator,
181 | num_workers=dataloader_num_workers,
182 | shuffle=False,
183 | )
184 | non_numeric = 0
185 | all_preds, all_labels, examples = [], [], []
186 | with torch.no_grad():
187 | for ix, example in ProgressLogger.wrap_iter(
188 | "eval", dataloader, len(dataloader), return_ix=True
189 | ):
190 | inputs = {
191 | k: v.to(model.device)
192 | for k, v in example.items()
193 | if k in ["input_ids", "attention_mask"]
194 | }
195 | output = model.generate(**inputs, **generation_kwargs)
196 | if tokenizer_type == "gpt":
197 | output = output[:, inputs["input_ids"].shape[-1]:, ...]
198 | raw_preds = tokenizer.batch_decode(output, skip_special_tokens=True)
199 | preds, non_numeric = process_preds(
200 | raw_preds,
201 | non_numeric,
202 | min_similarity,
203 | max_similarity,
204 | )
205 | labels = example["labels"].tolist()
206 | example_texts = tokenizer.batch_decode(
207 | inputs["input_ids"], skip_special_tokens=True
208 | )
209 | if ix * batch_size < 3:
210 | log_examples(
211 | ix, example_texts, raw_preds, labels, batch_size
212 | )
213 | all_preds.extend(preds)
214 | all_labels.extend(labels)
215 | examples.extend(
216 | [
217 | {
218 | "id": cix + ix * batch_size,
219 | "example": example_text,
220 | "raw_pred": raw_pred,
221 | "pred": pred,
222 | "label": label,
223 | }
224 | for cix, example_text, raw_pred, pred, label in zip(range(len(preds)), example_texts, raw_preds, preds, labels)
225 | ]
226 | )
227 | return all_preds, all_labels, examples, non_numeric
228 |
229 |
230 | def process_preds(raw_preds, non_numeric, min_similarity, max_similarity):
231 | preds = list(map(extract_float, raw_preds))
232 | non_numeric += sum(1 for p in preds if type(p) is not float)
233 | preds = [
234 | p
235 | if type(p) is float
236 | else torch.empty(1).uniform_(min_similarity, max_similarity).item()
237 | for p in preds
238 | ]
239 | return preds, non_numeric
240 |
241 |
242 | def log_example(ix, text, raw_pred, label):
243 | example_str = "Example %d:\n\t%sPRED=%s LABEL=%s" % (
244 | ix,
245 | text.replace("\n", "\n\t"),
246 | raw_pred,
247 | label,
248 | )
249 | logger.info(example_str)
250 |
251 |
252 | def log_examples(ix, example_texts, raw_preds, labels, batch_size):
253 | for cix in range(min(len(raw_preds), 3)):
254 | log_example(
255 | ix * batch_size + cix,
256 | example_texts[cix],
257 | raw_preds[cix],
258 | labels[cix],
259 | )
260 |
261 |
262 | def process_results(
263 | prefix,
264 | eval_time,
265 | samples,
266 | non_numeric,
267 | all_preds,
268 | all_labels,
269 | min_similarity,
270 | max_similarity,
271 | ):
272 | scaled_preds = np.array(all_preds)
273 | invalid_preds = sum(
274 | 1 for p in scaled_preds if not min_similarity <= p <= max_similarity
275 | )
276 | scaled_labels = np.array(all_labels)
277 | results = {
278 | "pearsonr": pearsonr(scaled_preds, scaled_labels)[0],
279 | "spearmanr": spearmanr(scaled_preds, scaled_labels)[0],
280 | "runtime": eval_time,
281 | "samples": samples,
282 | "samples_per_second": samples / eval_time,
283 | "non_numeric": non_numeric,
284 | "non_numeric_percent": non_numeric / samples,
285 | "mse": ((torch.tensor(all_preds) - torch.tensor(all_labels)) ** 2)
286 | .mean()
287 | .item(),
288 | "out_of_range": invalid_preds,
289 | "out_of_range_percent": invalid_preds / samples,
290 | }
291 | return {f"{prefix}_{k}": v for k, v in results.items()}
292 |
293 |
294 | def load_model_and_tokenizer(model_name_or_path, tokenizer_type, api_key, dtype):
295 | if model_name_or_path not in OPENAI_MODELS:
296 | if not torch.cuda.is_available() and dtype != 'fp32':
297 | logger.info("Using CPU, overriding dtype to fp32")
298 | dtype = torch.float32 if not torch.cuda.is_available() else DTYPES[dtype]
299 | model_cls = MODEL_CLASSES[tokenizer_type]
300 | weights_location = get_weights_location(model_name_or_path)
301 | config = AutoConfig.from_pretrained(weights_location)
302 | index_location = get_index_location(weights_location)
303 | with init_empty_weights():
304 | logger.info(f"Instantiating model from config")
305 | model = model_cls.from_config(config)
306 | model = load_model_weights(model, index_location, dtype, tokenizer_type)
307 | model = model.eval()
308 | tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, padding_side="left" if tokenizer_type == "gpt" else "right")
309 | if tokenizer_type == "gpt":
310 | tokenizer.pad_token_id = tokenizer.eos_token_id
311 | else:
312 | if not os.path.exists(api_key):
313 | raise ValueError("api_key must be a file containing your OpenAI API key")
314 | authenticate(api_key)
315 | model = model_name_or_path
316 | tokenizer = None
317 | return model, tokenizer
318 |
319 |
320 | def get_weights_location(model_name_or_path):
321 | if not os.path.exists(model_name_or_path):
322 | return snapshot_download(
323 | repo_id=model_name_or_path,
324 | ignore_patterns=["*h5*", "*msgpack*", "*safetensors*", '*tflite*', '*rust_model.ot*'], # only download pytorch weights
325 | )
326 | elif os.path.isdir(model_name_or_path):
327 | return model_name_or_path
328 | else:
329 | return os.path.dirname(model_name_or_path)
330 |
331 |
332 | def get_index_location(weights_location):
333 | index_location = os.path.join(weights_location, "pytorch_model.bin.index.json")
334 | if not os.path.exists(index_location):
335 | index_location = os.path.join(weights_location, "pytorch_model.bin")
336 | return index_location
337 |
338 |
339 | def load_model_weights(model, index_location, dtype, tokenizer_type):
340 | logger.info("Loading model weights with load_checkpoint_and_dispatch")
341 | model = load_checkpoint_and_dispatch(
342 | model,
343 | index_location,
344 | device_map="balanced",
345 | no_split_module_classes=NO_SKIP_MODULES[tokenizer_type],
346 | dtype=dtype,
347 | )
348 | logger.info(f"Loaded model with load_checkpoint_and_dispatch from {index_location}")
349 | return model
350 |
351 |
352 | def save_results(results, predictions, examples, output_dir, output_file_prefix):
353 | logger.info(f"{output_file_prefix} results: %s" % json.dumps(results, indent=4))
354 | logger.info("Writing eval_results to %s" % output_dir)
355 | with open(Path(output_dir, f"{output_file_prefix}_results.json"), "w") as f:
356 | json.dump(results, f, indent=4)
357 | with open(Path(output_dir, f"{output_file_prefix}_predictions.json"), "w") as f:
358 | json.dump(predictions, f, indent=4)
359 | with open(Path(output_dir, f"{output_file_prefix}_examples.json"), "w") as f:
360 | json.dump(examples, f, indent=4)
361 |
362 |
363 | def main(
364 | model_name_or_path,
365 | tokenizer_type,
366 | output_dir,
367 | train_file,
368 | validation_file,
369 | test_file,
370 | k_shot,
371 | prompt_name,
372 | seed,
373 | overwrite_output_dir,
374 | dataloader_num_workers,
375 | max_eval_samples,
376 | api_key,
377 | dtype,
378 | batch_size,
379 | ):
380 | skip_validation = overwrite_output_dir is False and Path(output_dir, "eval_results.json").exists()
381 | skip_test = overwrite_output_dir is False and Path(output_dir, "test_results.json").exists()
382 | if skip_validation:
383 | logger.info(f"Skipping validation, found eval_results.json in {output_dir}.\nSet overwrite_output_dir=True to override.")
384 | if skip_test:
385 | logger.info(f"Skipping test, found test_results.json in {output_dir}.\nSet overwrite_output_dir=True to override.")
386 | if skip_validation and skip_test:
387 | return
388 | if validation_file is None and test_file is None:
389 | logger.info("No validation or test file provided. Exiting.")
390 | return
391 | if model_name_or_path in OPENAI_MODELS:
392 | assert api_key is not None, "api_key path must be provided for OpenAI models"
393 | if dtype == "tf32":
394 | torch.backends.cuda.matmul.allow_tf32 = True
395 | if tokenizer_type is None:
396 | tokenizer_type = get_tokenizer_type(model_name_or_path)
397 | else:
398 | assert tokenizer_type in {'gpt', 't5'}, f"tokenizer_type must be one of 'gpt' or 't5', got {tokenizer_type}"
399 | logger.info(f"Using {tokenizer_type} tokenizer")
400 | config_key = f"{tokenizer_type}_k{k_shot}_prompt{prompt_name}"
401 | model, tokenizer = load_model_and_tokenizer(
402 | model_name_or_path, tokenizer_type, api_key, dtype
403 | )
404 | max_similarity = 5.0
405 | min_similarity = 1.0 if "csts" in Path(train_file).name else 0.0
406 | is_stsb = "stsb" in Path(train_file).name
407 | logging.info("Loading dataset %s" % config_key)
408 | dataset = make_dataset(
409 | train_file=train_file,
410 | validation_file=validation_file,
411 | test_file=test_file,
412 | tokenizer_type=tokenizer_type,
413 | k_shot=k_shot,
414 | prompt_name=prompt_name,
415 | seed=seed,
416 | is_stsb=is_stsb,
417 | )
418 | train_dataset = dataset["train"]
419 | eval_dataset, test_dataset = None, None
420 | if validation_file is not None:
421 | eval_dataset = dataset["validation"]
422 | if test_file is not None:
423 | test_dataset = dataset["test"]
424 | if max_eval_samples is not None and 'validation' in dataset:
425 | eval_dataset = eval_dataset.select(range(min(max_eval_samples, len(eval_dataset))))
426 | logger.info(
427 | "Loaded %d train examples, %d validation examples, %d test examples"
428 | % (len(train_dataset), len(eval_dataset) if eval_dataset is not None else 0, len(test_dataset) if test_dataset is not None else 0)
429 | )
430 | Path(output_dir).mkdir(parents=True, exist_ok=True)
431 | if validation_file is not None:
432 | logger.info("Evaluating validation dataset")
433 | eval_results, eval_predictions, eval_examples = eval(
434 | dataset=eval_dataset,
435 | model=model,
436 | tokenizer=tokenizer,
437 | prefix='eval',
438 | tokenizer_type=tokenizer_type,
439 | min_similarity=min_similarity,
440 | max_similarity=max_similarity,
441 | dataloader_num_workers=dataloader_num_workers,
442 | batch_size=batch_size,
443 | )
444 | save_results(eval_results, eval_predictions, eval_examples, output_dir, "eval")
445 | if test_file is not None:
446 | logger.info("Predicting on test dataset")
447 | test_results, test_predictions, test_examples = eval(
448 | dataset=test_dataset,
449 | model=model,
450 | tokenizer=tokenizer,
451 | prefix='test',
452 | tokenizer_type=tokenizer_type,
453 | min_similarity=min_similarity,
454 | max_similarity=max_similarity,
455 | dataloader_num_workers=dataloader_num_workers,
456 | batch_size=batch_size,
457 | )
458 | save_results(test_results, test_predictions, test_examples, output_dir, "test")
459 | logger.info("Done!")
460 |
461 |
462 | def string_to_bool(v):
463 | if isinstance(v, bool):
464 | return v
465 | if v.lower() in ("yes", "true", "t", "y", "1"):
466 | return True
467 | elif v.lower() in ("no", "false", "f", "n", "0"):
468 | return False
469 | else:
470 | raise argparse.ArgumentTypeError(
471 | f"Truthy value expected: got {v} but expected one of yes/no, true/false, t/f, y/n, 1/0 (case insensitive)."
472 | )
473 |
474 |
475 | if __name__ == "__main__":
476 | parser = ArgumentParser()
477 | parser.add_argument("--model_name_or_path", type=str, required=True)
478 | parser.add_argument("--tokenizer_type", type=str, help="Tokenizer type (gpt or t5). If not provided, will be inferred from model_name_or_path.")
479 | parser.add_argument("--k_shot", type=int, required=True, help="Number of examples to use in prompt.")
480 | parser.add_argument("--prompt_name", type=str, required=True, help="Name of prompt to use. See utils/fewshot/generate_in_context_dataset.py for options.")
481 | parser.add_argument("--seed", type=int, default=42)
482 | parser.add_argument("--train_file", type=str, required=True, help="Path to train file.")
483 | parser.add_argument("--validation_file", type=str, required=False, help="Path to validation file. If not provided, will not run validation.")
484 | parser.add_argument("--test_file", type=str, required=False, help="Path to test file. If not provided, will not run test.")
485 | parser.add_argument(
486 | "--output_dir", type=str, required=True, help="Directory to save results"
487 | )
488 | parser.add_argument(
489 | "--overwrite_output_dir",
490 | type=string_to_bool,
491 | default=False,
492 | nargs="?",
493 | const=True,
494 | help="Overwrite the content of the output directory",
495 | )
496 | parser.add_argument("--dataloader_num_workers", type=int, default=0)
497 | parser.add_argument(
498 | "--api_key", type=str, required=False, help="Path to OpenAI API key"
499 | )
500 | parser.add_argument(
501 | "--dtype",
502 | type=str,
503 | choices=["fp16", "bf16", "fp32", "tf32"],
504 | help="Data used for model. TF32 and BF16 are recommended but only supported for NVIDIA GPUs with Ampere architecture or later.",
505 | required=True,
506 | )
507 | parser.add_argument("--batch_size", type=int, default=2)
508 | parser.add_argument("--max_eval_samples", type=int, default=None)
509 | args = parser.parse_args()
510 | main(**vars(args))
511 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-nlp/c-sts/3bacafac67c6a359186e58d3086deaf180a43863/utils/__init__.py
--------------------------------------------------------------------------------
/utils/fewshot/generate_in_context_dataset.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from datasets import load_dataset, Dataset, DatasetDict
3 | import pandas as pd
4 | import random
5 | from pathlib import Path
6 | from functools import partial
7 | import logging
8 | from argparse import ArgumentParser
9 |
10 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
11 | logger = logging.getLogger(__name__)
12 |
13 |
14 | SEP_TOKENS = {
15 | 't5': '',
16 | 'gpt': ' ',
17 | }
18 | LABEL_KEY_FUNCS = {
19 | 't5': lambda x: f'{x.capitalize()}:',
20 | 'gpt': lambda x: f'{x.capitalize()}: ',
21 | }
22 |
23 | prompts = {
24 | 'none': None,
25 | 'short': 'On a scale between 1 and 5, how similar are the following two sentences with respect to the condition provided? Respond only with a score between 1 and 5.',
26 | 'long': 'Definition: Evaluate the similarity between the two sentences, with respect to the condition. Assign the pair a score between 1 and 5 as follows: 1 : The two sentences are completely dissimilar with respect to the condition. 2 : The two sentences are dissimilar, but are on a similar topic with respect to the condition. 3 : The two sentences are roughly equivalent, but some important information differs or is missing with respect to the condition. 4 : The two sentences are mostly equivalent, but some unimportant details differ with respect to the condition. 5 : The two sentences are completely equivalent with respect to the condition.',
27 | }
28 | stsb_prompts = {
29 | 'none': None,
30 | 'short': 'On a scale between 0 and 5, how similar are the following two sentences? Respond only with a score between 0 and 5.',
31 | 'long': 'Definition: Evaluate the similarity between them and classify them into classes from 0-5 as follows: 0 : The two sentences are completely dissimilar. 1 : The two sentences are not equivalent, but are on the same topic. 2 : The two sentences are not equivalent, but share some details. 3 : The two sentences are roughly equivalent, but some important information differs/missing. 4 : The two sentences are mostly equivalent, but some unimportant details differ. 5 : The two sentences are completely equivalent, as they mean the same thing.',
32 | }
33 | def get_prompt(prompt_name, is_stsb=False):
34 | if prompt_name is None:
35 | return None
36 | else:
37 | if is_stsb:
38 | return stsb_prompts[prompt_name]
39 | else:
40 | return prompts[prompt_name]
41 |
42 |
43 | def convert_text(
44 | example,
45 | sep_token,
46 | label_key_func=lambda x: f'{x.capitalize()}: ',
47 | sentence_1_label='sentence1',
48 | sentence_2_label='sentence2',
49 | condition_label='condition',
50 | answer_label='label',
51 | is_stsb=False,
52 | ):
53 | sent_1 = example[sentence_1_label].strip()
54 | sent_2 = example[sentence_2_label].strip()
55 | condition = example[condition_label] if example[condition_label] is not None else '' # bug some conditions are None
56 | similarity = example[answer_label]
57 | if is_stsb:
58 | ex_list = [
59 | label_key_func('input'),
60 | ' '.join([label_key_func('sentence 1'), sent_1, ]),
61 | ' '.join([label_key_func('sentence 2'), sent_2, ]),
62 | label_key_func('output'),
63 | ]
64 | else:
65 | ex_list = [
66 | label_key_func('input'),
67 | ' '.join([label_key_func('sentence 1'), sent_1, ]),
68 | ' '.join([label_key_func('sentence 2'), sent_2, ]),
69 | ' '.join([label_key_func('condition'), condition, ]),
70 | label_key_func('output'),
71 | ]
72 | ex_str = ' '.join(map(str, ex_list))
73 | return ex_str
74 |
75 |
76 | def add_context(
77 | example,
78 | context,
79 | prompt,
80 | sep_token,
81 | answer_label='label',
82 | label_func=lambda x: f'{float(x)}',
83 | ):
84 | if prompt is not None:
85 | ex_list = [
86 | prompt.strip(' :'),
87 | ]
88 | else:
89 | ex_list = []
90 | for ex in context:
91 | entry = ex['original_text'] + label_func(ex[answer_label])
92 | ex_list.extend([entry, ])
93 | ex_list.append(example['original_text']) # don't add a label to the last example
94 | return '\n'.join(ex_list)
95 |
96 |
97 | def add_in_context_examples(dataset, context_dataset, model, k, prompt, tokenizer_type, pairs=None):
98 | contexts = list()
99 | context_ids = list()
100 | for ix, entry in enumerate(dataset):
101 | if pairs is not None:
102 | random_pairs = random.sample(range(len(pairs)), k=(k+1)//2)
103 | context_example_ids = [x for pair in random_pairs for x in pairs[pair]][:k]
104 | else:
105 | context_example_ids = random.sample(list(set(range(len(context_dataset))) - {ix}), k=k)
106 | context_ids.append(context_example_ids)
107 | context_examples = [context_dataset[idx] for idx in context_example_ids]
108 | contexts.append(add_context(entry, context_examples, prompt, SEP_TOKENS[tokenizer_type]))
109 | dataset = dataset.add_column('context_ids', context_ids)
110 | dataset = dataset.add_column('text', contexts)
111 | return dataset
112 |
113 | def get_idx_pairs(
114 | dataset,
115 | sentence_1_label='sentence1',
116 | sentence_2_label='sentence2',
117 | condition_label='condition',
118 | answer_label='label',
119 | ):
120 | from collections import defaultdict
121 | pairs = defaultdict(list)
122 | for ix, datum in enumerate(dataset):
123 | pairs[datum[sentence_1_label] + '<-SEP->' + datum[sentence_2_label]].append(ix)
124 | pair_idxs = list(pairs.keys())
125 | drop_count = 0
126 | for pair_idx in pair_idxs:
127 | if len(pairs[pair_idx]) != 2:
128 | drop_count += len(pairs[pair_idx])
129 | pairs.pop(pair_idx)
130 | logger.warning('Dropping %d indices for missing pairs. Dataset has %d pairs total' % (drop_count, len(pair_idxs)))
131 | pairs = list(map(lambda x: sorted(pairs[x], key=lambda idx: -dataset[idx][answer_label]), pairs.keys()))
132 | # negative because we want to sort in descending order (highest similarity first)
133 | for idx1, idx2 in pairs:
134 | if (dataset[idx1][sentence_1_label] != dataset[idx2][sentence_1_label]) or (dataset[idx1][sentence_2_label] != dataset[idx2][sentence_2_label]):
135 | raise ValueError('Pairing of indices is incorrect, sentences do not match for pair %d and %d' % (idx1, idx2))
136 | if (dataset[idx1][answer_label] < dataset[idx2][answer_label]):
137 | raise ValueError('Pairing of indices is incorrect, similarity is not in descending order for pair %d and %d' % (idx1, idx2))
138 | return pairs
139 |
140 | def make_dataset(
141 | train_file,
142 | validation_file,
143 | test_file,
144 | tokenizer_type,
145 | k_shot,
146 | prompt_name,
147 | seed,
148 | is_stsb=False,
149 | ):
150 | convert_func = partial(convert_text, sep_token=SEP_TOKENS[tokenizer_type], label_key_func=LABEL_KEY_FUNCS[tokenizer_type], is_stsb=is_stsb)
151 | data_files = {'train': train_file}
152 | if validation_file is not None:
153 | data_files['validation'] = validation_file
154 | if test_file is not None:
155 | data_files['test'] = test_file
156 | raw_datasets = load_dataset('csv', data_files=data_files, keep_in_memory=True)
157 | raw_datasets = raw_datasets.map(lambda x: {'original_text': convert_func(x)}, batched=False, keep_in_memory=True)
158 | prompt = get_prompt(prompt_name, is_stsb)
159 | random.seed(seed)
160 | pairs = None
161 | if not is_stsb:
162 | pairs = get_idx_pairs(raw_datasets['train'])
163 | raw_datasets['train'] = add_in_context_examples(raw_datasets['train'], raw_datasets['train'], tokenizer_type, k_shot, prompt, tokenizer_type, pairs)
164 | if validation_file is not None:
165 | raw_datasets['validation'] = add_in_context_examples(raw_datasets['validation'], raw_datasets['train'], tokenizer_type, k_shot, prompt, tokenizer_type, pairs)
166 | if test_file is not None:
167 | raw_datasets['test'] = add_in_context_examples(raw_datasets['test'], raw_datasets['train'], tokenizer_type, k_shot, prompt, tokenizer_type, pairs)
168 | return raw_datasets
169 |
170 | def main(train_file, test_file, tokenizer_type, output_dir, k_shot, prompt_name, seed, is_stsb):
171 | dataset = make_dataset(train_file, test_file, tokenizer_type, k_shot, prompt_name, seed, is_stsb)
172 | output_file = Path(output_dir, Path(train_file).stem + f'_{tokenizer_type}_k{k_shot}_prompt{prompt_name}_seed{seed}')
173 | dataset.save_to_disk(output_file)
174 | logger.info(f'Saved to {output_file}')
175 |
176 |
177 | if __name__ == '__main__':
178 | parser = ArgumentParser()
179 | parser.add_argument('--train_file', type=str, required=True)
180 | parser.add_argument('--test_file', type=str, required=True)
181 | parser.add_argument('--tokenizer_type', type=str, choices=sorted(LABEL_KEY_FUNCS.keys()), required=True)
182 | parser.add_argument('--output_dir', type=str, default='./in_context_datasets')
183 | parser.add_argument('--k_shot', type=int, required=True)
184 | parser.add_argument('--prompt_name', type=str)
185 | parser.add_argument('--is_stsb', action='store_true')
186 | parser.add_argument('--seed', type=int, required=True)
187 | args = parser.parse_args()
188 | main(**vars(args))
189 |
--------------------------------------------------------------------------------
/utils/fewshot/openai_utils.py:
--------------------------------------------------------------------------------
1 | import openai
2 | import logging
3 | import time
4 |
5 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s: %(message)s')
6 | logger = logging.getLogger(__name__)
7 |
8 |
9 | LEGACY_MODELS = {
10 | 'davinci',
11 | 'curie',
12 | 'babbage',
13 | 'ada',
14 | }
15 |
16 | GPT3_MODELS = {
17 | 'text-davinci-003',
18 | 'text-davinci-002',
19 | 'text-davinci-001',
20 | 'text-curie-001',
21 | 'text-babbage-001',
22 | 'text-ada-001',
23 | }
24 |
25 | CHAT_MODELS = {
26 | 'gpt-4',
27 | 'gpt-4-0314',
28 | 'gpt-4-32k',
29 | 'gpt-4-32k-0314',
30 | 'gpt-3.5-turbo',
31 | 'gpt-3.5-turbo-0301',
32 | }
33 |
34 | OPENAI_MODELS = LEGACY_MODELS | GPT3_MODELS | CHAT_MODELS
35 |
36 |
37 | def parse_response(model, response, prompt):
38 | if model in CHAT_MODELS:
39 | response_text = response['choices'][0]['message']['content']
40 | else:
41 | response_text = response['choices'][0]['text'].replace(prompt, '')
42 | response_text.strip()
43 | return response_text
44 |
45 | def call_chat(model, prompt):
46 | response = openai.ChatCompletion.create(
47 | model=model,
48 | messages= [{'role': 'user', 'content': prompt}, ],
49 | temperature=0,
50 | max_tokens=5,
51 | top_p=1,
52 | frequency_penalty=0.0,
53 | presence_penalty=0.0,
54 | )
55 | return response
56 |
57 |
58 | def call_gpt(model, prompt):
59 | response = openai.Completion.create(
60 | model=model,
61 | prompt=prompt,
62 | temperature=0,
63 | max_tokens=5,
64 | top_p=1,
65 | frequency_penalty=0.0,
66 | presence_penalty=0.0,
67 | )
68 | return response
69 |
70 |
71 | def get_gpt_prediction(model, prompt):
72 | retries = 0
73 | while retries < 3:
74 | try:
75 | if model in CHAT_MODELS:
76 | response = call_chat(model, prompt)
77 | else:
78 | response = call_gpt(model, prompt)
79 | return parse_response(model, response, prompt)
80 | except Exception as e:
81 | logger.warning('Exception while getting gpt prediction: {}'.format(e))
82 | logger.warning(f'Retrying... {3 - retries} more times.')
83 | retries += 1
84 | time.sleep(20 * retries)
85 | raise Exception('Failed to get gpt prediction after 3 retries. Aborting run.')
86 |
87 |
88 | def authenticate(api_key):
89 | with open(api_key) as f:
90 | api_key = f.readlines()[0].strip()
91 | openai.api_key = api_key
92 |
--------------------------------------------------------------------------------
/utils/fewshot/progress_logger.py:
--------------------------------------------------------------------------------
1 | import time
2 | import logging
3 |
4 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s: %(message)s')
5 | logger = logger = logging.getLogger(__name__)
6 |
7 |
8 | class ProgressLogger:
9 | def __init__(self, process_name, total_iters, log_time_interval=60):
10 | self.total_iters = total_iters
11 | self.name = process_name
12 | if log_time_interval < 1:
13 | log_time_interval = 1
14 | logger.info('log_time_interval must be >= 1, setting to 1')
15 | self.log_time_interval = log_time_interval
16 | self.start_time = None
17 | self.last_log_time = None
18 | self.last_log_ix = 0
19 |
20 | def __call__(self, ix):
21 | if ix == 0:
22 | self.start_time = time.time()
23 | self.last_log_time = time.time()
24 | self.log(ix)
25 | elif ix == self.total_iters - 1:
26 | self.log(ix)
27 | elif time.time() - self.last_log_time > self.log_time_interval:
28 | self.log(ix)
29 |
30 | def log(self, ix):
31 | pct_complete = ix / self.total_iters * 100
32 | iters_per_sec = (ix - self.last_log_ix) / (time.time() - self.last_log_time)
33 | remaining = (time.time() - self.start_time) / (ix + 1) * (self.total_iters - ix - 1)
34 | remaining = time.strftime('%H:%M:%S', time.gmtime(remaining))
35 | logger.info('{}: {:_} / {:_} ({:.2f}%) ({:_.2f} iter/sec) (remaining: {})'.format(
36 | self.name,
37 | ix,
38 | self.total_iters,
39 | pct_complete,
40 | iters_per_sec,
41 | remaining,
42 | ))
43 | self.last_log_time = time.time()
44 | self.last_log_ix = ix
45 |
46 | @classmethod
47 | def wrap_iter(
48 | cls,
49 | process_name,
50 | iterable,
51 | total_iters,
52 | log_time_interval=60,
53 | return_ix=False,
54 | ):
55 | progress = cls(process_name, total_iters, log_time_interval)
56 | if return_ix:
57 | for ix, item in enumerate(iterable):
58 | progress(ix)
59 | yield ix, item
60 | else:
61 | for ix, item in enumerate(iterable):
62 | progress(ix)
63 | yield item
64 |
65 |
--------------------------------------------------------------------------------
/utils/progress_logger.py:
--------------------------------------------------------------------------------
1 | import time
2 | import torch
3 | from transformers import TrainerCallback
4 | import logging
5 |
6 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s: %(message)s')
7 |
8 | logger = logging.getLogger(__name__)
9 |
10 |
11 | class LogCallback(TrainerCallback):
12 | def __init__(self, *args, **kwargs):
13 | super().__init__(*args, **kwargs)
14 | self.start_time = time.time()
15 | self.last_log_time = self.start_time
16 | self.log_time_interval = 0
17 | self.current_step = 0
18 | self.is_training = False
19 | self.max_steps = -1
20 | self.first_step_of_run = 0
21 |
22 | def on_train_begin(self, args, state, control, **kwargs):
23 | if state.is_local_process_zero:
24 | args.logging_steps = 1
25 | args.logging_strategy = 'steps'
26 | self.log_time_interval = args.log_time_interval
27 | if self.log_time_interval > 0:
28 | logger.info(f'Using log_time_interval {self.log_time_interval} s. This may override logging_step.')
29 | self.is_training = True
30 | self.current_step = 0
31 | self.start_time = time.time()
32 | self.last_log_time = self.start_time
33 | self.max_steps = state.max_steps
34 | self.first_step_of_run = state.global_step
35 | if torch.distributed.is_initialized():
36 | torch.distributed.barrier()
37 |
38 | def on_log(self, args, state, control, logs=None, **kwargs):
39 | _ = logs.pop('total_flos', None)
40 | if state.is_local_process_zero:
41 | if self.is_training:
42 | current_time = time.time()
43 | time_diff = current_time - self.last_log_time
44 | force = len(logs) > 3 or any([k.endswith('runtime') for k in logs])
45 | if time_diff > self.log_time_interval or self.current_step >= self.max_steps - 1 or force:
46 | self.last_log_time = current_time
47 | steps_completed = max(self.current_step, 1)
48 | steps_since_first = max(1, self.current_step - self.first_step_of_run)
49 | remaining_steps = self.max_steps - steps_completed
50 | pct_completed = (steps_completed / self.max_steps) * 100
51 | time_since_start = current_time - self.start_time
52 | remaining_time = (time_since_start / steps_since_first) * remaining_steps
53 | update = {'completed': f'{pct_completed:.2f}% ({steps_completed:_} / {self.max_steps:_})', 'remaining time': self.format_duration(remaining_time)}
54 | logger.info(str({**logs, **update}))
55 | else:
56 | logger.info(str(logs))
57 |
58 | def on_step_end(self, args, state, control, **kwargs):
59 | if state.is_local_process_zero:
60 | self.current_step = state.global_step
61 |
62 | def on_train_end(self, args, state, control, **kwargs):
63 | if state.is_local_process_zero:
64 | self.is_training = False
65 |
66 | @staticmethod
67 | def format_duration(seconds):
68 | hours, remainder = divmod(seconds, 3600)
69 | minutes, seconds = divmod(remainder, 60)
70 | return f'{int(hours)}:{int(minutes):02}:{int(seconds):02}'
71 |
--------------------------------------------------------------------------------
/utils/sts/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-nlp/c-sts/3bacafac67c6a359186e58d3086deaf180a43863/utils/sts/__init__.py
--------------------------------------------------------------------------------
/utils/sts/dataset_preprocessing.py:
--------------------------------------------------------------------------------
1 | def scale_to_range(labels, _min, _max):
2 | return list(map(lambda x: (x - _min) / (_max - _min), labels))
3 |
4 |
5 | def get_preprocessing_function(
6 | tokenizer,
7 | sentence1_key,
8 | sentence2_key,
9 | condition_key,
10 | similarity_key,
11 | padding,
12 | max_seq_length,
13 | model_args,
14 | scale=None,
15 | condition_only=False,
16 | sentences_only=False,
17 | ):
18 | 'Returns a the preprocessing function for each encoding type'
19 | if model_args.encoding_type == 'bi_encoder':
20 | if condition_only or sentences_only:
21 | raise ValueError('condition_only and sentences_only doesn\'t apply to bi_encoder')
22 | def preprocess_function(examples):
23 | sent1_args = (examples[sentence1_key], examples[condition_key])
24 | sent1_result = tokenizer(*sent1_args, padding=padding, max_length=max_seq_length, truncation=True)
25 | sent2_args = (examples[sentence2_key], examples[condition_key])
26 | sent2_result = tokenizer(*sent2_args, padding=padding, max_length=max_seq_length, truncation=True)
27 | sent1_result['input_ids_2'] = sent2_result['input_ids']
28 | sent1_result['attention_mask_2'] = sent2_result['attention_mask']
29 | if 'token_type_ids' in sent2_result:
30 | sent1_result['token_type_ids_2'] = sent2_result['token_type_ids']
31 | sent1_result['labels'] = examples[similarity_key]
32 | if scale is not None:
33 | _min, _max = scale
34 | for label in sent1_result['labels']:
35 | if (label < _min or label > _max) and label != -1:
36 | raise ValueError(f'Label {label} is not in the range [{_min}, {_max}]')
37 | sent1_result['labels'] = scale_to_range(sent1_result['labels'], _min, _max)
38 | return sent1_result
39 | elif model_args.encoding_type == 'cross_encoder':
40 | def preprocess_function(examples):
41 | if condition_only:
42 | input_args = examples[condition_key]
43 | elif sentences_only:
44 | input_args = list(map(lambda x: ' '.join([x[0], tokenizer.sep_token, x[1]]), zip(examples[sentence1_key], examples[sentence2_key])))
45 | else:
46 | input_args = list(map(lambda x: ' '.join([x[0], tokenizer.sep_token, x[1], tokenizer.sep_token, x[2]]), zip(examples[sentence1_key], examples[sentence2_key], examples[condition_key])))
47 | result = tokenizer(input_args, padding=padding, max_length=max_seq_length, truncation=True)
48 | result['labels'] = examples[similarity_key]
49 | if scale is not None:
50 | _min, _max = scale
51 | for label in result['labels']:
52 | if (label < _min or label > _max) and label != -1:
53 | raise ValueError(f'Label {label} is not in the range [{_min}, {_max}]')
54 | result['labels'] = scale_to_range(result['labels'], _min, _max)
55 | return result
56 | elif model_args.encoding_type == 'tri_encoder':
57 | if condition_only or sentences_only:
58 | raise ValueError('condition_only and sentences_only doesn\'t apply to tri_encoder')
59 | def preprocess_function(examples):
60 | sent1_args = (examples[sentence1_key], )
61 | sent1_result = tokenizer(*sent1_args, padding=padding, max_length=max_seq_length, truncation=True)
62 | sent2_args = (examples[sentence2_key], )
63 | sent2_result = tokenizer(*sent2_args, padding=padding, max_length=max_seq_length, truncation=True)
64 | sent3_args = (examples[condition_key], )
65 | sent3_result = tokenizer(*sent3_args, padding=padding, max_length=max_seq_length, truncation=True)
66 | sent1_result['input_ids_2'] = sent2_result['input_ids']
67 | sent1_result['attention_mask_2'] = sent2_result['attention_mask']
68 | sent1_result['input_ids_3'] = sent3_result['input_ids']
69 | sent1_result['attention_mask_3'] = sent3_result['attention_mask']
70 | if 'token_type_ids' in sent2_result:
71 | sent1_result['token_type_ids_2'] = sent2_result['token_type_ids']
72 | sent1_result['token_type_ids_3'] = sent3_result['token_type_ids']
73 | sent1_result['labels'] = examples[similarity_key]
74 | if scale is not None:
75 | _min, _max = scale
76 | for label in sent1_result['labels']:
77 | if (label < _min or label > _max) and label != -1:
78 | raise ValueError(f'Label {label} is not in the range [{_min}, {_max}]')
79 | sent1_result['labels'] = scale_to_range(sent1_result['labels'], _min, _max)
80 | return sent1_result
81 | else:
82 | raise ValueError(f'Invalid model type: {model_args.encoding_type}')
83 | return preprocess_function
84 |
--------------------------------------------------------------------------------
/utils/sts/modeling_encoders.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.utils.checkpoint
3 | from torch import nn
4 | from torch.nn.functional import cosine_similarity
5 |
6 | from transformers.activations import ACT2FN
7 | from transformers.modeling_outputs import (
8 | SequenceClassifierOutput,
9 | )
10 | from transformers import PreTrainedModel, AutoModel
11 | import logging
12 |
13 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s: %(message)s')
14 |
15 | logger = logging.getLogger(__name__)
16 |
17 |
18 | def concat_features(*features):
19 | return torch.cat(features, dim=0) if features[0] is not None else None
20 |
21 |
22 | class QuadrupletLoss:
23 | def __init__(self, distance_function, margin=1.0):
24 | 'A cosine distance margin quadruplet loss'
25 | self.margin = margin
26 | self.distance_function = distance_function
27 |
28 | def __call__(self, pos1, pos2, neg1, neg2):
29 | dist_pos = self.distance_function(pos1, pos2)
30 | dist_neg = self.distance_function(neg1, neg2)
31 | loss = torch.clamp_min(self.margin + dist_pos - dist_neg, 0)
32 | return loss.mean()
33 |
34 |
35 | # Pooler class. Copied and adapted from SimCSE code
36 | class Pooler(nn.Module):
37 | '''
38 | Parameter-free poolers to get the sentence embedding
39 | 'cls': [CLS] representation with BERT/RoBERTa's MLP pooler.
40 | 'cls_before_pooler': [CLS] representation without the original MLP pooler.
41 | 'avg': average of the last layers' hidden states at each token.
42 | 'avg_top2': average of the last two layers.
43 | 'avg_first_last': average of the first and the last layers.
44 | '''
45 | def __init__(self, pooler_type):
46 | super().__init__()
47 | self.pooler_type = pooler_type
48 | assert self.pooler_type in ['cls', 'cls_before_pooler', 'avg', 'avg_top2', 'avg_first_last'], 'unrecognized pooling type %s' % self.pooler_type
49 |
50 | def forward(self, attention_mask, outputs):
51 | last_hidden = outputs.last_hidden_state
52 | pooler_output = outputs.pooler_output
53 | hidden_states = outputs.hidden_states
54 |
55 | if self.pooler_type in ['cls_before_pooler', 'cls']:
56 | return last_hidden[:, 0]
57 | elif self.pooler_type == 'avg':
58 | return ((last_hidden * attention_mask.unsqueeze(-1)).sum(1) / attention_mask.sum(-1).unsqueeze(-1))
59 | elif self.pooler_type == 'avg_first_last':
60 | first_hidden = hidden_states[0]
61 | last_hidden = hidden_states[-1]
62 | pooled_result = ((first_hidden + last_hidden) / 2.0 * attention_mask.unsqueeze(-1)).sum(1) / attention_mask.sum(-1).unsqueeze(-1)
63 | return pooled_result
64 | elif self.pooler_type == 'avg_top2':
65 | second_last_hidden = hidden_states[-2]
66 | last_hidden = hidden_states[-1]
67 | pooled_result = ((last_hidden + second_last_hidden) / 2.0 * attention_mask.unsqueeze(-1)).sum(1) / attention_mask.sum(-1).unsqueeze(-1)
68 | return pooled_result
69 | else:
70 | raise NotImplementedError
71 |
72 |
73 | class CrossEncoderForClassification(PreTrainedModel):
74 | 'Encoder model with backbone and classification head.'
75 | def __init__(self, config):
76 | super().__init__(config)
77 | self.backbone = AutoModel.from_pretrained(
78 | config.model_name_or_path,
79 | from_tf=bool('.ckpt' in config.model_name_or_path),
80 | config=config,
81 | cache_dir=config.cache_dir,
82 | revision=config.model_revision,
83 | use_auth_token=True if config.use_auth_token else None,
84 | add_pooling_layer=False,
85 | ).base_model
86 | classifier_dropout = (
87 | config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
88 | )
89 | if config.transform:
90 | self.transform = nn.Sequential(
91 | nn.Dropout(classifier_dropout),
92 | nn.Linear(config.hidden_size, config.hidden_size),
93 | ACT2FN[config.hidden_act],
94 | )
95 | else:
96 | self.transform = None
97 | self.classifier = nn.Sequential(
98 | nn.Dropout(config.hidden_dropout_prob),
99 | nn.Linear(config.hidden_size, config.num_labels),
100 | )
101 | self.pooler = Pooler(config.pooler_type)
102 | if config.pooler_type in {'avg_first_last', 'avg_top2'}:
103 | self.output_hidden_states = True
104 | else:
105 | self.output_hidden_states = False
106 | if config.num_labels == 1:
107 | self.reshape_function = lambda x: x.reshape(-1)
108 | if config.objective == 'mse':
109 | self.loss_fct_cls = nn.MSELoss
110 | elif config.objective in {'triplet', 'triplet_mse'}:
111 | raise NotImplementedError('Triplet loss is not implemented for CrossEncoderForClassification')
112 | else:
113 | raise ValueError(f'Only regression and triplet objectives are supported for CrossEncoderForClassification with num_labels=1. Got {config.objective}.')
114 | else:
115 | assert config.objective == 'classification'
116 | self.reshape_function = lambda x: x.reshape(-1, config.num_labels)
117 | self.loss_fct_cls = nn.CrossEntropyLoss
118 | self.post_init()
119 |
120 | def forward(
121 | self,
122 | input_ids=None,
123 | attention_mask=None,
124 | token_type_ids=None,
125 | position_ids=None,
126 | head_mask=None,
127 | inputs_embeds=None,
128 | labels=None,
129 | **kwargs,
130 | ):
131 | outputs = self.backbone(
132 | input_ids=input_ids,
133 | attention_mask=attention_mask,
134 | token_type_ids=token_type_ids,
135 | position_ids=position_ids,
136 | head_mask=head_mask,
137 | inputs_embeds=inputs_embeds,
138 | output_hidden_states=self.output_hidden_states,
139 | )
140 | features = self.pooler(attention_mask, outputs)
141 | if self.transform is not None:
142 | features = self.transform(features)
143 | logits = self.classifier(features)
144 | reshaped_logits = self.reshape_function(logits)
145 | loss = None
146 | if labels is not None:
147 | loss = self.loss_fct_cls()(reshaped_logits, labels.view(-1))
148 | return SequenceClassifierOutput(
149 | loss=loss,
150 | logits=logits,
151 | hidden_states=outputs.hidden_states,
152 | attentions=outputs.attentions,
153 | )
154 |
155 |
156 | class BiEncoderForClassification(PreTrainedModel):
157 | '''Encoder model with backbone and classification head.'''
158 | def __init__(self, config):
159 | super().__init__(config)
160 | self.backbone = AutoModel.from_pretrained(
161 | config.model_name_or_path,
162 | from_tf=bool('.ckpt' in config.model_name_or_path),
163 | config=config,
164 | cache_dir=config.cache_dir,
165 | revision=config.model_revision,
166 | use_auth_token=True if config.use_auth_token else None,
167 | add_pooling_layer=False,
168 | ).base_model
169 | classifier_dropout = (
170 | config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
171 | )
172 | if config.transform:
173 | self.transform = nn.Sequential(
174 | nn.Dropout(classifier_dropout),
175 | nn.Linear(config.hidden_size, config.hidden_size),
176 | ACT2FN[config.hidden_act],
177 | )
178 | else:
179 | self.transform = None
180 | self.pooler = Pooler(config.pooler_type)
181 | if config.pooler_type in {'avg_first_last', 'avg_top2'}:
182 | self.output_hidden_states = True
183 | else:
184 | self.output_hidden_states = False
185 | if config.objective == 'mse':
186 | self.loss_fct_cls = nn.MSELoss
187 | self.loss_fct_kwargs = {}
188 | elif config.objective in {'triplet', 'triplet_mse'}:
189 | self.loss_fct_cls = QuadrupletLoss
190 | self.loss_fct_kwargs = {'distance_function': lambda x, y: 1.0 - cosine_similarity(x, y)}
191 | else:
192 | raise ValueError('Only regression and triplet objectives are supported for BiEncoderForClassification')
193 | self.post_init()
194 |
195 | def forward(
196 | self,
197 | input_ids=None,
198 | attention_mask=None,
199 | token_type_ids=None,
200 | position_ids=None,
201 | head_mask=None,
202 | inputs_embeds=None,
203 | input_ids_2=None,
204 | attention_mask_2=None,
205 | token_type_ids_2=None,
206 | position_ids_2=None,
207 | head_mask_2=None,
208 | inputs_embeds_2=None,
209 | labels=None,
210 | **kwargs,
211 | ):
212 | bsz = input_ids.shape[0]
213 | input_ids = concat_features(input_ids, input_ids_2)
214 | attention_mask = concat_features(attention_mask, attention_mask_2)
215 | token_type_ids = concat_features(token_type_ids, token_type_ids_2)
216 | position_ids = concat_features(position_ids, position_ids_2)
217 | head_mask = concat_features(head_mask, head_mask_2)
218 | inputs_embeds = concat_features(inputs_embeds, inputs_embeds_2)
219 | outputs = self.backbone(
220 | input_ids=input_ids,
221 | attention_mask=attention_mask,
222 | token_type_ids=token_type_ids,
223 | position_ids=position_ids,
224 | head_mask=head_mask,
225 | inputs_embeds=inputs_embeds,
226 | output_hidden_states=self.output_hidden_states,
227 | )
228 | features = self.pooler(attention_mask, outputs)
229 | if self.transform is not None:
230 | features = self.transform(features)
231 | features_1, features_2 = torch.split(features, bsz, dim=0) # [sentence1, condtion], [sentence2, condition]
232 | loss = None
233 | if self.config.objective in {'triplet', 'triplet_mse'}:
234 | positives1, negatives1 = torch.split(features_1, bsz // 2, dim=0)
235 | positives2, negatives2 = torch.split(features_2, bsz // 2, dim=0)
236 | if labels is not None:
237 | loss = self.loss_fct_cls(**self.loss_fct_kwargs)(positives1, positives2, negatives1, negatives2)
238 | logits = cosine_similarity(features_1, features_2, dim=1)
239 | if self.config.objective in {'triplet_mse'} and labels is not None:
240 | loss += nn.MSELoss()(logits, labels)
241 | else:
242 | logits = logits.detach()
243 | else:
244 | logits = cosine_similarity(features_1, features_2, dim=1)
245 | if labels is not None:
246 | loss = self.loss_fct_cls(**self.loss_fct_kwargs)(logits, labels)
247 | return SequenceClassifierOutput(
248 | loss=loss,
249 | logits=logits,
250 | )
251 |
252 | class TriEncoderForClassification(PreTrainedModel):
253 | def __init__(self, config):
254 | super().__init__(config)
255 | self.backbone = AutoModel.from_pretrained(
256 | config.model_name_or_path,
257 | from_tf=bool('.ckpt' in config.model_name_or_path),
258 | config=config,
259 | cache_dir=config.cache_dir,
260 | revision=config.model_revision,
261 | use_auth_token=True if config.use_auth_token else None,
262 | add_pooling_layer=False,
263 | ).base_model
264 | self.triencoder_head = config.triencoder_head
265 | classifier_dropout = (
266 | config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
267 | )
268 | if config.transform:
269 | self.transform = nn.Sequential(
270 | nn.Dropout(classifier_dropout),
271 | nn.Linear(config.hidden_size, config.hidden_size),
272 | ACT2FN[config.hidden_act],
273 | )
274 | else:
275 | self.transform = None
276 | self.condition_transform = nn.Sequential(
277 | nn.Dropout(classifier_dropout),
278 | nn.Linear(config.hidden_size, config.hidden_size)
279 | )
280 | if self.triencoder_head == 'concat':
281 | self.concat_transform = nn.Sequential(
282 | nn.Dropout(classifier_dropout),
283 | nn.Linear(config.hidden_size * 2, config.hidden_size),
284 | ACT2FN[config.hidden_act],
285 | )
286 | elif self.triencoder_head == 'hadamard':
287 | self.concat_transform = None
288 | self.pooler = Pooler(config.pooler_type)
289 | if config.pooler_type in {'avg_first_last', 'avg_top2'}:
290 | self.output_hidden_states = True
291 | else:
292 | self.output_hidden_states = False
293 | if config.num_labels == 1:
294 | self.reshape_function = lambda x: x.reshape(-1)
295 | if config.objective == 'mse':
296 | self.loss_fct_cls = nn.MSELoss
297 | self.loss_fct_kwargs = {}
298 | elif config.objective in {'triplet', 'triplet_mse'}:
299 | self.loss_fct_cls = QuadrupletLoss
300 | self.loss_fct_kwargs = {'distance_function': lambda x, y: 1.0 - cosine_similarity(x, y)}
301 | else:
302 | raise ValueError('Only regression and triplet objectives are supported for TriEncoderForClassification')
303 | else:
304 | self.reshape_function = lambda x: x.reshape(-1, config.num_labels)
305 | self.loss_fct_cls = nn.CrossEntropyLoss
306 | self.post_init()
307 |
308 | def forward(
309 | self,
310 | input_ids=None,
311 | attention_mask=None,
312 | token_type_ids=None,
313 | position_ids=None,
314 | head_mask=None,
315 | inputs_embeds=None,
316 | input_ids_2=None,
317 | attention_mask_2=None,
318 | token_type_ids_2=None,
319 | position_ids_2=None,
320 | head_mask_2=None,
321 | inputs_embeds_2=None,
322 | input_ids_3=None,
323 | attention_mask_3=None,
324 | token_type_ids_3=None,
325 | position_ids_3=None,
326 | head_mask_3=None,
327 | inputs_embeds_3=None,
328 | labels=None,
329 | **kwargs,
330 | ):
331 | bsz = input_ids.shape[0]
332 | input_ids = concat_features(input_ids, input_ids_2, input_ids_3)
333 | attention_mask = concat_features(attention_mask, attention_mask_2, attention_mask_3)
334 | token_type_ids = concat_features(token_type_ids, token_type_ids_2, token_type_ids_3)
335 | position_ids = concat_features(position_ids, position_ids_2, position_ids_3)
336 | head_mask = concat_features(head_mask, head_mask_2, head_mask_3)
337 | inputs_embeds = concat_features(inputs_embeds, inputs_embeds_2, inputs_embeds_3)
338 | outputs = self.backbone(
339 | input_ids=input_ids,
340 | attention_mask=attention_mask,
341 | token_type_ids=token_type_ids,
342 | position_ids=position_ids,
343 | head_mask=head_mask,
344 | inputs_embeds=inputs_embeds,
345 | output_hidden_states=self.output_hidden_states,
346 | )
347 | features = self.pooler(attention_mask, outputs)
348 | features_1, features_2, features_3 = torch.split(features, bsz, dim=0)
349 | features_3 = self.condition_transform(features_3)
350 | # do we need positional embeddings?
351 | loss = None
352 | if self.transform is not None:
353 | features_1 = self.transform(features_1)
354 | features_2 = self.transform(features_2)
355 | if self.triencoder_head == 'concat':
356 | features_1 = torch.cat([features_1, features_3], dim=-1)
357 | features_2 = torch.cat([features_2, features_3], dim=-1)
358 | features_1 = self.concat_transform(features_1)
359 | features_2 = self.concat_transform(features_2)
360 | elif self.triencoder_head == 'hadamard':
361 | features_1 = features_1 * features_3
362 | features_2 = features_2 * features_3
363 | if self.config.objective in {'triplet', 'triplet_mse'}:
364 | positive_idxs = torch.arange(0, features_1.shape[0]//2)
365 | negative_idxs = torch.arange(features_1.shape[0]//2, features_1.shape[0])
366 | positives1 = features_1[positive_idxs]
367 | positives2 = features_2[positive_idxs]
368 | negatives1 = features_1[negative_idxs]
369 | negatives2 = features_2[negative_idxs]
370 | if labels is not None:
371 | loss = self.loss_fct_cls(**self.loss_fct_kwargs)(positives1, positives2, negatives1, negatives2)
372 | logits = cosine_similarity(features_1, features_2, dim=1)
373 | if self.config.objective == 'triplet_mse' and labels is not None:
374 | loss += nn.MSELoss()(logits, labels)
375 | else:
376 | logits = logits.detach()
377 | else:
378 | logits = cosine_similarity(features_1, features_2, dim=1)
379 | if labels is not None:
380 | loss = self.loss_fct_cls(**self.loss_fct_kwargs)(logits, labels)
381 | return SequenceClassifierOutput(
382 | loss=loss,
383 | logits=logits,
384 | )
385 |
--------------------------------------------------------------------------------
/utils/sts/modeling_utils.py:
--------------------------------------------------------------------------------
1 | from .modeling_encoders import BiEncoderForClassification, CrossEncoderForClassification, TriEncoderForClassification
2 | import torch
3 | import numpy as np
4 | from dataclasses import dataclass
5 | from typing import List, Dict, Any, Optional
6 |
7 |
8 | @dataclass
9 | class DataCollatorWithPadding:
10 | pad_token_id: int
11 | pad_token_type_id: int = 0
12 | pad_to_multiple_of: Optional[int] = None
13 | return_tensors: str = 'pt'
14 |
15 | def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
16 | # get max length of all sequences in features
17 | max_length = max(max(len(feature[key]) for feature in features) for key in features[0] if key.startswith('input_ids'))
18 | if self.pad_to_multiple_of is not None:
19 | max_length = ((max_length + self.pad_to_multiple_of - 1) // self.pad_to_multiple_of) * self.pad_to_multiple_of
20 | # pad all sequences to max length
21 | out_features = {}
22 | for key in features[0].keys():
23 | if key.startswith('input_ids') or key.startswith('attention_mask') or key.startswith('token_type_ids'):
24 | if key.startswith('input_ids'):
25 | pad_token = self.pad_token_id
26 | elif key.startswith('attention_mask'):
27 | pad_token = 0
28 | else:
29 | pad_token = self.pad_token_type_id
30 | out_features[key] = [feature[key] + [pad_token] * (max_length - len(feature[key])) for feature in features]
31 | else:
32 | out_features[key] = [feature[key] for feature in features]
33 | if self.return_tensors == 'pt':
34 | out_features = {key: torch.tensor(value) for key, value in out_features.items()}
35 | elif self.return_tensors == 'np':
36 | out_features = {key: np.array(value) for key, value in out_features.items()}
37 | return out_features
38 |
39 |
40 | def get_model(model_args):
41 | if model_args.encoding_type == 'bi_encoder':
42 | return BiEncoderForClassification
43 | if model_args.encoding_type == 'cross_encoder':
44 | return CrossEncoderForClassification
45 | if model_args.encoding_type == 'tri_encoder':
46 | return TriEncoderForClassification
47 | raise ValueError(f'Invalid model type: {model_args.encoding_type}')
48 |
--------------------------------------------------------------------------------
/utils/sts/triplet_trainer.py:
--------------------------------------------------------------------------------
1 | from torch.utils import data
2 | from transformers import Trainer
3 | from transformers.trainer_utils import seed_worker
4 | import datasets
5 | from torch.utils.data import DataLoader,Dataset, Sampler
6 | from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
7 | from torch.utils.data.distributed import DistributedSampler
8 | from collections import defaultdict
9 | import random
10 | import logging
11 |
12 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s: %(message)s')
13 |
14 | logger = logging.getLogger(__name__)
15 |
16 |
17 | class TripletBatchSampler(Sampler[List[int]]):
18 | r'''Samples elements from the dataset, grouping pairs of (positives, negatives) together.
19 |
20 | Args:
21 | batch_size (int)
22 | generator (Generator): Generator used in sampling
23 | '''
24 | batch_size: int
25 | generator: Optional[Callable[[int], int]]
26 |
27 | def __init__(
28 | self,
29 | batch_size: int,
30 | sentence1_key,
31 | sentence2_key,
32 | trainer: Any,
33 | generator = None
34 | ) -> None:
35 | self.batch_size = batch_size
36 | assert self.batch_size % 2 == 0, 'Batch size must be even for triplet loss'
37 | self.generator = generator
38 | self.trainer = trainer
39 | # self.trainer.train_dataset contains all the original data and feature names
40 | self.pairs = self._get_idx_pairs(self.trainer.train_dataset, sentence1_key, sentence2_key)
41 |
42 | def _get_idx_pairs(self, dataset: Dataset, sentence1_key: str, sentence2_key: str) -> List[int]:
43 | '''Get the index order of the dataset, where each index is paired with a positive and negative index
44 | '''
45 | pairs = defaultdict(list)
46 | for ix, datum in enumerate(dataset):
47 | pairs[datum[sentence1_key] + '' + datum[sentence2_key]].append(ix)
48 | pair_idxs = list(pairs.keys())
49 | drop_count = 0
50 | for pair_idx in pair_idxs:
51 | if len(pairs[pair_idx]) != 2:
52 | drop_count += len(pairs[pair_idx])
53 | pairs.pop(pair_idx)
54 | logger.warning('Dropping %d indices for missing pairs. Dataset has %d pairs total' % (drop_count, len(pair_idxs)))
55 | pairs = list(map(lambda x: sorted(pairs[x], key=lambda idx: -dataset[idx]['label']), pairs.keys()))
56 | # negative because we want to sort in descending order (highest similarity first)
57 | for idx1, idx2 in pairs:
58 | if (dataset[idx1][sentence1_key] != dataset[idx2][sentence1_key]) or (dataset[idx1][sentence2_key] != dataset[idx2][sentence2_key]):
59 | raise ValueError('Pairing of indices is incorrect, sentences do not match for pair %d and %d' % (idx1, idx2))
60 | if (dataset[idx1]['label'] < dataset[idx2]['label']):
61 | raise ValueError('Pairing of indices is incorrect, similarity is not in descending order for pair %d and %d' % (idx1, idx2))
62 | return pairs
63 |
64 | def __iter__(self):
65 | '''Generate a batch of indices with tiled positive and negative indices
66 | '''
67 | random.shuffle(self.pairs)
68 | for i in range(0, len(self.pairs), self.batch_size // 2):
69 | batch = self.pairs[i:i+self.batch_size//2]
70 | positives, negatives = zip(*batch)
71 | yield list(positives) + list(negatives)
72 |
73 | def __len__(self) -> int:
74 | return len(self.pairs) * 2 // self.batch_size
75 |
76 |
77 | class TripletTrainer(Trainer):
78 | def get_train_dataloader(self) -> DataLoader:
79 | '''
80 | Returns the training [`~torch.utils.data.DataLoader`].
81 |
82 | Will use no sampler if `train_dataset` does not implement `__len__`, a random sampler (adapted to distributed
83 | training if necessary) otherwise.
84 |
85 | Subclass and override this method if you want to inject some custom behavior.
86 | '''
87 | if self.train_dataset is None:
88 | raise ValueError('Trainer: training requires a train_dataset.')
89 | train_dataset = self.train_dataset
90 | data_collator = self.data_collator
91 | if isinstance(train_dataset, datasets.Dataset):
92 | train_dataset = self._remove_unused_columns(train_dataset, description='training')
93 | else:
94 | data_collator = self._get_collator_with_removed_columns(data_collator, description='training')
95 | train_sampler = TripletBatchSampler(
96 | self.args.train_batch_size,
97 | 'sentence1',
98 | 'sentence2',
99 | self,
100 | )
101 | return DataLoader(
102 | train_dataset,
103 | #batch_size=self._train_batch_size,
104 | batch_sampler=train_sampler,
105 | collate_fn=data_collator,
106 | #drop_last=self.args.dataloader_drop_last,
107 | num_workers=self.args.dataloader_num_workers,
108 | pin_memory=self.args.dataloader_pin_memory,
109 | worker_init_fn=seed_worker,
110 | )
111 |
--------------------------------------------------------------------------------