├── .gitignore ├── README.md ├── data ├── csts.tar.enc └── extract.sh ├── make_test_submission.py ├── requirements.txt ├── run_sts.py ├── run_sts.sh ├── run_sts_fewshot.py └── utils ├── __init__.py ├── fewshot ├── generate_in_context_dataset.py ├── openai_utils.py └── progress_logger.py ├── progress_logger.py └── sts ├── __init__.py ├── dataset_preprocessing.py ├── modeling_encoders.py ├── modeling_utils.py └── triplet_trainer.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # C-STS 2 | 3 | This repository contains the dataset and code for the paper C-STS: Conditional Semantic Textual Similarity. [[ArXiv]](https://arxiv.org/abs/2305.15093) 4 | 5 | 6 | ## Table of Contents 7 | - [Data](#data) 8 | - [Code](#code) 9 | - [Fine-tuning](#fine-tuning) 10 | - [Few-shot Evaluation](#few-shot-evaluation) 11 | - [Submitting Test Results](#submitting-test-results) 12 | - [Citation](#citation) 13 | 14 | ## Data 15 | 16 | To avoid the intentional/unintentional scraping of the C-STS dataset for pre-training LLMs, which could cause training data contamination and impact their evaluation, we adopt the following approach for our dataset release. 17 | 18 | The dataset for C-STS is stored in an encrypted file named `csts.tar.enc`. To access the dataset, follow these steps: 19 | 20 | 1. Request Access: Submit a request to obtain the decryption password by [clicking here](https://docs.google.com/forms/d/e/1FAIpQLSfoYig6I3qEBUBaNmzugnAKGpX1mSpM5cbGeO-dXq-u_sMPJQ/viewform?usp=sf_link). You will receive an email response with the password immediately. 21 | 22 | 2. Decrypt the Dataset: Once you have received the password via email, you can decrypt the `csts.tar.enc` file using the provided `extract.sh` script. Follow the instructions below: 23 | 24 | - Open a terminal and navigate to the `data` directory. 25 | - Run the following command, replacing `` with the decryption password obtained via email: 26 | 27 | ```bash 28 | bash extract.sh csts.tar.enc 29 | ``` 30 | 31 | Provided the correct password, this step will generate three files `csts_train.csv`, `csts_validation.csv`, and `csts_test.csv`, the unencrypted dataset splits. 32 | 33 | You can load the data using [datasets](https://github.com/huggingface/datasets) with the following lines 34 | 35 | ```python 36 | from datasets import load_dataset 37 | 38 | dataset = load_dataset( 39 | 'csv', 40 | data_files= 41 | { 42 | 'train': 'data/csts_train.csv', 43 | 'validation': 'data/csts_validation.csv', 44 | 'test': 'data/csts_test.csv' 45 | } 46 | ) 47 | ``` 48 | 49 | **Important: By using this dataset, you agree to not publicly share its unencrypted contents or decryption password.** 50 | 51 | ## Code 52 | We provide the basic training scripts and utilities for finetuning and evaluating the models in the paper. The code is adapted from the [HuggingFace Transformers](www.huggingface.co/transformers) library. Refer to the [documentation](https://huggingface.co/transformers/) for more details. 53 | 54 | ### Fine-tuning 55 | The current code supports finetuning any encoder-only model, using the `cross_encoder`, `bi_encoder`, or `tri_encoder` settings described in the paper. 56 | You can finetune the models described in the paper using the `run_sts.sh` script. For example, to finetune the `princeton-nlp/sup-simcse-roberta-base` model on the C-STS dataset, run the following command: 57 | 58 | ```bash 59 | MODEL=princeton-nlp/sup-simcse-roberta-base \ 60 | ENCODER_TYPE=bi_encoder \ 61 | LR=1e-5 \ 62 | WD=0.1 \ 63 | TRANSFORM=False \ 64 | OBJECTIVE=mse \ 65 | OUTPUT_DIR=output \ 66 | TRAIN_FILE=data/csts_train.csv \ 67 | EVAL_FILE=data/csts_validation.csv \ 68 | TEST_FILE=data/csts_test.csv \ 69 | bash run_sts.sh 70 | ``` 71 | 72 | See `run_sts.sh` for a full description of the available options and default values. 73 | 74 | ### Few-shot Evaluation 75 | The script `run_sts_fewshot.sh` can be used to evaluate large language-models in a few-shot setting with or without instructions. For example, to evaluate the `google/flan-t5-xxl` model on the C-STS dataset, run the following command: 76 | 77 | ```bash 78 | python run_sts_fewshot.py \ 79 | --model_name_or_path google/flan-t5-xxl \ 80 | --k_shot 2 \ 81 | --prompt_name long \ 82 | --train_file data/csts_train.csv \ 83 | --validation_file data/csts_validation.csv \ 84 | --test_file data/csts_test.csv \ 85 | --output_dir output/flan-t5-xxl/k2_long \ 86 | --dtype tf32 \ 87 | --batch_size 4 88 | ``` 89 | 90 | To accommodate large model types `run_sts_fewshot.sh` will use all visible GPUs to load the model in model parallel. For smaller models set `CUDA_VISIBLE_DEVICES` to the desired GPU ids. 91 | 92 | Run `python run_sts_fewshot.py --help` for a full description of additional options and default values. 93 | 94 | 95 | ### Submitting Test Results 96 | You can scores for your model on the test set by submitting your predictions using the `make_test_submission.py` script as follows: 97 | 98 | ```bash 99 | python make_test_submission.py your_email@email.com /path/to/your/predictions.json 100 | ``` 101 | 102 | This script expects the test predictions file to be in the format generated automatically by the scripts above; i.e. 103 | 104 | ```json 105 | { 106 | "0": 1.0, 107 | "1": 0.0, 108 | "...": 109 | "4731": 0.5 110 | } 111 | ``` 112 | 113 | After submission your results will be emailed to the submitted email address with the relevant filename in the subject. 114 | 115 | 116 | ## Citation 117 | ```tex 118 | @misc{deshpande2023csts, 119 | title={CSTS: Conditional Semantic Textual Similarity}, 120 | author={Ameet Deshpande and Carlos E. Jimenez and Howard Chen and Vishvak Murahari and Victoria Graf and Tanmay Rajpurohit and Ashwin Kalyan and Danqi Chen and Karthik Narasimhan}, 121 | year={2023}, 122 | eprint={2305.15093}, 123 | archivePrefix={arXiv}, 124 | primaryClass={cs.CL} 125 | } 126 | ``` 127 | -------------------------------------------------------------------------------- /data/extract.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | if [ $# -ne 2 ]; then 3 | echo "Error: Please provide the path to the encrypted file and the decryption password." 4 | echo "Usage: ./extract.sh " 5 | exit 1 6 | fi 7 | ENCRYPTED_FILE="$1" 8 | PASSWORD="$2" 9 | openssl aes-256-cbc -a -d -salt -pbkdf2 -in "$ENCRYPTED_FILE" -out csts.tar -pass pass:"$PASSWORD" 10 | EXIT_CODE=$? 11 | if [ $EXIT_CODE -ne 0 ]; then 12 | rm -f csts.tar 13 | echo "Error: Failed to decrypt the file." 14 | exit $EXIT_CODE 15 | fi 16 | tar -xvf csts.tar 17 | EXIT_CODE=$? 18 | if [ $EXIT_CODE -ne 0 ]; then 19 | echo "Error: Failed to extract the decrypted file." 20 | exit $EXIT_CODE 21 | fi 22 | rm -f csts.tar 23 | if [ $? -ne 0 ]; then 24 | echo "Error: Failed to remove the files." 25 | exit $? 26 | fi 27 | echo "Decryption and cleanup completed successfully." 28 | exit 0 29 | 30 | -------------------------------------------------------------------------------- /make_test_submission.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import re 3 | import json 4 | import requests 5 | from pathlib import Path 6 | 7 | 8 | def send_post_request(email, predictions, filename): 9 | # Prepare the data to be sent 10 | if len(filename) > 200: 11 | raise ValueError('Submission name (%s) longer than 200 characters. Please choose a shorter filename or set the name with --name' % filename) 12 | data = { 13 | 'email': email, 14 | 'predictions': predictions, 15 | 'filename': filename, 16 | } 17 | data_str = json.dumps({'body': json.dumps(data)}) 18 | headers = {'content-type': 'application/json'} 19 | # url = 'https://rcxnewlbk5.execute-api.us-east-2.amazonaws.com/test/eval-csts' 20 | url = "https://0sy74d2tog.execute-api.us-east-2.amazonaws.com/dev/c-sts-eval-lambda" 21 | # Create the request object 22 | request_object = { 23 | "url": url, 24 | "headers": headers, 25 | "data": data 26 | } 27 | json.dump(request_object, open('request.json', 'w'), indent=4) 28 | response = requests.post(url, headers=headers, data=data_str) 29 | if response.status_code == 200: 30 | print("Evaluation successful!") 31 | print(response.json()['body']) 32 | print("See email: \"C-STS Evaluation Results for %s\"" % filename) 33 | 34 | 35 | def main(email, predictions_file, name): 36 | predictions_file = Path(predictions_file).resolve(strict=True) 37 | if name is None: 38 | name = predictions_file.as_posix() 39 | if not re.match(r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+(?:[A-Za-z0-9.-]+)*\b", email): 40 | raise ValueError("Email %s is invalid" % email) 41 | with open(predictions_file, 'r') as f: 42 | preds = json.load(f) 43 | keys, preds = zip(*sorted(preds.items(), key=lambda x: int(x[0]))) 44 | preds = list(map(float, preds)) 45 | if len(keys) != 4732: 46 | raise ValueError("There should be exactly 4732 predictions, but got %d instead" % len(keys)) 47 | send_post_request(email, preds, name) 48 | 49 | 50 | if __name__ == "__main__": 51 | parser = argparse.ArgumentParser(description='Send email and predictions to server') 52 | parser.add_argument('email', type=str, help='The email to be sent') 53 | parser.add_argument('predictions_file', type=str, help='The path to the JSON file containing the predictions') 54 | parser.add_argument('--name', type=str, help='The name of the submission. Uses the filename if not specified') 55 | args = parser.parse_args() 56 | main(**vars(args)) 57 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate>=0.20.3 2 | datasets>=2.11.0 3 | huggingface_hub>=0.13.4 4 | openai>=0.27.7 5 | pandas>=2.0.0 6 | scipy>=1.11.1 7 | torch>=2.0.0 8 | transformers>=4.28.1 9 | -------------------------------------------------------------------------------- /run_sts.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adapted code from HuggingFace run_glue.py 3 | 4 | Author: Ameet Deshpande, Carlos E. Jimenez 5 | """ 6 | import json 7 | import logging 8 | import os 9 | 10 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 11 | import random 12 | import sys 13 | from dataclasses import dataclass, field 14 | from typing import Optional 15 | 16 | import datasets 17 | import numpy as np 18 | import transformers 19 | from datasets import load_dataset 20 | from scipy.stats import pearsonr, spearmanr 21 | from transformers import ( 22 | AutoConfig, 23 | AutoTokenizer, 24 | EvalPrediction, 25 | HfArgumentParser, 26 | PrinterCallback, 27 | Trainer, 28 | ) 29 | from transformers import TrainingArguments as HFTrainingArguments 30 | from transformers import default_data_collator, set_seed 31 | from transformers.trainer_utils import get_last_checkpoint 32 | 33 | from utils.progress_logger import LogCallback 34 | from utils.sts.dataset_preprocessing import get_preprocessing_function 35 | from utils.sts.modeling_utils import DataCollatorWithPadding, get_model 36 | from utils.sts.triplet_trainer import TripletTrainer 37 | 38 | logging.basicConfig( 39 | level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s: %(message)s" 40 | ) 41 | 42 | logger = logging.getLogger(__name__) 43 | 44 | 45 | @dataclass 46 | class TrainingArguments(HFTrainingArguments): 47 | log_time_interval: int = field( 48 | default=15, 49 | metadata={ 50 | "help": ( 51 | "Log at each `log_time_interval` seconds. " 52 | "Default will be to log every 15 seconds." 53 | ) 54 | }, 55 | ) 56 | 57 | 58 | @dataclass 59 | class DataTrainingArguments: 60 | """ 61 | Arguments pertaining to what data we are going to input our model for training and eval. 62 | 63 | Using `HfArgumentParser` we can turn this class 64 | into argparse arguments to be able to specify them on 65 | the command line. 66 | """ 67 | 68 | max_seq_length: int = field( 69 | default=128, 70 | metadata={ 71 | "help": "The maximum total input sequence length after tokenization. Sequences longer " 72 | "than this will be truncated, sequences shorter will be padded." 73 | }, 74 | ) 75 | overwrite_cache: bool = field( 76 | default=False, 77 | metadata={"help": "Overwrite the cached preprocessed datasets or not."}, 78 | ) 79 | pad_to_max_length: bool = field( 80 | default=False, 81 | metadata={ 82 | "help": "Whether to pad all samples to `max_seq_length`. " 83 | "If False, will pad the samples dynamically when batching to the maximum length in the batch." 84 | }, 85 | ) 86 | max_train_samples: Optional[int] = field( 87 | default=None, 88 | metadata={ 89 | "help": "For debugging purposes or quicker training, truncate the number of training examples to this " 90 | "value if set." 91 | }, 92 | ) 93 | max_eval_samples: Optional[int] = field( 94 | default=None, 95 | metadata={ 96 | "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this " 97 | "value if set." 98 | }, 99 | ) 100 | max_predict_samples: Optional[int] = field( 101 | default=None, 102 | metadata={ 103 | "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this " 104 | "value if set." 105 | }, 106 | ) 107 | train_file: Optional[str] = field( 108 | default=None, 109 | metadata={"help": "A csv or a json file containing the training data."}, 110 | ) 111 | validation_file: Optional[str] = field( 112 | default=None, 113 | metadata={"help": "A csv or a json file containing the validation data."}, 114 | ) 115 | test_file: Optional[str] = field( 116 | default=None, 117 | metadata={"help": "A csv or a json file containing the test data."}, 118 | ) 119 | # Dataset specific arguments 120 | max_similarity: Optional[float] = field( 121 | default=None, metadata={"help": "Maximum similarity score."} 122 | ) 123 | min_similarity: Optional[float] = field( 124 | default=None, metadata={"help": "Minimum similarity score."} 125 | ) 126 | condition_only: Optional[bool] = field( 127 | default=False, metadata={"help": "Only use condition column."} 128 | ) 129 | sentences_only: Optional[bool] = field( 130 | default=False, metadata={"help": "Only use sentences column."} 131 | ) 132 | 133 | def __post_init__(self): 134 | validation_extension = self.validation_file.split(".")[-1] 135 | if self.train_file is not None: 136 | train_extension = self.train_file.split(".")[-1] 137 | assert train_extension in [ 138 | "csv", 139 | "json", 140 | ], "`train_file` should be a csv or a json file." 141 | assert ( 142 | train_extension == validation_extension 143 | ), "`train_file` and `validation_file` should have the same extension." 144 | if self.test_file is not None: 145 | test_extension = self.test_file.split(".")[-1] 146 | assert test_extension in [ 147 | "csv", 148 | "json", 149 | ], "`test_file` should be a csv or a json file." 150 | assert ( 151 | test_extension == validation_extension 152 | ), "`test_file` and `validation_file` should have the same extension." 153 | 154 | 155 | @dataclass 156 | class ModelArguments: 157 | """ 158 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. 159 | """ 160 | 161 | model_name_or_path: str = field( 162 | metadata={ 163 | "help": "Path to pretrained model or model identifier from huggingface.co/models" 164 | } 165 | ) 166 | config_name: Optional[str] = field( 167 | default=None, 168 | metadata={ 169 | "help": "Pretrained config name or path if not the same as model_name" 170 | }, 171 | ) 172 | tokenizer_name: Optional[str] = field( 173 | default=None, 174 | metadata={ 175 | "help": "Pretrained tokenizer name or path if not the same as model_name" 176 | }, 177 | ) 178 | cache_dir: Optional[str] = field( 179 | default=None, 180 | metadata={ 181 | "help": "Where do you want to store the pretrained models downloaded from huggingface.co" 182 | }, 183 | ) 184 | use_fast_tokenizer: bool = field( 185 | default=True, 186 | metadata={ 187 | "help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not." 188 | }, 189 | ) 190 | model_revision: str = field( 191 | default="main", 192 | metadata={ 193 | "help": "The specific model version to use (can be a branch name, tag name or commit id)." 194 | }, 195 | ) 196 | use_auth_token: bool = field( 197 | default=False, 198 | metadata={ 199 | "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script " 200 | "with private models)." 201 | }, 202 | ) 203 | objective: Optional[str] = field( 204 | default="mse", 205 | metadata={ 206 | "help": "Objective function for training. Options:\ 207 | 1) regression: Regression task (uses MSELoss).\ 208 | 2) classification: Classification task (uses CrossEntropyLoss).\ 209 | 3) triplet: Regression task (uses QuadrupletLoss).\ 210 | 4) triplet_mse: Regression task uses QuadrupletLoss with MSE loss." 211 | }, 212 | ) 213 | # What type of modeling 214 | encoding_type: Optional[str] = field( 215 | default="cross_encoder", 216 | metadata={ 217 | "help": "What kind of model to choose. Options:\ 218 | 1) cross_encoder: Full encoder model.\ 219 | 2) bi_encoder: Bi-encoder model.\ 220 | 3) tri_encoder: Tri-encoder model." 221 | }, 222 | ) 223 | # Pooler for bi-encoder 224 | pooler_type: Optional[str] = field( 225 | default="cls", 226 | metadata={ 227 | "help": "Pooler type: Options:\ 228 | 1) cls: Use [CLS] token.\ 229 | 2) avg: Mean pooling." 230 | }, 231 | ) 232 | freeze_encoder: Optional[bool] = field( 233 | default=False, metadata={"help": "Freeze encoder weights."} 234 | ) 235 | transform: Optional[bool] = field( 236 | default=False, 237 | metadata={"help": "Use a linear transformation on the encoder output"}, 238 | ) 239 | triencoder_head: Optional[str] = field( 240 | default="hadamard", 241 | metadata={ 242 | "help": "Tri-encoder head type: Options:\ 243 | 1) hadamard: Hadamard product.\ 244 | 2) transformer: Transformer." 245 | }, 246 | ) 247 | 248 | 249 | def main(): 250 | parser = HfArgumentParser( 251 | (ModelArguments, DataTrainingArguments, TrainingArguments) 252 | ) 253 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 254 | model_args, data_args, training_args = parser.parse_json_file( 255 | json_file=os.path.abspath(sys.argv[1]), 256 | ) 257 | else: 258 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 259 | logging.basicConfig( 260 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 261 | datefmt="%m/%d/%Y %H:%M:%S", 262 | handlers=[logging.StreamHandler(sys.stdout)], 263 | ) 264 | training_args.log_level = "info" 265 | log_level = training_args.get_process_log_level() 266 | logger.setLevel(log_level) 267 | datasets.utils.logging.set_verbosity(log_level) 268 | transformers.utils.logging.set_verbosity(log_level) 269 | transformers.utils.logging.enable_default_handler() 270 | transformers.utils.logging.enable_explicit_format() 271 | logger.warning( 272 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" 273 | + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" 274 | ) 275 | if model_args.objective in {"triplet", "triplet_mse"}: 276 | training_args.dataloader_drop_last = True 277 | training_args.per_device_eval_batch_size = 2 278 | logger.info("Training/evaluation parameters %s" % training_args) 279 | last_checkpoint = None 280 | if ( 281 | os.path.isdir(training_args.output_dir) 282 | and training_args.do_train 283 | and not training_args.overwrite_output_dir 284 | ): 285 | last_checkpoint = get_last_checkpoint(training_args.output_dir) 286 | if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: 287 | raise ValueError( 288 | f"Output directory ({training_args.output_dir}) already exists and is not empty. " 289 | "Use --overwrite_output_dir to overcome." 290 | ) 291 | elif ( 292 | last_checkpoint is not None and training_args.resume_from_checkpoint is None 293 | ): 294 | logger.warning( 295 | f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " 296 | "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." 297 | ) 298 | set_seed(training_args.seed) 299 | data_files = {"validation": data_args.validation_file} 300 | if training_args.do_train: 301 | data_files["train"] = data_args.train_file 302 | if data_args.test_file is not None: 303 | data_files["test"] = data_args.test_file 304 | elif training_args.do_predict: 305 | raise ValueError("test_file argument is missing. required for do_predict.") 306 | for key, name in data_files.items(): 307 | logger.info(f"load a local file for {key}: {name}") 308 | if data_args.validation_file.endswith(".csv") or data_args.validation_file.endswith( 309 | ".tsv" 310 | ): 311 | # Loading a dataset from local csv files 312 | raw_datasets = load_dataset( 313 | "csv", 314 | data_files=data_files, 315 | cache_dir=model_args.cache_dir, 316 | use_auth_token=True if model_args.use_auth_token else None, 317 | ) 318 | elif data_args.validation_file.endswith(".json"): 319 | # Loading a dataset from local json files 320 | raw_datasets = load_dataset( 321 | "json", 322 | data_files=data_files, 323 | cache_dir=model_args.cache_dir, 324 | use_auth_token=True if model_args.use_auth_token else None, 325 | ) 326 | else: 327 | raise ValueError("validation_file should be a csv or a json file.") 328 | labels = set() 329 | for key in set(raw_datasets.keys()) - {"test"}: 330 | labels.update(raw_datasets[key]["label"]) 331 | if data_args.min_similarity is None: 332 | data_args.min_similarity = min(labels) 333 | logger.warning( 334 | f"Setting min_similarity: {data_args.min_similarity}. Override by setting --min_similarity." 335 | ) 336 | if data_args.max_similarity is None: 337 | data_args.max_similarity = max(labels) 338 | logger.warning( 339 | f"Setting max_similarity: {data_args.max_similarity}. Override by setting --max_similarity." 340 | ) 341 | config = AutoConfig.from_pretrained( 342 | model_args.config_name 343 | if model_args.config_name 344 | else model_args.model_name_or_path, 345 | num_labels=1, 346 | # finetuning_task=None, 347 | cache_dir=model_args.cache_dir, 348 | revision=model_args.model_revision, 349 | use_auth_token=True if model_args.use_auth_token else None, 350 | ) 351 | tokenizer = AutoTokenizer.from_pretrained( 352 | model_args.tokenizer_name 353 | if model_args.tokenizer_name 354 | else model_args.model_name_or_path, 355 | cache_dir=model_args.cache_dir, 356 | use_fast=model_args.use_fast_tokenizer, 357 | revision=model_args.model_revision, 358 | use_auth_token=True if model_args.use_auth_token else None, 359 | ) 360 | model_cls = get_model(model_args) 361 | config.update( 362 | { 363 | "use_auth_token": model_args.use_auth_token, 364 | "model_revision": model_args.model_revision, 365 | "cache_dir": model_args.cache_dir, 366 | "model_name_or_path": model_args.model_name_or_path, 367 | "objective": model_args.objective, 368 | "pooler_type": model_args.pooler_type, 369 | "transform": model_args.transform, 370 | "triencoder_head": model_args.triencoder_head, 371 | } 372 | ) 373 | model = model_cls(config=config) 374 | if model_args.freeze_encoder: 375 | for param in model.backbone.parameters(): 376 | param.requires_grad = False 377 | sentence1_key, sentence2_key, condition_key, similarity_key = ( 378 | "sentence1", 379 | "sentence2", 380 | "condition", 381 | "label", 382 | ) 383 | # Padding strategy 384 | if data_args.pad_to_max_length: 385 | padding = "max_length" 386 | else: 387 | padding = False 388 | if data_args.max_seq_length > tokenizer.model_max_length: 389 | logger.warning( 390 | "The max_seq_length passed (%d) is larger than the maximum length for the " 391 | "model (%d). Using max_seq_length=%d." 392 | % ( 393 | data_args.max_seq_length, 394 | tokenizer.model_max_length, 395 | tokenizer.model_max_length, 396 | ) 397 | ) 398 | max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length) 399 | preprocess_function = get_preprocessing_function( 400 | tokenizer, 401 | sentence1_key, 402 | sentence2_key, 403 | condition_key, 404 | similarity_key, 405 | padding, 406 | max_seq_length, 407 | model_args, 408 | scale=(data_args.min_similarity, data_args.max_similarity) 409 | if model_args.objective in {"mse", "triplet", "triplet_mse"} 410 | else None, 411 | condition_only=data_args.condition_only, 412 | sentences_only=data_args.sentences_only, 413 | ) 414 | with training_args.main_process_first(desc="dataset map pre-processing"): 415 | raw_datasets = raw_datasets.map( 416 | preprocess_function, 417 | batched=True, 418 | load_from_cache_file=not data_args.overwrite_cache, 419 | desc="Running tokenizer on dataset", 420 | remove_columns=raw_datasets["train"].column_names, 421 | ) 422 | if training_args.do_train: 423 | if "train" not in raw_datasets: 424 | raise ValueError("--do_train requires a train dataset") 425 | train_dataset = raw_datasets["train"] 426 | if data_args.max_train_samples is not None: 427 | max_train_samples = min(len(train_dataset), data_args.max_train_samples) 428 | train_dataset = train_dataset.select(range(max_train_samples)) 429 | if training_args.do_eval: 430 | if ( 431 | "validation" not in raw_datasets 432 | and "validation_matched" not in raw_datasets 433 | ): 434 | raise ValueError("--do_eval requires a validation dataset") 435 | eval_dataset = raw_datasets["validation"] 436 | if data_args.max_eval_samples is not None: 437 | max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples) 438 | eval_dataset = eval_dataset.select(range(max_eval_samples)) 439 | if training_args.do_predict or data_args.test_file is not None: 440 | if "test" not in raw_datasets and "test_matched" not in raw_datasets: 441 | raise ValueError("--do_predict requires a test dataset") 442 | predict_dataset = raw_datasets["test"] 443 | if data_args.max_predict_samples is not None: 444 | max_predict_samples = min( 445 | len(predict_dataset), data_args.max_predict_samples 446 | ) 447 | predict_dataset = predict_dataset.select(range(max_predict_samples)) 448 | # Log a few random samples from the training set: 449 | if training_args.do_train: 450 | for index in random.sample(range(len(train_dataset)), 3): 451 | input_ids = train_dataset[index]["input_ids"] 452 | logger.info(f"tokens: {tokenizer.decode(input_ids)}") 453 | logger.info(f"Sample {index} of the training set: {train_dataset[index]}.") 454 | 455 | def compute_metrics(output: EvalPrediction): 456 | preds = ( 457 | output.predictions[0] 458 | if isinstance(output.predictions, tuple) 459 | else output.predictions 460 | ) 461 | preds = np.squeeze(preds) 462 | return { 463 | "mse": ((preds - output.label_ids) ** 2).mean().item(), 464 | "pearsonr": pearsonr(preds, output.label_ids)[0], 465 | "spearmanr": spearmanr(preds, output.label_ids)[0], 466 | } 467 | 468 | if data_args.pad_to_max_length: 469 | data_collator = default_data_collator 470 | else: 471 | data_collator = DataCollatorWithPadding( 472 | pad_token_id=tokenizer.pad_token_id, 473 | pad_token_type_id=tokenizer.pad_token_type_id, 474 | pad_to_multiple_of=8 if training_args.fp16 else None, 475 | ) 476 | # Initialize our Trainer 477 | trainer_cls = ( 478 | TripletTrainer 479 | if model_args.objective in {"triplet", "triplet_mse"} 480 | else Trainer 481 | ) 482 | trainer = trainer_cls( 483 | model=model, 484 | args=training_args, 485 | train_dataset=train_dataset if training_args.do_train else None, 486 | eval_dataset=eval_dataset if training_args.do_eval else None, 487 | compute_metrics=compute_metrics, 488 | tokenizer=tokenizer, 489 | data_collator=data_collator, 490 | ) 491 | trainer.remove_callback(PrinterCallback) 492 | trainer.add_callback(LogCallback) 493 | if training_args.do_train: 494 | checkpoint = None 495 | if training_args.resume_from_checkpoint is not None: 496 | checkpoint = training_args.resume_from_checkpoint 497 | elif last_checkpoint is not None: 498 | checkpoint = last_checkpoint 499 | train_result = trainer.train(resume_from_checkpoint=checkpoint) 500 | metrics = train_result.metrics 501 | max_train_samples = ( 502 | data_args.max_train_samples 503 | if data_args.max_train_samples is not None 504 | else len(train_dataset) 505 | ) 506 | metrics["train_samples"] = min(max_train_samples, len(train_dataset)) 507 | trainer.save_model() # Saves the tokenizer too for easy upload 508 | trainer.log_metrics("train", metrics) 509 | trainer.save_metrics("train", metrics) 510 | trainer.save_state() 511 | # Evaluation 512 | combined = {} 513 | if training_args.do_eval: 514 | logger.info("*** Evaluate ***") 515 | metrics = trainer.evaluate(eval_dataset=eval_dataset) 516 | max_eval_samples = ( 517 | data_args.max_eval_samples 518 | if data_args.max_eval_samples is not None 519 | else len(eval_dataset) 520 | ) 521 | metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset)) 522 | combined.update(metrics) 523 | trainer.log_metrics("eval", metrics) 524 | trainer.save_metrics("eval", combined) 525 | if training_args.do_train: 526 | metrics = trainer.evaluate( 527 | eval_dataset=train_dataset, metric_key_prefix="train" 528 | ) 529 | max_eval_samples = ( 530 | data_args.max_eval_samples 531 | if data_args.max_eval_samples is not None 532 | else len(eval_dataset) 533 | ) 534 | metrics["train_samples"] = min(max_eval_samples, len(train_dataset)) 535 | trainer.log_metrics("train", metrics) 536 | trainer.save_metrics("train", combined) 537 | if training_args.do_predict: 538 | logger.info("*** Predict ***") 539 | # Removing the `label` columns because it contains -1 and Trainer won't like that. 540 | predict_dataset = predict_dataset.remove_columns("labels") 541 | predictions = trainer.predict( 542 | predict_dataset, metric_key_prefix="predict" 543 | ).predictions 544 | predictions = ( 545 | np.squeeze(predictions) 546 | if model_args.objective in {"mse", "triplet", "triplet_mse"} 547 | else np.argmax(predictions, axis=1) 548 | ) 549 | predictions = dict(enumerate(predictions.tolist())) 550 | output_predict_file = os.path.join( 551 | training_args.output_dir, f"test_predictions.json" 552 | ) 553 | if trainer.is_world_process_zero(): 554 | with open(output_predict_file, "w", encoding="utf-8") as outfile: 555 | json.dump(predictions, outfile) 556 | kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "CSTS"} 557 | if training_args.push_to_hub: 558 | trainer.push_to_hub(**kwargs) 559 | else: 560 | trainer.create_model_card(**kwargs) 561 | 562 | 563 | if __name__ == "__main__": 564 | main() 565 | -------------------------------------------------------------------------------- /run_sts.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | model=${MODEL:-princeton-nlp/sup-simcse-roberta-large} # pre-trained model 3 | encoding=${ENCODER_TYPE:-bi_encoder} # cross_encoder, bi_encoder, tri_encoder 4 | lr=${LR:-1e-5} # learning rate 5 | wd=${WD:-0.1} # weight decay 6 | transform=${TRANSFORM:-False} # whether to use an additional linear layer after the encoder 7 | objective=${OBJECTIVE:-mse} # mse, triplet, triplet_mse 8 | triencoder_head=${TRIENCODER_HEAD:-None} # hadamard, concat (set for tri_encoder) 9 | seed=${SEED:-42} 10 | output_dir=${OUTPUT_DIR:-output} 11 | config=enc_${encoding}__lr_${lr}__wd_${wd}__trans_${transform}__obj_${objective}__tri_${triencoder_head}__s_${seed} 12 | train_file=${TRAIN_FILE:-data/csts_train.csv} 13 | eval_file=${EVAL_FILE:-data/csts_validation.csv} 14 | test_file=${TEST_FILE:-data/csts_test.csv} 15 | 16 | python run_sts.py \ 17 | --output_dir "${output_dir}/${model//\//__}/${config}" \ 18 | --model_name_or_path ${model} \ 19 | --objective ${objective} \ 20 | --encoding_type ${encoding} \ 21 | --pooler_type cls \ 22 | --freeze_encoder False \ 23 | --transform ${transform} \ 24 | --triencoder_head ${triencoder_head} \ 25 | --max_seq_length 512 \ 26 | --train_file ${train_file} \ 27 | --validation_file ${eval_file} \ 28 | --test_file ${test_file} \ 29 | --condition_only False \ 30 | --sentences_only False \ 31 | --do_train \ 32 | --do_eval \ 33 | --do_predict \ 34 | --evaluation_strategy epoch \ 35 | --per_device_train_batch_size 8 \ 36 | --gradient_accumulation_steps 4 \ 37 | --learning_rate ${lr} \ 38 | --weight_decay ${wd} \ 39 | --max_grad_norm 0.0 \ 40 | --num_train_epochs 3 \ 41 | --lr_scheduler_type linear \ 42 | --warmup_ratio 0.1 \ 43 | --log_level info \ 44 | --disable_tqdm True \ 45 | --save_strategy epoch \ 46 | --save_total_limit 1 \ 47 | --seed ${seed} \ 48 | --data_seed ${seed} \ 49 | --fp16 True \ 50 | --log_time_interval 15 -------------------------------------------------------------------------------- /run_sts_fewshot.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import logging 4 | import os 5 | import re 6 | import time 7 | from argparse import ArgumentParser 8 | from pathlib import Path 9 | 10 | import numpy as np 11 | import torch 12 | from accelerate import init_empty_weights, load_checkpoint_and_dispatch 13 | from huggingface_hub import snapshot_download 14 | from scipy.stats import pearsonr, spearmanr 15 | from torch.utils.data import DataLoader 16 | from transformers import (AutoConfig, AutoModelForCausalLM, 17 | AutoModelForSeq2SeqLM, AutoTokenizer, 18 | default_data_collator) 19 | 20 | from utils.fewshot.generate_in_context_dataset import make_dataset 21 | from utils.fewshot.openai_utils import (OPENAI_MODELS, authenticate, 22 | get_gpt_prediction) 23 | from utils.fewshot.progress_logger import ProgressLogger 24 | 25 | logging.basicConfig( 26 | level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s: %(message)s" 27 | ) 28 | logger = logging.getLogger(__name__) 29 | 30 | 31 | MODEL_CLASSES = { 32 | "gpt": AutoModelForCausalLM, 33 | "t5": AutoModelForSeq2SeqLM, 34 | } 35 | 36 | NO_SKIP_MODULES = { 37 | "t5": ["T5Block"], 38 | "gpt": ["GPTJBlock"], 39 | } 40 | 41 | DTYPES = { 42 | "bf16": torch.bfloat16, 43 | "fp16": torch.float16, 44 | "fp32": torch.float32, 45 | "tf32": torch.float32, # tf32 set by cuda.backend 46 | } 47 | 48 | 49 | def get_tokenizer_type(model): 50 | if ( 51 | "t5" in model.lower() 52 | or "t0" in model.lower() 53 | or "tk-" in model.lower() 54 | or "ul2" in model.lower() 55 | ): 56 | return "t5" 57 | elif "gpt" in model.lower(): 58 | return "gpt" 59 | else: 60 | raise ValueError(f"Unknown tokenizer type {model}") 61 | 62 | 63 | def extract_float(s): 64 | match = re.search(r"(\d+\.\d+|\d+)", s) 65 | if match: 66 | return float(match.group(1)) 67 | return s 68 | 69 | 70 | def eval( 71 | dataset, 72 | model, 73 | tokenizer, 74 | prefix, 75 | tokenizer_type, 76 | min_similarity, 77 | max_similarity, 78 | dataloader_num_workers, 79 | batch_size, 80 | ): 81 | start_time = time.time() 82 | if model in OPENAI_MODELS: 83 | all_preds, all_labels, examples, non_numeric = openai_model_eval( 84 | model, 85 | dataset, 86 | min_similarity, 87 | max_similarity, 88 | ) 89 | else: 90 | all_preds, all_labels, examples, non_numeric = non_openai_model_eval( 91 | model, 92 | tokenizer, 93 | tokenizer_type, 94 | dataset, 95 | dataloader_num_workers, 96 | batch_size, 97 | min_similarity, 98 | max_similarity, 99 | ) 100 | eval_time = time.time() - start_time 101 | predictions = dict(enumerate(all_preds)) 102 | logger.info(f"Example Preds: {all_preds[:3]}") 103 | logger.info(f"Example Labels: {all_labels[:3]}") 104 | results = process_results( 105 | prefix, 106 | eval_time, 107 | len(dataset), 108 | non_numeric, 109 | all_preds, 110 | all_labels, 111 | min_similarity, 112 | max_similarity, 113 | ) 114 | return results, predictions, examples 115 | 116 | 117 | def get_tokenizer_func(tokenizer, tokenizer_type): 118 | def tokenizer_func(example): 119 | return tokenizer( 120 | example["text"], 121 | padding="longest", 122 | truncation=True, 123 | return_tensors="pt", 124 | add_special_tokens=tokenizer_type == "t5", 125 | ) 126 | return tokenizer_func 127 | 128 | 129 | def openai_model_eval( 130 | model, dataset, min_similarity, max_similarity 131 | ): 132 | all_preds, all_labels, examples = [], [], [] 133 | non_numeric = 0 134 | for ix, example in ProgressLogger.wrap_iter( 135 | "eval", dataset, len(dataset), return_ix=True 136 | ): 137 | raw_pred = get_gpt_prediction(model, example["text"]) 138 | pred = extract_float(raw_pred) 139 | if type(pred) is not float: 140 | non_numeric += 1 141 | pred = torch.empty(1).uniform_(min_similarity, max_similarity).item() 142 | label = float(example["label"]) 143 | all_preds.append(pred) 144 | all_labels.append(label) 145 | examples.append( 146 | { 147 | "id": ix, 148 | "example": example["text"], 149 | "raw_pred": raw_pred, 150 | "pred": pred, 151 | "label": label, 152 | } 153 | ) 154 | if ix < 3: 155 | log_example(ix, example["text"], raw_pred, label) 156 | return all_preds, all_labels, examples, non_numeric 157 | 158 | 159 | def non_openai_model_eval( 160 | model, 161 | tokenizer, 162 | tokenizer_type, 163 | dataset, 164 | dataloader_num_workers, 165 | batch_size, 166 | min_similarity, 167 | max_similarity, 168 | ): 169 | preprocess_func = get_tokenizer_func(tokenizer, tokenizer_type) 170 | dataset = dataset.map( 171 | preprocess_func, batched=True, batch_size=batch_size 172 | ) 173 | generation_kwargs = { 174 | "gpt": {"max_new_tokens": 20, 'pad_token_id': tokenizer.eos_token_id}, 175 | "t5": {"max_new_tokens": 20}, 176 | }[tokenizer_type] 177 | dataloader = DataLoader( 178 | dataset, 179 | batch_size=batch_size, 180 | collate_fn=default_data_collator, 181 | num_workers=dataloader_num_workers, 182 | shuffle=False, 183 | ) 184 | non_numeric = 0 185 | all_preds, all_labels, examples = [], [], [] 186 | with torch.no_grad(): 187 | for ix, example in ProgressLogger.wrap_iter( 188 | "eval", dataloader, len(dataloader), return_ix=True 189 | ): 190 | inputs = { 191 | k: v.to(model.device) 192 | for k, v in example.items() 193 | if k in ["input_ids", "attention_mask"] 194 | } 195 | output = model.generate(**inputs, **generation_kwargs) 196 | if tokenizer_type == "gpt": 197 | output = output[:, inputs["input_ids"].shape[-1]:, ...] 198 | raw_preds = tokenizer.batch_decode(output, skip_special_tokens=True) 199 | preds, non_numeric = process_preds( 200 | raw_preds, 201 | non_numeric, 202 | min_similarity, 203 | max_similarity, 204 | ) 205 | labels = example["labels"].tolist() 206 | example_texts = tokenizer.batch_decode( 207 | inputs["input_ids"], skip_special_tokens=True 208 | ) 209 | if ix * batch_size < 3: 210 | log_examples( 211 | ix, example_texts, raw_preds, labels, batch_size 212 | ) 213 | all_preds.extend(preds) 214 | all_labels.extend(labels) 215 | examples.extend( 216 | [ 217 | { 218 | "id": cix + ix * batch_size, 219 | "example": example_text, 220 | "raw_pred": raw_pred, 221 | "pred": pred, 222 | "label": label, 223 | } 224 | for cix, example_text, raw_pred, pred, label in zip(range(len(preds)), example_texts, raw_preds, preds, labels) 225 | ] 226 | ) 227 | return all_preds, all_labels, examples, non_numeric 228 | 229 | 230 | def process_preds(raw_preds, non_numeric, min_similarity, max_similarity): 231 | preds = list(map(extract_float, raw_preds)) 232 | non_numeric += sum(1 for p in preds if type(p) is not float) 233 | preds = [ 234 | p 235 | if type(p) is float 236 | else torch.empty(1).uniform_(min_similarity, max_similarity).item() 237 | for p in preds 238 | ] 239 | return preds, non_numeric 240 | 241 | 242 | def log_example(ix, text, raw_pred, label): 243 | example_str = "Example %d:\n\t%sPRED=%s LABEL=%s" % ( 244 | ix, 245 | text.replace("\n", "\n\t"), 246 | raw_pred, 247 | label, 248 | ) 249 | logger.info(example_str) 250 | 251 | 252 | def log_examples(ix, example_texts, raw_preds, labels, batch_size): 253 | for cix in range(min(len(raw_preds), 3)): 254 | log_example( 255 | ix * batch_size + cix, 256 | example_texts[cix], 257 | raw_preds[cix], 258 | labels[cix], 259 | ) 260 | 261 | 262 | def process_results( 263 | prefix, 264 | eval_time, 265 | samples, 266 | non_numeric, 267 | all_preds, 268 | all_labels, 269 | min_similarity, 270 | max_similarity, 271 | ): 272 | scaled_preds = np.array(all_preds) 273 | invalid_preds = sum( 274 | 1 for p in scaled_preds if not min_similarity <= p <= max_similarity 275 | ) 276 | scaled_labels = np.array(all_labels) 277 | results = { 278 | "pearsonr": pearsonr(scaled_preds, scaled_labels)[0], 279 | "spearmanr": spearmanr(scaled_preds, scaled_labels)[0], 280 | "runtime": eval_time, 281 | "samples": samples, 282 | "samples_per_second": samples / eval_time, 283 | "non_numeric": non_numeric, 284 | "non_numeric_percent": non_numeric / samples, 285 | "mse": ((torch.tensor(all_preds) - torch.tensor(all_labels)) ** 2) 286 | .mean() 287 | .item(), 288 | "out_of_range": invalid_preds, 289 | "out_of_range_percent": invalid_preds / samples, 290 | } 291 | return {f"{prefix}_{k}": v for k, v in results.items()} 292 | 293 | 294 | def load_model_and_tokenizer(model_name_or_path, tokenizer_type, api_key, dtype): 295 | if model_name_or_path not in OPENAI_MODELS: 296 | if not torch.cuda.is_available() and dtype != 'fp32': 297 | logger.info("Using CPU, overriding dtype to fp32") 298 | dtype = torch.float32 if not torch.cuda.is_available() else DTYPES[dtype] 299 | model_cls = MODEL_CLASSES[tokenizer_type] 300 | weights_location = get_weights_location(model_name_or_path) 301 | config = AutoConfig.from_pretrained(weights_location) 302 | index_location = get_index_location(weights_location) 303 | with init_empty_weights(): 304 | logger.info(f"Instantiating model from config") 305 | model = model_cls.from_config(config) 306 | model = load_model_weights(model, index_location, dtype, tokenizer_type) 307 | model = model.eval() 308 | tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, padding_side="left" if tokenizer_type == "gpt" else "right") 309 | if tokenizer_type == "gpt": 310 | tokenizer.pad_token_id = tokenizer.eos_token_id 311 | else: 312 | if not os.path.exists(api_key): 313 | raise ValueError("api_key must be a file containing your OpenAI API key") 314 | authenticate(api_key) 315 | model = model_name_or_path 316 | tokenizer = None 317 | return model, tokenizer 318 | 319 | 320 | def get_weights_location(model_name_or_path): 321 | if not os.path.exists(model_name_or_path): 322 | return snapshot_download( 323 | repo_id=model_name_or_path, 324 | ignore_patterns=["*h5*", "*msgpack*", "*safetensors*", '*tflite*', '*rust_model.ot*'], # only download pytorch weights 325 | ) 326 | elif os.path.isdir(model_name_or_path): 327 | return model_name_or_path 328 | else: 329 | return os.path.dirname(model_name_or_path) 330 | 331 | 332 | def get_index_location(weights_location): 333 | index_location = os.path.join(weights_location, "pytorch_model.bin.index.json") 334 | if not os.path.exists(index_location): 335 | index_location = os.path.join(weights_location, "pytorch_model.bin") 336 | return index_location 337 | 338 | 339 | def load_model_weights(model, index_location, dtype, tokenizer_type): 340 | logger.info("Loading model weights with load_checkpoint_and_dispatch") 341 | model = load_checkpoint_and_dispatch( 342 | model, 343 | index_location, 344 | device_map="balanced", 345 | no_split_module_classes=NO_SKIP_MODULES[tokenizer_type], 346 | dtype=dtype, 347 | ) 348 | logger.info(f"Loaded model with load_checkpoint_and_dispatch from {index_location}") 349 | return model 350 | 351 | 352 | def save_results(results, predictions, examples, output_dir, output_file_prefix): 353 | logger.info(f"{output_file_prefix} results: %s" % json.dumps(results, indent=4)) 354 | logger.info("Writing eval_results to %s" % output_dir) 355 | with open(Path(output_dir, f"{output_file_prefix}_results.json"), "w") as f: 356 | json.dump(results, f, indent=4) 357 | with open(Path(output_dir, f"{output_file_prefix}_predictions.json"), "w") as f: 358 | json.dump(predictions, f, indent=4) 359 | with open(Path(output_dir, f"{output_file_prefix}_examples.json"), "w") as f: 360 | json.dump(examples, f, indent=4) 361 | 362 | 363 | def main( 364 | model_name_or_path, 365 | tokenizer_type, 366 | output_dir, 367 | train_file, 368 | validation_file, 369 | test_file, 370 | k_shot, 371 | prompt_name, 372 | seed, 373 | overwrite_output_dir, 374 | dataloader_num_workers, 375 | max_eval_samples, 376 | api_key, 377 | dtype, 378 | batch_size, 379 | ): 380 | skip_validation = overwrite_output_dir is False and Path(output_dir, "eval_results.json").exists() 381 | skip_test = overwrite_output_dir is False and Path(output_dir, "test_results.json").exists() 382 | if skip_validation: 383 | logger.info(f"Skipping validation, found eval_results.json in {output_dir}.\nSet overwrite_output_dir=True to override.") 384 | if skip_test: 385 | logger.info(f"Skipping test, found test_results.json in {output_dir}.\nSet overwrite_output_dir=True to override.") 386 | if skip_validation and skip_test: 387 | return 388 | if validation_file is None and test_file is None: 389 | logger.info("No validation or test file provided. Exiting.") 390 | return 391 | if model_name_or_path in OPENAI_MODELS: 392 | assert api_key is not None, "api_key path must be provided for OpenAI models" 393 | if dtype == "tf32": 394 | torch.backends.cuda.matmul.allow_tf32 = True 395 | if tokenizer_type is None: 396 | tokenizer_type = get_tokenizer_type(model_name_or_path) 397 | else: 398 | assert tokenizer_type in {'gpt', 't5'}, f"tokenizer_type must be one of 'gpt' or 't5', got {tokenizer_type}" 399 | logger.info(f"Using {tokenizer_type} tokenizer") 400 | config_key = f"{tokenizer_type}_k{k_shot}_prompt{prompt_name}" 401 | model, tokenizer = load_model_and_tokenizer( 402 | model_name_or_path, tokenizer_type, api_key, dtype 403 | ) 404 | max_similarity = 5.0 405 | min_similarity = 1.0 if "csts" in Path(train_file).name else 0.0 406 | is_stsb = "stsb" in Path(train_file).name 407 | logging.info("Loading dataset %s" % config_key) 408 | dataset = make_dataset( 409 | train_file=train_file, 410 | validation_file=validation_file, 411 | test_file=test_file, 412 | tokenizer_type=tokenizer_type, 413 | k_shot=k_shot, 414 | prompt_name=prompt_name, 415 | seed=seed, 416 | is_stsb=is_stsb, 417 | ) 418 | train_dataset = dataset["train"] 419 | eval_dataset, test_dataset = None, None 420 | if validation_file is not None: 421 | eval_dataset = dataset["validation"] 422 | if test_file is not None: 423 | test_dataset = dataset["test"] 424 | if max_eval_samples is not None and 'validation' in dataset: 425 | eval_dataset = eval_dataset.select(range(min(max_eval_samples, len(eval_dataset)))) 426 | logger.info( 427 | "Loaded %d train examples, %d validation examples, %d test examples" 428 | % (len(train_dataset), len(eval_dataset) if eval_dataset is not None else 0, len(test_dataset) if test_dataset is not None else 0) 429 | ) 430 | Path(output_dir).mkdir(parents=True, exist_ok=True) 431 | if validation_file is not None: 432 | logger.info("Evaluating validation dataset") 433 | eval_results, eval_predictions, eval_examples = eval( 434 | dataset=eval_dataset, 435 | model=model, 436 | tokenizer=tokenizer, 437 | prefix='eval', 438 | tokenizer_type=tokenizer_type, 439 | min_similarity=min_similarity, 440 | max_similarity=max_similarity, 441 | dataloader_num_workers=dataloader_num_workers, 442 | batch_size=batch_size, 443 | ) 444 | save_results(eval_results, eval_predictions, eval_examples, output_dir, "eval") 445 | if test_file is not None: 446 | logger.info("Predicting on test dataset") 447 | test_results, test_predictions, test_examples = eval( 448 | dataset=test_dataset, 449 | model=model, 450 | tokenizer=tokenizer, 451 | prefix='test', 452 | tokenizer_type=tokenizer_type, 453 | min_similarity=min_similarity, 454 | max_similarity=max_similarity, 455 | dataloader_num_workers=dataloader_num_workers, 456 | batch_size=batch_size, 457 | ) 458 | save_results(test_results, test_predictions, test_examples, output_dir, "test") 459 | logger.info("Done!") 460 | 461 | 462 | def string_to_bool(v): 463 | if isinstance(v, bool): 464 | return v 465 | if v.lower() in ("yes", "true", "t", "y", "1"): 466 | return True 467 | elif v.lower() in ("no", "false", "f", "n", "0"): 468 | return False 469 | else: 470 | raise argparse.ArgumentTypeError( 471 | f"Truthy value expected: got {v} but expected one of yes/no, true/false, t/f, y/n, 1/0 (case insensitive)." 472 | ) 473 | 474 | 475 | if __name__ == "__main__": 476 | parser = ArgumentParser() 477 | parser.add_argument("--model_name_or_path", type=str, required=True) 478 | parser.add_argument("--tokenizer_type", type=str, help="Tokenizer type (gpt or t5). If not provided, will be inferred from model_name_or_path.") 479 | parser.add_argument("--k_shot", type=int, required=True, help="Number of examples to use in prompt.") 480 | parser.add_argument("--prompt_name", type=str, required=True, help="Name of prompt to use. See utils/fewshot/generate_in_context_dataset.py for options.") 481 | parser.add_argument("--seed", type=int, default=42) 482 | parser.add_argument("--train_file", type=str, required=True, help="Path to train file.") 483 | parser.add_argument("--validation_file", type=str, required=False, help="Path to validation file. If not provided, will not run validation.") 484 | parser.add_argument("--test_file", type=str, required=False, help="Path to test file. If not provided, will not run test.") 485 | parser.add_argument( 486 | "--output_dir", type=str, required=True, help="Directory to save results" 487 | ) 488 | parser.add_argument( 489 | "--overwrite_output_dir", 490 | type=string_to_bool, 491 | default=False, 492 | nargs="?", 493 | const=True, 494 | help="Overwrite the content of the output directory", 495 | ) 496 | parser.add_argument("--dataloader_num_workers", type=int, default=0) 497 | parser.add_argument( 498 | "--api_key", type=str, required=False, help="Path to OpenAI API key" 499 | ) 500 | parser.add_argument( 501 | "--dtype", 502 | type=str, 503 | choices=["fp16", "bf16", "fp32", "tf32"], 504 | help="Data used for model. TF32 and BF16 are recommended but only supported for NVIDIA GPUs with Ampere architecture or later.", 505 | required=True, 506 | ) 507 | parser.add_argument("--batch_size", type=int, default=2) 508 | parser.add_argument("--max_eval_samples", type=int, default=None) 509 | args = parser.parse_args() 510 | main(**vars(args)) 511 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-nlp/c-sts/3bacafac67c6a359186e58d3086deaf180a43863/utils/__init__.py -------------------------------------------------------------------------------- /utils/fewshot/generate_in_context_dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from datasets import load_dataset, Dataset, DatasetDict 3 | import pandas as pd 4 | import random 5 | from pathlib import Path 6 | from functools import partial 7 | import logging 8 | from argparse import ArgumentParser 9 | 10 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | SEP_TOKENS = { 15 | 't5': '', 16 | 'gpt': ' ', 17 | } 18 | LABEL_KEY_FUNCS = { 19 | 't5': lambda x: f'{x.capitalize()}:', 20 | 'gpt': lambda x: f'{x.capitalize()}: ', 21 | } 22 | 23 | prompts = { 24 | 'none': None, 25 | 'short': 'On a scale between 1 and 5, how similar are the following two sentences with respect to the condition provided? Respond only with a score between 1 and 5.', 26 | 'long': 'Definition: Evaluate the similarity between the two sentences, with respect to the condition. Assign the pair a score between 1 and 5 as follows: 1 : The two sentences are completely dissimilar with respect to the condition. 2 : The two sentences are dissimilar, but are on a similar topic with respect to the condition. 3 : The two sentences are roughly equivalent, but some important information differs or is missing with respect to the condition. 4 : The two sentences are mostly equivalent, but some unimportant details differ with respect to the condition. 5 : The two sentences are completely equivalent with respect to the condition.', 27 | } 28 | stsb_prompts = { 29 | 'none': None, 30 | 'short': 'On a scale between 0 and 5, how similar are the following two sentences? Respond only with a score between 0 and 5.', 31 | 'long': 'Definition: Evaluate the similarity between them and classify them into classes from 0-5 as follows: 0 : The two sentences are completely dissimilar. 1 : The two sentences are not equivalent, but are on the same topic. 2 : The two sentences are not equivalent, but share some details. 3 : The two sentences are roughly equivalent, but some important information differs/missing. 4 : The two sentences are mostly equivalent, but some unimportant details differ. 5 : The two sentences are completely equivalent, as they mean the same thing.', 32 | } 33 | def get_prompt(prompt_name, is_stsb=False): 34 | if prompt_name is None: 35 | return None 36 | else: 37 | if is_stsb: 38 | return stsb_prompts[prompt_name] 39 | else: 40 | return prompts[prompt_name] 41 | 42 | 43 | def convert_text( 44 | example, 45 | sep_token, 46 | label_key_func=lambda x: f'{x.capitalize()}: ', 47 | sentence_1_label='sentence1', 48 | sentence_2_label='sentence2', 49 | condition_label='condition', 50 | answer_label='label', 51 | is_stsb=False, 52 | ): 53 | sent_1 = example[sentence_1_label].strip() 54 | sent_2 = example[sentence_2_label].strip() 55 | condition = example[condition_label] if example[condition_label] is not None else '' # bug some conditions are None 56 | similarity = example[answer_label] 57 | if is_stsb: 58 | ex_list = [ 59 | label_key_func('input'), 60 | ' '.join([label_key_func('sentence 1'), sent_1, ]), 61 | ' '.join([label_key_func('sentence 2'), sent_2, ]), 62 | label_key_func('output'), 63 | ] 64 | else: 65 | ex_list = [ 66 | label_key_func('input'), 67 | ' '.join([label_key_func('sentence 1'), sent_1, ]), 68 | ' '.join([label_key_func('sentence 2'), sent_2, ]), 69 | ' '.join([label_key_func('condition'), condition, ]), 70 | label_key_func('output'), 71 | ] 72 | ex_str = ' '.join(map(str, ex_list)) 73 | return ex_str 74 | 75 | 76 | def add_context( 77 | example, 78 | context, 79 | prompt, 80 | sep_token, 81 | answer_label='label', 82 | label_func=lambda x: f'{float(x)}', 83 | ): 84 | if prompt is not None: 85 | ex_list = [ 86 | prompt.strip(' :'), 87 | ] 88 | else: 89 | ex_list = [] 90 | for ex in context: 91 | entry = ex['original_text'] + label_func(ex[answer_label]) 92 | ex_list.extend([entry, ]) 93 | ex_list.append(example['original_text']) # don't add a label to the last example 94 | return '\n'.join(ex_list) 95 | 96 | 97 | def add_in_context_examples(dataset, context_dataset, model, k, prompt, tokenizer_type, pairs=None): 98 | contexts = list() 99 | context_ids = list() 100 | for ix, entry in enumerate(dataset): 101 | if pairs is not None: 102 | random_pairs = random.sample(range(len(pairs)), k=(k+1)//2) 103 | context_example_ids = [x for pair in random_pairs for x in pairs[pair]][:k] 104 | else: 105 | context_example_ids = random.sample(list(set(range(len(context_dataset))) - {ix}), k=k) 106 | context_ids.append(context_example_ids) 107 | context_examples = [context_dataset[idx] for idx in context_example_ids] 108 | contexts.append(add_context(entry, context_examples, prompt, SEP_TOKENS[tokenizer_type])) 109 | dataset = dataset.add_column('context_ids', context_ids) 110 | dataset = dataset.add_column('text', contexts) 111 | return dataset 112 | 113 | def get_idx_pairs( 114 | dataset, 115 | sentence_1_label='sentence1', 116 | sentence_2_label='sentence2', 117 | condition_label='condition', 118 | answer_label='label', 119 | ): 120 | from collections import defaultdict 121 | pairs = defaultdict(list) 122 | for ix, datum in enumerate(dataset): 123 | pairs[datum[sentence_1_label] + '<-SEP->' + datum[sentence_2_label]].append(ix) 124 | pair_idxs = list(pairs.keys()) 125 | drop_count = 0 126 | for pair_idx in pair_idxs: 127 | if len(pairs[pair_idx]) != 2: 128 | drop_count += len(pairs[pair_idx]) 129 | pairs.pop(pair_idx) 130 | logger.warning('Dropping %d indices for missing pairs. Dataset has %d pairs total' % (drop_count, len(pair_idxs))) 131 | pairs = list(map(lambda x: sorted(pairs[x], key=lambda idx: -dataset[idx][answer_label]), pairs.keys())) 132 | # negative because we want to sort in descending order (highest similarity first) 133 | for idx1, idx2 in pairs: 134 | if (dataset[idx1][sentence_1_label] != dataset[idx2][sentence_1_label]) or (dataset[idx1][sentence_2_label] != dataset[idx2][sentence_2_label]): 135 | raise ValueError('Pairing of indices is incorrect, sentences do not match for pair %d and %d' % (idx1, idx2)) 136 | if (dataset[idx1][answer_label] < dataset[idx2][answer_label]): 137 | raise ValueError('Pairing of indices is incorrect, similarity is not in descending order for pair %d and %d' % (idx1, idx2)) 138 | return pairs 139 | 140 | def make_dataset( 141 | train_file, 142 | validation_file, 143 | test_file, 144 | tokenizer_type, 145 | k_shot, 146 | prompt_name, 147 | seed, 148 | is_stsb=False, 149 | ): 150 | convert_func = partial(convert_text, sep_token=SEP_TOKENS[tokenizer_type], label_key_func=LABEL_KEY_FUNCS[tokenizer_type], is_stsb=is_stsb) 151 | data_files = {'train': train_file} 152 | if validation_file is not None: 153 | data_files['validation'] = validation_file 154 | if test_file is not None: 155 | data_files['test'] = test_file 156 | raw_datasets = load_dataset('csv', data_files=data_files, keep_in_memory=True) 157 | raw_datasets = raw_datasets.map(lambda x: {'original_text': convert_func(x)}, batched=False, keep_in_memory=True) 158 | prompt = get_prompt(prompt_name, is_stsb) 159 | random.seed(seed) 160 | pairs = None 161 | if not is_stsb: 162 | pairs = get_idx_pairs(raw_datasets['train']) 163 | raw_datasets['train'] = add_in_context_examples(raw_datasets['train'], raw_datasets['train'], tokenizer_type, k_shot, prompt, tokenizer_type, pairs) 164 | if validation_file is not None: 165 | raw_datasets['validation'] = add_in_context_examples(raw_datasets['validation'], raw_datasets['train'], tokenizer_type, k_shot, prompt, tokenizer_type, pairs) 166 | if test_file is not None: 167 | raw_datasets['test'] = add_in_context_examples(raw_datasets['test'], raw_datasets['train'], tokenizer_type, k_shot, prompt, tokenizer_type, pairs) 168 | return raw_datasets 169 | 170 | def main(train_file, test_file, tokenizer_type, output_dir, k_shot, prompt_name, seed, is_stsb): 171 | dataset = make_dataset(train_file, test_file, tokenizer_type, k_shot, prompt_name, seed, is_stsb) 172 | output_file = Path(output_dir, Path(train_file).stem + f'_{tokenizer_type}_k{k_shot}_prompt{prompt_name}_seed{seed}') 173 | dataset.save_to_disk(output_file) 174 | logger.info(f'Saved to {output_file}') 175 | 176 | 177 | if __name__ == '__main__': 178 | parser = ArgumentParser() 179 | parser.add_argument('--train_file', type=str, required=True) 180 | parser.add_argument('--test_file', type=str, required=True) 181 | parser.add_argument('--tokenizer_type', type=str, choices=sorted(LABEL_KEY_FUNCS.keys()), required=True) 182 | parser.add_argument('--output_dir', type=str, default='./in_context_datasets') 183 | parser.add_argument('--k_shot', type=int, required=True) 184 | parser.add_argument('--prompt_name', type=str) 185 | parser.add_argument('--is_stsb', action='store_true') 186 | parser.add_argument('--seed', type=int, required=True) 187 | args = parser.parse_args() 188 | main(**vars(args)) 189 | -------------------------------------------------------------------------------- /utils/fewshot/openai_utils.py: -------------------------------------------------------------------------------- 1 | import openai 2 | import logging 3 | import time 4 | 5 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s: %(message)s') 6 | logger = logging.getLogger(__name__) 7 | 8 | 9 | LEGACY_MODELS = { 10 | 'davinci', 11 | 'curie', 12 | 'babbage', 13 | 'ada', 14 | } 15 | 16 | GPT3_MODELS = { 17 | 'text-davinci-003', 18 | 'text-davinci-002', 19 | 'text-davinci-001', 20 | 'text-curie-001', 21 | 'text-babbage-001', 22 | 'text-ada-001', 23 | } 24 | 25 | CHAT_MODELS = { 26 | 'gpt-4', 27 | 'gpt-4-0314', 28 | 'gpt-4-32k', 29 | 'gpt-4-32k-0314', 30 | 'gpt-3.5-turbo', 31 | 'gpt-3.5-turbo-0301', 32 | } 33 | 34 | OPENAI_MODELS = LEGACY_MODELS | GPT3_MODELS | CHAT_MODELS 35 | 36 | 37 | def parse_response(model, response, prompt): 38 | if model in CHAT_MODELS: 39 | response_text = response['choices'][0]['message']['content'] 40 | else: 41 | response_text = response['choices'][0]['text'].replace(prompt, '') 42 | response_text.strip() 43 | return response_text 44 | 45 | def call_chat(model, prompt): 46 | response = openai.ChatCompletion.create( 47 | model=model, 48 | messages= [{'role': 'user', 'content': prompt}, ], 49 | temperature=0, 50 | max_tokens=5, 51 | top_p=1, 52 | frequency_penalty=0.0, 53 | presence_penalty=0.0, 54 | ) 55 | return response 56 | 57 | 58 | def call_gpt(model, prompt): 59 | response = openai.Completion.create( 60 | model=model, 61 | prompt=prompt, 62 | temperature=0, 63 | max_tokens=5, 64 | top_p=1, 65 | frequency_penalty=0.0, 66 | presence_penalty=0.0, 67 | ) 68 | return response 69 | 70 | 71 | def get_gpt_prediction(model, prompt): 72 | retries = 0 73 | while retries < 3: 74 | try: 75 | if model in CHAT_MODELS: 76 | response = call_chat(model, prompt) 77 | else: 78 | response = call_gpt(model, prompt) 79 | return parse_response(model, response, prompt) 80 | except Exception as e: 81 | logger.warning('Exception while getting gpt prediction: {}'.format(e)) 82 | logger.warning(f'Retrying... {3 - retries} more times.') 83 | retries += 1 84 | time.sleep(20 * retries) 85 | raise Exception('Failed to get gpt prediction after 3 retries. Aborting run.') 86 | 87 | 88 | def authenticate(api_key): 89 | with open(api_key) as f: 90 | api_key = f.readlines()[0].strip() 91 | openai.api_key = api_key 92 | -------------------------------------------------------------------------------- /utils/fewshot/progress_logger.py: -------------------------------------------------------------------------------- 1 | import time 2 | import logging 3 | 4 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s: %(message)s') 5 | logger = logger = logging.getLogger(__name__) 6 | 7 | 8 | class ProgressLogger: 9 | def __init__(self, process_name, total_iters, log_time_interval=60): 10 | self.total_iters = total_iters 11 | self.name = process_name 12 | if log_time_interval < 1: 13 | log_time_interval = 1 14 | logger.info('log_time_interval must be >= 1, setting to 1') 15 | self.log_time_interval = log_time_interval 16 | self.start_time = None 17 | self.last_log_time = None 18 | self.last_log_ix = 0 19 | 20 | def __call__(self, ix): 21 | if ix == 0: 22 | self.start_time = time.time() 23 | self.last_log_time = time.time() 24 | self.log(ix) 25 | elif ix == self.total_iters - 1: 26 | self.log(ix) 27 | elif time.time() - self.last_log_time > self.log_time_interval: 28 | self.log(ix) 29 | 30 | def log(self, ix): 31 | pct_complete = ix / self.total_iters * 100 32 | iters_per_sec = (ix - self.last_log_ix) / (time.time() - self.last_log_time) 33 | remaining = (time.time() - self.start_time) / (ix + 1) * (self.total_iters - ix - 1) 34 | remaining = time.strftime('%H:%M:%S', time.gmtime(remaining)) 35 | logger.info('{}: {:_} / {:_} ({:.2f}%) ({:_.2f} iter/sec) (remaining: {})'.format( 36 | self.name, 37 | ix, 38 | self.total_iters, 39 | pct_complete, 40 | iters_per_sec, 41 | remaining, 42 | )) 43 | self.last_log_time = time.time() 44 | self.last_log_ix = ix 45 | 46 | @classmethod 47 | def wrap_iter( 48 | cls, 49 | process_name, 50 | iterable, 51 | total_iters, 52 | log_time_interval=60, 53 | return_ix=False, 54 | ): 55 | progress = cls(process_name, total_iters, log_time_interval) 56 | if return_ix: 57 | for ix, item in enumerate(iterable): 58 | progress(ix) 59 | yield ix, item 60 | else: 61 | for ix, item in enumerate(iterable): 62 | progress(ix) 63 | yield item 64 | 65 | -------------------------------------------------------------------------------- /utils/progress_logger.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | from transformers import TrainerCallback 4 | import logging 5 | 6 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s: %(message)s') 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | 11 | class LogCallback(TrainerCallback): 12 | def __init__(self, *args, **kwargs): 13 | super().__init__(*args, **kwargs) 14 | self.start_time = time.time() 15 | self.last_log_time = self.start_time 16 | self.log_time_interval = 0 17 | self.current_step = 0 18 | self.is_training = False 19 | self.max_steps = -1 20 | self.first_step_of_run = 0 21 | 22 | def on_train_begin(self, args, state, control, **kwargs): 23 | if state.is_local_process_zero: 24 | args.logging_steps = 1 25 | args.logging_strategy = 'steps' 26 | self.log_time_interval = args.log_time_interval 27 | if self.log_time_interval > 0: 28 | logger.info(f'Using log_time_interval {self.log_time_interval} s. This may override logging_step.') 29 | self.is_training = True 30 | self.current_step = 0 31 | self.start_time = time.time() 32 | self.last_log_time = self.start_time 33 | self.max_steps = state.max_steps 34 | self.first_step_of_run = state.global_step 35 | if torch.distributed.is_initialized(): 36 | torch.distributed.barrier() 37 | 38 | def on_log(self, args, state, control, logs=None, **kwargs): 39 | _ = logs.pop('total_flos', None) 40 | if state.is_local_process_zero: 41 | if self.is_training: 42 | current_time = time.time() 43 | time_diff = current_time - self.last_log_time 44 | force = len(logs) > 3 or any([k.endswith('runtime') for k in logs]) 45 | if time_diff > self.log_time_interval or self.current_step >= self.max_steps - 1 or force: 46 | self.last_log_time = current_time 47 | steps_completed = max(self.current_step, 1) 48 | steps_since_first = max(1, self.current_step - self.first_step_of_run) 49 | remaining_steps = self.max_steps - steps_completed 50 | pct_completed = (steps_completed / self.max_steps) * 100 51 | time_since_start = current_time - self.start_time 52 | remaining_time = (time_since_start / steps_since_first) * remaining_steps 53 | update = {'completed': f'{pct_completed:.2f}% ({steps_completed:_} / {self.max_steps:_})', 'remaining time': self.format_duration(remaining_time)} 54 | logger.info(str({**logs, **update})) 55 | else: 56 | logger.info(str(logs)) 57 | 58 | def on_step_end(self, args, state, control, **kwargs): 59 | if state.is_local_process_zero: 60 | self.current_step = state.global_step 61 | 62 | def on_train_end(self, args, state, control, **kwargs): 63 | if state.is_local_process_zero: 64 | self.is_training = False 65 | 66 | @staticmethod 67 | def format_duration(seconds): 68 | hours, remainder = divmod(seconds, 3600) 69 | minutes, seconds = divmod(remainder, 60) 70 | return f'{int(hours)}:{int(minutes):02}:{int(seconds):02}' 71 | -------------------------------------------------------------------------------- /utils/sts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-nlp/c-sts/3bacafac67c6a359186e58d3086deaf180a43863/utils/sts/__init__.py -------------------------------------------------------------------------------- /utils/sts/dataset_preprocessing.py: -------------------------------------------------------------------------------- 1 | def scale_to_range(labels, _min, _max): 2 | return list(map(lambda x: (x - _min) / (_max - _min), labels)) 3 | 4 | 5 | def get_preprocessing_function( 6 | tokenizer, 7 | sentence1_key, 8 | sentence2_key, 9 | condition_key, 10 | similarity_key, 11 | padding, 12 | max_seq_length, 13 | model_args, 14 | scale=None, 15 | condition_only=False, 16 | sentences_only=False, 17 | ): 18 | 'Returns a the preprocessing function for each encoding type' 19 | if model_args.encoding_type == 'bi_encoder': 20 | if condition_only or sentences_only: 21 | raise ValueError('condition_only and sentences_only doesn\'t apply to bi_encoder') 22 | def preprocess_function(examples): 23 | sent1_args = (examples[sentence1_key], examples[condition_key]) 24 | sent1_result = tokenizer(*sent1_args, padding=padding, max_length=max_seq_length, truncation=True) 25 | sent2_args = (examples[sentence2_key], examples[condition_key]) 26 | sent2_result = tokenizer(*sent2_args, padding=padding, max_length=max_seq_length, truncation=True) 27 | sent1_result['input_ids_2'] = sent2_result['input_ids'] 28 | sent1_result['attention_mask_2'] = sent2_result['attention_mask'] 29 | if 'token_type_ids' in sent2_result: 30 | sent1_result['token_type_ids_2'] = sent2_result['token_type_ids'] 31 | sent1_result['labels'] = examples[similarity_key] 32 | if scale is not None: 33 | _min, _max = scale 34 | for label in sent1_result['labels']: 35 | if (label < _min or label > _max) and label != -1: 36 | raise ValueError(f'Label {label} is not in the range [{_min}, {_max}]') 37 | sent1_result['labels'] = scale_to_range(sent1_result['labels'], _min, _max) 38 | return sent1_result 39 | elif model_args.encoding_type == 'cross_encoder': 40 | def preprocess_function(examples): 41 | if condition_only: 42 | input_args = examples[condition_key] 43 | elif sentences_only: 44 | input_args = list(map(lambda x: ' '.join([x[0], tokenizer.sep_token, x[1]]), zip(examples[sentence1_key], examples[sentence2_key]))) 45 | else: 46 | input_args = list(map(lambda x: ' '.join([x[0], tokenizer.sep_token, x[1], tokenizer.sep_token, x[2]]), zip(examples[sentence1_key], examples[sentence2_key], examples[condition_key]))) 47 | result = tokenizer(input_args, padding=padding, max_length=max_seq_length, truncation=True) 48 | result['labels'] = examples[similarity_key] 49 | if scale is not None: 50 | _min, _max = scale 51 | for label in result['labels']: 52 | if (label < _min or label > _max) and label != -1: 53 | raise ValueError(f'Label {label} is not in the range [{_min}, {_max}]') 54 | result['labels'] = scale_to_range(result['labels'], _min, _max) 55 | return result 56 | elif model_args.encoding_type == 'tri_encoder': 57 | if condition_only or sentences_only: 58 | raise ValueError('condition_only and sentences_only doesn\'t apply to tri_encoder') 59 | def preprocess_function(examples): 60 | sent1_args = (examples[sentence1_key], ) 61 | sent1_result = tokenizer(*sent1_args, padding=padding, max_length=max_seq_length, truncation=True) 62 | sent2_args = (examples[sentence2_key], ) 63 | sent2_result = tokenizer(*sent2_args, padding=padding, max_length=max_seq_length, truncation=True) 64 | sent3_args = (examples[condition_key], ) 65 | sent3_result = tokenizer(*sent3_args, padding=padding, max_length=max_seq_length, truncation=True) 66 | sent1_result['input_ids_2'] = sent2_result['input_ids'] 67 | sent1_result['attention_mask_2'] = sent2_result['attention_mask'] 68 | sent1_result['input_ids_3'] = sent3_result['input_ids'] 69 | sent1_result['attention_mask_3'] = sent3_result['attention_mask'] 70 | if 'token_type_ids' in sent2_result: 71 | sent1_result['token_type_ids_2'] = sent2_result['token_type_ids'] 72 | sent1_result['token_type_ids_3'] = sent3_result['token_type_ids'] 73 | sent1_result['labels'] = examples[similarity_key] 74 | if scale is not None: 75 | _min, _max = scale 76 | for label in sent1_result['labels']: 77 | if (label < _min or label > _max) and label != -1: 78 | raise ValueError(f'Label {label} is not in the range [{_min}, {_max}]') 79 | sent1_result['labels'] = scale_to_range(sent1_result['labels'], _min, _max) 80 | return sent1_result 81 | else: 82 | raise ValueError(f'Invalid model type: {model_args.encoding_type}') 83 | return preprocess_function 84 | -------------------------------------------------------------------------------- /utils/sts/modeling_encoders.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.checkpoint 3 | from torch import nn 4 | from torch.nn.functional import cosine_similarity 5 | 6 | from transformers.activations import ACT2FN 7 | from transformers.modeling_outputs import ( 8 | SequenceClassifierOutput, 9 | ) 10 | from transformers import PreTrainedModel, AutoModel 11 | import logging 12 | 13 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s: %(message)s') 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | 18 | def concat_features(*features): 19 | return torch.cat(features, dim=0) if features[0] is not None else None 20 | 21 | 22 | class QuadrupletLoss: 23 | def __init__(self, distance_function, margin=1.0): 24 | 'A cosine distance margin quadruplet loss' 25 | self.margin = margin 26 | self.distance_function = distance_function 27 | 28 | def __call__(self, pos1, pos2, neg1, neg2): 29 | dist_pos = self.distance_function(pos1, pos2) 30 | dist_neg = self.distance_function(neg1, neg2) 31 | loss = torch.clamp_min(self.margin + dist_pos - dist_neg, 0) 32 | return loss.mean() 33 | 34 | 35 | # Pooler class. Copied and adapted from SimCSE code 36 | class Pooler(nn.Module): 37 | ''' 38 | Parameter-free poolers to get the sentence embedding 39 | 'cls': [CLS] representation with BERT/RoBERTa's MLP pooler. 40 | 'cls_before_pooler': [CLS] representation without the original MLP pooler. 41 | 'avg': average of the last layers' hidden states at each token. 42 | 'avg_top2': average of the last two layers. 43 | 'avg_first_last': average of the first and the last layers. 44 | ''' 45 | def __init__(self, pooler_type): 46 | super().__init__() 47 | self.pooler_type = pooler_type 48 | assert self.pooler_type in ['cls', 'cls_before_pooler', 'avg', 'avg_top2', 'avg_first_last'], 'unrecognized pooling type %s' % self.pooler_type 49 | 50 | def forward(self, attention_mask, outputs): 51 | last_hidden = outputs.last_hidden_state 52 | pooler_output = outputs.pooler_output 53 | hidden_states = outputs.hidden_states 54 | 55 | if self.pooler_type in ['cls_before_pooler', 'cls']: 56 | return last_hidden[:, 0] 57 | elif self.pooler_type == 'avg': 58 | return ((last_hidden * attention_mask.unsqueeze(-1)).sum(1) / attention_mask.sum(-1).unsqueeze(-1)) 59 | elif self.pooler_type == 'avg_first_last': 60 | first_hidden = hidden_states[0] 61 | last_hidden = hidden_states[-1] 62 | pooled_result = ((first_hidden + last_hidden) / 2.0 * attention_mask.unsqueeze(-1)).sum(1) / attention_mask.sum(-1).unsqueeze(-1) 63 | return pooled_result 64 | elif self.pooler_type == 'avg_top2': 65 | second_last_hidden = hidden_states[-2] 66 | last_hidden = hidden_states[-1] 67 | pooled_result = ((last_hidden + second_last_hidden) / 2.0 * attention_mask.unsqueeze(-1)).sum(1) / attention_mask.sum(-1).unsqueeze(-1) 68 | return pooled_result 69 | else: 70 | raise NotImplementedError 71 | 72 | 73 | class CrossEncoderForClassification(PreTrainedModel): 74 | 'Encoder model with backbone and classification head.' 75 | def __init__(self, config): 76 | super().__init__(config) 77 | self.backbone = AutoModel.from_pretrained( 78 | config.model_name_or_path, 79 | from_tf=bool('.ckpt' in config.model_name_or_path), 80 | config=config, 81 | cache_dir=config.cache_dir, 82 | revision=config.model_revision, 83 | use_auth_token=True if config.use_auth_token else None, 84 | add_pooling_layer=False, 85 | ).base_model 86 | classifier_dropout = ( 87 | config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob 88 | ) 89 | if config.transform: 90 | self.transform = nn.Sequential( 91 | nn.Dropout(classifier_dropout), 92 | nn.Linear(config.hidden_size, config.hidden_size), 93 | ACT2FN[config.hidden_act], 94 | ) 95 | else: 96 | self.transform = None 97 | self.classifier = nn.Sequential( 98 | nn.Dropout(config.hidden_dropout_prob), 99 | nn.Linear(config.hidden_size, config.num_labels), 100 | ) 101 | self.pooler = Pooler(config.pooler_type) 102 | if config.pooler_type in {'avg_first_last', 'avg_top2'}: 103 | self.output_hidden_states = True 104 | else: 105 | self.output_hidden_states = False 106 | if config.num_labels == 1: 107 | self.reshape_function = lambda x: x.reshape(-1) 108 | if config.objective == 'mse': 109 | self.loss_fct_cls = nn.MSELoss 110 | elif config.objective in {'triplet', 'triplet_mse'}: 111 | raise NotImplementedError('Triplet loss is not implemented for CrossEncoderForClassification') 112 | else: 113 | raise ValueError(f'Only regression and triplet objectives are supported for CrossEncoderForClassification with num_labels=1. Got {config.objective}.') 114 | else: 115 | assert config.objective == 'classification' 116 | self.reshape_function = lambda x: x.reshape(-1, config.num_labels) 117 | self.loss_fct_cls = nn.CrossEntropyLoss 118 | self.post_init() 119 | 120 | def forward( 121 | self, 122 | input_ids=None, 123 | attention_mask=None, 124 | token_type_ids=None, 125 | position_ids=None, 126 | head_mask=None, 127 | inputs_embeds=None, 128 | labels=None, 129 | **kwargs, 130 | ): 131 | outputs = self.backbone( 132 | input_ids=input_ids, 133 | attention_mask=attention_mask, 134 | token_type_ids=token_type_ids, 135 | position_ids=position_ids, 136 | head_mask=head_mask, 137 | inputs_embeds=inputs_embeds, 138 | output_hidden_states=self.output_hidden_states, 139 | ) 140 | features = self.pooler(attention_mask, outputs) 141 | if self.transform is not None: 142 | features = self.transform(features) 143 | logits = self.classifier(features) 144 | reshaped_logits = self.reshape_function(logits) 145 | loss = None 146 | if labels is not None: 147 | loss = self.loss_fct_cls()(reshaped_logits, labels.view(-1)) 148 | return SequenceClassifierOutput( 149 | loss=loss, 150 | logits=logits, 151 | hidden_states=outputs.hidden_states, 152 | attentions=outputs.attentions, 153 | ) 154 | 155 | 156 | class BiEncoderForClassification(PreTrainedModel): 157 | '''Encoder model with backbone and classification head.''' 158 | def __init__(self, config): 159 | super().__init__(config) 160 | self.backbone = AutoModel.from_pretrained( 161 | config.model_name_or_path, 162 | from_tf=bool('.ckpt' in config.model_name_or_path), 163 | config=config, 164 | cache_dir=config.cache_dir, 165 | revision=config.model_revision, 166 | use_auth_token=True if config.use_auth_token else None, 167 | add_pooling_layer=False, 168 | ).base_model 169 | classifier_dropout = ( 170 | config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob 171 | ) 172 | if config.transform: 173 | self.transform = nn.Sequential( 174 | nn.Dropout(classifier_dropout), 175 | nn.Linear(config.hidden_size, config.hidden_size), 176 | ACT2FN[config.hidden_act], 177 | ) 178 | else: 179 | self.transform = None 180 | self.pooler = Pooler(config.pooler_type) 181 | if config.pooler_type in {'avg_first_last', 'avg_top2'}: 182 | self.output_hidden_states = True 183 | else: 184 | self.output_hidden_states = False 185 | if config.objective == 'mse': 186 | self.loss_fct_cls = nn.MSELoss 187 | self.loss_fct_kwargs = {} 188 | elif config.objective in {'triplet', 'triplet_mse'}: 189 | self.loss_fct_cls = QuadrupletLoss 190 | self.loss_fct_kwargs = {'distance_function': lambda x, y: 1.0 - cosine_similarity(x, y)} 191 | else: 192 | raise ValueError('Only regression and triplet objectives are supported for BiEncoderForClassification') 193 | self.post_init() 194 | 195 | def forward( 196 | self, 197 | input_ids=None, 198 | attention_mask=None, 199 | token_type_ids=None, 200 | position_ids=None, 201 | head_mask=None, 202 | inputs_embeds=None, 203 | input_ids_2=None, 204 | attention_mask_2=None, 205 | token_type_ids_2=None, 206 | position_ids_2=None, 207 | head_mask_2=None, 208 | inputs_embeds_2=None, 209 | labels=None, 210 | **kwargs, 211 | ): 212 | bsz = input_ids.shape[0] 213 | input_ids = concat_features(input_ids, input_ids_2) 214 | attention_mask = concat_features(attention_mask, attention_mask_2) 215 | token_type_ids = concat_features(token_type_ids, token_type_ids_2) 216 | position_ids = concat_features(position_ids, position_ids_2) 217 | head_mask = concat_features(head_mask, head_mask_2) 218 | inputs_embeds = concat_features(inputs_embeds, inputs_embeds_2) 219 | outputs = self.backbone( 220 | input_ids=input_ids, 221 | attention_mask=attention_mask, 222 | token_type_ids=token_type_ids, 223 | position_ids=position_ids, 224 | head_mask=head_mask, 225 | inputs_embeds=inputs_embeds, 226 | output_hidden_states=self.output_hidden_states, 227 | ) 228 | features = self.pooler(attention_mask, outputs) 229 | if self.transform is not None: 230 | features = self.transform(features) 231 | features_1, features_2 = torch.split(features, bsz, dim=0) # [sentence1, condtion], [sentence2, condition] 232 | loss = None 233 | if self.config.objective in {'triplet', 'triplet_mse'}: 234 | positives1, negatives1 = torch.split(features_1, bsz // 2, dim=0) 235 | positives2, negatives2 = torch.split(features_2, bsz // 2, dim=0) 236 | if labels is not None: 237 | loss = self.loss_fct_cls(**self.loss_fct_kwargs)(positives1, positives2, negatives1, negatives2) 238 | logits = cosine_similarity(features_1, features_2, dim=1) 239 | if self.config.objective in {'triplet_mse'} and labels is not None: 240 | loss += nn.MSELoss()(logits, labels) 241 | else: 242 | logits = logits.detach() 243 | else: 244 | logits = cosine_similarity(features_1, features_2, dim=1) 245 | if labels is not None: 246 | loss = self.loss_fct_cls(**self.loss_fct_kwargs)(logits, labels) 247 | return SequenceClassifierOutput( 248 | loss=loss, 249 | logits=logits, 250 | ) 251 | 252 | class TriEncoderForClassification(PreTrainedModel): 253 | def __init__(self, config): 254 | super().__init__(config) 255 | self.backbone = AutoModel.from_pretrained( 256 | config.model_name_or_path, 257 | from_tf=bool('.ckpt' in config.model_name_or_path), 258 | config=config, 259 | cache_dir=config.cache_dir, 260 | revision=config.model_revision, 261 | use_auth_token=True if config.use_auth_token else None, 262 | add_pooling_layer=False, 263 | ).base_model 264 | self.triencoder_head = config.triencoder_head 265 | classifier_dropout = ( 266 | config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob 267 | ) 268 | if config.transform: 269 | self.transform = nn.Sequential( 270 | nn.Dropout(classifier_dropout), 271 | nn.Linear(config.hidden_size, config.hidden_size), 272 | ACT2FN[config.hidden_act], 273 | ) 274 | else: 275 | self.transform = None 276 | self.condition_transform = nn.Sequential( 277 | nn.Dropout(classifier_dropout), 278 | nn.Linear(config.hidden_size, config.hidden_size) 279 | ) 280 | if self.triencoder_head == 'concat': 281 | self.concat_transform = nn.Sequential( 282 | nn.Dropout(classifier_dropout), 283 | nn.Linear(config.hidden_size * 2, config.hidden_size), 284 | ACT2FN[config.hidden_act], 285 | ) 286 | elif self.triencoder_head == 'hadamard': 287 | self.concat_transform = None 288 | self.pooler = Pooler(config.pooler_type) 289 | if config.pooler_type in {'avg_first_last', 'avg_top2'}: 290 | self.output_hidden_states = True 291 | else: 292 | self.output_hidden_states = False 293 | if config.num_labels == 1: 294 | self.reshape_function = lambda x: x.reshape(-1) 295 | if config.objective == 'mse': 296 | self.loss_fct_cls = nn.MSELoss 297 | self.loss_fct_kwargs = {} 298 | elif config.objective in {'triplet', 'triplet_mse'}: 299 | self.loss_fct_cls = QuadrupletLoss 300 | self.loss_fct_kwargs = {'distance_function': lambda x, y: 1.0 - cosine_similarity(x, y)} 301 | else: 302 | raise ValueError('Only regression and triplet objectives are supported for TriEncoderForClassification') 303 | else: 304 | self.reshape_function = lambda x: x.reshape(-1, config.num_labels) 305 | self.loss_fct_cls = nn.CrossEntropyLoss 306 | self.post_init() 307 | 308 | def forward( 309 | self, 310 | input_ids=None, 311 | attention_mask=None, 312 | token_type_ids=None, 313 | position_ids=None, 314 | head_mask=None, 315 | inputs_embeds=None, 316 | input_ids_2=None, 317 | attention_mask_2=None, 318 | token_type_ids_2=None, 319 | position_ids_2=None, 320 | head_mask_2=None, 321 | inputs_embeds_2=None, 322 | input_ids_3=None, 323 | attention_mask_3=None, 324 | token_type_ids_3=None, 325 | position_ids_3=None, 326 | head_mask_3=None, 327 | inputs_embeds_3=None, 328 | labels=None, 329 | **kwargs, 330 | ): 331 | bsz = input_ids.shape[0] 332 | input_ids = concat_features(input_ids, input_ids_2, input_ids_3) 333 | attention_mask = concat_features(attention_mask, attention_mask_2, attention_mask_3) 334 | token_type_ids = concat_features(token_type_ids, token_type_ids_2, token_type_ids_3) 335 | position_ids = concat_features(position_ids, position_ids_2, position_ids_3) 336 | head_mask = concat_features(head_mask, head_mask_2, head_mask_3) 337 | inputs_embeds = concat_features(inputs_embeds, inputs_embeds_2, inputs_embeds_3) 338 | outputs = self.backbone( 339 | input_ids=input_ids, 340 | attention_mask=attention_mask, 341 | token_type_ids=token_type_ids, 342 | position_ids=position_ids, 343 | head_mask=head_mask, 344 | inputs_embeds=inputs_embeds, 345 | output_hidden_states=self.output_hidden_states, 346 | ) 347 | features = self.pooler(attention_mask, outputs) 348 | features_1, features_2, features_3 = torch.split(features, bsz, dim=0) 349 | features_3 = self.condition_transform(features_3) 350 | # do we need positional embeddings? 351 | loss = None 352 | if self.transform is not None: 353 | features_1 = self.transform(features_1) 354 | features_2 = self.transform(features_2) 355 | if self.triencoder_head == 'concat': 356 | features_1 = torch.cat([features_1, features_3], dim=-1) 357 | features_2 = torch.cat([features_2, features_3], dim=-1) 358 | features_1 = self.concat_transform(features_1) 359 | features_2 = self.concat_transform(features_2) 360 | elif self.triencoder_head == 'hadamard': 361 | features_1 = features_1 * features_3 362 | features_2 = features_2 * features_3 363 | if self.config.objective in {'triplet', 'triplet_mse'}: 364 | positive_idxs = torch.arange(0, features_1.shape[0]//2) 365 | negative_idxs = torch.arange(features_1.shape[0]//2, features_1.shape[0]) 366 | positives1 = features_1[positive_idxs] 367 | positives2 = features_2[positive_idxs] 368 | negatives1 = features_1[negative_idxs] 369 | negatives2 = features_2[negative_idxs] 370 | if labels is not None: 371 | loss = self.loss_fct_cls(**self.loss_fct_kwargs)(positives1, positives2, negatives1, negatives2) 372 | logits = cosine_similarity(features_1, features_2, dim=1) 373 | if self.config.objective == 'triplet_mse' and labels is not None: 374 | loss += nn.MSELoss()(logits, labels) 375 | else: 376 | logits = logits.detach() 377 | else: 378 | logits = cosine_similarity(features_1, features_2, dim=1) 379 | if labels is not None: 380 | loss = self.loss_fct_cls(**self.loss_fct_kwargs)(logits, labels) 381 | return SequenceClassifierOutput( 382 | loss=loss, 383 | logits=logits, 384 | ) 385 | -------------------------------------------------------------------------------- /utils/sts/modeling_utils.py: -------------------------------------------------------------------------------- 1 | from .modeling_encoders import BiEncoderForClassification, CrossEncoderForClassification, TriEncoderForClassification 2 | import torch 3 | import numpy as np 4 | from dataclasses import dataclass 5 | from typing import List, Dict, Any, Optional 6 | 7 | 8 | @dataclass 9 | class DataCollatorWithPadding: 10 | pad_token_id: int 11 | pad_token_type_id: int = 0 12 | pad_to_multiple_of: Optional[int] = None 13 | return_tensors: str = 'pt' 14 | 15 | def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]: 16 | # get max length of all sequences in features 17 | max_length = max(max(len(feature[key]) for feature in features) for key in features[0] if key.startswith('input_ids')) 18 | if self.pad_to_multiple_of is not None: 19 | max_length = ((max_length + self.pad_to_multiple_of - 1) // self.pad_to_multiple_of) * self.pad_to_multiple_of 20 | # pad all sequences to max length 21 | out_features = {} 22 | for key in features[0].keys(): 23 | if key.startswith('input_ids') or key.startswith('attention_mask') or key.startswith('token_type_ids'): 24 | if key.startswith('input_ids'): 25 | pad_token = self.pad_token_id 26 | elif key.startswith('attention_mask'): 27 | pad_token = 0 28 | else: 29 | pad_token = self.pad_token_type_id 30 | out_features[key] = [feature[key] + [pad_token] * (max_length - len(feature[key])) for feature in features] 31 | else: 32 | out_features[key] = [feature[key] for feature in features] 33 | if self.return_tensors == 'pt': 34 | out_features = {key: torch.tensor(value) for key, value in out_features.items()} 35 | elif self.return_tensors == 'np': 36 | out_features = {key: np.array(value) for key, value in out_features.items()} 37 | return out_features 38 | 39 | 40 | def get_model(model_args): 41 | if model_args.encoding_type == 'bi_encoder': 42 | return BiEncoderForClassification 43 | if model_args.encoding_type == 'cross_encoder': 44 | return CrossEncoderForClassification 45 | if model_args.encoding_type == 'tri_encoder': 46 | return TriEncoderForClassification 47 | raise ValueError(f'Invalid model type: {model_args.encoding_type}') 48 | -------------------------------------------------------------------------------- /utils/sts/triplet_trainer.py: -------------------------------------------------------------------------------- 1 | from torch.utils import data 2 | from transformers import Trainer 3 | from transformers.trainer_utils import seed_worker 4 | import datasets 5 | from torch.utils.data import DataLoader,Dataset, Sampler 6 | from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional 7 | from torch.utils.data.distributed import DistributedSampler 8 | from collections import defaultdict 9 | import random 10 | import logging 11 | 12 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s: %(message)s') 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | 17 | class TripletBatchSampler(Sampler[List[int]]): 18 | r'''Samples elements from the dataset, grouping pairs of (positives, negatives) together. 19 | 20 | Args: 21 | batch_size (int) 22 | generator (Generator): Generator used in sampling 23 | ''' 24 | batch_size: int 25 | generator: Optional[Callable[[int], int]] 26 | 27 | def __init__( 28 | self, 29 | batch_size: int, 30 | sentence1_key, 31 | sentence2_key, 32 | trainer: Any, 33 | generator = None 34 | ) -> None: 35 | self.batch_size = batch_size 36 | assert self.batch_size % 2 == 0, 'Batch size must be even for triplet loss' 37 | self.generator = generator 38 | self.trainer = trainer 39 | # self.trainer.train_dataset contains all the original data and feature names 40 | self.pairs = self._get_idx_pairs(self.trainer.train_dataset, sentence1_key, sentence2_key) 41 | 42 | def _get_idx_pairs(self, dataset: Dataset, sentence1_key: str, sentence2_key: str) -> List[int]: 43 | '''Get the index order of the dataset, where each index is paired with a positive and negative index 44 | ''' 45 | pairs = defaultdict(list) 46 | for ix, datum in enumerate(dataset): 47 | pairs[datum[sentence1_key] + '' + datum[sentence2_key]].append(ix) 48 | pair_idxs = list(pairs.keys()) 49 | drop_count = 0 50 | for pair_idx in pair_idxs: 51 | if len(pairs[pair_idx]) != 2: 52 | drop_count += len(pairs[pair_idx]) 53 | pairs.pop(pair_idx) 54 | logger.warning('Dropping %d indices for missing pairs. Dataset has %d pairs total' % (drop_count, len(pair_idxs))) 55 | pairs = list(map(lambda x: sorted(pairs[x], key=lambda idx: -dataset[idx]['label']), pairs.keys())) 56 | # negative because we want to sort in descending order (highest similarity first) 57 | for idx1, idx2 in pairs: 58 | if (dataset[idx1][sentence1_key] != dataset[idx2][sentence1_key]) or (dataset[idx1][sentence2_key] != dataset[idx2][sentence2_key]): 59 | raise ValueError('Pairing of indices is incorrect, sentences do not match for pair %d and %d' % (idx1, idx2)) 60 | if (dataset[idx1]['label'] < dataset[idx2]['label']): 61 | raise ValueError('Pairing of indices is incorrect, similarity is not in descending order for pair %d and %d' % (idx1, idx2)) 62 | return pairs 63 | 64 | def __iter__(self): 65 | '''Generate a batch of indices with tiled positive and negative indices 66 | ''' 67 | random.shuffle(self.pairs) 68 | for i in range(0, len(self.pairs), self.batch_size // 2): 69 | batch = self.pairs[i:i+self.batch_size//2] 70 | positives, negatives = zip(*batch) 71 | yield list(positives) + list(negatives) 72 | 73 | def __len__(self) -> int: 74 | return len(self.pairs) * 2 // self.batch_size 75 | 76 | 77 | class TripletTrainer(Trainer): 78 | def get_train_dataloader(self) -> DataLoader: 79 | ''' 80 | Returns the training [`~torch.utils.data.DataLoader`]. 81 | 82 | Will use no sampler if `train_dataset` does not implement `__len__`, a random sampler (adapted to distributed 83 | training if necessary) otherwise. 84 | 85 | Subclass and override this method if you want to inject some custom behavior. 86 | ''' 87 | if self.train_dataset is None: 88 | raise ValueError('Trainer: training requires a train_dataset.') 89 | train_dataset = self.train_dataset 90 | data_collator = self.data_collator 91 | if isinstance(train_dataset, datasets.Dataset): 92 | train_dataset = self._remove_unused_columns(train_dataset, description='training') 93 | else: 94 | data_collator = self._get_collator_with_removed_columns(data_collator, description='training') 95 | train_sampler = TripletBatchSampler( 96 | self.args.train_batch_size, 97 | 'sentence1', 98 | 'sentence2', 99 | self, 100 | ) 101 | return DataLoader( 102 | train_dataset, 103 | #batch_size=self._train_batch_size, 104 | batch_sampler=train_sampler, 105 | collate_fn=data_collator, 106 | #drop_last=self.args.dataloader_drop_last, 107 | num_workers=self.args.dataloader_num_workers, 108 | pin_memory=self.args.dataloader_pin_memory, 109 | worker_init_fn=seed_worker, 110 | ) 111 | --------------------------------------------------------------------------------