├── .gitignore ├── LICENSE ├── README.md ├── bash_scripts ├── basic_tests.sh ├── hpc_setup.sh ├── operational_test.sh ├── setup.sh └── train_tests.sh ├── data ├── __init__.py ├── annotated_quality_debates_loader.py ├── cleaning_scripts │ ├── consultancy_to_double_consultancy_converter.py │ ├── debates_format_converter.py │ └── replicate_and_bin_format_converter.py ├── dataset.py ├── datasets │ ├── annotated-quality-debates │ │ ├── annotated-data-set.p │ │ └── classifier.p │ ├── quality-debates │ │ └── debates-readable.jsonl │ ├── quality │ │ ├── QuALITY.v1.0.1.htmlstripped.dev │ │ ├── QuALITY.v1.0.1.htmlstripped.test │ │ └── QuALITY.v1.0.1.htmlstripped.train │ ├── quote-relevance │ │ └── quote-relevance.p │ └── scratchpad-quality-debates │ │ └── scratchpad-quality-debates.p ├── judge_preferences_loader.py ├── loader_utils.py ├── quality_debates_loader.py ├── quality_judging_loader.py ├── quality_loader.py ├── quote_relevance_loader.py └── scratchpad_quality_debates_loader.py ├── debate ├── __init__.py ├── agent.py ├── debate_round.py ├── debater.py ├── judge.py ├── speech_format.py └── transcript.py ├── experiments ├── __init__.py ├── annotator.py ├── configs │ ├── backup_experiments.yaml │ ├── example_experiment.yaml │ ├── standard_experiment.yaml │ └── test_experiment.yaml ├── experiment_loader.py ├── power_pair_scheduler.py ├── quotes_collector.py └── results_collector.py ├── models ├── __init__.py ├── anthropic_model.py ├── arbitrary_attribute_model.py ├── deterministic_model.py ├── human_model.py ├── llm_model.py ├── model.py ├── model_utils.py ├── offline_model.py ├── openai_model.py ├── random_model.py ├── repetitive_model.py └── served_model.py ├── outputs ├── graphs │ └── .gitignore ├── runs │ └── .gitignore ├── stats │ └── .gitignore └── transcripts │ └── .gitignore ├── prompts ├── __init__.py ├── configs │ └── prompts.yaml └── parser.py ├── requirements.txt ├── scripts ├── generate_quote_labels.py ├── load_model.py ├── merge.py ├── oai_finetune.py ├── push_to_hub.py ├── run_debate.py ├── run_iterative_dpo.py ├── run_ppo.py ├── run_sft.py └── script_utils.py ├── train ├── .DS_Store ├── __init__.py ├── configs │ ├── dpo_config.yaml │ ├── ppo_config.yaml │ └── sft_config.yaml ├── impl │ ├── __init__.py │ ├── llama_with_gradient_checkpointing_impl.py │ ├── smoothed_dpo_trainer.py │ └── verbose_ppo_trainer.py ├── iterative_dpo_trainer.py ├── ppo_trainer.py ├── row_converter.py ├── sft_trainer.py └── train_utils.py └── utils ├── __init__.py ├── constants.py ├── flash_attn_utils.py ├── input_utils.py ├── logger_utils.py ├── quote_utils.py ├── save_utils.py ├── string_utils.py └── timer_utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | .env 2 | .DS_Store 3 | datasets/ 4 | **/__pycache__/* 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Samuel Arnesen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /bash_scripts/basic_tests.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | PYTHON_PROGRAM="scripts/run_debate.py" 4 | ARGUMENTS=("--configuration=Simple_Test --num_iters=1" "--configuration=Single_Test --num_iters=1" "--configuration=Batched_Test --num_iters=40" "--configuration=Quality_Test --num_iters=10" "--configuration=Quality_Test --num_iters=10" "--configuration=BoN_Test --num_iters=10" "--configuration=Previous_Run_To_Replicate_Test --num_iters=10" "--configuration=Stub_LLM_Test --num_iters=10" "--configuration=Consultancy_Test --num_iters=10" "--configuration=Empty_Round_Test --num_iters=10") 5 | COMMON_ARGS=("--local --test --suppress_graphs --log_level=INFO") 6 | 7 | # Loop over each argument and run the Python program 8 | for ARG in "${ARGUMENTS[@]}"; do 9 | eval $(echo python "$PYTHON_PROGRAM" "${COMMON_ARGS[@]}" "$ARG") 10 | #echo python "$PYTHON_PROGRAM" "${COMMON_ARGS[@]}" "$ARG" 11 | 12 | # Check if the Python script exited with an error 13 | if [ $? -ne 0 ]; then 14 | echo "$ARG failed" 15 | echo $(echo python "$PYTHON_PROGRAM" "${COMMON_ARGS[@]}" "$ARG") 16 | break 17 | fi 18 | done 19 | -------------------------------------------------------------------------------- /bash_scripts/hpc_setup.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | pip install bitsandbytes; 4 | pip install --pre -v torch --index-url https://download.pytorch.org/whl/nightly/cu121 --prefix=/ext3/miniconda3; 5 | MAX_JOBS=4 FLASH_ATTENTION_FORCE_BUILD=TRUE pip install --force-reinstall --upgrade flash-attn --no-build-isolation; -------------------------------------------------------------------------------- /bash_scripts/operational_test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python scripts/run_debate.py --configuration=$1_Test --num_iters=2502 --local --test --log_level=INFO --force_save_results #--suppress_graphs #--force_save_transcripts -------------------------------------------------------------------------------- /bash_scripts/setup.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python -m spacy download en_core_web_sm -------------------------------------------------------------------------------- /bash_scripts/train_tests.sh: -------------------------------------------------------------------------------- 1 | python scripts/run_sft.py --config="Test" --load_only --local --test 2 | python scripts/run_ppo.py --config='Test' --load_only --local --test 3 | python scripts/run_iterative_dpo.py --config='Test - Iterative' --load_only --local --test 4 | python scripts/run_sft.py --config="Test - Open" --load_only --local --test 5 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | from . import loader_utils 2 | from .annotated_quality_debates_loader import Annotation, AnnotatedQualityDebatesDataset, AnnotatedQualityDebatesLoader 3 | from .dataset import ( 4 | AnnotationBracket, 5 | AnnotationData, 6 | AnnotationTag, 7 | DataRow, 8 | DatasetConfig, 9 | DatasetType, 10 | JudgePreferenceDataRow, 11 | JudgingProbeDataRow, 12 | RawDataLoader, 13 | RawDataset, 14 | SpeakerType, 15 | SpeechData, 16 | SplitType, 17 | ) 18 | from .judge_preferences_loader import JudgePreferencesLoader, JudgePreferencesDataset, JudgePreferencesLoader, RewardType 19 | from .quality_debates_loader import ( 20 | QualityDebatesDataset, 21 | QualityConsultancyLoader, 22 | QualityDebatesLoader, 23 | QualityModelBasedDebateDataset, 24 | QualityModelBasedDebateLoader, 25 | QualityTranscriptsLoader, 26 | ) 27 | from .quality_judging_loader import QualityJudgingDataset, QualityJudgingLoader 28 | from .quality_loader import QualityDataset, QualityLoader 29 | from .quote_relevance_loader import ( 30 | QuoteRelevanceDataset, 31 | QuoteRelevanceLoader, 32 | QuoteRelevanceProcessedBatchItem, 33 | QuoteRelevanceTopicInfo, 34 | ) 35 | from .scratchpad_quality_debates_loader import ScratchpadQualityDebatesDataset, ScratchpadQualityDebatesLoader 36 | -------------------------------------------------------------------------------- /data/annotated_quality_debates_loader.py: -------------------------------------------------------------------------------- 1 | from data.dataset import ( 2 | AnnotationBracket, 3 | AnnotationData, 4 | AnnotationTag, 5 | DataRow, 6 | DatasetType, 7 | RawDataLoader, 8 | RawDataset, 9 | SpeakerType, 10 | SpeechData, 11 | SplitType, 12 | ) 13 | from data.quality_debates_loader import QualityDebatesDataset, QualityDebatesLoader 14 | import utils.constants as constants 15 | 16 | from pydantic import BaseModel 17 | 18 | from difflib import SequenceMatcher 19 | from enum import Enum 20 | from typing import Optional, Union 21 | import copy 22 | import pickle 23 | import os 24 | import re 25 | import sys 26 | 27 | 28 | class Annotation(BaseModel): 29 | text: str 30 | clean: str 31 | metrics: dict[str | AnnotationTag, float] 32 | 33 | 34 | class AnnotatedQualityDebatesDataset(RawDataset): 35 | def __init__(self, dataset: QualityDebatesDataset, annotations_file_path: str): 36 | """ 37 | Dataset where the transcripts of the human debates are annotated with stylistic tags (e.g. statement, rebuttal) 38 | 39 | Params: 40 | dataset: a normal QualityDebateDataset 41 | annotations_file_path: path to the file with all the annotations. We will match the speeches in this 42 | file to the speeches in the normal QualityDebateDataset to attach the annotations 43 | """ 44 | 45 | super().__init__(DatasetType.ANNOTATED_QUALITY_DEBATES) 46 | self.data = { 47 | SplitType.TRAIN: dataset.get_data(SplitType.TRAIN), 48 | SplitType.VAL: dataset.get_data(SplitType.VAL), 49 | SplitType.TEST: dataset.get_data(SplitType.TEST), 50 | } 51 | self.__add_annotation(annotations_file_path=annotations_file_path) 52 | 53 | def get_data(self, split: SplitType = SplitType.TRAIN) -> list[DataRow]: 54 | """Returns all the data for a given split""" 55 | if split not in self.data: 56 | raise ValueError(f"Split type {split} is not recognized. Only TRAIN, VAL, and TEST are recognized") 57 | return self.data[split] 58 | 59 | def get_batch(self, split: SplitType = SplitType.TRAIN, batch_size: int = 1) -> list[DataRow]: 60 | """Returns a subset of the data for a given split""" 61 | if batch_size < 1: 62 | raise ValueError(f"Batch size must be >= 1. Inputted batch size was {batch_size}") 63 | data_to_return = self.data[split][self.idxs[split] : min(self.idxs[split] + batch_size, len(self.data[split]))] 64 | self.idxs[split] = self.idxs[split] + batch_size if self.idxs[split] + batch_size < len(self.data[split]) else 0 65 | return data_to_return 66 | 67 | def get_example(self, split: SplitType = SplitType.TRAIN, idx: int = 0) -> DataRow: 68 | """Returns an individual row in the dataset""" 69 | return self.data[split][idx % len(self.data[split])] 70 | 71 | @classmethod 72 | def meets_threshold( 73 | cls, tag: AnnotationTag, bracket: AnnotationBracket, threshold: float, positive: bool, speech: SpeechData 74 | ): 75 | """ 76 | Checks whether a given speech meets all the required filters. 77 | 78 | Params: 79 | tag: Which annotation tag to filter for (e.g. refutuation) 80 | bracket: When combined with the `threshold' and 'tag' params, this determines what qualifies as an eligible example. 81 | For example, a tag of 'refutation', a threshold of 0.9, and a bracket of 'high' means that you want samples 82 | that are in at least the 90th percentile of having the most refutation. By contrast, a bracket of 'low' in that 83 | situation would mean one would want to be in the 90th percentile of having the least refutation. 84 | threshold: When combined with the `threshold' and 'bracket' params, this determines what qualifies as an eligible example. 85 | positive: If one wants to filter for rows that explicitly meet the other criteria (true) or explicitly 86 | meet the opposite of the criteria (false) 87 | speech: The speech that is being evaluated for those criteria. 88 | """ 89 | if not speech.annotation or not speech.annotation.percentiles or tag not in speech.annotation.percentiles: 90 | return False 91 | 92 | if positive: 93 | if bracket == AnnotationBracket.HIGH: 94 | return speech.annotation.percentiles[tag] > threshold 95 | elif bracket == AnnotationBracket.LOW: 96 | return speech.annotation.percentiles[tag] < (1 - threshold) 97 | else: 98 | return speech.annotation.percentiles[tag] <= threshold and speech.annotation.percentiles[tag] >= ( 99 | 1 - threshold 100 | ) 101 | else: 102 | if bracket == AnnotationBracket.HIGH: 103 | return speech.annotation.percentiles[tag] < (1 - threshold) 104 | elif bracket == AnnotationBracket.LOW: 105 | return speech.annotation.percentiles[tag] > threshold 106 | else: 107 | return speech.annotation.percentiles[tag] <= threshold and speech.annotation.percentiles[tag] >= ( 108 | 1 - threshold 109 | ) 110 | 111 | def get_annotation_examples( 112 | self, 113 | tag: AnnotationTag, 114 | bracket: AnnotationBracket, 115 | threshold: float, 116 | positive: bool, 117 | source_row: Optional[DataRow] = None, 118 | ) -> list[SpeechData]: 119 | """ 120 | Filters the dataset to provide some few-shot examples of the same tag. For instance, if one wants to instruct a model 121 | to 'use a style that has a lot of refutation', one could use this method to fetch some examples of other 122 | speeches from the dataset that also have a lot of refutation. 123 | 124 | Params: 125 | tag: Which annotation tag to filter for (e.g. refutuation) 126 | bracket: When combined with the `threshold' and 'tag' params, this determines what qualifies as an eligible example. 127 | For example, a tag of 'refutation', a threshold of 0.9, and a bracket of 'high' means that you want samples 128 | that are in at least the 90th percentile of having the most refutation. By contrast, a bracket of 'low' in that 129 | situation would mean one would want to be in the 90th percentile of having the least refutation. 130 | threshold: When combined with the `threshold' and 'bracket' params, this determines what qualifies as an eligible example. 131 | positive: If one wants to filter for rows that explicitly meet the other criteria (true) or explicitly 132 | meet the opposite of the criteria (false) 133 | source_row: An optional row to exclude in case one is using this for supervised finetuning. If one 134 | is trying to induce a model to generate a particular speech using few-shot examples, one shouldn't pass 135 | in the target speech as one of the examples. 136 | 137 | Returns: 138 | A list of speeches that meet the specified criteria. 139 | """ 140 | eligible_examples = [] 141 | for row in filter(lambda x: not source_row or source_row.story_title != x.story_title, self.data[SplitType.TRAIN]): 142 | for speech in filter( 143 | lambda x: AnnotatedQualityDebatesDataset.meets_threshold(tag, bracket, threshold, positive, x), row.speeches 144 | ): 145 | eligible_examples.append(speech) 146 | return eligible_examples 147 | 148 | def __add_annotation(self, annotations_file_path: str) -> None: 149 | def match_speeches(speech: SpeechData, annotations: list[Annotation]): 150 | cleaned_speech = re.sub("\s+", " ", speech.text) 151 | for annotation in annotations: 152 | cleaned_annotation = re.sub("\s+", " ", annotation.clean).lstrip().rstrip() 153 | ratio = SequenceMatcher(None, cleaned_annotation, cleaned_speech).ratio() 154 | if cleaned_annotation == cleaned_speech or ratio > 0.99: 155 | return annotation 156 | return None 157 | 158 | with open(annotations_file_path, "rb") as f: 159 | id_to_speeches = pickle.load(f) 160 | 161 | annotated_speeches = [] 162 | for split in [SplitType.TRAIN, SplitType.VAL, SplitType.TEST]: 163 | for i, row in enumerate(self.data[split]): 164 | annotations = [Annotation(**entry) for entry in id_to_speeches[row.debate_id]] 165 | for annotation in annotations: 166 | annotation.metrics = {AnnotationTag[key.upper()]: value for key, value in annotation.metrics.items()} 167 | for speech in row.speeches: 168 | matching = match_speeches(speech, annotations) 169 | speech.annotation = AnnotationData(percents={}, percentiles={}) 170 | if matching: 171 | speech.annotation = AnnotationData(percents=copy.deepcopy(matching.metrics), percentiles={}) 172 | annotated_speeches.append(speech) 173 | 174 | for tag in AnnotationTag: 175 | distribution = sorted([speech.annotation.percents[tag] for speech in annotated_speeches]) 176 | for idx, speech in enumerate(annotated_speeches): 177 | for i in range(len(distribution)): 178 | if speech.annotation.percents[tag] <= distribution[i]: 179 | speech.annotation.percentiles[tag] = i / len(distribution) 180 | break 181 | if speech.annotation.percents[tag] > distribution[-1]: 182 | speech.annotation.percentiles[tag] = 1 183 | 184 | 185 | class AnnotatedQualityDebatesLoader(RawDataLoader): 186 | DEFAULT_ANNOTATIONS_FILE_PATH = ( 187 | os.environ[constants.SRC_ROOT] + "data/datasets/annotated-quality-debates/annotated-data-set.p" 188 | ) 189 | 190 | @classmethod 191 | def load( 192 | cls, 193 | full_dataset_filepath: Optional[str] = None, 194 | deduplicate: bool = False, 195 | supplemental_file_paths: Optional[dict[str, str]] = None, 196 | **kwargs, 197 | ) -> AnnotatedQualityDebatesDataset: 198 | """Constructs an AnnotatedQualityDebatesDataset""" 199 | annotations_file_path = ( 200 | supplemental_file_paths.get("annotations_file_path", AnnotatedQualityDebatesLoader.DEFAULT_ANNOTATIONS_FILE_PATH) 201 | if supplemental_file_paths 202 | else AnnotatedQualityDebatesLoader.DEFAULT_ANNOTATIONS_FILE_PATH 203 | ) 204 | 205 | quality_debates_dataset = QualityDebatesLoader.load( 206 | full_dataset_filepath=full_dataset_filepath, deduplicate=deduplicate 207 | ) 208 | return AnnotatedQualityDebatesDataset(dataset=quality_debates_dataset, annotations_file_path=annotations_file_path) 209 | -------------------------------------------------------------------------------- /data/cleaning_scripts/consultancy_to_double_consultancy_converter.py: -------------------------------------------------------------------------------- 1 | import os, sys, json 2 | 3 | src_root = "/Users/samarnesen/nyu/debate/nyu-debate-modeling/" 4 | os.environ["SRC_ROOT"] = src_root 5 | sys.path.insert(0, src_root) 6 | 7 | from utils import input_utils, InputType 8 | 9 | output_prefix = "/Users/samarnesen/nyu/scratch/runs/double-consultancy-llama/" 10 | file_prefix = "2024-04-04_20:45:09.521845" 11 | 12 | texts = [ 13 | json.loads(x) 14 | for x in input_utils.read_file_texts( 15 | f"{src_root}outputs/transcripts/{file_prefix}", input_type=InputType.JSON_TRANSCRIPT 16 | ) 17 | ] 18 | 19 | idx = 0 20 | new_files = [] 21 | while idx < len(texts): 22 | first = texts[idx] 23 | second = texts[idx + 1] 24 | assert first["metadata"]["question"] == second["metadata"]["question"] 25 | assert first["metadata"]["first_debater_answer"] != second["metadata"]["first_debater_answer"] 26 | assert first["metadata"]["first_debater_answer"] == second["metadata"]["second_debater_answer"] 27 | 28 | first_speakers = set([speech["speaker"] for speech in first["speeches"]]) 29 | second_speakers = set([speech["speaker"] for speech in second["speeches"]]) 30 | 31 | first_speeches = [speech for speech in filter(lambda x: x["speaker"] == "Debater_A", first["speeches"])] 32 | second_speeches = [speech for speech in filter(lambda x: x["speaker"] == "Debater_B", second["speeches"])] 33 | 34 | new = {"metadata": first["metadata"], "speeches": first_speeches + second_speeches} 35 | 36 | with open(f"{output_prefix}{file_prefix}_{idx // 2}_0.json", "w") as f: 37 | json.dump(new, f) 38 | 39 | idx += 2 40 | -------------------------------------------------------------------------------- /data/cleaning_scripts/debates_format_converter.py: -------------------------------------------------------------------------------- 1 | """This script loads debates that John Hughes, Dan Valentine, and Akbir Khan ran with GPT-4 (https://github.com/akbir/debate), 2 | and reformats them to the format that QualityDebatesLoader expects. 3 | """ 4 | 5 | import json 6 | import pandas as pd 7 | import re 8 | import sys 9 | 10 | 11 | """ 12 | external_debate_sources = [ 13 | "/Users/samarnesen/Downloads/sp/claude2.1_Bo16_claude2.1_Bo16/debate_sim/data0.csv", 14 | "/Users/samarnesen/Downloads/sp/claude2.1_Bo4_Co8_claude2.1_Bo4_Co8/debate_sim/data0.csv", 15 | "/Users/samarnesen/Downloads/sp/claude2.1_Bo4_claude2.1_Bo4/debate_sim/data0.csv", 16 | "/Users/samarnesen/Downloads/sp/claude2.1_Bo8_claude2.1_Bo8/debate_sim/data0.csv", 17 | "/Users/samarnesen/Downloads/sp/claude2.1_Co16_claude2.1_Co16/debate_sim/data0.csv", 18 | "/Users/samarnesen/Downloads/sp/claude2.1_Co2_claude2.1_Co2/debate_sim/data0.csv", 19 | "/Users/samarnesen/Downloads/sp/gpt4t_Bo16_gpt4t_Bo16/debate_sim/data0.csv", 20 | "/Users/samarnesen/Downloads/sp/gpt4t_Bo1_gpt4t_Bo1/debate_sim/data0.csv", 21 | "/Users/samarnesen/Downloads/sp/gpt4t_Bo32_gpt4t_Bo32/debate_sim/data0.csv", 22 | "/Users/samarnesen/Downloads/sp/gpt4t_Bo4_Co8_gpt4t_Bo4_Co8/debate_sim/data0.csv", 23 | "/Users/samarnesen/Downloads/sp/gpt4t_Bo4_gpt4t_Bo4/debate_sim/data0.csv", 24 | "/Users/samarnesen/Downloads/sp/gpt4t_Bo8_gpt4t_Bo8/debate_sim/data0.csv", 25 | "/Users/samarnesen/Downloads/sp/gpt4t_Co16_gpt4t_Co16/debate_sim/data0.csv", 26 | ] 27 | 28 | """ 29 | output_gpt_only = "/Users/samarnesen/nyu/scratch/combined_gpt_debates.jsonl" 30 | output_combined = "/Users/samarnesen/nyu/scratch/combined_human_and_gpt4_debates.jsonl" 31 | 32 | 33 | external_debate_sources = [ 34 | "/Users/samarnesen/Downloads/sp/gpt4t_Bo32_gpt4t_Bo32/debate_sim/data0.csv", 35 | ] 36 | 37 | # output_gpt_only = "/Users/samarnesen/nyu/scratch/adj_datasets/gpt3.5_bo16.jsonl" 38 | 39 | 40 | def get_debaters_from_file_path(file_path: str): 41 | filename = file_path.split("/")[-3] 42 | split_debaters = filename.split("_") 43 | debater_one = {"model_type": None, "bo": 0, "co": 0} 44 | debater_two = {"model_type": None, "bo": 0, "co": 0} 45 | current = None 46 | for comp in filename.split("_"): 47 | if re.match("Bo\d", comp): 48 | current["bo"] = int(re.match("Bo(\d)", comp).group(1)) 49 | elif re.match("Co\d", comp): 50 | current["co"] = int(re.match("Co(\d)", comp).group(1)) 51 | else: 52 | if not current: 53 | current = debater_one 54 | else: 55 | current = debater_two 56 | current["model_type"] = comp 57 | 58 | return debater_one, debater_two 59 | 60 | 61 | def get_external_debates(file_path: str): 62 | def create_turn(text: str, correct: bool = True, swap: bool = False, role: str = "Debater"): 63 | return { 64 | "chars": -1, 65 | "charLimit": -1, 66 | "quoteChars": -1, 67 | "quoteCharLimit": -1, 68 | "index": 0 if correct != swap else 1, 69 | "probabilities": None, 70 | "role": role, 71 | "text": text, 72 | } 73 | 74 | debates = [] 75 | 76 | df = pd.read_csv(file_path) 77 | one, two = get_debaters_from_file_path(file_path) 78 | for i, row in df.iterrows(): 79 | swap = i % 2 == 1 80 | transcript = json.loads(row["transcript"]) 81 | if swap: 82 | answers = [transcript["answers"]["incorrect"], transcript["answers"]["correct"]] 83 | else: 84 | answers = [transcript["answers"]["correct"], transcript["answers"]["incorrect"]] 85 | turns = [] 86 | 87 | rd = transcript["rounds"][0] 88 | correct_turn = create_turn(text=rd["correct"], correct=True, swap=swap) 89 | incorrect_turn = create_turn(rd["incorrect"], correct=False, swap=swap) 90 | turns.extend([correct_turn, incorrect_turn] if not swap else [incorrect_turn, correct_turn]) 91 | turns.append(create_turn(text="", role="Judge")) 92 | 93 | new_debate = { 94 | "storyId": "-1", 95 | "storyTitle": row["story_title"], 96 | "story": row["story"], 97 | "question": row["question"], 98 | "answers": answers, 99 | "debateId": "-1", 100 | "judge": "-1", 101 | "turns": turns, 102 | "isJudgeCorrect": False, 103 | "correctAnswer": row["correct answer"], 104 | "debaters": [one, two], 105 | } 106 | debates.append(new_debate) 107 | 108 | return debates 109 | 110 | 111 | def deduplicate(debates: list[dict]): 112 | story_id_to_debate = {} 113 | for debate in debates: 114 | key = debate["storyTitle"] + "_" + debate["question"] 115 | if key not in story_id_to_debate: 116 | story_id_to_debate[key] = debate 117 | else: 118 | existing = story_id_to_debate[key]["debaters"][0] 119 | current = debate["debaters"][0] 120 | if existing["model_type"] != current["model_type"]: 121 | story_id_to_debate[key] = ( 122 | story_id_to_debate[key] 123 | if existing["model_type"] == "gpt4t" and current["model_type"] == "claude2.1" 124 | else debate 125 | ) 126 | elif existing["bo"] != current["bo"]: 127 | story_id_to_debate[key] = story_id_to_debate[key] if existing["bo"] > current["bo"] else debate 128 | elif existing["co"] != current["co"]: 129 | story_id_to_debate[key] = story_id_to_debate[key] if existing["co"] > current["co"] else debate 130 | for debate in debates: 131 | del debate["debaters"] 132 | 133 | return list(story_id_to_debate.values()) 134 | 135 | 136 | if __name__ == "__main__": 137 | gpt_debates = [] 138 | for source in external_debate_sources: 139 | external_debates = get_external_debates(source) 140 | for debate in external_debates: 141 | is_truncated = False 142 | for turn in debate["turns"]: 143 | if not turn["text"]: 144 | print("MISSING") 145 | if "TRUNCATED" in turn["text"]: 146 | is_truncated = True 147 | if not is_truncated: 148 | gpt_debates.append(debate) 149 | 150 | """ 151 | with open(output_gpt_only, "w+") as f: 152 | for debate in gpt_debates: 153 | f.write(json.dumps(debate)) 154 | f.write("\n") 155 | """ 156 | 157 | with open("data/datasets/quality-debates/debates-readable.jsonl", "r") as human_f: 158 | lines = human_f.readlines() 159 | human_debates = [json.loads(line) for line in lines] 160 | 161 | """ 162 | with open(output_combined, "w") as f: 163 | all_debates = human_debates + gpt_debates 164 | for debate in all_debates: 165 | f.write(json.dumps(debate)) 166 | f.write("\n") 167 | """ 168 | 169 | print(len(gpt_debates)) 170 | print(len(human_debates)) 171 | -------------------------------------------------------------------------------- /data/cleaning_scripts/replicate_and_bin_format_converter.py: -------------------------------------------------------------------------------- 1 | """This script loads debates that John Hughes, Dan Valentine, and Akbir Khan ran with GPT-4 (https://github.com/akbir/debate), 2 | and reformats them to the format that QualityDebatesLoader expects. 3 | """ 4 | 5 | import json 6 | import pandas as pd 7 | import re 8 | import sys 9 | from typing import Optional 10 | 11 | 12 | output_gpt_only = "/Users/samarnesen/nyu/scratch/binned_gpt_debates_and_consultancies.jsonl" 13 | output_combined = "/Users/samarnesen/nyu/scratch/binned_human_and_gpt4_debates_and_consultancies.jsonl" 14 | 15 | external_debate_sources = [ 16 | "/Users/samarnesen/Downloads/llm_debate_dataset/llm_debate_human_judge_dataset.csv", 17 | ] 18 | 19 | 20 | def get_debaters_from_file_path(file_path: str): 21 | filename = file_path.split("/")[-3] 22 | split_debaters = filename.split("_") 23 | debater_one = {"model_type": None, "bo": 0, "co": 0} 24 | debater_two = {"model_type": None, "bo": 0, "co": 0} 25 | current = None 26 | for comp in filename.split("_"): 27 | if re.match("Bo\d", comp): 28 | current["bo"] = int(re.match("Bo(\d)", comp).group(1)) 29 | elif re.match("Co\d", comp): 30 | current["co"] = int(re.match("Co(\d)", comp).group(1)) 31 | else: 32 | if not current: 33 | current = debater_one 34 | else: 35 | current = debater_two 36 | current["model_type"] = comp 37 | 38 | return debater_one, debater_two 39 | 40 | 41 | def get_external_debates(file_path: str): 42 | def create_turn( 43 | text: str, 44 | correct: bool = True, 45 | swap: bool = False, 46 | role: str = "Debater", 47 | probs: Optional[tuple[float, float]] = None, 48 | is_judge: bool = False, 49 | ): 50 | index = 0 if correct != swap else 1 51 | return { 52 | "chars": -1, 53 | "charLimit": -1, 54 | "quoteChars": -1, 55 | "quoteCharLimit": -1, 56 | "index": index if not is_judge else None, 57 | "probabilities": probs, 58 | "role": role, 59 | "text": text, 60 | } 61 | 62 | debates = [] 63 | 64 | df = pd.read_csv(file_path) 65 | df = df[~df["judge_id"].isin([11, 17, 20, 22, 29, 32, 34, 35, 36, 37])] 66 | df = df[(df["debate_method"] == "debate") | (df["debate_method"] == "consultancy")] 67 | one, two = get_debaters_from_file_path(file_path) 68 | for i, row in df.iterrows(): 69 | swap = i % 2 == 1 70 | transcript = json.loads(row["transcript"]) 71 | if swap: 72 | answers = [transcript["answers"]["incorrect"], transcript["answers"]["correct"]] 73 | else: 74 | answers = [transcript["answers"]["correct"], transcript["answers"]["incorrect"]] 75 | 76 | rd = transcript["rounds"][0] 77 | is_debate = rd["correct"] is not None and rd["incorrect"] is not None 78 | correct_turn = create_turn(text=rd["correct"], correct=True, swap=swap) if rd["correct"] else None 79 | incorrect_turn = create_turn(text=rd["incorrect"], correct=False, swap=swap) if rd["incorrect"] else None 80 | 81 | turns = [correct_turn, incorrect_turn] if not swap else [incorrect_turn, correct_turn] 82 | turns = [turn for turn in filter(lambda x: x is not None, turns)] 83 | 84 | judge_probs = [0, 0] 85 | if not swap: 86 | if row["correct"]: # then A is correct and they voted for A 87 | judge_probs = (row["confidence"] / 100, 1 - (row["confidence"] / 100)) 88 | else: # then A is correct and they voted for B 89 | judge_probs = (1 - (row["confidence"] / 100), row["confidence"] / 100) 90 | else: 91 | if row["correct"]: # then B is correct and they voted for B 92 | judge_probs = (1 - (row["confidence"] / 100), row["confidence"] / 100) 93 | else: # then B is correct and they voted for A 94 | judge_probs = (row["confidence"] / 100, 1 - (row["confidence"] / 100)) 95 | turns.append(create_turn(text="", role="Judge", probs=judge_probs, is_judge=True)) 96 | 97 | new_debate = { 98 | "storyId": "-1", 99 | "storyTitle": row["story_title"], 100 | "story": row["story_title"], 101 | "question": row["question"], 102 | "answers": answers, 103 | "debateId": "-1", 104 | "judge": "-1", 105 | "turns": turns, 106 | "isJudgeCorrect": False, 107 | "correctAnswer": transcript["answers"]["correct"], 108 | "debaters": [one, two], 109 | } 110 | 111 | debates.append(new_debate) 112 | 113 | return debates 114 | 115 | 116 | def deduplicate(debates: list[dict]): 117 | story_id_to_debate = {} 118 | for debate in debates: 119 | key = debate["storyTitle"] + "_" + debate["question"] 120 | if key not in story_id_to_debate: 121 | story_id_to_debate[key] = debate 122 | else: 123 | existing = story_id_to_debate[key]["debaters"][0] 124 | current = debate["debaters"][0] 125 | if existing["model_type"] != current["model_type"]: 126 | story_id_to_debate[key] = ( 127 | story_id_to_debate[key] 128 | if existing["model_type"] == "gpt4t" and current["model_type"] == "claude2.1" 129 | else debate 130 | ) 131 | elif existing["bo"] != current["bo"]: 132 | story_id_to_debate[key] = story_id_to_debate[key] if existing["bo"] > current["bo"] else debate 133 | elif existing["co"] != current["co"]: 134 | story_id_to_debate[key] = story_id_to_debate[key] if existing["co"] > current["co"] else debate 135 | for debate in debates: 136 | del debate["debaters"] 137 | 138 | return list(story_id_to_debate.values()) 139 | 140 | 141 | if __name__ == "__main__": 142 | gpt_debates = [] 143 | for source in external_debate_sources: 144 | external_debates = get_external_debates(source) 145 | for debate in external_debates: 146 | is_truncated = False 147 | for turn in debate["turns"]: 148 | if "TRUNCATED" in turn["text"]: 149 | is_truncated = True 150 | if not is_truncated: 151 | gpt_debates.append(debate) 152 | 153 | with open("data/datasets/quality-debates/debates-readable.jsonl", "r") as human_f: 154 | lines = human_f.readlines() 155 | human_debates = [json.loads(line) for line in lines] 156 | 157 | with open(output_combined, "w") as f: 158 | all_debates = human_debates + gpt_debates 159 | for debate in all_debates: 160 | f.write(json.dumps(debate)) 161 | f.write("\n") 162 | -------------------------------------------------------------------------------- /data/dataset.py: -------------------------------------------------------------------------------- 1 | from abc import ABC 2 | from dataclasses import dataclass 3 | from enum import Enum 4 | from typing import Any, Optional, Union 5 | 6 | from pydantic import BaseModel, field_validator 7 | import torch 8 | 9 | 10 | class SplitType(Enum): 11 | TRAIN = 1 12 | VAL = 2 13 | TEST = 3 14 | 15 | 16 | class DatasetType(Enum): 17 | QUALITY = (1, True) 18 | QUALITY_DEBATES = (2, True) 19 | JUDGE_PREFERENCES = (3, True) 20 | ANNOTATED_QUALITY_DEBATES = (4, True) 21 | SCRATCHPAD_QUALITY_DEBATES = (5, True) 22 | QUOTE_RELEVANCE = (6, True) 23 | JUDGING_PROBE = (7, True) 24 | QUALITY_CONSULTANCY = (8, True) 25 | CORRECTNESS_JUDGE_PREFERENCES = (9, True) 26 | EXTERNAL_HUGGINGFACE = (10, False) 27 | 28 | def __init__(self, idx: int, is_instantiable: bool): 29 | self.id = idx 30 | self.is_instantiable = is_instantiable 31 | 32 | 33 | class DatasetConfig(BaseModel): 34 | dataset_type: DatasetType 35 | full_dataset_file_path: Optional[str | list[str]] = None 36 | train_file_path: Optional[str] = None 37 | val_file_path: Optional[str] = None 38 | test_file_path: Optional[str] = None 39 | supplemental_file_paths: dict[str, str | list[str]] = {} 40 | split_type: SplitType = SplitType.TRAIN 41 | combine_train_and_val: bool = False 42 | flip_sides: bool = False 43 | shuffle_deterministically: bool = False 44 | 45 | @field_validator("split_type", mode="before") 46 | @classmethod 47 | def validate_tournament_type(cls, split_type: str): 48 | return SplitType[split_type.upper()] 49 | 50 | @field_validator("dataset_type", mode="before") 51 | @classmethod 52 | def validate_dataset_type(cls, dataset_type: str): 53 | return DatasetType[dataset_type.upper()] 54 | 55 | 56 | class SpeakerType(Enum): 57 | DEBATER = 1 58 | JUDGE = 2 59 | 60 | 61 | class AnnotationTag(Enum): 62 | QUOTE = 0 63 | SUMMARY = 1 64 | REFUTATION = 2 65 | ANALYSIS = 3 66 | REPLY = 4 67 | FLOURISH = 5 68 | FRAMING = 6 69 | STATEMENT = 7 70 | LOGIC = 8 71 | Q_CONTEXT = 9 72 | POSITION = 10 73 | OOB_QUOTE = 11 74 | PROMISE = 12 75 | 76 | 77 | class AnnotationBracket(Enum): 78 | HIGH = 1 79 | LOW = 2 80 | NEUTRAL = 3 81 | 82 | 83 | class AnnotationData(BaseModel): 84 | percents: Optional[dict[AnnotationTag | str, float]] = None 85 | percentiles: Optional[dict[AnnotationTag | str, float]] = None 86 | 87 | 88 | class SpeechData(BaseModel): 89 | text: str 90 | position: int 91 | speaker_type: SpeakerType 92 | scratchpad: Optional[str] = None 93 | annotation: Optional[AnnotationData] = None 94 | probabilities: Optional[tuple[float, float]] = None 95 | 96 | 97 | class DataRow(BaseModel): 98 | background_text: str 99 | question: Optional[str] = None 100 | positions: Optional[tuple[str, str]] = None 101 | speeches: Optional[list[SpeechData]] = None 102 | correct_index: Optional[int] = None 103 | debate_id: Optional[str] = None 104 | story_title: Optional[str] = None 105 | 106 | 107 | class JudgePreferenceDataRow(BaseModel): 108 | prompt: str 109 | chosen: str 110 | rejected: str 111 | preference: float = 1.0 112 | 113 | 114 | @dataclass 115 | class JudgingProbeDataRow: 116 | internal_representation: torch.tensor 117 | target: torch.tensor 118 | 119 | 120 | class RawDataset(ABC): 121 | def __init__(self, dataset_type: DatasetType): 122 | self.dataset_type = dataset_type 123 | 124 | def get_data(self, split: SplitType = SplitType.TRAIN) -> list[tuple[str, Any]]: 125 | """Fetches all the data for a given split of the data""" 126 | pass 127 | 128 | def get_batch(self, split: SplitType = SplitType.TRAIN, batch_size: int = 1) -> list[tuple[str, Any]]: 129 | """Gets a subset of the data""" 130 | pass 131 | 132 | def get_example(self, split: SplitType = SplitType.TRAIN, idx: int = 0) -> DataRow: 133 | """Returns an individual row at the specified index""" 134 | pass 135 | 136 | def get_dataset_type(self): 137 | """Gets the name of the dataset""" 138 | return self.dataset_type 139 | 140 | def merge(self, other): 141 | """Combines the data from two datasets""" 142 | pass 143 | 144 | 145 | class RawDataLoader(ABC): 146 | @classmethod 147 | def load( 148 | cls, 149 | full_dataset_filepath: Optional[str] = None, 150 | train_filepath: Optional[str] = None, 151 | validation_filepath: Optional[str] = None, 152 | test_filepath: Optional[str] = None, 153 | supplemental_file_paths: Optional[str] = None, 154 | combine_train_and_val: bool = False, 155 | ) -> RawDataset: 156 | """Constructs a dataset""" 157 | pass 158 | -------------------------------------------------------------------------------- /data/datasets/annotated-quality-debates/annotated-data-set.p: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/samuelarnesen/nyu-debate-modeling/455615d3d6fb1a0ebc158f7eb894bfd1aa63a90a/data/datasets/annotated-quality-debates/annotated-data-set.p -------------------------------------------------------------------------------- /data/datasets/annotated-quality-debates/classifier.p: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/samuelarnesen/nyu-debate-modeling/455615d3d6fb1a0ebc158f7eb894bfd1aa63a90a/data/datasets/annotated-quality-debates/classifier.p -------------------------------------------------------------------------------- /data/datasets/quote-relevance/quote-relevance.p: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/samuelarnesen/nyu-debate-modeling/455615d3d6fb1a0ebc158f7eb894bfd1aa63a90a/data/datasets/quote-relevance/quote-relevance.p -------------------------------------------------------------------------------- /data/datasets/scratchpad-quality-debates/scratchpad-quality-debates.p: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/samuelarnesen/nyu-debate-modeling/455615d3d6fb1a0ebc158f7eb894bfd1aa63a90a/data/datasets/scratchpad-quality-debates/scratchpad-quality-debates.p -------------------------------------------------------------------------------- /data/judge_preferences_loader.py: -------------------------------------------------------------------------------- 1 | from data.dataset import DataRow, DatasetType, JudgePreferenceDataRow, RawDataLoader, RawDataset, SplitType 2 | from data.quality_loader import QualityLoader 3 | from utils import InputType, input_utils, quote_utils 4 | import utils.constants as constants 5 | 6 | from enum import Enum, auto 7 | from typing import Any, Optional 8 | import json, math, random 9 | 10 | 11 | class RewardType(Enum): 12 | LOG_PROB = auto() 13 | PROB = auto() 14 | LOGIT = auto() 15 | SIGMOID = auto() 16 | BINARY = auto() 17 | 18 | 19 | class JudgePreferencesDataset(RawDataset): 20 | def __init__(self, train_data: list[str, Any], val_data: list[str, Any], test_data: list[str, Any]): 21 | """ 22 | A dataset of judge preferences from a previous best-of-n run. Each row is a pair of speeches with one 23 | labelled as the chosen speech and the other as the rejected speech. 24 | """ 25 | super().__init__(DatasetType.JUDGE_PREFERENCES) 26 | self.data = { 27 | SplitType.TRAIN: self.__convert_batch_to_rows(train_data), 28 | SplitType.VAL: self.__convert_batch_to_rows(val_data), 29 | SplitType.TEST: self.__convert_batch_to_rows(test_data), 30 | } 31 | self.idxs = {SplitType.TRAIN: 0, SplitType.VAL: 0, SplitType.TEST: 0} 32 | 33 | def get_data(self, split: SplitType = SplitType.TRAIN) -> list[JudgePreferenceDataRow]: 34 | """Returns all the data for a given split""" 35 | if split not in self.data: 36 | raise ValueError(f"Split type {split} is not recognized. Only TRAIN, VAL, and TEST are recognized") 37 | return self.data[split] 38 | 39 | def get_batch(self, split: SplitType = SplitType.TRAIN, batch_size: int = 1) -> list[JudgePreferenceDataRow]: 40 | """Returns a subset of the data for a given split""" 41 | if batch_size < 1: 42 | raise ValueError(f"Batch size must be >= 1. Inputted batch size was {batch_size}") 43 | data_to_return = self.data[split][self.idxs[split] : min(self.idxs[split] + batch_size, len(self.data[split]))] 44 | self.idxs[split] = self.idxs[split] + batch_size if self.idxs[split] + batch_size < len(self.data[split]) else 0 45 | return data_to_return 46 | 47 | def get_example(self, split: SplitType = SplitType.TRAIN, idx: int = 0) -> JudgePreferenceDataRow: 48 | """Returns an individual row in the dataset""" 49 | return self.data[split][idx % len(self.data[split])] 50 | 51 | def merge(self, other: RawDataset): 52 | """Combines the data from two datasets""" 53 | for key in filter(lambda x: x in other.data, self.data): 54 | self.data[key] += other.data[key] 55 | 56 | def __convert_batch_to_rows(self, train_data: list[tuple[str, str, str, float]]): 57 | return [ 58 | JudgePreferenceDataRow(prompt=instruction, chosen=chosen, rejected=rejected, preference=preference) 59 | for instruction, chosen, rejected, preference in train_data 60 | ] 61 | 62 | 63 | class JudgePreferencesLoader(RawDataLoader): 64 | MIN_GAP = 0.00 65 | 66 | @classmethod 67 | def process_row( 68 | cls, data: dict[Any, Any], reward_type: RewardType = RewardType.LOG_PROB, **kwargs 69 | ) -> list[tuple[str, str, str, float]]: 70 | def clean_speech(speech: str) -> str: 71 | speech = speech.replace(constants.INVALID_QUOTE_TAG, constants.QUOTE_TAG).replace( 72 | constants.INVALID_UNQUOTE_TAG, constants.UNQUOTE_TAG 73 | ) 74 | return quote_utils.clean_up_quotes(speech_content=speech) 75 | 76 | def get_preference(selected: dict, rejected: dict) -> float: 77 | selected_pref = selected["supplemental"]["preference"] 78 | rejected_pref = rejected["preference"] 79 | if reward_type == RewardType.LOGIT: 80 | selected_over_rejected = selected_pref * (1 - rejected_pref) 81 | rejected_over_selected = rejected_pref * (1 - selected_pref) 82 | return selected_over_rejected / (selected_over_rejected + rejected_over_selected) 83 | elif reward_type == RewardType.PROB: 84 | multiplier = kwargs.get("multiplier", 5.75) 85 | return math.exp(multiplier * selected_pref) / ( 86 | math.exp(multiplier * selected_pref) + math.exp(multiplier * rejected_pref) 87 | ) 88 | elif reward_type == RewardType.SIGMOID: 89 | multiplier = kwargs.get("multiplier", 5) 90 | temperature = kwargs.get("temperature", 0.125) 91 | mean = 0.5 92 | selected_reward = multiplier / (1 + math.exp(-((selected_pref - mean) / temperature))) 93 | rejected_reward = multiplier / (1 + math.exp(-((rejected_pref - mean) / temperature))) 94 | return math.exp(selected_reward) / (math.exp(selected_reward) + math.exp(rejected_reward)) 95 | elif reward_type == RewardType.BINARY: 96 | return 1.0 97 | else: 98 | multiplier = kwargs.get("multiplier", 2.25) 99 | return (selected_pref**multiplier) / ((rejected_pref**multiplier) + (selected_pref**multiplier)) 100 | 101 | outputs = [] 102 | for selected in filter( 103 | lambda x: x["speaker"] in [constants.DEFAULT_DEBATER_A_NAME, constants.DEFAULT_DEBATER_B_NAME], 104 | data["speeches"], 105 | ): 106 | instruction = selected["supplemental"]["prompt"] 107 | rejected = sorted(selected["supplemental"]["rejected_responses"], key=lambda x: x["preference"])[0] 108 | if selected["supplemental"]["preference"] - rejected["preference"] > JudgePreferencesLoader.MIN_GAP: 109 | selected_speech = clean_speech(selected["content"]) 110 | rejected_speech = clean_speech(rejected["speech"]) 111 | preference = get_preference(selected, rejected) 112 | outputs.append((instruction, selected_speech, rejected_speech, preference)) 113 | return outputs 114 | 115 | @classmethod 116 | def load( 117 | cls, full_dataset_filepath: str | list[str], reward_type: RewardType = RewardType.LOG_PROB, **kwargs 118 | ) -> JudgePreferencesDataset: 119 | """ 120 | Constructs a JudgePreferencesDataset. 121 | 122 | Params: 123 | full_dataset_filepath: This is the *prefix* of the files with all the Best-of-N generations. 124 | 125 | Returns: 126 | A JudgePreferencesDataset where each row has a chosen and a rejected speech. 127 | """ 128 | 129 | train_data = [] 130 | input_texts = input_utils.read_file_texts(base_path=full_dataset_filepath, input_type=InputType.JSON_TRANSCRIPT) 131 | for text in input_texts: 132 | train_data.extend(JudgePreferencesLoader.process_row(json.loads(text), reward_type=reward_type, **kwargs)) 133 | 134 | return JudgePreferencesDataset( 135 | train_data=train_data, 136 | val_data=[], 137 | test_data=[], 138 | ) 139 | 140 | 141 | class CorrectnessJudgePreferencesLoader(RawDataLoader): 142 | @classmethod 143 | def load(cls, full_dataset_filepath: str | list[str], **kwargs) -> JudgePreferencesDataset: 144 | """ 145 | Constructs a CorrectnessJudgePreferencesDataset. This is a modified JudgePreferencesDataset where a speech is marked 146 | as "chosen" if it defends the correct side and "rejected" if it defends the incorrect side. The "preference" is just 147 | the outright win probability. 148 | 149 | Params: 150 | full_dataset_filepath: This is the *prefix* of the files with all the Best-of-N generations. 151 | 152 | Returns: 153 | A JudgePreferencesDataset where each row has a chosen and a rejected speech. 154 | """ 155 | 156 | def clean_speech(speech: str) -> str: 157 | speech = speech.replace(constants.INVALID_QUOTE_TAG, constants.QUOTE_TAG).replace( 158 | constants.INVALID_UNQUOTE_TAG, constants.UNQUOTE_TAG 159 | ) 160 | return quote_utils.clean_up_quotes(speech_content=speech) 161 | 162 | def get_actual_judge_score(speeches: list[dict[Any, Any]], name: str) -> Optional[float]: 163 | for i in range(len(speeches)): 164 | speech = speeches[len(speeches) - i - 1] 165 | if speech["speaker"] == constants.DEFAULT_JUDGE_NAME and speech["supplemental"]["probabilistic_decision"]: 166 | return speech["supplemental"]["probabilistic_decision"][name] 167 | return None 168 | 169 | train_data = [] 170 | input_texts = input_utils.read_file_texts(base_path=full_dataset_filepath, input_type=InputType.JSON_TRANSCRIPT) 171 | for text in input_texts: 172 | data = json.loads(text) 173 | for selected in filter( 174 | lambda x: x["speaker"] in [constants.DEFAULT_DEBATER_A_NAME, constants.DEFAULT_DEBATER_B_NAME], 175 | data["speeches"], 176 | ): 177 | instruction = selected["supplemental"]["prompt"] 178 | 179 | speech_preference_pairs = [(selected["content"], selected["supplemental"]["preference"])] + [ 180 | (rejected["speech"], rejected["preference"]) 181 | for rejected in selected["supplemental"]["rejected_responses"] 182 | ] 183 | random_selected_speech, random_selected_preference = random.choice(speech_preference_pairs) 184 | 185 | if not selected["supplemental"]["preference"]: 186 | random_selected_preference = get_actual_judge_score(data["speeches"], selected["speaker"]) 187 | 188 | is_correct = ( 189 | data["metadata"]["first_debater_correct"] and selected["speaker"] == constants.DEFAULT_DEBATER_A_NAME 190 | ) or ( 191 | not data["metadata"]["first_debater_correct"] and selected["speaker"] == constants.DEFAULT_DEBATER_B_NAME 192 | ) 193 | if is_correct: 194 | train_data.append((instruction, clean_speech(random_selected_speech), "", random_selected_preference)) 195 | else: 196 | train_data.append((instruction, "", clean_speech(random_selected_speech), random_selected_preference)) 197 | 198 | return JudgePreferencesDataset( 199 | train_data=train_data, 200 | val_data=[], 201 | test_data=[], 202 | ) 203 | -------------------------------------------------------------------------------- /data/loader_utils.py: -------------------------------------------------------------------------------- 1 | from data.dataset import RawDataLoader, DatasetType 2 | from data.annotated_quality_debates_loader import AnnotatedQualityDebatesLoader 3 | from data.judge_preferences_loader import CorrectnessJudgePreferencesLoader, JudgePreferencesLoader 4 | from data.scratchpad_quality_debates_loader import ScratchpadQualityDebatesLoader 5 | from data.quality_loader import QualityLoader 6 | from data.quality_debates_loader import QualityConsultancyLoader, QualityDebatesLoader 7 | from data.quality_judging_loader import QualityJudgingLoader 8 | from data.quote_relevance_loader import QuoteRelevanceLoader 9 | 10 | from enum import Enum 11 | from typing import Type 12 | 13 | 14 | def get_loader_type(dataset_type: DatasetType) -> Type[RawDataLoader]: 15 | """Returns the class associated with the inputted DatasetType""" 16 | if dataset_type == DatasetType.QUALITY: 17 | return QualityLoader 18 | elif dataset_type == DatasetType.QUALITY_DEBATES: 19 | return QualityDebatesLoader 20 | elif dataset_type == DatasetType.JUDGE_PREFERENCES: 21 | return JudgePreferencesLoader 22 | elif dataset_type == DatasetType.ANNOTATED_QUALITY_DEBATES: 23 | return AnnotatedQualityDebatesLoader 24 | elif dataset_type == DatasetType.SCRATCHPAD_QUALITY_DEBATES: 25 | return ScratchpadQualityDebatesLoader 26 | elif dataset_type == DatasetType.QUOTE_RELEVANCE: 27 | return QuoteRelevanceLoader 28 | elif dataset_type == DatasetType.JUDGING_PROBE: 29 | return QualityJudgingLoader 30 | elif dataset_type == DatasetType.QUALITY_CONSULTANCY: 31 | return QualityConsultancyLoader 32 | elif dataset_type == DatasetType.CORRECTNESS_JUDGE_PREFERENCES: 33 | return CorrectnessJudgePreferencesLoader 34 | 35 | raise Exception(f"Loader {dataset_type} not found") 36 | -------------------------------------------------------------------------------- /data/quality_judging_loader.py: -------------------------------------------------------------------------------- 1 | from data.dataset import DataRow, DatasetType, JudgingProbeDataRow, RawDataLoader, RawDataset, SplitType 2 | from data.quality_loader import QualityLoader 3 | from utils import InputType, input_utils 4 | import utils.constants as constants 5 | 6 | import torch 7 | 8 | from typing import Any, Optional 9 | import base64 10 | import io 11 | import json 12 | 13 | 14 | class QualityJudgingDataset(RawDataset): 15 | def __init__(self, train_data: list[str, Any], val_data: list[str, Any], test_data: list[str, Any]): 16 | """ 17 | A dataset of judge internal representations, mapped to a target (whether it corresponds to the correct side). 18 | """ 19 | super().__init__(DatasetType.JUDGING_PROBE) 20 | self.data = { 21 | SplitType.TRAIN: self.__convert_batch_to_rows(train_data), 22 | SplitType.VAL: self.__convert_batch_to_rows(val_data), 23 | SplitType.TEST: self.__convert_batch_to_rows(test_data), 24 | } 25 | self.idxs = {SplitType.TRAIN: 0, SplitType.VAL: 0, SplitType.TEST: 0} 26 | 27 | def get_data(self, split: SplitType = SplitType.TRAIN) -> list[JudgingProbeDataRow]: 28 | """Returns all the data for a given split""" 29 | if split not in self.data: 30 | raise ValueError(f"Split type {split} is not recognized. Only TRAIN, VAL, and TEST are recognized") 31 | return self.data[split] 32 | 33 | def get_batch(self, split: SplitType = SplitType.TRAIN, batch_size: int = 1) -> list[JudgingProbeDataRow]: 34 | """Returns a subset of the data for a given split""" 35 | if batch_size < 1: 36 | raise ValueError(f"Batch size must be >= 1. Inputted batch size was {batch_size}") 37 | data_to_return = self.data[split][self.idxs[split] : min(self.idxs[split] + batch_size, len(self.data[split]))] 38 | self.idxs[split] = self.idxs[split] + batch_size if self.idxs[split] + batch_size < len(self.data[split]) else 0 39 | return data_to_return 40 | 41 | def get_example(self, split: SplitType = SplitType.TRAIN, idx: int = 0) -> JudgingProbeDataRow: 42 | """Returns an individual row in the dataset""" 43 | return self.data[split][idx % len(self.data[split])] 44 | 45 | def __convert_batch_to_rows(self, train_data: list[tuple[torch.tensor, torch.tensor]]): 46 | return [ 47 | JudgingProbeDataRow(internal_representation=internal_representation, target=target) 48 | for internal_representation, target in train_data 49 | ] 50 | 51 | 52 | class QualityJudgingLoader(RawDataLoader): 53 | @classmethod 54 | def load( 55 | cls, 56 | full_dataset_filepath: str | list[str], 57 | supplemental_file_paths: Optional[dict[str, str]] = None, 58 | linear_idxs: Optional[list[int]] = None, 59 | combine_train_and_val: bool = False, 60 | **kwargs, 61 | ) -> QualityJudgingDataset: 62 | """ 63 | Constructs a QualityJudgingDataset. 64 | 65 | Params: 66 | full_dataset_filepath: This is the *prefix* of the files with all the stored internal representations 67 | supplemental_file_paths: An optional dictionary of paths that could be used to support the creation 68 | of the dataset. In this case, the relevant one would be quality_file_path. 69 | linear_idxs: list of layer indexes that should be used for the linear probes 70 | combine_train_and_val: if the validation set should be merged into the training set (for when one is done 71 | with validation and just wants to train on the whole dataset) 72 | Returns: 73 | A QualityJudgingDataset where each row has an internal representation tensor and a target winning percentage 74 | """ 75 | 76 | # move this to the quality dataset 77 | def get_original_data_row(data: dict[Any, Any], dataset: RawDataset) -> DataRow: 78 | debate_identifier = data["metadata"]["debate_identifier"] 79 | question = data["metadata"]["question"] 80 | story_title = debate_identifier.replace("_" + question, "") 81 | for row in dataset.get_data(split=SplitType.TRAIN): 82 | if row.story_title == story_title and row.question == question: 83 | return row 84 | raise Exception(f"A row with title {story_title} and question {question} could not be found in the dataset") 85 | 86 | device = "cuda" if torch.cuda.is_available() else "cpu" 87 | quality_filepath = (supplemental_file_paths or {}).get("quality_file_path", QualityLoader.DEFAULT_TRAIN_PATH) 88 | quality_dataset = QualityLoader.load(full_dataset_filepath=quality_filepath) 89 | 90 | data_list = [] 91 | input_texts = input_utils.read_file_texts(base_path=full_dataset_filepath, input_type=InputType.JSON_TRANSCRIPT) 92 | for text in input_texts: 93 | data = json.loads(text) 94 | row = get_original_data_row(data=data, dataset=quality_dataset) 95 | for speech in filter( 96 | lambda x: x["speaker"] == constants.DEFAULT_JUDGE_NAME and x["supplemental"]["internal_representations"], 97 | data["speeches"], 98 | ): 99 | internal_representations = speech["supplemental"]["internal_representations"] 100 | decoded_representation = base64.b64decode(internal_representations) 101 | big_buffer = io.BytesIO(decoded_representation) 102 | decoded_list = torch.load(big_buffer, map_location=device) 103 | relevant_entries = [decoded_list[int(idx)] for idx in (linear_idxs or [-1])] # relevant parts to concat 104 | x = torch.cat(relevant_entries, dim=0) 105 | y = torch.tensor([1, 0] if row.correct_index == 0 else [0, 1]).float() 106 | data_list.append((x, y)) 107 | 108 | train_data = data_list[0 : int(0.8 * len(data_list))] 109 | val_data = data_list[int(0.8 * len(data_list)) :] 110 | if combine_train_and_val: 111 | train_data = data_list 112 | val_data = [] 113 | 114 | return QualityJudgingDataset( 115 | train_data=train_data, 116 | val_data=val_data, 117 | test_data=[], 118 | ) 119 | -------------------------------------------------------------------------------- /data/quote_relevance_loader.py: -------------------------------------------------------------------------------- 1 | from data.dataset import DataRow, DatasetType, RawDataLoader, RawDataset, SpeakerType, SpeechData, SplitType 2 | from data.quality_loader import QualityLoader, QualityDataset 3 | from data.scratchpad_quality_debates_loader import ScratchpadQualityDebatesLoader, ScratchpadQualityDebatesDataset 4 | import utils.constants as constants 5 | 6 | from typing import Any, Optional 7 | 8 | from pydantic import BaseModel 9 | import json 10 | import os 11 | import pickle 12 | 13 | 14 | class QuoteRelevanceTopicInfo(BaseModel): 15 | question: str 16 | a_position: str 17 | b_position: str 18 | 19 | 20 | class QuoteRelevanceProcessedBatchItem(BaseModel): 21 | a_quote_map: dict[str, int] 22 | b_quote_map: dict[str, int] 23 | question_info: QuoteRelevanceTopicInfo 24 | 25 | 26 | class QuoteRelevanceDataset(QualityDataset): 27 | FILTER_THRESHOLD = 5 28 | 29 | def __init__( 30 | self, 31 | train_data: list[dict[str, Any]], 32 | val_data: list[dict[str, Any]], 33 | test_data: list[dict[str, Any]], 34 | quote_label_file_path: str, 35 | scratchpad_dataset: ScratchpadQualityDebatesDataset, 36 | ): 37 | """Dataset that builds on top of the quality dataset but there are scratchpads added that contain 38 | the most relevant quotes from the passage""" 39 | super().__init__( 40 | train_data=train_data, 41 | val_data=val_data, 42 | test_data=test_data, 43 | override_type=DatasetType.QUOTE_RELEVANCE, 44 | allow_multiple_positions_per_question=True, 45 | ) 46 | self.__match_processed_quotes_to_stories( 47 | quote_label_file_path=quote_label_file_path, scratchpad_dataset=scratchpad_dataset 48 | ) 49 | 50 | def __match_processed_quotes_to_stories( 51 | self, quote_label_file_path: str, scratchpad_dataset: ScratchpadQualityDebatesDataset 52 | ): 53 | def standardize_string(input_string: str): 54 | return input_string.strip().lower() 55 | 56 | with open(quote_label_file_path, "rb") as f: 57 | quote_labels = pickle.load(f) 58 | 59 | pairs = [] 60 | for i, item in enumerate(quote_labels): 61 | question_info = item.question_info 62 | for j, row in enumerate(self.data[SplitType.TRAIN]): 63 | positions = [standardize_string(position) for position in row.positions] 64 | if ( 65 | standardize_string(row.question) == standardize_string(question_info.question) 66 | and standardize_string(question_info.a_position) in positions 67 | and standardize_string(question_info.b_position) in positions 68 | ): 69 | pairs.append((item, row)) 70 | break 71 | 72 | rows_to_use = [] 73 | for item, row in pairs: 74 | row.speeches = [] 75 | 76 | filtered_a_quote_map = { 77 | quote: score 78 | for quote, score in filter(lambda x: x[1] > QuoteRelevanceDataset.FILTER_THRESHOLD, item.a_quote_map.items()) 79 | } 80 | a_scratchpad = "\n\n".join( 81 | [ 82 | f"{(i + 1)}. {constants.QUOTE_TAG}{quote}{constants.UNQUOTE_TAG}" 83 | for i, quote in enumerate(filter(lambda x: x, filtered_a_quote_map)) 84 | ] 85 | ).strip() 86 | row.speeches.append(SpeechData(text="", position=0, speaker_type=SpeakerType.DEBATER, scratchpad=a_scratchpad)) 87 | 88 | filtered_b_quote_map = { 89 | quote: score 90 | for quote, score in filter(lambda x: x[1] > QuoteRelevanceDataset.FILTER_THRESHOLD, item.b_quote_map.items()) 91 | } 92 | 93 | b_scratchpad = "\n\n".join( 94 | [ 95 | f"{(i + 1)}. {constants.QUOTE_TAG}{quote}{constants.UNQUOTE_TAG}" 96 | for i, quote in enumerate(filter(lambda x: x, filtered_b_quote_map)) 97 | ] 98 | ).strip() 99 | row.speeches.append(SpeechData(text="", position=1, speaker_type=SpeakerType.DEBATER, scratchpad=b_scratchpad)) 100 | 101 | if a_scratchpad or b_scratchpad: 102 | rows_to_use.append(row) 103 | 104 | rows_to_use.extend(scratchpad_dataset.get_data(split=SplitType.TRAIN)) 105 | 106 | self.data[SplitType.TRAIN] = rows_to_use 107 | self.data[SplitType.VAL] = [] 108 | self.data[SplitType.TEST] = [] 109 | 110 | 111 | class QuoteRelevanceLoader(RawDataLoader): 112 | DEFAULT_QUOTE_LABEL_FILE_PATH = os.environ["SRC_ROOT"] + "data/datasets/quote-relevance/quote-relevance.p" 113 | 114 | @classmethod 115 | def load( 116 | cls, 117 | train_filepath: Optional[str] = None, 118 | val_filepath: Optional[str] = None, 119 | test_filepath: Optional[str] = None, 120 | supplemental_file_paths: Optional[dict[str, str]] = None, 121 | **kwargs, 122 | ) -> QuoteRelevanceDataset: 123 | """Constructs a QuoteRelevanceDataset""" 124 | quote_label_file_path = ( 125 | supplemental_file_paths.get("quote_label_file_path", QuoteRelevanceLoader.DEFAULT_QUOTE_LABEL_FILE_PATH) 126 | if supplemental_file_paths 127 | else QuoteRelevanceLoader.DEFAULT_QUOTE_LABEL_FILE_PATH 128 | ) 129 | 130 | debate_file_path = supplemental_file_paths.get("debate_file_path", None) if supplemental_file_paths else None 131 | scratchpad_dataset = ScratchpadQualityDebatesLoader.load(full_dataset_filepath=debate_file_path, deduplicate=False) 132 | 133 | train, val, test = QualityLoader.get_splits( 134 | train_filepath=train_filepath, val_filepath=val_filepath, test_filepath=test_filepath 135 | ) 136 | 137 | return QuoteRelevanceDataset( 138 | train_data=train, 139 | val_data=val, 140 | test_data=val, 141 | quote_label_file_path=quote_label_file_path, 142 | scratchpad_dataset=scratchpad_dataset, 143 | ) 144 | -------------------------------------------------------------------------------- /data/scratchpad_quality_debates_loader.py: -------------------------------------------------------------------------------- 1 | from data.dataset import DataRow, DatasetType, RawDataLoader, SpeechData, SplitType 2 | from data.quality_debates_loader import QualityDebatesLoader, QualityDebatesDataset, QualityTranscriptsLoader 3 | from utils import quote_utils 4 | import utils.constants as constants 5 | 6 | from tqdm import tqdm 7 | 8 | from typing import Any, Optional 9 | import os 10 | import pickle 11 | import re 12 | 13 | 14 | class ScratchpadQualityDebatesDataset(QualityDebatesDataset): 15 | MINIMUM_QUOTE_LENGTH = 1 16 | CONTEXT_SIZE = 0 17 | DEFAULT_SCRATCHPAD_TEXT = "No quotes needed" 18 | 19 | def __init__(self, train_data: list[str, Any], val_data: list[str, Any], test_data: list[str, Any]): 20 | """Dataset where each row has a question, position, debate transcript (from the human debates) and an 21 | automatically generated scratchpad continuation for each speech that lists out the quotes used""" 22 | super().__init__( 23 | train_data=train_data, 24 | val_data=val_data, 25 | test_data=test_data, 26 | override_type=DatasetType.SCRATCHPAD_QUALITY_DEBATES, 27 | ) 28 | self._generate_scratchpads() 29 | 30 | def _generate_scratchpads(self) -> None: 31 | for split in SplitType: 32 | for row in self.data[split]: 33 | for speech in row.speeches: 34 | self._generate_scratchpad(speech=speech, row=row) 35 | 36 | def _generate_scratchpad(self, speech: SpeechData, row: DataRow) -> Optional[str]: 37 | original_quotes = quote_utils.extract_quotes(speech.text) 38 | contexts = [ 39 | quote_utils.extract_quote_context( 40 | quote_text=quote, 41 | background_text=row.background_text, 42 | context_size=ScratchpadQualityDebatesDataset.CONTEXT_SIZE, 43 | ) 44 | for quote in filter( 45 | lambda x: len(x.split()) >= ScratchpadQualityDebatesDataset.MINIMUM_QUOTE_LENGTH, original_quotes 46 | ) 47 | ] 48 | speech.scratchpad = ( 49 | "\n\n".join( 50 | [ 51 | f"{(i + 1)}. {constants.QUOTE_TAG}{context}{constants.UNQUOTE_TAG}" 52 | for i, context in enumerate(filter(lambda x: x, contexts)) 53 | ] 54 | ) 55 | if contexts 56 | else ScratchpadQualityDebatesDataset.DEFAULT_SCRATCHPAD_TEXT 57 | ) 58 | 59 | 60 | class ScratchpadQualityDebatesLoader(RawDataLoader): 61 | DEFAULT_PICKLE_PATH = ( 62 | os.environ[constants.SRC_ROOT] + "data/datasets/scratchpad-quality-debates/scratchpad-quality-debates.p" 63 | ) 64 | 65 | @classmethod 66 | def load( 67 | cls, 68 | full_dataset_filepath: Optional[str] = None, 69 | deduplicate: bool = False, 70 | **kwargs, 71 | ) -> ScratchpadQualityDebatesDataset: 72 | """Constructs a ScratchpadQualityDebatesDataset""" 73 | if os.path.exists(ScratchpadQualityDebatesLoader.DEFAULT_PICKLE_PATH): 74 | with open(ScratchpadQualityDebatesLoader.DEFAULT_PICKLE_PATH, "rb") as f: 75 | return pickle.load(f) 76 | full_dataset_filepath = full_dataset_filepath or QualityTranscriptsLoader.DEFAULT_FILE_PATH 77 | train, val, test = QualityDebatesLoader.get_splits(file_path=full_dataset_filepath, deduplicate=deduplicate) 78 | return ScratchpadQualityDebatesDataset( 79 | train_data=train, 80 | val_data=val, 81 | test_data=test, 82 | ) 83 | -------------------------------------------------------------------------------- /debate/__init__.py: -------------------------------------------------------------------------------- 1 | from .agent import Agent, AgentConfig, ScratchpadConfig 2 | from .debate_round import DebateRound, DebateRoundSummary, QuestionMetadata, SplittingRule 3 | from .debater import BestOfNDebater, Debater, HumanDebater 4 | from .judge import BranchedJudge, Judge, MultiRoundBranchingSetting 5 | from .speech_format import ( 6 | Speech, 7 | SpeechFormat, 8 | SpeechFormatEntry, 9 | SpeechType, 10 | SpeechFormat, 11 | SpeechFormatEntry, 12 | SpeechFormatStructure, 13 | SpeechFormatType, 14 | ) 15 | from .transcript import Transcript 16 | -------------------------------------------------------------------------------- /debate/agent.py: -------------------------------------------------------------------------------- 1 | from debate.transcript import SpeechFormat, Transcript 2 | from models import BestOfNConfig, Model, ModelSettings 3 | from prompts import Prompt 4 | from utils import logger_utils 5 | import utils.constants as constants 6 | 7 | from pydantic import BaseModel, ConfigDict, model_validator 8 | 9 | from typing import Any, Optional, Union 10 | 11 | 12 | class ScratchpadConfig(BaseModel): 13 | use_scratchpad: bool = False 14 | scratchpad_word_limit: Optional[int] = None 15 | scratchpad_public: bool = False 16 | 17 | @model_validator(mode="before") 18 | def check_one_true_and_a_none(cls, values): 19 | if ( 20 | not values.get("use_scratchpad") 21 | and (values.get("scratchpad_word_limit") is not None and values.get("scratchpad_word_limit") > 0) 22 | and values.get("scratchpad_public") 23 | ): 24 | raise ValueError("If use_scratchpad=False, then one should not set scratchpad_word_limit or scratchpad_public") 25 | return values 26 | 27 | 28 | class AgentConfig(BaseModel): 29 | model_settings: ModelSettings 30 | scratchpad: ScratchpadConfig = ScratchpadConfig() 31 | best_of_n: Optional[BestOfNConfig] = None 32 | 33 | model_config = ConfigDict(protected_namespaces=("protect_me_", "also_protect_")) 34 | 35 | 36 | class Agent: 37 | def __init__( 38 | self, 39 | name: str, 40 | is_debater: bool, 41 | prompt: Prompt | list[Prompt], 42 | model: Model, 43 | num_speeches: int, 44 | receive_validated_quotes: bool, 45 | quotes_require_validation: bool, 46 | speech_format: SpeechFormat, 47 | ): 48 | """ 49 | An abstraction that controls access to the underlying models. It maintains a prompt and transcript 50 | to determine what to send to the underlying models. A Debater and a Judge are examples of an agent. 51 | 52 | Params: 53 | name: A string to identify the agent. It needs only to be unique within its own debate round. 54 | is_debater: Boolean indicating whether the agent is a debater or a judge. 55 | prompt: The Prompt structure that controls the inputs to the models. A list is passed in for batch processing. 56 | model: The model that actually performs the text generation. 57 | num_speeches: The number of speeches each debater will generate in the round. 58 | receive_validated_quotes: Whether speeches delivered by others should be corrected to show the validation status 59 | of their quotes. This is generally true for judges (they see whether a debater made up a quote) but not for 60 | debaters (they should learn not to make up quotes). 61 | quotes_require_validation: Whether or not the speeches generated by this agent already have had their quotes 62 | validated. Quote validation takes some time, so this helps us perform validation only when necessary. This 63 | is true for speeches generated by the HumanModel and false for the other models. 64 | speech_format: The order of speeches that the debater is expecting to receive. 65 | """ 66 | self.name = name 67 | self.is_debater = is_debater 68 | self.model = model 69 | self.num_speeches = num_speeches 70 | self.receive_validated_quotes = receive_validated_quotes 71 | self.quotes_require_validation = quotes_require_validation 72 | 73 | self.speech_format = speech_format 74 | 75 | self.prompts = prompt if type(prompt) == list else [prompt] 76 | self.transcripts = [ 77 | Transcript(name=self.name, prompt=p, speech_format=speech_format, index=i) for i, p in enumerate(self.prompts) 78 | ] 79 | self.cached_messages = {} 80 | 81 | def receive_message( 82 | self, speaker: str, content: str, idx: int, supplemental: Optional[dict[Any, Any] | list[dict[Any, Any]]] = None 83 | ): 84 | """ 85 | The agent takes in a speech from another agent (or itself) and adds it to its internal transcript: 86 | 87 | Params: 88 | speaker: The name of the agent who delivered the speech 89 | content: The text of the speech 90 | idx: The index corresponding to which debate in the batch this speech is a part of. 91 | supplemental: Any additional data that one wants to associate with the speech 92 | """ 93 | if idx >= len(self.transcripts): 94 | return 95 | 96 | self.cached_messages.setdefault(speaker, {}).setdefault(idx, []).append((content, supplemental)) 97 | expected_speaker = self.get_next_expected_speaker(idx=idx) 98 | while self.cached_messages.get(expected_speaker, {}).get(idx): 99 | for message, supplemental in self.cached_messages[expected_speaker][idx]: 100 | self.transcripts[idx].add_speech(speaker=expected_speaker, content=message, supplemental=supplemental) 101 | del self.cached_messages[expected_speaker][idx] 102 | expected_speaker = self.get_next_expected_speaker(idx=idx) 103 | 104 | def __call__(self) -> Optional[list[str]]: 105 | """This must be implemented in each agent. This is where they should generate text""" 106 | pass 107 | 108 | def save(self, save_file_path_prefix: str, metadata: Optional[list[dict[Any, Any]]] = None): 109 | """Saves the transcripts to the specified location, with a separate file for each element in the batch""" 110 | metadata = (metadata or []) + [{} for i in range(len(self.transcripts) - len((metadata or [])))] 111 | for i, (transcript, metadata) in enumerate(zip(self.transcripts, metadata)): 112 | transcript.save(save_file_path_prefix=f"{save_file_path_prefix}_{i}", metadata=metadata) 113 | 114 | def get_transcript(self, idx: int = 0) -> Transcript: 115 | """Returns the transcript at the specified index""" 116 | return self.transcripts[idx] 117 | 118 | def get_alias(self) -> str: 119 | """Gets the alias of the model underpinning the agent""" 120 | return self.model.alias if self.model else constants.DEFAULT_ALIAS 121 | 122 | def get_next_expected_speaker(self, idx: int = 0) -> Optional[str]: 123 | """Gets the name of the agent that this agent expects to deliver the next speech""" 124 | return self.transcripts[idx].get_next_expected_speaker() 125 | 126 | def post_speech_processing(self) -> None: 127 | """Handles any post-speech logic. This should mostly be a no-op but is needed for some multi-round 128 | branching cases where the judge needs to handle speeches coming from different rounds""" 129 | pass 130 | -------------------------------------------------------------------------------- /debate/debate_round.py: -------------------------------------------------------------------------------- 1 | from debate.agent import Agent 2 | from debate.debater import Debater 3 | from debate.judge import Judge 4 | from debate.transcript import Transcript 5 | from models import ModelResponse 6 | from prompts import Prompt, PromptConfig, PromptParser 7 | from utils import logger_utils, quote_utils 8 | import utils.constants as constants 9 | 10 | from pydantic import BaseModel 11 | 12 | from enum import Enum 13 | from typing import Optional, Any, Union 14 | import copy 15 | import random 16 | 17 | 18 | class QuestionMetadata(BaseModel): 19 | first_debater_correct: bool 20 | question_idx: int 21 | background_text: str 22 | question: str 23 | first_debater_answer: str 24 | second_debater_answer: str 25 | debate_identifier: str 26 | 27 | 28 | class DebateRoundSummary(BaseModel): 29 | metadata: QuestionMetadata 30 | transcript: Any 31 | winning_alias: str 32 | losing_alias: str 33 | first_debater_alias: str 34 | second_debater_alias: str 35 | first_debater_wins: bool 36 | judge_alias: str 37 | winning_debater_prob: float = 1.0 38 | first_debater_win_prob: float = 0.5 39 | second_debater_win_prob: float = 0.5 40 | first_debater_speaks: bool = True 41 | second_debater_speaks: bool = True 42 | failed: bool = False 43 | 44 | 45 | class SplittingRule(Enum): 46 | OPENING_ONLY = 1 47 | ALL_RANDOM = 2 48 | 49 | 50 | class DebateRound: 51 | def __init__( 52 | self, 53 | first_debater: Debater, 54 | second_debater: Debater, 55 | judge: Judge, 56 | metadata: QuestionMetadata | list[QuestionMetadata], 57 | ): 58 | """An abstraction that coordinates the ordered generation of speeches by the debaters and the judge.""" 59 | self.first_debater = first_debater 60 | self.second_debater = second_debater 61 | self.judge = judge 62 | self.metadata = metadata if type(metadata) == list else [metadata] 63 | self.name_to_agent = { 64 | self.first_debater.name: self.first_debater, 65 | self.second_debater.name: self.second_debater, 66 | self.judge.name: self.judge, 67 | } 68 | self.logger = logger_utils.get_default_logger(__name__) 69 | 70 | def set_first_debater(self, debater: Debater): 71 | """Changes the identity of the first debater in the debate.""" 72 | self.first_debater = debater 73 | self.name_to_agent[self.first_debater.name] = debater 74 | 75 | def set_second_debater(self, debater: Debater): 76 | """Changes the identity of the second debater in the debate.""" 77 | self.second_debater = debater 78 | self.name_to_agent[self.second_debater.name] = debater 79 | 80 | def set_judge(self, judge: Judge): 81 | """Changes the identity of the judge in the debate.""" 82 | self.judge = judge 83 | self.name_to_agent[self.judge.name] = judge 84 | 85 | def run_round(self) -> tuple[list[str], ModelResponse]: 86 | """ 87 | Each debater generates speeches until the judge renders their decision. 88 | 89 | Returns: 90 | last_output: a list of strings with the name of the agent that won each debate in the batch 91 | last_model_output: the model generation from the judge's decision. This is useful if the judge 92 | also returns the probability that a given debater won. 93 | """ 94 | last_output = None 95 | last_model_output = None 96 | next_speaker = self.judge.get_next_expected_speaker() 97 | while next_speaker: 98 | speaker = self.name_to_agent[next_speaker] 99 | try: 100 | batch_response, model_output = speaker() 101 | except Exception as e: 102 | self.logger.error("Received an error while trying to generate a speech %s", str(e), exc_info=True) 103 | return None, None 104 | 105 | for idx, (response, output) in enumerate(zip(batch_response, model_output)): 106 | validated_response = str(response) 107 | if speaker.quotes_require_validation: 108 | validated_response = quote_utils.validate_and_replace_quotes( 109 | speech_content=str(response), 110 | background_text=self.metadata[min(idx, len(self.metadata) - 1)].background_text, 111 | ) 112 | for _, agent in self.name_to_agent.items(): 113 | response_to_use = validated_response if agent.receive_validated_quotes else response 114 | agent.receive_message(speaker=speaker.name, content=response_to_use, idx=idx, supplemental=output) 115 | 116 | self.judge.post_speech_processing() 117 | next_speaker = self.judge.get_next_expected_speaker() 118 | 119 | last_output = batch_response 120 | last_model_output = model_output 121 | 122 | return last_output, last_model_output 123 | 124 | def record_winners( 125 | self, 126 | last_output: Optional[list[str]], 127 | last_model_output: Optional[list[ModelResponse]], 128 | save_file_path_prefix: Optional[str] = None, 129 | ) -> list[DebateRoundSummary]: 130 | """Generates a full summary of the debate round including the winner, transcript, metadata, and aliases of all the participating models""" 131 | if not last_output: 132 | return [] 133 | 134 | first_debater_win_list = [] 135 | winning_probability_list = [] 136 | failed_list = [] 137 | for i, (debater_a_wins, model_output) in enumerate(zip(last_output, last_model_output)): 138 | winner = constants.DEFAULT_DEBATER_A_NAME if debater_a_wins else constants.DEFAULT_DEBATER_B_NAME 139 | first_debater_win_list.append(winner == self.first_debater.name) 140 | string_value = self.judge.get_transcript(idx=i).full_string_value() 141 | winning_probability_list.append( 142 | 1.0 if not model_output.probabilistic_decision else model_output.probabilistic_decision[winner] 143 | ) 144 | failed_list.append(model_output.failed) 145 | self.logger.debug(string_value) 146 | 147 | if save_file_path_prefix: 148 | self.name_to_agent[self.judge.expected_saver].save( 149 | save_file_path_prefix=save_file_path_prefix, metadata=[item.dict() for item in self.metadata] 150 | ) 151 | 152 | return [ 153 | DebateRoundSummary( 154 | metadata=self.metadata[i % len(self.metadata)], 155 | transcript=self.judge.get_transcript(idx=i), 156 | winning_alias=self.first_debater.get_alias() if first_debater_wins else self.second_debater.get_alias(), 157 | losing_alias=self.first_debater.get_alias() if not first_debater_wins else self.second_debater.get_alias(), 158 | first_debater_alias=self.first_debater.get_alias(), 159 | second_debater_alias=self.second_debater.get_alias(), 160 | first_debater_wins=first_debater_wins, 161 | judge_alias=self.judge.get_alias(), 162 | winning_debater_prob=winning_probability_list[i], 163 | first_debater_win_prob=winning_probability_list[i] 164 | if first_debater_wins 165 | else (1 - winning_probability_list[i]), 166 | second_debater_win_prob=(1 - winning_probability_list[i]) 167 | if first_debater_wins 168 | else winning_probability_list[i], 169 | first_debater_speaks=constants.DEFAULT_DEBATER_A_NAME in self.judge.get_transcript(idx=i).get_speakers(), 170 | second_debater_speaks=constants.DEFAULT_DEBATER_B_NAME in self.judge.get_transcript(idx=i).get_speakers(), 171 | failed=failed_list[i], 172 | ) 173 | for i, first_debater_wins in enumerate(first_debater_win_list) 174 | ] 175 | 176 | def __call__(self, save_file_path_prefix: Optional[str] = None) -> list[DebateRoundSummary]: 177 | """Runs the round and generates a summary of the results""" 178 | last_output, last_model_output = self.run_round() 179 | return self.record_winners( 180 | last_output=last_output, last_model_output=last_model_output, save_file_path_prefix=save_file_path_prefix 181 | ) 182 | -------------------------------------------------------------------------------- /debate/transcript.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from debate.speech_format import Speech, SpeechType, SpeechFormatEntry, SpeechFormat 4 | from models import ModelInput, ModelResponse 5 | from prompts import Prompt, RoleType 6 | import utils.constants as constants 7 | 8 | from pydantic import BaseModel 9 | 10 | from enum import Enum 11 | from typing import Any, Callable, Optional, Union 12 | import copy 13 | import json 14 | 15 | 16 | class Transcript: 17 | def __init__( 18 | self, name: str, prompt: Prompt, speech_format: SpeechFormat, index: int = 0, alternate_prompts: bool = False 19 | ): 20 | """ 21 | An abstraction that tracks the commands and speeches delivered in the round. This can then 22 | be used to construct an input to a model. 23 | 24 | Params: 25 | name: The name of the debater who is to use this transcript. 26 | prompt: The prompt that is used to generate commands. 27 | speech_format: The order of speeches and commands that the debater expects to receive. 28 | index: The index corresponding to which element in the batch this transcript is being used for. 29 | alternate_prompts: False if the default prompts should always be used (good for validation). True if one wants to 30 | mix in alternate prompts (good for training) 31 | """ 32 | self.prompt = prompt 33 | self.name = name 34 | self.speeches = [] 35 | self.speech_format = speech_format 36 | self.alternate_prompts = alternate_prompts 37 | 38 | def reset(self) -> None: 39 | """Removes all the given speeches""" 40 | self.speeches = [] 41 | 42 | def add_speech( 43 | self, speaker: str, content: str, supplemental: Optional[ModelResponse | list[ModelResponse]] = None 44 | ) -> None: 45 | """ 46 | Adds an agent-generated speech to the transcript 47 | 48 | Args: 49 | speaker: The name of the debater (Debater_A, Debater_B) that gave the speech 50 | content: The text of the speech 51 | supplemental: Any additional metadata that one wants to tag along with the speech 52 | """ 53 | self.speeches.append(Speech(speaker=speaker, content=content, supplemental=supplemental)) 54 | 55 | def save(self, save_file_path_prefix: str, metadata: Optional[dict[Any, Any]]) -> None: 56 | """Saves to the specified path""" 57 | """ 58 | with open(save_file_path_prefix + ".txt", "w") as f: 59 | f.write(str(self.full_string_value())) 60 | """ 61 | with open(save_file_path_prefix + ".json", "w") as f: 62 | json.dump(self.json_value(metadata=metadata), f) 63 | 64 | def to_model_input(self) -> list[ModelInput]: 65 | """Converts the speech to a list of inputs that can be used to generate more text by models""" 66 | 67 | def add_to_model_inputs(model_inputs: list[ModelInput], new_addition: ModelInput) -> None: 68 | if model_inputs and model_inputs[-1].role == new_addition.role: 69 | model_inputs[-1] = ModelInput( 70 | role=new_addition.role, content=f"{model_inputs[-1].content}\n\n{new_addition.content}" 71 | ) 72 | else: 73 | model_inputs.append(new_addition) 74 | 75 | model_inputs = [] 76 | index = 0 77 | for i, (speech_type, prompt_tag, last_only_prompt_tag, expected_speaker) in enumerate(self.speech_format): 78 | if speech_type == SpeechType.PRE_FILLED: 79 | prompt_tag_to_use = ( 80 | prompt_tag if (index < len(self.speeches) or not last_only_prompt_tag) else last_only_prompt_tag 81 | ) 82 | 83 | content_idx = index % len(self.prompt.messages[prompt_tag_to_use].content) if self.alternate_prompts else 0 84 | add_to_model_inputs( 85 | model_inputs, 86 | ModelInput( 87 | role=RoleType.SYSTEM if i < 2 else RoleType.USER, 88 | content=self.prompt.messages[prompt_tag_to_use].content[content_idx], 89 | ), 90 | ) 91 | else: 92 | if index >= len(self.speeches): 93 | break 94 | role = ( 95 | RoleType.USER 96 | if self.speeches[index].speaker != self.name or index < len(self.speeches) 97 | else RoleType.ASSISTANT 98 | ) 99 | 100 | add_to_model_inputs(model_inputs, ModelInput(role=role, content=str(self.speeches[index].content))) 101 | index += 1 102 | 103 | return model_inputs 104 | 105 | def get_last_external_speech(self) -> Optional[str]: 106 | """Get the text of the last speech that was delivered by a different agent""" 107 | for i in range(len(self.speeches)): 108 | speech = self.speeches[len(self.speeches) - i - 1] 109 | if speech.speaker != self.name: 110 | return speech 111 | return "" 112 | 113 | def get_last_internal_speech(self) -> Optional[str]: 114 | """Get the text of the last speech that was delivered by one self""" 115 | for i in range(len(self.speeches)): 116 | speech = self.speeches[len(self.speeches) - i - 1] 117 | if speech.speaker == self.name: 118 | return speech 119 | return "" 120 | 121 | def get_speakers(self) -> set[str]: 122 | """Gets a list of all the speakers who appear in the transcript""" 123 | return set([speech.speaker for speech in self.speeches]) 124 | 125 | def get_next_expected_speaker(self) -> Optional[str]: 126 | """Gets the name of the next agent that is expected to deliver a speech""" 127 | expected_speakers = [expected_speaker for _, _, _, expected_speaker in filter(lambda x: x[-1], self.speech_format)] 128 | return expected_speakers[len(self.speeches)] if len(self.speeches) < len(expected_speakers) else None 129 | 130 | def only_decision_remains(self) -> bool: 131 | """Returns true if there are no more speeches that are expected to be delivered besides the 132 | judge's final verdict""" 133 | expected_speakers = [expected_speaker for _, _, _, expected_speaker in filter(lambda x: x[-1], self.speech_format)] 134 | remaining_speakers = ( 135 | set(expected_speakers[len(self.speeches) :]) if len(self.speeches) < len(expected_speakers) else set() 136 | ) 137 | return constants.DEFAULT_JUDGE_NAME in remaining_speakers and len(remaining_speakers) == 1 138 | 139 | def full_string_value(self) -> str: 140 | """Converts the transcript into a string for logging and saving""" 141 | return "\n\n".join([x.content for x in self.to_model_input()]) 142 | 143 | def json_value(self, metadata: Optional[dict[Any, Any]] = None) -> str: 144 | """Converts the transcript into a json object that can be parsed for offline processing""" 145 | 146 | def clean(obj): 147 | if isinstance(obj, dict): 148 | new_dict = {} 149 | for key, val in obj.items(): 150 | if isinstance(val, dict): 151 | new_dict[key] = clean(val) 152 | elif "token" in key and isinstance(val, list) and val and isinstance(val[0], int): 153 | pass 154 | elif isinstance(val, list): 155 | new_dict[key] = [clean(item) for item in val] 156 | else: 157 | new_dict[key] = val 158 | return obj 159 | 160 | speeches = [] 161 | index = 0 162 | for i, (speech_type, prompt_tag, _, expected_speaker) in enumerate(self.speech_format): 163 | supplemental = None 164 | if speech_type == SpeechType.PRE_FILLED: 165 | content = self.prompt.messages[prompt_tag].content[index % len(self.prompt.messages[prompt_tag].content)] 166 | else: 167 | if index >= len(self.speeches): 168 | break 169 | content = self.speeches[index].content 170 | supplemental = clean(self.speeches[index].supplemental) 171 | # supplemental = {k: v for k, v in filter(lambda x: "token" not in x, self.speeches[index].supplemental)} 172 | index += 1 173 | speeches.append(Speech(speaker=expected_speaker or "Prompt", content=content, supplemental=supplemental).dict()) 174 | 175 | return {"metadata": metadata, "speeches": speeches} 176 | 177 | def copy(self) -> Transcript: 178 | """Deepcopies this objects""" 179 | return copy.deepcopy(self) 180 | 181 | def get_external_speech_count(self) -> int: 182 | """Returns the number of external speeches that have been added to the transcript""" 183 | return len(self.speeches) 184 | 185 | def truncate(self, idx: int, debaters_only: bool = False) -> None: 186 | """ 187 | Removes all the speeches after the specified index. 188 | 189 | Params: 190 | idx: The last speech in the round to include before removing the rest 191 | debaters_only: whether the idx refers to only speeches given by the debaters 192 | """ 193 | max_idx = len(self.speeches) 194 | if debaters_only: 195 | counter = 0 196 | max_idx = 0 197 | idx_to_true_idx = {} 198 | for i, speech in enumerate(self.speeches): 199 | if speech.speaker != constants.DEFAULT_JUDGE_NAME: 200 | idx_to_true_idx[counter] = i 201 | max_idx = counter 202 | counter += 1 203 | idx = idx_to_true_idx[idx] 204 | return self.speeches[: min(idx, max_idx)] 205 | 206 | def get_speech_count(self, debaters_only: bool = False) -> int: 207 | """Returns the number of speeches that have already been added (only includes speeches by debaters 208 | if the debaters_only parameter is true)""" 209 | if not debaters_only: 210 | return len(self.speeches) 211 | else: 212 | return len([speech for speech in filter(lambda x: x.speaker != constants.DEFAULT_JUDGE_NAME, self.speeches)]) 213 | 214 | def __str__(self): 215 | """Shorter string representation as compared to full_string_value()""" 216 | return f"Name: {self.name}\n\n" + "\n\n".join([str(speech) for speech in self.speeches]) 217 | -------------------------------------------------------------------------------- /experiments/__init__.py: -------------------------------------------------------------------------------- 1 | from .annotator import Annotator, ClassificationConfig, PredictedAnnotation, ParagraphClassification, SentenceClassification 2 | from .experiment_loader import ( 3 | AgentConfig, 4 | AgentsConfig, 5 | ExperimentConfig, 6 | ExperimentLoader, 7 | PromptLoadingConfig, 8 | ) 9 | from .quotes_collector import QuotesCollector, QuoteStats 10 | from .results_collector import JudgeStats, ResultsCollector, WinStats 11 | -------------------------------------------------------------------------------- /experiments/configs/example_experiment.yaml: -------------------------------------------------------------------------------- 1 | MyExperiment: # The name of your experiment 2 | batch_size: 1 # Number of rounds to be processed simultaneously 3 | num_speeches: 1 # Number of speeches that each debater will give per round 4 | flip: True # Whether each debater will debate each side of the debate 5 | enable_self_debate: True # Whether self-play is allowed or just cross-play 6 | speech_structure: default_debate # This references the specific speech ordering you want to use. Defaults to default_debate. Other option at the moment is default_consultancy. 7 | previous_run: # Config that lets you reuse information from previous runs 8 | file_path: path-to-previous-run # path to the previous run you want to use 9 | replicate_topics: True # whether you want to use the same questions as the previous run (useful for making performance comparisons) 10 | merge_results: False # whether you want to merge the results of the previous run in the final results csv and graphs 11 | prompt_config: # Specifies which set of prompts to use -- if not specified, the defaults will be used 12 | file_path: /path/to/source/code/prompts/configs/prompts.yaml # Path to prompts file 13 | default_prompt_name: Base Prompt # Name of the specific set of prompts to use 14 | use_hardcoded_topics: False # Whether one wants to hard code the topics to debate 15 | hardcoded_topic_config: null # The specific config that's used if use_hardcoded_topics is True 16 | agents: # The configuration for the debaters and judges that will be participating in the debate round 17 | debaters: # The configuration for the debaters. You must specify at least 1. If only 1 is specified, the model will debate itself. If more than 2 are specified, then a round robin will be performed. 18 | - model_settings: # specifies the configuration to be used for the model this debater will use 19 | model_type: llama # The type of model that is to be used. See the README for a comprehensive list of options 20 | model_file_path: /path/to/llama/directory # The path to the model weights. Not needed if the model doesn't have weights 21 | alias: model-name # A name to identify the model. If this name is duplicated, then stats for those debaters will be aggregated 22 | require_quote_validation: True # Whether quotes generated with quote tags should be validated (defaults to True) 23 | generation_params: # Optional configuration that controls the LLM generations 24 | max_new_tokens: 300 # Max to generate in each speech, defaults to 300 25 | temperature: 0.5 # temperature for generation, defaults to 0.5 26 | top_p: 1.0 # top p for generation, defaults to 1.0 27 | repetition_penalty: # penalty for repetition, defaults to 1.2 (but is disabled based on the use_generation_penalties flag) 28 | do_sample: True # whether to sample from the model (True) or do greedy decoding (False) 29 | use_generation_penalties: False # whether to disable the repetition_penalty and exponential length penalty (False) or enable them (True) 30 | peft_base_model: /path/to/peft/base_model # The path to the base model. Optional (even if one is using a peft model) but speeds up inference 31 | - model_settings: 32 | model_type: deterministic # The type of model for the second debater 33 | alias: model-name-2 # Alias for the second debater 34 | override_prompt: override-prompt-name # the name of the prompt to use for this debater only 35 | nucleus: True # the decoding strategy (True -> nucleus sampling [default], False -> beam search) 36 | is_human: False # Whether the model being used should be converted into a human debate model (which uses the real transcripts) 37 | best_of_n: 38 | n: 4 # the number of samples to draw 39 | opponent_n: 2 # the number of opponent samples to test against 40 | maxmin: True # whether to select the speech with the best minimum score or the best average score 41 | recompute: False # whether to rejudge each speech to see which is the best 42 | - model_settings: 43 | model_type: random # the type of the third debater 44 | alias: random-model # the name of the third debater 45 | offline_file_path: offline-data-prefix # either the timestamp of the offline debate run (if the transcript is in the default output folder) or the full path to all the offline debates (Only needed if one is recreating a previously-run debate) 46 | scratchpad: 47 | use_scratchpad: True # Whether the debater gets access to a scratchpad (default to false) 48 | scratchpad_word_limit: 100 # Number of words that the debater can use a scratchpad for (0 or None disables scratchpad usage) 49 | scratchpad_public: False # Whether the scratchpad generations are exposed to the other participants 50 | judge: 51 | model_settings: 52 | model_type: llama # The model type for the judge 53 | model_file_path: /path/to/model/weights # The path to the model weights. Not needed if the model doesn't have weights 54 | alias: judge-alias # Name of the judging model 55 | dataset: # Configuration for the dataset to be used 56 | dataset_type: quality_debates # the name of the dataset 57 | full_dataset_file_path: null # path to a file that stores the entire dataset. Not needed if the dataset is pre-split or if defaults are used 58 | train_file_path: null # path to a file that stores the training dataset. Not needed if the dataset is not pre-split or if defaults are used 59 | val_file_path: null # path to a file that stores the training dataset. Not needed if the dataset is not pre-split or if defaults are used 60 | test_file_path: null # path to a file that stores the training dataset. Not needed if the dataset is not pre-split or if defaults are used 61 | supplemental_file_paths: null # set of additional kwargs that are specific to individual dataset 62 | split_type: train # the split of the data to actually use 63 | shuffle_deterministically: False # whether the order of the rounds should be shuffled (False) or fixed (defaults False) 64 | tournament: # a configuration that lets you specify the type of tournament you want to run. Not needed if you're using a round robin 65 | tournament_type: custom # other options include round_robin (default), capped round robin, replication, and self_play_only. 66 | custom_matchups: # for the custom tournament type specifically, these are the aliases that you want to debate against each other 67 | - ["model-name", "model-name-2"] 68 | - ["model-name", "random"] -------------------------------------------------------------------------------- /experiments/power_pair_scheduler.py: -------------------------------------------------------------------------------- 1 | from debate import DebateRound, DebateRoundSummary 2 | 3 | 4 | class PowerPairScheduler: 5 | """ 6 | This handles the scheduling for a power-paired (Swiss) tournament. 7 | """ 8 | 9 | def __init__(self, debates: list[DebateRound]): 10 | self.alias_to_record = {alias: [0, 0] for alias in self.__get_aliases(debates=debates)} 11 | self.debate_map, self.debate_idx_map = self.__get_debate_map(debates=debates) 12 | 13 | def get_next_pairings(self) -> list[DebateRound]: 14 | """Gets the next batch of rounds to run""" 15 | sorted_aliases = sorted( 16 | self.alias_to_record.keys(), 17 | key=lambda x: alias_to_record[x][0] / alias_to_record[x][1] if alias_to_record[x][1] else 0.5, 18 | ) 19 | matchups = [] 20 | for i in range(len(sorted_aliases) // 2): 21 | matchups.append("_".join(sorted([sorted_aliases[i], sorted_aliases[i + 1]]))) 22 | 23 | pairings = [] 24 | for matchup in matchups: 25 | idx = self.debate_idx_map[matchup] 26 | rounds = self.debate_map[matchup] 27 | if idx < len(rounds): 28 | pairings.append(rounds[idx]) 29 | self.debate_idx_map[matchup] += 1 30 | return pairings 31 | 32 | def update(self, summary: DebateRoundSummary | list[DebateRoundSummary]): 33 | """Updates the Win-Loss record after each round so one can do more accurate pairings""" 34 | summary = summary if isinstance(summary, list) else [summary] 35 | for summary in summary: 36 | self.alias_to_record[summary.metadata.winning_alias][0] += 1 37 | self.alias_to_record[summary.metadata.winning_alias][1] += 1 38 | self.alias_to_record[summary.metadata.losing_alias][1] += 1 39 | 40 | def __get_aliases(self, debates: list[DebateRound]) -> list[str]: 41 | aliases = set() 42 | for debate in debates: 43 | aliases.add(debate.metadata[0].first_debater_alias) 44 | aliases.add(debate.metadata[0].second_debater_alias) 45 | return list(aliases) 46 | 47 | def __get_debate_map(self, debates: list[DebateRound]) -> dict[str, list[DebateRound]]: 48 | debate_map = {} 49 | for debate in debates: 50 | key = "_".join(sorted([debate.metadata[0].first_debater_alias, debate.metadata[0].second_debater_alias])) 51 | if key not in debate_map: 52 | debate_map[key] = [] 53 | debate_map[key].append(debate) 54 | debate_idx_map = {alias: 0 for alias in debate_map} 55 | return debate_map, debate_idx_map 56 | -------------------------------------------------------------------------------- /experiments/quotes_collector.py: -------------------------------------------------------------------------------- 1 | from debate import DebateRoundSummary 2 | from experiments.experiment_loader import ExperimentConfig, ExperimentLoader 3 | from utils import logger_utils, quote_utils 4 | import utils.constants as constants 5 | 6 | from pydantic import BaseModel 7 | 8 | import copy 9 | import re 10 | import sys 11 | 12 | 13 | class QuoteStats(BaseModel): 14 | number_of_quotes: int 15 | number_of_valid_quotes: int 16 | total_valid_quote_length: int 17 | quote_length_to_accuracy: list[list[int]] 18 | 19 | 20 | class QuotesCollector: 21 | MAX_TRACKED_QUOTE_LENGTH = 300 22 | 23 | def __init__(self, experiment: ExperimentConfig): 24 | """Collects metrics about quotation usage from debate rounds""" 25 | self.logger = logger_utils.get_default_logger(__name__) 26 | self.dataset = ExperimentLoader.create_dataset(experiment) 27 | self.alias_to_results = {} 28 | 29 | def record_result(self, summary: DebateRoundSummary) -> None: 30 | """Records metrics on the use of quotations in the inputted debate round and stores it""" 31 | 32 | def add_new_alias(alias): 33 | default = QuoteStats( 34 | number_of_quotes=0, 35 | number_of_valid_quotes=0, 36 | total_valid_quote_length=0, 37 | quote_length_to_accuracy=[[0, 0] for i in range(QuotesCollector.MAX_TRACKED_QUOTE_LENGTH)], 38 | ) 39 | self.alias_to_results[alias] = {} 40 | self.alias_to_results[alias][constants.OVERALL] = copy.deepcopy(default) 41 | self.alias_to_results[alias][constants.CORRECT] = copy.deepcopy(default) 42 | self.alias_to_results[alias][constants.WINNER] = copy.deepcopy(default) 43 | self.alias_to_results[alias][constants.LOSER] = copy.deepcopy(default) 44 | self.alias_to_results[alias][constants.INCORRECT] = copy.deepcopy(default) 45 | 46 | def is_correct(speaker: str): 47 | return ( 48 | speaker == constants.DEFAULT_DEBATER_A_NAME 49 | and summary.metadata.first_debater_correct 50 | and summary.first_debater_speaks 51 | ) or ( 52 | speaker == constants.DEFAULT_DEBATER_B_NAME 53 | and not summary.metadata.first_debater_correct 54 | and summary.second_debater_speaks 55 | ) 56 | 57 | def is_incorrect(speaker: str): 58 | return ( 59 | speaker == constants.DEFAULT_DEBATER_A_NAME 60 | and not summary.metadata.first_debater_correct 61 | and summary.first_debater_speaks 62 | ) or ( 63 | speaker == constants.DEFAULT_DEBATER_B_NAME 64 | and summary.metadata.first_debater_correct 65 | and summary.second_debater_speaks 66 | ) 67 | 68 | def is_winner(speaker: str): 69 | return ( 70 | speaker == constants.DEFAULT_DEBATER_A_NAME and summary.first_debater_wins and summary.first_debater_speaks 71 | ) or ( 72 | speaker == constants.DEFAULT_DEBATER_B_NAME 73 | and not summary.first_debater_wins 74 | and summary.second_debater_speaks 75 | ) 76 | 77 | def is_loser(speaker: str): 78 | return ( 79 | speaker == constants.DEFAULT_DEBATER_A_NAME 80 | and not summary.first_debater_wins 81 | and summary.first_debater_speaks 82 | ) or ( 83 | speaker == constants.DEFAULT_DEBATER_B_NAME and summary.first_debater_wins and summary.second_debater_speaks 84 | ) 85 | 86 | def get_alias_from_speaker(speaker: str): 87 | if speech.speaker == constants.DEFAULT_DEBATER_A_NAME: 88 | return summary.first_debater_alias 89 | elif speech.speaker == constants.DEFAULT_DEBATER_B_NAME: 90 | return summary.second_debater_alias 91 | else: 92 | return constants.DEFAULT_JUDGE_NAME 93 | 94 | if summary.first_debater_alias not in self.alias_to_results: 95 | add_new_alias(summary.first_debater_alias) 96 | 97 | if summary.second_debater_alias not in self.alias_to_results: 98 | add_new_alias(summary.second_debater_alias) 99 | 100 | for speech in summary.transcript.speeches: 101 | outputted_quotes = quote_utils.extract_quotes(speech.content) 102 | alias = get_alias_from_speaker(speech.speaker) 103 | if alias == constants.DEFAULT_JUDGE_NAME: 104 | continue 105 | correct = is_correct(speech.speaker) 106 | incorrect = is_incorrect(speech.speaker) 107 | winner = is_winner(speech.speaker) 108 | loser = is_loser(speech.speaker) 109 | 110 | num_valid = 0 111 | total = 0 112 | for quote in outputted_quotes: 113 | total += 1 114 | quote_length = len(quote.split()) 115 | if quote_utils.validate_quote(quote, summary.metadata.background_text, speech.content): 116 | num_valid += 1 117 | self.alias_to_results[alias][constants.OVERALL].number_of_valid_quotes += 1 118 | self.alias_to_results[alias][constants.OVERALL].total_valid_quote_length += quote_length 119 | self.alias_to_results[alias][constants.OVERALL].quote_length_to_accuracy[quote_length][0] += 1 120 | if winner: 121 | self.alias_to_results[alias][constants.WINNER].number_of_valid_quotes += 1 122 | self.alias_to_results[alias][constants.WINNER].total_valid_quote_length += quote_length 123 | self.alias_to_results[alias][constants.WINNER].quote_length_to_accuracy[quote_length][0] += 1 124 | if loser: 125 | self.alias_to_results[alias][constants.LOSER].number_of_valid_quotes += 1 126 | self.alias_to_results[alias][constants.LOSER].total_valid_quote_length += quote_length 127 | self.alias_to_results[alias][constants.LOSER].quote_length_to_accuracy[quote_length][0] += 1 128 | if correct: 129 | self.alias_to_results[alias][constants.CORRECT].number_of_valid_quotes += 1 130 | self.alias_to_results[alias][constants.CORRECT].total_valid_quote_length += quote_length 131 | self.alias_to_results[alias][constants.CORRECT].quote_length_to_accuracy[quote_length][0] += 1 132 | if incorrect: 133 | self.alias_to_results[alias][constants.INCORRECT].number_of_valid_quotes += 1 134 | self.alias_to_results[alias][constants.INCORRECT].total_valid_quote_length += quote_length 135 | self.alias_to_results[alias][constants.INCORRECT].quote_length_to_accuracy[quote_length][0] += 1 136 | else: 137 | self.logger.debug("The following quote was invalid:\n{}".format(quote)) 138 | 139 | self.alias_to_results[alias][constants.OVERALL].number_of_quotes += 1 140 | self.alias_to_results[alias][constants.OVERALL].quote_length_to_accuracy[quote_length][1] += 1 141 | if winner: 142 | self.alias_to_results[alias][constants.WINNER].number_of_quotes += 1 143 | self.alias_to_results[alias][constants.WINNER].quote_length_to_accuracy[quote_length][1] += 1 144 | if loser: 145 | self.alias_to_results[alias][constants.LOSER].number_of_quotes += 1 146 | self.alias_to_results[alias][constants.LOSER].quote_length_to_accuracy[quote_length][1] += 1 147 | if correct: 148 | self.alias_to_results[alias][constants.CORRECT].number_of_quotes += 1 149 | self.alias_to_results[alias][constants.CORRECT].quote_length_to_accuracy[quote_length][1] += 1 150 | if incorrect: 151 | self.alias_to_results[alias][constants.INCORRECT].number_of_quotes += 1 152 | self.alias_to_results[alias][constants.CORRECT].quote_length_to_accuracy[quote_length][1] += 1 153 | 154 | def get_results(self) -> dict[str, dict[str, QuoteStats]]: 155 | """ 156 | Returns the stored results 157 | 158 | Returns: 159 | alias_to_results: a dictionary that maps a model alias to another dictionary, where the keys are different 160 | slices of the data (e.g 'overall', 'winner', 'correct') and the values are raw counts. 161 | """ 162 | simplified_results = {} 163 | for alias in self.alias_to_results: 164 | simplified_results[alias] = copy.deepcopy(self.alias_to_results[alias]) 165 | for key in simplified_results[alias]: 166 | vals = [ 167 | idx 168 | for idx, pair in filter( 169 | lambda x: x[1][1] > 0, enumerate(simplified_results[alias][key].quote_length_to_accuracy) 170 | ) 171 | ] 172 | max_val = max(vals) if vals else 0 173 | simplified_results[alias][key].quote_length_to_accuracy = simplified_results[alias][ 174 | key 175 | ].quote_length_to_accuracy[: (max_val + 1)] 176 | return simplified_results 177 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .arbitrary_attribute_model import ArbitraryAttributeModel 2 | from .deterministic_model import DeterministicModel 3 | from .human_model import HumanModel 4 | from .llm_model import ( 5 | LlamaModel, 6 | Llama3Model, 7 | LLModel, 8 | LLModuleWithLinearProbe, 9 | LLMInput, 10 | LLMType, 11 | MistralModel, 12 | ModelStub, 13 | ProbeHyperparams, 14 | StubLLModel, 15 | TokenizerStub, 16 | ) 17 | from .model_utils import ModelType, ModelUtils 18 | from .model import BestOfNConfig, GenerationParams, Model, ModelInput, ModelResponse, ModelSettings, SpeechStructure 19 | from .offline_model import OfflineDataFormat, OfflineModel, OfflineModelHelper 20 | from .openai_model import OpenAIModel 21 | from .random_model import RandomModel 22 | from .repetitive_model import RepetitiveModel 23 | from .served_model import ServedModel 24 | -------------------------------------------------------------------------------- /models/anthropic_model.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from models.model import Model, ModelInput, ModelResponse, RoleType, SpeechStructure 4 | from utils import logger_utils 5 | import utils.constants as constants 6 | 7 | import anthropic 8 | import backoff 9 | 10 | from concurrent.futures import ThreadPoolExecutor 11 | from typing import Union, Optional 12 | import logging 13 | import os 14 | import math 15 | import random 16 | import re 17 | 18 | 19 | class AnthropicModel(Model): 20 | MAX_PARALLEL_REQUESTS = 16 21 | DEFAULT_MODEL_ENDPOINT = "claude-3-opus-20240229" 22 | 23 | def __init__(self, alias: str, is_debater: bool = True, endpoint: Optional[str] = None, **kwargs): 24 | """ 25 | An AnthropicModel calls Claude to generate the appropriate text. 26 | 27 | Args: 28 | alias: String that identifies the model for metrics and deduplication 29 | is_debater: Boolean indicating whether the model is a debater (true) or judge (false) 30 | """ 31 | super().__init__(alias=alias, is_debater=is_debater) 32 | self.client = anthropic.Anthropic() 33 | self.endpoint = endpoint if endpoint else AnthropicModel.DEFAULT_MODEL_ENDPOINT 34 | self.logger = logger_utils.get_default_logger(__name__) 35 | 36 | def predict( 37 | self, 38 | inputs: list[list[ModelInput] | str], 39 | max_new_tokens=200, 40 | speech_structure: SpeechStructure = SpeechStructure.OPEN_ENDED, 41 | **kwargs, 42 | ) -> list[ModelResponse]: 43 | """ 44 | Generates a list of texts in response to the given input. 45 | 46 | Args: 47 | inputs: A list of list of model inputs. Each ModelInput corresponds roughly to one command, 48 | a list of ModelInputs corresponds to a single debate (or entry in a batch), and so the 49 | list of lists is basically a batch of debates. 50 | max_new_tokens: The maximum total number of new tokens to generate. 51 | speech_structure: the format that the answer is expected to be in. Option includes "open-ended" 52 | (which is just free text), and "decision" (which means a boolean is expected) 53 | 54 | Returns: 55 | A list of model responses, with one string for each entry in the batch. 56 | """ 57 | with ThreadPoolExecutor(max_workers=AnthropicModel.MAX_PARALLEL_REQUESTS) as executor: 58 | futures = [ 59 | executor.submit( 60 | self.predict_single_input, 61 | model_input_list=input_value, 62 | max_new_tokens=max_new_tokens, 63 | speech_structure=speech_structure, 64 | ) 65 | for input_value in inputs 66 | ] 67 | results = [future.result() for future in futures] 68 | 69 | return results 70 | 71 | def predict_single_input( 72 | self, 73 | model_input_list: list[ModelInput] | str, 74 | max_new_tokens=200, 75 | speech_structure: SpeechStructure = SpeechStructure.OPEN_ENDED, 76 | **kwargs, 77 | ) -> ModelResponse: 78 | """ 79 | Generates a list of texts in response to a single given input. 80 | 81 | Args: 82 | model_input_list: A list of model inputs. Each ModelInput corresponds roughly to one command 83 | max_new_tokens: The maximum total number of new tokens to generate. 84 | speech_structure: the format that the answer is expected to be in. Option includes "open-ended" 85 | (which is just free text) and "decision" (which means a boolean is expected) 86 | 87 | Returns: 88 | A list of model responses, with one string for each entry in the batch. 89 | """ 90 | 91 | def extract_response_from_structured_speech(message: str, regex_str: str, default: str) -> str: 92 | match = re.match(regex_str, message) 93 | if match: 94 | return match.group(1) 95 | else: 96 | self.logger.warn("The regex {} did not match the following message: {}".format(regex_str, message)) 97 | return default 98 | 99 | def process_logprobs(completion: dict) -> tuple[float, float]: 100 | """This exists to maintain parity with the OpenAI model functionality even though the Anthropic API 101 | does not support logprobs yet""" 102 | 103 | if re.search(constants.DEFAULT_DEBATER_A_NAME, completion.content[0].text): 104 | return 1.0, 0.0 105 | elif re.search(constants.DEFAULT_DEBATER_B_NAME, completion.content[0].text): 106 | return 0.0, 1.0 107 | print("uh oh!", completion.content[0].text) 108 | return 0.5, 0.5 109 | 110 | system, messages = AnthropicModel.generate_llm_input_from_model_inputs(input_list=model_input_list) 111 | 112 | try: 113 | completion = self.call_anthropic( 114 | system=system, messages=messages, max_new_tokens=max_new_tokens, speech_structure=speech_structure 115 | ) 116 | except Exception as e: 117 | self.logger.warn(f"Anthropic API returned an API Error: {e}") 118 | self.logger.warn(e) 119 | return ModelResponse(failed=True) 120 | 121 | message = completion.content[0].text 122 | 123 | if speech_structure == SpeechStructure.DECISION: 124 | a_odds, b_odds = process_logprobs(completion) 125 | message = ( 126 | constants.DEFAULT_DEBATER_A_NAME 127 | if a_odds > b_odds 128 | else ( 129 | constants.DEFAULT_DEBATER_B_NAME 130 | if (b_odds > a_odds or random.random() > 0.5) 131 | else constants.DEFAULT_DEBATER_A_NAME 132 | ) 133 | ) 134 | self.logger.debug(f"Debater A's odds: {a_odds}, Debater B's odds: {b_odds}, Winner: {message}") 135 | return ModelResponse( 136 | decision=message, 137 | probabilistic_decision={ 138 | constants.DEFAULT_DEBATER_A_NAME: a_odds, 139 | constants.DEFAULT_DEBATER_B_NAME: b_odds, 140 | }, 141 | prompt="\n".join(model_input.content for model_input in model_input_list), 142 | ) 143 | 144 | return ModelResponse(speech=message, prompt="\n".join(model_input.content for model_input in model_input_list)) 145 | 146 | # @backoff.on_exception(backoff.expo, backoff.on_exception, max_tries=4) 147 | def call_anthropic( 148 | self, system: str, messages: list[dict[str, str]], speech_structure: SpeechStructure, max_new_tokens: int 149 | ): 150 | return self.client.messages.create( 151 | model=self.endpoint, # "claude-3-haiku-20240307", #"claude-3-opus-20240229", 152 | max_tokens=max_new_tokens, 153 | system=system, 154 | messages=messages, 155 | temperature=0.0 if speech_structure == SpeechStructure.DECISION else 0.5, 156 | ) 157 | 158 | def copy(self, alias: str, is_debater: Optional[bool] = None, **kwargs) -> AnthropicModel: 159 | """Generates a deepcopy of this model""" 160 | return AnthropicModel(alias=alias, is_debater=is_debater, endpoint=self.endpoint) 161 | 162 | @classmethod 163 | def generate_llm_input_from_model_inputs( 164 | cls, input_list: list[ModelInput], extra_suffix: str = "" 165 | ) -> tuple[str, dict[str, list[dict[str, str]]]]: 166 | """Converts a ModelInput into the format that the Anthropic API expects. The first output 167 | is the system prompt and the second is the messages list""" 168 | 169 | def model_input_to_anthropic_format(model_input: ModelInput | str) -> dict[str, str]: 170 | if isinstance(model_input, str): 171 | return {"role": RoleType.USER.name.lower(), "content": model_input} 172 | return {"role": model_input.role.name.lower(), "content": model_input.content} 173 | 174 | def add_actual_speech(messages: list[dict[str, str]], actual_speech: str) -> None: 175 | messages.append({"role": "assistant", "content": actual_speech}) 176 | 177 | messages = [model_input_to_anthropic_format(model_input) for model_input in input_list] 178 | if extra_suffix: 179 | add_actual_speech(messages=messages, actual_speech=extra_suffix) 180 | 181 | if messages[0]["role"] == RoleType.SYSTEM.name.lower(): 182 | return messages[0]["content"], messages[1:] 183 | return "", messages 184 | -------------------------------------------------------------------------------- /models/arbitrary_attribute_model.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from models.model import Model, ModelInput, ModelResponse, SpeechStructure 4 | from prompts import RoleType 5 | import utils.constants as constants 6 | 7 | from typing import Union, Optional 8 | import random 9 | import re 10 | 11 | 12 | class ArbitraryAttributeModel(Model): 13 | def __init__(self, alias: str, is_debater: bool = False, feature: Optional[str] = None, **kwargs): 14 | """ 15 | An ArbitraryAttributeModel model picks up on an arbitrary but deterministic feature. 16 | Can be used only for judging. Useful for testing. 17 | 18 | Args: 19 | alias: string that identifies the model for metrics and deduplication 20 | is_debater: boolean indicating whether the model is a debater (true) or judge (false) 21 | feature: the attribute to count when judging 22 | """ 23 | super().__init__(alias=alias, is_debater=is_debater) 24 | if is_debater: 25 | raise Exception("ArbitraryAttributeModel only supports judge mode") 26 | self.feature = feature or "quote" # TODO: change 27 | 28 | def predict( 29 | self, 30 | inputs: list[list[ModelInput]], 31 | max_new_tokens=250, 32 | speech_structure: SpeechStructure = SpeechStructure.OPEN_ENDED, 33 | num_return_sequences: int = 1, 34 | **kwargs, 35 | ) -> list[str]: 36 | """ 37 | Generates a list of texts in response to the given input. 38 | 39 | Args: 40 | inputs: A list of list of model inputs. Each ModelInput corresponds roughly to one command, 41 | a list of ModelInputs corresponds to a single debate (or entry in a batch), and so the 42 | list of lists is basically a batch of debates. Since the model will return a random 43 | response no matter what, the content of the input does not matter. 44 | max_new_tokens: The total number of new tokens to generate. 45 | speech_structure: The format that the answer is expected to be in. Option includes "open-ended" 46 | (which is just free text) and "decision" (which means a boolean is expected) 47 | num_return_sequences: The number of responses that the model is expected to generate. If a batch 48 | size of >1 is passed in, then this value will be overridden by the batch size (so you cannot 49 | have both num_return_sequences > 1 and len(inputs) > 1) 50 | 51 | Returns: 52 | A list of text, with one string for each entry in the batch (or for as many sequences are specified 53 | to be returned by num_return_sequences). 54 | 55 | Raises: 56 | Exception: Raises Exception if num_return_sequences > 1 and len(inputs) > 1 57 | """ 58 | 59 | def score_speeches(model_inputs: list[ModelInput]) -> tuple[str, tuple[float, float]]: 60 | speeches = [speech for speech in filter(lambda x: x.role == RoleType.USER, model_inputs)] 61 | a_speech = ( 62 | re.search( 63 | "This is what Debater_A said during their speech.(.*)This is what Debater_B said during their speech", 64 | speeches[-1].content, 65 | flags=re.DOTALL, 66 | ) 67 | .group(1) 68 | .strip() 69 | ) 70 | 71 | b_speech = ( 72 | re.search( 73 | "This is what Debater_B said during their speech.(.*)Now it is the time for a decision", 74 | speeches[-1].content, 75 | flags=re.DOTALL, 76 | ) 77 | .group(1) 78 | .strip() 79 | ) 80 | 81 | a_score = a_speech.count(self.feature) 82 | b_score = b_speech.count(self.feature) 83 | 84 | b_score = 5 # TODO: change this 85 | 86 | if a_score + b_score > 0: 87 | return constants.DEFAULT_DEBATER_A_NAME if a_score >= b_score else constants.DEFAULT_DEBATER_B_NAME, ( 88 | a_score / (a_score + b_score), 89 | b_score / (a_score + b_score), 90 | ) 91 | return constants.DEFAULT_DEBATER_A_NAME if random.random() <= 0.5 else constants.DEFAULT_DEBATER_B_NAME, ( 92 | 0.5, 93 | 0.5, 94 | ) 95 | 96 | if speech_structure != SpeechStructure.DECISION: 97 | raise Exception("ArbitraryAttributeModel only supports making decisions") 98 | 99 | if len(inputs) > 1 and num_return_sequences > 1: 100 | raise Exception( 101 | f"Length of input ({len(inputs)}) and num_return_sequences ({num_return_sequences}) cannot both be greater than 1." 102 | ) 103 | 104 | decisions = [] 105 | for i in range(len(inputs)): 106 | decision, (a_odds, b_odds) = score_speeches(inputs[i]) 107 | decisions.append( 108 | ModelResponse( 109 | decision=decision, 110 | probabilistic_decision={ 111 | constants.DEFAULT_DEBATER_A_NAME: a_odds, 112 | constants.DEFAULT_DEBATER_B_NAME: b_odds, 113 | }, 114 | prompt="\n".join([model_input.content for model_input in inputs[i]]), 115 | ) 116 | ) 117 | return decisions 118 | 119 | def copy(self, alias: str, is_debater: Optional[bool] = None, **kwargs) -> RandomModel: 120 | """Generates a deepcopy of this model""" 121 | return ArbitraryAttributeModel( 122 | alias=alias, is_debater=is_debater if is_debater is not None else False, feature=self.feature 123 | ) 124 | -------------------------------------------------------------------------------- /models/deterministic_model.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from models.model import Model, ModelInput, ModelResponse, SpeechStructure 4 | import utils.constants as constants 5 | 6 | from typing import Union, Optional 7 | import random 8 | 9 | 10 | class DeterministicModel(Model): 11 | def __init__(self, alias: str, is_debater: bool = False): 12 | """ 13 | A deterministic model responds with the same deterministic string in response to every input. Useful for testing. 14 | 15 | Args: 16 | alias: string that identifies the model for metrics and deduplication 17 | is_debater: boolean indicating whether the model is a debater (true) or judge (false) 18 | """ 19 | super().__init__(alias=alias, is_debater=is_debater) 20 | self.text = "My position is correct. You have to vote for me." if self.is_debater else "I am a judge. Let me think." 21 | self.text_length = len(self.text.split()) 22 | 23 | def predict( 24 | self, 25 | inputs: list[list[ModelInput]], 26 | max_new_tokens: int = 250, 27 | speech_structure: SpeechStructure = SpeechStructure.OPEN_ENDED, 28 | num_return_sequences: int = 1, 29 | **kwargs, 30 | ) -> ModelResponse: 31 | """ 32 | Generates a list of texts in response to the given input. 33 | 34 | Args: 35 | inputs: A list of list of model inputs. Each ModelInput corresponds roughly to one command, 36 | a list of ModelInputs corresponds to a single debate (or entry in a batch), and so the 37 | list of lists is basically a batch of debates. Since the model will return the same 38 | deterministic response no matter what, the content of the input does not matter. 39 | max_new_tokens: The total number of new tokens to generate. The canned line will be repeated 40 | until it reaches that limit. 41 | speech_structure: The format that the answer is expected to be in. Option includes "open-ended" 42 | (which is just free text), and "decision" (which means a boolean is expected) 43 | num_return_sequences: The number of responses that the model is expected to generate. If a batch 44 | size of >1 is passed in, then this value will be overridden by the batch size (so you cannot 45 | have both num_return_sequences > 1 and len(inputs) > 1) 46 | 47 | Returns: 48 | A list of ModelResponses, with one response for each entry in the batch (or for as many sequences are specified 49 | to be returned by num_return_sequences). 50 | 51 | Raises: 52 | Exception: Raises Exception if num_return_sequences > 1 and len(inputs) > 1 53 | """ 54 | if len(inputs) > 1 and num_return_sequences > 1: 55 | raise Exception( 56 | f"Length of input ({len(inputs)}) and num_return_sequences ({num_return_sequences}) cannot both be greater than 1." 57 | ) 58 | 59 | if speech_structure == SpeechStructure.DECISION: 60 | return [ModelResponse(decision=constants.DEFAULT_DEBATER_A_NAME) for i in range(len(inputs))] 61 | text_to_repeat = "\n".join([self.text for i in range(int(max_new_tokens / self.text_length))]) 62 | 63 | num_return_sequences = len(inputs) if len(inputs) > 1 else num_return_sequences 64 | return [ModelResponse(speech=text_to_repeat) for i in range(num_return_sequences)] 65 | 66 | def copy(self, alias: str, is_debater: Optional[bool] = None, **kwargs) -> DeterministicModel: 67 | """Generates a deepcopy of this model""" 68 | return DeterministicModel(alias=alias, is_debater=is_debater if is_debater is not None else False) 69 | -------------------------------------------------------------------------------- /models/human_model.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from models.model import Model, ModelInput, ModelResponse 4 | from data import SpeakerType, SpeechData 5 | from utils import logger_utils 6 | import utils.constants as constants 7 | 8 | from typing import Union, Optional 9 | 10 | 11 | class HumanModel(Model): 12 | def __init__(self, alias: str, is_debater: bool, debater_name: str, speeches: list[SpeechData], **kwargs): 13 | """ 14 | A human model returns the text that the human debaters generated during the human debate experiments. 15 | 16 | Args: 17 | alias: String that identifies the model for metrics and deduplication 18 | is_debater: Boolean indicating whether the model is a debater (true) or judge (false) 19 | speeches: List of speeches delivered by the human debaters. These speeches **must be in the same order 20 | as the subsequent debate rounds** 21 | """ 22 | super().__init__(alias=alias, is_debater=is_debater) 23 | position = 0 if debater_name == constants.DEFAULT_DEBATER_A_NAME else 1 24 | self.speeches = [ 25 | speech for speech in filter(lambda x: x.speaker_type == SpeakerType.DEBATER and x.position == position, speeches) 26 | ] 27 | self.speech_idx = 0 28 | self.debater_name = debater_name 29 | self.logger = logger_utils.get_default_logger(__name__) 30 | 31 | def predict(self, inputs: list[list[ModelInput]], **kwargs) -> ModelResponse: 32 | """ 33 | Generates a list of texts in response to the given input. This does not support batch processing. 34 | 35 | Args: 36 | Inputs: **This input is ignored**. This model returns a deterministic response so the content of the input 37 | does not matter. It is maintained only to be consistent with the interface. 38 | 39 | Returns: 40 | A list of length 1 containing the text of the corresponding speech from the human debates. 41 | 42 | Raises: 43 | Exception: Raises an exception if the batch size is greater than 1. 44 | """ 45 | if len(inputs) > 1: 46 | raise Exception(f"The HumanModel does not support batch processing. Input was of length {len(inputs)}") 47 | 48 | speech = "" 49 | if self.speech_idx < len(self.speeches): 50 | speech = self.speeches[self.speech_idx].text 51 | self.speech_idx += 1 52 | else: 53 | logger.warn( 54 | f"Human debater {self.alias} was unable to generate a speech. Current index is {self.speech_idx} but there are only {len(self.speeches)} in its speech list." 55 | ) 56 | return [ModelResponse(speech=speech, prompt="\n".join(model_input.content for model_input in inputs[0]))] 57 | 58 | def copy(self, alias: str, is_debater: Optional[bool] = None, **kwargs) -> HumanModel: 59 | """Generates a deepcopy of this model""" 60 | return HumanModel(alias=alias, is_debater=is_debater, debater_name=self.debater_name, speeches=self.speeches) 61 | -------------------------------------------------------------------------------- /models/model.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from pydantic import BaseModel, ConfigDict, field_validator, model_validator, validator 4 | 5 | from prompts import RoleType 6 | import utils.constants as constants 7 | 8 | from abc import ABC 9 | from enum import Enum 10 | from typing import Literal, Optional 11 | 12 | 13 | class ModelType(Enum): 14 | RANDOM = 1 15 | LLAMA = 2 16 | DETERMINISTIC = 3 17 | OPENAI = 4 18 | OFFLINE = 5 19 | HUMAN = 6 20 | MISTRAL = 7 21 | STUB_LLM = 8 22 | ARBITRARY_ATTRIBUTE = 9 23 | ANTHROPIC = 10 24 | LLAMA3 = 11 25 | REPETITIVE = 12 26 | 27 | 28 | class BestOfNConfig(BaseModel): 29 | n: int 30 | opponent_n: int = 0 31 | maxmin: bool = False 32 | recompute: bool = False 33 | 34 | 35 | class GenerationParams(BaseModel): 36 | max_new_tokens: int = 300 37 | temperature: float = 0.5 38 | top_p: float = 1.0 39 | repetition_penalty: float = 1.2 40 | do_sample: bool = True 41 | use_generation_penalties: bool = False 42 | 43 | 44 | class ModelInput(BaseModel): 45 | role: RoleType 46 | content: str 47 | 48 | 49 | class ModelResponse(BaseModel): 50 | speech: str = "" 51 | decision: Literal[constants.DEFAULT_DEBATER_A_NAME, constants.DEFAULT_DEBATER_B_NAME, ""] = "" 52 | probabilistic_decision: Optional[dict[str, float]] = None 53 | preference: Optional[float] = None 54 | rejected_responses: list[ModelResponse] = [] 55 | bon_opposing_model_responses: list[ModelResponse] = [] 56 | bon_probabilistic_preferences: list[float] = [] 57 | internal_representations: Optional[str] = "" 58 | response_tokens: list[int] = [] 59 | prompt_tokens: list[int] = [] 60 | prompt: str = "" 61 | failed: bool = False 62 | 63 | @validator("probabilistic_decision") 64 | def check_keys(cls, v): 65 | if v: 66 | if not constants.DEFAULT_DEBATER_A_NAME in v: 67 | raise ValueError(f"Probabilistic decision is missing required key: {constants.DEFAULT_DEBATER_A_NAME}") 68 | if not constants.DEFAULT_DEBATER_B_NAME in v: 69 | raise ValueError(f"Probabilistic decision is missing required key: {constants.DEFAULT_DEBATER_B_NAME}") 70 | if len(v) > 2: 71 | all_keys = ", ".join(v.keys()) 72 | raise ValueError(f"There are too many keys in the probabilistic decision map. Keys: {all_keys}") 73 | 74 | eps = 0.001 75 | total_prob = sum(v.values()) 76 | if total_prob < 1 - eps or total_prob > 1 + eps: 77 | raise ValueError(f"Total probability does not sum to 1 -- it sums to {total_prob}. Map is {v}") 78 | 79 | return v 80 | 81 | 82 | class ProbeHyperparams(BaseModel): 83 | file_path: str = "" 84 | hidden_size: Optional[int] = None 85 | linear_idxs: list[int] = [-1] 86 | 87 | 88 | class ModelSettings(BaseModel): 89 | model_type: str | ModelType = ModelType.RANDOM 90 | model_file_path: Optional[str] = None 91 | alias: str 92 | override_prompt: Optional[str] = None 93 | nucleus: bool = True 94 | is_human: bool = False 95 | offline_file_path: Optional[str] = None 96 | served: bool = False 97 | probe_hyperparams: Optional[ProbeHyperparams] = None 98 | require_quote_validation: bool = True 99 | generation_params: GenerationParams = GenerationParams() 100 | peft_base_model: Optional[str] = None 101 | 102 | @model_validator(mode="before") 103 | def verify_custom_settings(cls, values): 104 | existence_count = sum([values.get("is_human", False), values.get("served", False)]) + ( 105 | 1 if values.get("offline_file_path", None) else 0 106 | ) 107 | if existence_count > 1: 108 | raise ValueError("One cannot set more than one of is_human, served, or offline_file_path to non-null and true") 109 | return values 110 | 111 | model_config = ConfigDict(protected_namespaces=("protect_me_", "also_protect_")) 112 | 113 | @field_validator("alias", mode="before") 114 | @classmethod 115 | def validate_alias(cls, alias: str | int): 116 | return str(alias) 117 | 118 | @field_validator("model_type", mode="before") 119 | @classmethod 120 | def validate_model_type(cls, model_type: str | ModelType): 121 | if isinstance(model_type, str): 122 | return ModelType[model_type.upper()] 123 | return model_type 124 | 125 | 126 | class SpeechStructure(Enum): 127 | OPEN_ENDED = 1 128 | DECISION = 2 129 | 130 | 131 | class Model(ABC): 132 | def __init__(self, alias: str, is_debater: bool = False): 133 | self.alias = alias 134 | self.is_debater = is_debater 135 | 136 | def predict(self, inputs: list[list[ModelInput]], max_new_tokens: 250, **kwargs) -> ModelResponse: 137 | pass 138 | 139 | def copy(self, is_debater: Optional[bool] = None, **kwargs) -> Model: 140 | return self 141 | 142 | def can_merge(self, other: Model) -> bool: 143 | return other == self 144 | 145 | def merge(self, other: Model) -> Model: 146 | if self.can_merge(other): 147 | return self 148 | raise Exception("Cannot merge across models") 149 | -------------------------------------------------------------------------------- /models/model_utils.py: -------------------------------------------------------------------------------- 1 | from models.anthropic_model import AnthropicModel 2 | from models.arbitrary_attribute_model import ArbitraryAttributeModel 3 | from models.deterministic_model import DeterministicModel 4 | from models.llm_model import LlamaModel, Llama3Model, MistralModel, StubLLModel 5 | from models.model import Model, ModelSettings, ModelType 6 | from models.openai_model import OpenAIModel 7 | from models.random_model import RandomModel 8 | from models.repetitive_model import RepetitiveModel 9 | from models.served_model import ServedModel 10 | 11 | from pydantic import BaseModel 12 | 13 | from enum import Enum 14 | from typing import Optional 15 | 16 | 17 | class ModelUtils: 18 | @classmethod 19 | def instantiate_model( 20 | cls, 21 | model_settings: ModelSettings, 22 | is_debater: bool = True, 23 | ) -> Optional[Model]: 24 | """ 25 | Builds a model using the given inputs. 26 | 27 | Args: 28 | model_settings: the configuration object for the model 29 | is_debater: Boolean indicating if the model is to be used as a debater or judge. 30 | 31 | Returns: 32 | An instantiated model of the given type. 33 | 34 | Raises: 35 | Exception: Raises exception if the model type does not exist or if the model cannot be instantiated 36 | directly. At the moment, neither the OfflineModel nor the HumanModel can be instantiated directly. 37 | """ 38 | model_type = model_settings.model_type 39 | if model_type == ModelType.RANDOM: 40 | model = RandomModel(alias=model_settings.alias, is_debater=is_debater) 41 | elif model_type == ModelType.LLAMA: 42 | model = LlamaModel( 43 | alias=model_settings.alias, 44 | file_path=model_settings.model_file_path, 45 | is_debater=is_debater, 46 | nucleus=model_settings.nucleus, 47 | probe_hyperparams=model_settings.probe_hyperparams, 48 | generation_params=model_settings.generation_params, 49 | peft_base_model=model_settings.peft_base_model, 50 | ) 51 | elif model_type == ModelType.MISTRAL: 52 | model = MistralModel( 53 | alias=model_settings.alias, 54 | file_path=model_settings.model_file_path, 55 | is_debater=is_debater, 56 | nucleus=model_settings.nucleus, 57 | probe_hyperparams=model_settings.probe_hyperparams, 58 | generation_params=model_settings.generation_params, 59 | peft_base_model=model_settings.peft_base_model, 60 | ) 61 | elif model_type == ModelType.LLAMA3: 62 | model = Llama3Model( 63 | alias=model_settings.alias, 64 | file_path=model_settings.model_file_path, 65 | is_debater=is_debater, 66 | nucleus=model_settings.nucleus, 67 | probe_hyperparams=model_settings.probe_hyperparams, 68 | generation_params=model_settings.generation_params, 69 | peft_base_model=model_settings.peft_base_model, 70 | ) 71 | elif model_type == ModelType.STUB_LLM: 72 | model = StubLLModel(alias=model_settings.alias, generation_params=model_settings.generation_params) 73 | elif model_type == ModelType.DETERMINISTIC: 74 | model = DeterministicModel(alias=model_settings.alias, is_debater=is_debater) 75 | elif model_type == ModelType.OPENAI: 76 | model = OpenAIModel(alias=model_settings.alias, is_debater=is_debater, endpoint=model_settings.model_file_path) 77 | elif model_type == ModelType.ARBITRARY_ATTRIBUTE: 78 | model = ArbitraryAttributeModel(alias=model_settings.alias, is_debater=is_debater) 79 | elif model_type == ModelType.ANTHROPIC: 80 | model = AnthropicModel( 81 | alias=model_settings.alias, is_debater=is_debater, endpoint=model_settings.model_file_path 82 | ) 83 | elif model_type == ModelType.REPETITIVE: 84 | model = RepetitiveModel(alias=model_settings.alias, is_debater=is_debater) 85 | elif model_type == ModelType.OFFLINE: 86 | model = None # offline models aren't directly instantiated 87 | elif model_type == ModelType.HUMAN: 88 | model = None # offline models aren't directly instantiated 89 | else: 90 | raise Exception(f"Model {model_type} not found") 91 | 92 | if model_settings.served: 93 | if model_type in [ModelType.LLAMA, ModelType.MISTRAL]: # expand when more types allow serving 94 | model = ServedModel(base_model=model) 95 | else: 96 | raise Exception(f"Model type {model_type} does not support serving") 97 | 98 | return model 99 | -------------------------------------------------------------------------------- /models/openai_model.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from models.model import Model, ModelInput, ModelResponse, RoleType, SpeechStructure 4 | from utils import logger_utils 5 | import utils.constants as constants 6 | 7 | import backoff 8 | import openai 9 | 10 | from concurrent.futures import ThreadPoolExecutor 11 | from typing import Union, Optional 12 | import logging 13 | import os 14 | import math 15 | import random 16 | import re 17 | 18 | 19 | class OpenAIModel(Model): 20 | MAX_PARALLEL_REQUESTS = 16 21 | DEFAULT_MODEL_ENDPOINT = "gpt-4-0125-preview" 22 | 23 | def __init__(self, alias: str, is_debater: bool = True, endpoint: Optional[str] = None, **kwargs): 24 | """ 25 | An OpenAIModel calls GPT4 to generate the appropriate text. 26 | 27 | Args: 28 | alias: String that identifies the model for metrics and deduplication 29 | is_debater: Boolean indicating whether the model is a debater (true) or judge (false) 30 | """ 31 | super().__init__(alias=alias, is_debater=is_debater) 32 | self.__configure() 33 | self.client = openai.OpenAI() 34 | self.endpoint = endpoint if endpoint else OpenAIModel.DEFAULT_MODEL_ENDPOINT 35 | self.logger = logger_utils.get_default_logger(__name__) 36 | 37 | def __configure(self): 38 | openai.organization = os.getenv("OPENAI_ORGANIZATION") 39 | openai.api_key = os.getenv("OPENAI_API_KEY") 40 | 41 | def predict( 42 | self, 43 | inputs: list[list[ModelInput] | str], 44 | max_new_tokens=200, 45 | speech_structure: SpeechStructure = SpeechStructure.OPEN_ENDED, 46 | **kwargs, 47 | ) -> list[ModelResponse]: 48 | """ 49 | Generates a list of texts in response to the given input. 50 | 51 | Args: 52 | inputs: A list of list of model inputs. Each ModelInput corresponds roughly to one command, 53 | a list of ModelInputs corresponds to a single debate (or entry in a batch), and so the 54 | list of lists is basically a batch of debates. 55 | max_new_tokens: The maximum total number of new tokens to generate. 56 | speech_structure: the format that the answer is expected to be in. Option includes "open-ended" 57 | (which is just free text), and "decision" (which means a boolean is expected) 58 | 59 | Returns: 60 | A list of model responses, with one string for each entry in the batch. 61 | """ 62 | with ThreadPoolExecutor(max_workers=OpenAIModel.MAX_PARALLEL_REQUESTS) as executor: 63 | futures = [ 64 | executor.submit( 65 | self.predict_single_input, 66 | model_input_list=input_value, 67 | max_new_tokens=max_new_tokens, 68 | speech_structure=speech_structure, 69 | ) 70 | for input_value in inputs 71 | ] 72 | results = [future.result() for future in futures] 73 | 74 | return results 75 | 76 | def predict_single_input( 77 | self, 78 | model_input_list: list[ModelInput] | str, 79 | max_new_tokens=200, 80 | speech_structure: SpeechStructure = SpeechStructure.OPEN_ENDED, 81 | **kwargs, 82 | ) -> ModelResponse: 83 | """ 84 | Generates a list of texts in response to a single given input. 85 | 86 | Args: 87 | model_input_list: A list of model inputs. Each ModelInput corresponds roughly to one command 88 | max_new_tokens: The maximum total number of new tokens to generate. 89 | speech_structure: the format that the answer is expected to be in. Option includes "open-ended" 90 | (which is just free text) and "decision" (which means a boolean is expected) 91 | 92 | Returns: 93 | A list of model responses, with one string for each entry in the batch. 94 | """ 95 | 96 | def extract_response_from_structured_speech(message: str, regex_str: str, default: str) -> str: 97 | match = re.match(regex_str, message) 98 | if match: 99 | return match.group(1) 100 | else: 101 | self.logger.warn("The regex {} did not match the following message: {}".format(regex_str, message)) 102 | return default 103 | 104 | def process_logprobs(completion: dict) -> tuple[float, float]: 105 | debater_suffixes = ["_A", "_B"] 106 | logprobs = completion.choices[0].logprobs.content 107 | for entry in logprobs: 108 | if entry.token in debater_suffixes: 109 | scores = {suffix: 0 for suffix in debater_suffixes} 110 | for option in filter(lambda x: x.token in debater_suffixes, entry.top_logprobs): 111 | scores[option.token] = math.exp(float(option.logprob)) 112 | total_probs = sum(scores.values()) 113 | renormalized_scores = {suffix: scores[suffix] / total_probs for suffix in scores} 114 | return ( 115 | renormalized_scores[debater_suffixes[0]], 116 | renormalized_scores[debater_suffixes[1]], 117 | ) 118 | return 0.5, 0.5 119 | 120 | messages = OpenAIModel.generate_llm_input_from_model_inputs(input_list=model_input_list) 121 | 122 | try: 123 | completion = self.call_openai( 124 | messages=messages, max_new_tokens=max_new_tokens, speech_structure=speech_structure 125 | ) 126 | except openai.APIError as e: 127 | self.logger.warn(f"OpenAI API returned an API Error: {e}") 128 | self.logger.warn(e) 129 | return ModelResponse(failed=True) 130 | except openai.APIConnectionError as e: 131 | self.logger.warn(f"Failed to connect to OpenAI API: {e}") 132 | self.logger.warn(e) 133 | return ModelResponse(failed=True) 134 | except openai.RateLimitError as e: 135 | self.logger.warn(f"OpenAI API request exceeded rate limit: {e}") 136 | self.logger.warn(e) 137 | return ModelResponse(failed=True) 138 | 139 | message = completion.choices[0].message.content 140 | 141 | if speech_structure == SpeechStructure.DECISION: 142 | a_odds, b_odds = process_logprobs(completion) 143 | message = ( 144 | constants.DEFAULT_DEBATER_A_NAME 145 | if a_odds > b_odds 146 | else ( 147 | constants.DEFAULT_DEBATER_B_NAME 148 | if (b_odds > a_odds or random.random() > 0.5) 149 | else constants.DEFAULT_DEBATER_A_NAME 150 | ) 151 | ) 152 | self.logger.debug(f"Debater A's odds: {a_odds}, Debater B's odds: {b_odds}, Winner: {message}") 153 | return ModelResponse( 154 | decision=message, 155 | probabilistic_decision={ 156 | constants.DEFAULT_DEBATER_A_NAME: a_odds, 157 | constants.DEFAULT_DEBATER_B_NAME: b_odds, 158 | }, 159 | prompt="\n".join(model_input.content for model_input in model_input_list), 160 | ) 161 | 162 | return ModelResponse(speech=message, prompt="\n".join(model_input.content for model_input in model_input_list)) 163 | 164 | @backoff.on_exception(backoff.expo, backoff.on_exception, max_tries=4) 165 | def call_openai( 166 | self, messages: list[dict[str, str]], speech_structure: SpeechStructure, max_new_tokens: int 167 | ) -> openai.ChatCompletion: 168 | return self.client.chat.completions.create( 169 | model=self.endpoint, 170 | messages=messages, 171 | max_tokens=max_new_tokens, 172 | logprobs=(speech_structure != SpeechStructure.OPEN_ENDED), 173 | top_logprobs=5 if (speech_structure != SpeechStructure.OPEN_ENDED) else None, 174 | ) 175 | 176 | def copy(self, alias: str, is_debater: Optional[bool] = None, **kwargs) -> OpenAIModel: 177 | """Generates a deepcopy of this model""" 178 | return OpenAIModel(alias=alias, is_debater=is_debater, endpoint=self.endpoint) 179 | 180 | @classmethod 181 | def generate_llm_input_from_model_inputs( 182 | cls, input_list: list[ModelInput], extra_suffix: str = "" 183 | ) -> dict[str, list[dict[str, str]]]: 184 | """Converts a ModelInput into the format that the OpenAI API expects""" 185 | 186 | def model_input_to_openai_format(model_input: ModelInput | str) -> dict[str, str]: 187 | if isinstance(model_input, str): 188 | return {"role": RoleType.USER.name.lower(), "content": model_input} 189 | return {"role": model_input.role.name.lower(), "content": model_input.content} 190 | 191 | def add_actual_speech(messages: list[dict[str, str]], actual_speech: str) -> None: 192 | messages.append({"role": "assistant", "content": actual_speech}) 193 | 194 | messages = [model_input_to_openai_format(model_input) for model_input in input_list] 195 | if extra_suffix: 196 | add_actual_speech(messages=messages, actual_speech=extra_suffix) 197 | return messages 198 | -------------------------------------------------------------------------------- /models/random_model.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from models.model import Model, ModelInput, ModelResponse, SpeechStructure 4 | import utils.constants as constants 5 | from utils import logger_utils 6 | 7 | from typing import Union, Optional 8 | import random 9 | import re 10 | 11 | 12 | class RandomModel(Model): 13 | def __init__(self, alias: str, is_debater: bool = False, **kwargs): 14 | """ 15 | A random model responds with a random string in response to every input. Useful for testing. 16 | 17 | Args: 18 | alias: string that identifies the model for metrics and deduplication 19 | is_debater: boolean indicating whether the model is a debater (true) or judge (false) 20 | """ 21 | super().__init__(alias=alias, is_debater=is_debater) 22 | self.alphabet = "abcdefghijklmnopqrstuvwxyz" 23 | self.logger = logger_utils.get_default_logger(__name__) 24 | 25 | def predict( 26 | self, 27 | inputs: list[list[ModelInput]], 28 | max_new_tokens=250, 29 | speech_structure: SpeechStructure = SpeechStructure.OPEN_ENDED, 30 | num_return_sequences: int = 1, 31 | **kwargs, 32 | ) -> list[str]: 33 | """ 34 | Generates a list of texts in response to the given input. 35 | 36 | Args: 37 | inputs: A list of list of model inputs. Each ModelInput corresponds roughly to one command, 38 | a list of ModelInputs corresponds to a single debate (or entry in a batch), and so the 39 | list of lists is basically a batch of debates. Since the model will return a random 40 | response no matter what, the content of the input does not matter. 41 | max_new_tokens: The total number of new tokens to generate. 42 | speech_structure: The format that the answer is expected to be in. Option includes "open-ended" 43 | (which is just free text) and "decision" (which means a boolean is expected) 44 | num_return_sequences: The number of responses that the model is expected to generate. If a batch 45 | size of >1 is passed in, then this value will be overridden by the batch size (so you cannot 46 | have both num_return_sequences > 1 and len(inputs) > 1) 47 | 48 | Returns: 49 | A list of text, with one string for each entry in the batch (or for as many sequences are specified 50 | to be returned by num_return_sequences). 51 | 52 | Raises: 53 | Exception: Raises Exception if num_return_sequences > 1 and len(inputs) > 1 54 | """ 55 | 56 | def generate_random_text(): 57 | return ( 58 | " ".join( 59 | [ 60 | "".join(random.choices(self.alphabet, k=random.randrange(1, 8))) 61 | for i in range(random.randrange(1, max_new_tokens)) 62 | ] 63 | ) 64 | + f"{constants.QUOTE_TAG} This is not real {constants.UNQUOTE_TAG}." 65 | ) 66 | 67 | def generate_random_decision(): 68 | a_odds = random.random() 69 | b_odds = 1 - a_odds 70 | decision = constants.DEFAULT_DEBATER_A_NAME if a_odds > 0.5 else constants.DEFAULT_DEBATER_B_NAME 71 | self.logger.debug(f"Debater A's odds: {a_odds}, Debater B's odds: {b_odds}") 72 | return decision, (a_odds, b_odds) 73 | 74 | if len(inputs) > 1 and num_return_sequences > 1: 75 | raise Exception( 76 | f"Length of input ({len(inputs)}) and num_return_sequences ({num_return_sequences}) cannot both be greater than 1." 77 | ) 78 | 79 | if speech_structure == SpeechStructure.DECISION: 80 | decisions = [] 81 | for i in range(len(inputs)): 82 | decision, (a_odds, b_odds) = generate_random_decision() 83 | decisions.append( 84 | ModelResponse( 85 | decision=decision, 86 | probabilistic_decision={ 87 | constants.DEFAULT_DEBATER_A_NAME: a_odds, 88 | constants.DEFAULT_DEBATER_B_NAME: b_odds, 89 | }, 90 | prompt="\n".join([model_input.content for model_input in inputs[i]]), 91 | ) 92 | ) 93 | return decisions 94 | 95 | num_return_sequences = max(num_return_sequences, len(inputs)) 96 | return [ 97 | ModelResponse( 98 | speech=generate_random_text(), prompt="\n".join([model_input.content for model_input in inputs[i]]) 99 | ) 100 | for i in range(num_return_sequences) 101 | ] 102 | 103 | def copy(self, alias: str, is_debater: Optional[bool] = None, **kwargs) -> RandomModel: 104 | """Generates a deepcopy of this model""" 105 | return RandomModel(alias=alias, is_debater=is_debater if is_debater is not None else False) 106 | -------------------------------------------------------------------------------- /models/repetitive_model.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from models.model import Model, ModelInput, ModelResponse, SpeechStructure 4 | import utils.constants as constants 5 | 6 | from typing import Union, Optional 7 | import random 8 | import sys 9 | 10 | 11 | class RepetitiveModel(Model): 12 | def __init__(self, alias: str, is_debater: bool = False): 13 | """ 14 | An repetitive model only works for judging and always responds with whatever letter appeared most frequently in the previous speech. 15 | Useful for evaluating whether an open debater chose the correct side and for debugging 16 | 17 | Args: 18 | alias: string that identifies the model for metrics and deduplication 19 | is_debater: boolean indicating whether the model is a debater (true) or judge (false) 20 | """ 21 | super().__init__(alias=alias, is_debater=is_debater) 22 | 23 | def predict( 24 | self, 25 | inputs: list[list[ModelInput]], 26 | max_new_tokens: int = 250, 27 | speech_structure: SpeechStructure = SpeechStructure.OPEN_ENDED, 28 | num_return_sequences: int = 1, 29 | **kwargs, 30 | ) -> ModelResponse: 31 | """ 32 | Generates a list of texts in response to the given input. 33 | 34 | Args: 35 | inputs: A list of list of model inputs. Each ModelInput corresponds roughly to one command, 36 | a list of ModelInputs corresponds to a single debate (or entry in a batch), and so the 37 | list of lists is basically a batch of debates. Since the model will return the same 38 | deterministic response no matter what, the content of the input does not matter. 39 | max_new_tokens: The total number of new tokens to generate. The canned line will be repeated 40 | until it reaches that limit. 41 | speech_structure: The format that the answer is expected to be in. Option includes "open-ended" 42 | (which is just free text), and "decision" (which means a boolean is expected) 43 | num_return_sequences: The number of responses that the model is expected to generate. If a batch 44 | size of >1 is passed in, then this value will be overridden by the batch size (so you cannot 45 | have both num_return_sequences > 1 and len(inputs) > 1) 46 | 47 | Returns: 48 | A list of ModelResponses, with one response for each entry in the batch (or for as many sequences are specified 49 | to be returned by num_return_sequences). 50 | 51 | Raises: 52 | Exception: Raises Exception if num_return_sequences > 1 and len(inputs) > 1 53 | """ 54 | if len(inputs) > 1 and num_return_sequences > 1: 55 | raise Exception( 56 | f"Length of input ({len(inputs)}) and num_return_sequences ({num_return_sequences}) cannot both be greater than 1." 57 | ) 58 | 59 | outputs = [] 60 | if speech_structure == SpeechStructure.DECISION: 61 | for model_input in inputs: 62 | content = ( 63 | constants.DEFAULT_DEBATER_A_NAME 64 | if model_input[-1].content.lower().rfind("a") > model_input[-1].content.lower().rfind("b") 65 | else constants.DEFAULT_DEBATER_B_NAME 66 | ) 67 | outputs.append(ModelResponse(decision=content)) 68 | else: 69 | for model_input in inputs: 70 | outputs.append(ModelResponse(speech="A" if random.random() < 0.5 else "B")) 71 | return outputs 72 | 73 | def copy(self, alias: str, is_debater: Optional[bool] = None, **kwargs) -> DeterministicModel: 74 | """Generates a deepcopy of this model""" 75 | return RepetitiveModel(alias=alias, is_debater=is_debater if is_debater is not None else False) 76 | -------------------------------------------------------------------------------- /models/served_model.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from models.model import Model, ModelInput, ModelResponse, SpeechStructure 4 | from models.llm_model import GenerationParams, LLModel 5 | from utils import logger_utils, timer 6 | 7 | from pydantic import BaseModel 8 | import requests 9 | 10 | from typing import Optional 11 | from concurrent.futures import ThreadPoolExecutor 12 | import sys 13 | import time 14 | 15 | 16 | class RequestParams(BaseModel): 17 | inputs: str 18 | parameters: GenerationParams 19 | 20 | 21 | class ResponseStruct(BaseModel): 22 | generated_text: str 23 | 24 | 25 | class ServedModel(Model): 26 | DEFAULT_GENERATION_PARAMS = GenerationParams() 27 | DEFAULT_SERVING_ENDPOINT = "http://127.0.0.1:8080/generate" 28 | MAX_PARALLEL_REQUESTS = 8 29 | DEFAULT_HEADER = {"Content-Type": "application/json"} 30 | 31 | def __init__(self, base_model: Model): 32 | """ 33 | A served model calls a hosted model running on a local endpoint for inference. 34 | 35 | Args: 36 | base_model: A model of the type that is being served. This is needed so that we can 37 | define the input format appropriately and set the correct alias. 38 | """ 39 | super().__init__(alias=base_model.alias, is_debater=base_model.is_debater) 40 | self.base_model = base_model 41 | self.logger = logger_utils.get_default_logger(__name__) 42 | 43 | def fetch(self, input_string: str) -> str: 44 | """Hits the default endpoint for the served model""" 45 | data = RequestParams(inputs=input_string, parameters=ServedModel.DEFAULT_GENERATION_PARAMS).dict() 46 | response = requests.post(ServedModel.DEFAULT_SERVING_ENDPOINT, headers=ServedModel.DEFAULT_HEADER, json=data) 47 | return ResponseStruct(**response.json()).generated_text 48 | 49 | @timer("served LLM inference") 50 | def predict( 51 | self, 52 | inputs: list[list[ModelInput]], 53 | max_new_tokens=300, 54 | num_return_sequences: int = 1, 55 | **kwargs, 56 | ) -> list[str]: 57 | """ 58 | Generates a list of texts in response to the given input. Note that this can only be used for 59 | speeches and not for judging since the log probs are not exposed. 60 | 61 | Args: 62 | inputs: A list of list of model inputs. Each ModelInput corresponds roughly to one command, 63 | a list of ModelInputs corresponds to a single debate (or entry in a batch), and so the 64 | list of lists is basically a batch of debates. 65 | max_new_tokens: the maximum number of new tokens to generate. 66 | num_return_sequences: the number of responses that the model is expected to generate. If a batch 67 | size of >1 is passed in, then this value will be overridden by the batch size (so you cannot 68 | have both num_return_sequences > 1 and len(inputs) > 1) 69 | 70 | Returns: 71 | A list of text, with one string for each entry in the batch (or for as many sequences are specified 72 | to be returned by num_return_sequences). 73 | 74 | Raises: 75 | Exception: Raises Exception if num_return_sequences > 1 and len(inputs) > 1 76 | """ 77 | 78 | if num_return_sequences > 1 and len(inputs) > 1: 79 | raise Exception("You cannot have multiple return sequences and a batch size of >1") 80 | 81 | with ThreadPoolExecutor(max_workers=ServedModel.MAX_PARALLEL_REQUESTS) as executor: 82 | input_strs = [ 83 | input_string 84 | for input_string in self.base_model.generate_input_strs( 85 | inputs=inputs, speech_structure=SpeechStructure.OPEN_ENDED 86 | ) 87 | ] 88 | futures = [executor.submit(self.fetch, input_string) for input_string in input_strs] 89 | results = [ModelResponse(speech=future.result(), prompt=input_strs[i]) for i, future in enumerate(futures)] 90 | 91 | return results 92 | -------------------------------------------------------------------------------- /outputs/graphs/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /outputs/runs/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /outputs/stats/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /outputs/transcripts/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /prompts/__init__.py: -------------------------------------------------------------------------------- 1 | from .parser import ( 2 | ExamplesTag, 3 | HardcodedTopicConfig, 4 | Message, 5 | Prompt, 6 | PromptConfig, 7 | PromptTag, 8 | RoleType, 9 | PromptLoadingConfig, 10 | PromptParser, 11 | ) 12 | -------------------------------------------------------------------------------- /prompts/parser.py: -------------------------------------------------------------------------------- 1 | from data import AnnotationBracket, AnnotatedQualityDebatesDataset, AnnotationTag, DataRow, SplitType 2 | import utils.constants as constants 3 | 4 | from pydantic import BaseModel 5 | import yaml 6 | 7 | from enum import Enum 8 | from typing import Any, Optional 9 | import os 10 | import re 11 | 12 | 13 | class HardcodedTopicConfig(BaseModel): 14 | topic: str 15 | positions: tuple[str, str] 16 | 17 | 18 | class PromptLoadingConfig(BaseModel): 19 | file_path: Optional[str] = None 20 | default_prompt_name: str = "Debate Prompt" 21 | use_hardcoded_topics: bool = False 22 | hardcoded_topic_config: Optional[HardcodedTopicConfig] = None 23 | is_memorized: bool = False 24 | 25 | 26 | class PromptTag(Enum): 27 | PRE_DEBATE = 1 28 | PRE_OPENING_SPEECH = 2 29 | PRE_OPPONENT_SPEECH = 3 30 | DEBATER_SYSTEM = 4 31 | JUDGE_SYSTEM = 5 32 | OVERALL_SYSTEM = 6 33 | PRE_SPEECH = 7 34 | PRE_DEBATE_JUDGE = 8 35 | PRE_DEBATER_A_SPEECH_JUDGE = 9 36 | PRE_DEBATER_B_SPEECH_JUDGE = 10 37 | POST_ROUND_JUDGE = 11 38 | JUDGE_QUESTION_INSTRUCTIONS = 12 39 | PRE_JUDGE_QUESTIONS = 13 40 | JUDGE_DECISION = 14 41 | DEBATER_SCRATCHPAD = 16 42 | JUDGE_DECISION_FOR_DEBATER = 17 43 | PREVIOUS_DEBATER_SCRATCHPAD = 18 44 | PRE_PREVIOUS_SPEECH = 19 45 | POST_ROUND_JUDGE_WITHOUT_REASONING = 20 46 | 47 | 48 | class RoleType(Enum): 49 | SYSTEM = 1 50 | USER = 2 51 | ASSISTANT = 3 52 | 53 | 54 | class ExamplesTag(Enum): 55 | POSITIVE_EXAMPLES = 1 56 | NEGATIVE_EXAMPLES = 2 57 | 58 | 59 | class Message(BaseModel): 60 | role: str | RoleType 61 | content: str | list[str] 62 | 63 | 64 | class Prompt(BaseModel): 65 | name: str 66 | messages: dict[str, dict[str, Any]] | dict[PromptTag, Message] 67 | 68 | 69 | class PromptConfig(BaseModel): 70 | name: str 71 | opponent_name: str 72 | position: str 73 | opponent_position: str 74 | topic: str 75 | background_text: str 76 | 77 | 78 | class PromptParser: 79 | DEFAULT_PROMPT_FILE_PATH = os.environ[constants.SRC_ROOT] + "prompts/configs/prompts.yaml" 80 | DEFAULT_PROMPT_NAME = "Debate Prompt" 81 | 82 | try: 83 | with open(DEFAULT_PROMPT_FILE_PATH) as f: 84 | DEFAULT_YAML = yaml.safe_load(f) 85 | except: 86 | DEFAULT_YAML = None 87 | 88 | @classmethod 89 | def parse( 90 | cls, 91 | prompt_config: PromptConfig, 92 | prompts_file_path: Optional[str] = None, 93 | name: str = "Debate Prompt", 94 | ) -> Prompt: 95 | """ 96 | Constructs a Prompt object that can then be used by a Debater or Judge to generate text. 97 | 98 | Params: 99 | prompt_config: configuration containing the values to fill in the prompt with 100 | (e.g. names of the debaters, topic to be debated, background text) 101 | prompts_file_path: path to where the prompt messages are listed 102 | name: the specific prompt name to use (aka which messages to select from the prompt file) 103 | 104 | Returns: 105 | prompt: a prompt object containing a list of messages that the agents use to run a debate round 106 | """ 107 | if not prompts_file_path or prompts_file_path == PromptParser.DEFAULT_PROMPT_FILE_PATH and DEFAULT_YAML: 108 | loaded_yaml = PromptParser.DEFAULT_YAML 109 | else: 110 | prompts_file_path = prompts_file_path or PromptParser.DEFAULT_PROMPT_FILE_PATH 111 | with open(prompts_file_path) as f: 112 | loaded_yaml = yaml.safe_load(f) 113 | 114 | name = name or PromptParser.DEFAULT_PROMPT_NAME 115 | prompt = Prompt(name=name, messages=loaded_yaml[name]) 116 | prompt.messages = {PromptTag[tag.upper()]: Message(**message) for tag, message in prompt.messages.items()} 117 | 118 | base_prompt = Prompt(name=name, messages=loaded_yaml[name]) 119 | base_prompt.messages = {PromptTag[tag.upper()]: Message(**message) for tag, message in base_prompt.messages.items()} 120 | 121 | for prop, value in prompt_config: 122 | key = f"<{prop.upper()}>" 123 | for tag, messages in prompt.messages.items(): 124 | for i, message in enumerate(messages.content): 125 | prompt.messages[tag].content[i] = message.replace(key, str(value)) 126 | for tag, messages in base_prompt.messages.items(): 127 | for i, message in enumerate(messages.content): 128 | base_prompt.messages[tag].content[i] = message.replace(key, str(value)) 129 | if tag not in prompt.messages: 130 | prompt.messages[tag] = base_prompt.messages[tag] 131 | 132 | return prompt 133 | 134 | @classmethod 135 | def generate_opponent_config(cls, config: PromptConfig) -> PromptConfig: 136 | """Generates a prompt config using the config from an opposing debater""" 137 | return PromptConfig( 138 | name=config.opponent_name, 139 | opponent_name=config.name, 140 | position=config.opponent_position, 141 | opponent_position=config.position, 142 | topic=config.topic, 143 | background_text=config.background_text, 144 | ) 145 | 146 | @classmethod 147 | def convert_data_row_to_default_prompt_config( 148 | cls, row: DataRow, position: int, use_title_as_background_text: bool = False 149 | ) -> PromptConfig: 150 | """Generates a default prompt config using a data row -- used in training""" 151 | position = max(position, 0) 152 | return PromptConfig( 153 | name=constants.DEFAULT_DEBATER_A_NAME if position == 0 else constants.DEFAULT_DEBATER_B_NAME, 154 | opponent_name=constants.DEFAULT_DEBATER_B_NAME if position == 0 else constants.DEFAULT_DEBATER_A_NAME, 155 | position=row.positions[position], 156 | opponent_position=row.positions[(position - 1) * -1], 157 | topic=row.question, 158 | background_text=row.background_text if not use_title_as_background_text else row.story_title, 159 | ) 160 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | anthropic 2 | backoff 3 | datasets 4 | bitsandbytes==0.41.3.post2 5 | deepspeed==0.15.1 6 | einops 7 | matplotlib 8 | matplotlib-inline 9 | ninja==1.11.1 10 | numpy>=2.0 11 | openai>=1.51.0 12 | pandas==1.5.3 13 | peft==0.5.0 14 | protobuf==4.24.3 15 | pydantic>=2.0.0 16 | python-dotenv 17 | pyyaml 18 | PyYAML 19 | Requests 20 | rich 21 | seaborn 22 | sentence_transformers 23 | sentencepiece==0.1.99 24 | scipy==1.10.1 25 | spacy 26 | torch>=2.0.1 27 | transformers>=4.41.2 28 | trl>=0.7.10 29 | wandb -------------------------------------------------------------------------------- /scripts/generate_quote_labels.py: -------------------------------------------------------------------------------- 1 | from script_utils import ScriptUtils 2 | 3 | ScriptUtils.setup_script() 4 | 5 | from data import QuoteRelevanceLoader, QuoteRelevanceDataset, QuoteRelevanceProcessedBatchItem, QuoteRelevanceTopicInfo 6 | from models import OpenAIModel, ModelInput 7 | from prompts import RoleType 8 | from utils import input_utils 9 | import utils.constants as constants 10 | 11 | from pydantic import BaseModel 12 | from tqdm import tqdm 13 | 14 | import argparse 15 | import pickle 16 | import random 17 | import re 18 | import os 19 | 20 | DEFAULT_INSTRUCTIONS = """ 21 | You will be presented with a topic that is going to be debated by two debaters as part of an educational exercise. 22 | They will be debating the answer to a reading comprehension question. 23 | They have read the story the question is about and their goal is to convince a judge as to the correct answer, 24 | despite the judge not having read the story themselves. As part of the preparations for this exercise, 25 | the debaters have selected quotes from a background story. I will present those quotes to you in order. 26 | Please respond whether you think that quote could potentially help support their side. 27 | Each numbered line is a separate quote. You should respond similarly, with a numeric rating 28 | (0 is completely irrelevant, 10 is obviously very relevant, 5 is potentially relevant) 29 | next to the appropriate numbered item. Do not provide any extra text. 30 | If the debater did not provide any quotes, we will write 'None provided'. In that case, just ignore that debater. 31 | 32 | Example output format: 33 | 34 | Debater_A: 35 | 1. 3 36 | 2. 8 37 | 3. 4 38 | 39 | Debater_B: 40 | 1. 3 41 | 2. 8 42 | 3. 4 43 | 4. 9 44 | 45 | Here is the question they will be discussing: 46 | 47 | 48 | 49 | Here is the side that Debater_A is defending: 50 | 51 | 52 | 53 | Here is the side Debater_B is defending: 54 | 55 | 56 | 57 | Here are the quotes that Debater_A selected: 58 | 59 | 60 | 61 | Here are the quotes that Debater_B selected: 62 | 63 | 64 | 65 | Now provide us the expected feedback: 66 | """ 67 | 68 | 69 | class QuoteRelevanceBatchItem(BaseModel): 70 | a_quote_list: list[str] 71 | b_quote_list: list[str] 72 | model_input: ModelInput 73 | question_info: QuoteRelevanceTopicInfo 74 | 75 | 76 | class FakeOpenAIModel: 77 | def predict(self, a_quote_list: list[str], b_quote_list: list[str]): 78 | a_scores = [random.randint(0, 10) for quote in a_quote_list] 79 | a_text = "\n".join([f"{i + 1}. {score}" for i, score in enumerate(a_scores)]) 80 | 81 | b_scores = [random.randint(0, 10) for quote in b_quote_list] 82 | b_text = "\n".join([f"{i + 1}. {score}" for i, score in enumerate(b_scores)]) 83 | 84 | return f"{constants.DEFAULT_DEBATER_A_NAME}:\n{a_text}\n\n{constants.DEFAULT_DEBATER_B_NAME}:\n{b_text}" 85 | 86 | 87 | def get_scratchpads(text: str): 88 | a_match = re.search(r"This is what Debater_A said during their speech\.(.*?)#####", text, flags=re.DOTALL) 89 | b_match = re.search(r"This is what Debater_B said during their speech\.(.*?)#####", text, flags=re.DOTALL) 90 | return a_match.group(1) if a_match else None, b_match.group(1) if b_match else None 91 | 92 | 93 | def get_topic(text: str): 94 | question = None 95 | first_position = None 96 | second_position = None 97 | full_topic_section = re.search( 98 | r"Here is the topic they will be debating:(.*)This is what Debater_A said during their speech", 99 | text, 100 | re.DOTALL, 101 | ) 102 | if not full_topic_section: 103 | return question, first_position, second_position 104 | matching_text = full_topic_section.group(1).strip() 105 | 106 | question_section = re.search(r"(.*)Debater_A will defend the position", matching_text, re.DOTALL) 107 | if not question_section: 108 | return question, first_position, second_position 109 | question = question_section.group(1).strip() 110 | 111 | first_position_section = re.search( 112 | r"Debater_A will defend the position that the answer is \"(.*)\s*\"\s*\.\s*Debater_B", matching_text 113 | ) 114 | if not first_position_section: 115 | return question, first_position, second_position 116 | first_position = first_position_section.group(1).strip() 117 | 118 | second_position_section = re.search( 119 | r"Debater_B will defend the position that the answer is \"(.*)\s*\"\s*\.\s*", matching_text 120 | ) 121 | if not second_position_section: 122 | return question, first_position, second_position 123 | second_position = second_position_section.group(1).strip() 124 | 125 | return question, first_position, second_position 126 | 127 | 128 | def process_scratchpad(scratchpad_text: str): 129 | return re.findall(rf"{constants.QUOTE_TAG}(.*?){constants.UNQUOTE_TAG}", scratchpad_text) 130 | 131 | 132 | def process_model_output(output: str, a_quote_list: list[str], b_quote_list: list[str]): 133 | debater_a_match = re.search( 134 | f"{constants.DEFAULT_DEBATER_A_NAME}:(.*?){constants.DEFAULT_DEBATER_A_NAME}:", output, flags=re.DOTALL 135 | ) 136 | debater_a_text = debater_a_match.group(1) if debater_a_match else "" 137 | debater_b_match = re.search(f"{constants.DEFAULT_DEBATER_B_NAME}:(.*)", output, flags=re.DOTALL) 138 | debater_b_text = debater_b_match.group(1) if debater_b_match else "" 139 | 140 | a_quote_map = {} 141 | debater_a_score_lines = re.findall(r"\d.\s*\d+", debater_a_text, flags=re.DOTALL) 142 | for i, (quote, score_line) in enumerate(zip(a_quote_list, debater_a_score_lines)): 143 | a_quote_map[quote] = int(re.search(r"\d.\s*(\d+)", score_line, flags=re.DOTALL).group(1)) 144 | 145 | b_quote_map = {} 146 | debater_b_score_lines = re.findall(r"\d.\s*\d+", debater_b_text, flags=re.DOTALL) 147 | for i, (quote, score_line) in enumerate(zip(b_quote_list, debater_b_score_lines)): 148 | b_quote_map[quote] = int(re.search(r"\d.\s*(\d+)", score_line, flags=re.DOTALL).group(1)) 149 | 150 | return a_quote_map, b_quote_map 151 | 152 | 153 | if __name__ == "__main__": 154 | root = os.environ[constants.SRC_ROOT] 155 | batch_size = 8 156 | 157 | parser = argparse.ArgumentParser() 158 | parser.add_argument("--test", action="store_true", default=False) 159 | parser.add_argument("--timestamp", type=str, default="") 160 | parser.add_argument("--save", action="store_true", default=False) 161 | args = parser.parse_args() 162 | 163 | model = OpenAIModel(alias="relevance-judge", is_debater=False) if not args.test else FakeOpenAIModel() 164 | input_texts = input_utils.read_file_texts(base_path=f"{root}outputs/transcripts/{args.timestamp}") 165 | 166 | results = [] 167 | current_batch = [] 168 | total_score = 0 169 | total_quotes = 0 170 | for i, text in tqdm(enumerate(input_texts)): 171 | a, b = get_scratchpads(text) 172 | if not a or not b: 173 | continue 174 | 175 | question, first_position, second_position = get_topic(text) 176 | if not question or not first_position or not second_position: 177 | continue 178 | 179 | a_quote_list = set(process_scratchpad(a)) 180 | b_quote_list = set(process_scratchpad(b)) 181 | a_quotes = "\n".join([f"{i + 1}. {text}" for i, text in enumerate(a_quote_list)]) 182 | b_quotes = "\n".join([f"{i + 1}. {text}" for i, text in enumerate(b_quote_list)]) 183 | 184 | instructions = ( 185 | DEFAULT_INSTRUCTIONS.replace("", question) 186 | .replace("", first_position) 187 | .replace("", second_position) 188 | .replace("", a_quotes if a_quotes else "") 189 | .replace("", b_quotes if b_quotes else "") 190 | ) 191 | 192 | current_batch.append( 193 | QuoteRelevanceBatchItem( 194 | a_quote_list=a_quote_list, 195 | b_quote_list=b_quote_list, 196 | model_input=ModelInput(role=RoleType.USER, content=instructions), 197 | question_info=QuoteRelevanceTopicInfo( 198 | question=question, a_position=first_position, b_position=second_position 199 | ), 200 | ) 201 | ) 202 | if len(current_batch) == batch_size or i == len(input_texts) - 1: 203 | model_inputs = [[item.model_input] for item in current_batch] 204 | predictions = ( 205 | model.predict(model_inputs) 206 | if not args.test 207 | else [model.predict(item.a_quote_list, item.b_quote_list) for item in current_batch] 208 | ) 209 | 210 | for prediction, item in zip(predictions, current_batch): 211 | a_quote_map, b_quote_map = process_model_output( 212 | output=prediction, a_quote_list=item.a_quote_list, b_quote_list=item.b_quote_list 213 | ) 214 | results.append( 215 | QuoteRelevanceProcessedBatchItem( 216 | a_quote_map=a_quote_map, b_quote_map=b_quote_map, question_info=item.question_info 217 | ) 218 | ) 219 | total_score += sum(a_quote_map.values()) + sum(b_quote_map.values()) 220 | total_quotes += len(a_quote_map.values()) + len(b_quote_map.values()) 221 | 222 | current_batch = [] 223 | 224 | pickle_path = f"{root}data/datasets/quote-relevance/quote-relevance.p" 225 | 226 | if args.save or not args.test: 227 | with open(pickle_path, "wb") as f: 228 | pickle.dump(results, f) 229 | 230 | dataset = QuoteRelevanceLoader.load() 231 | 232 | average_score = total_score / total_quotes 233 | print(f"Average score is {average_score}") 234 | -------------------------------------------------------------------------------- /scripts/load_model.py: -------------------------------------------------------------------------------- 1 | from dotenv import load_dotenv 2 | from transformers import AutoModelForCausalLM, AutoTokenizer 3 | import torch 4 | 5 | import argparse 6 | import os 7 | 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument("--model_name", type=str) 10 | parser.add_argument("--save_name", type=str) 11 | parser.add_argument("--requires_token", action="store_true", default=False) 12 | args = parser.parse_args() 13 | 14 | load_dotenv() 15 | 16 | base_model = AutoModelForCausalLM.from_pretrained( 17 | args.model_name, 18 | return_dict=True, 19 | trust_remote_code=True, 20 | torch_dtype=torch.float16, 21 | token=os.getenv("META_ACCESS_TOKEN") if args.requires_token else None, 22 | ) 23 | 24 | base_model.save_pretrained(args.save_name) 25 | 26 | base_tokenizer = AutoTokenizer.from_pretrained(args.model_name) 27 | base_tokenizer.save_pretrained(args.save_name) 28 | -------------------------------------------------------------------------------- /scripts/merge.py: -------------------------------------------------------------------------------- 1 | from script_utils import ScriptUtils 2 | 3 | ScriptUtils.setup_script() 4 | 5 | from utils import save_utils 6 | 7 | save_utils.save( 8 | base_model_name="/vast/spa9663/models/base_models/llama3-8b-262k", 9 | adapter_name="/vast/spa9663/models/trained_models/llama-3-mega", 10 | merge_name="/vast/spa9663/models/trained_models/llama-3-mega-merged", 11 | ) 12 | -------------------------------------------------------------------------------- /scripts/oai_finetune.py: -------------------------------------------------------------------------------- 1 | from script_utils import ScriptUtils, TrainType 2 | 3 | ScriptUtils.setup_script() 4 | 5 | from openai import OpenAI 6 | 7 | client = OpenAI() 8 | 9 | """ 10 | client.files.create( 11 | file=open("/Users/samarnesen/nyu/scratch/nyu-blind-rounds.jsonl", "rb"), 12 | purpose="fine-tune" 13 | ) 14 | """ 15 | """ 16 | client.fine_tuning.jobs.create( 17 | training_file="file-n7fvnScZBWodZ7F8nLGk4eQG", 18 | model="gpt-4", 19 | hyperparameters={"n_epochs": 2} 20 | ) 21 | """ 22 | -------------------------------------------------------------------------------- /scripts/push_to_hub.py: -------------------------------------------------------------------------------- 1 | from script_utils import ScriptUtils 2 | 3 | ScriptUtils.setup_script() 4 | 5 | from experiments import ExperimentLoader, ResultsCollector 6 | from utils import logger_utils 7 | 8 | from tqdm import tqdm 9 | 10 | from datetime import datetime 11 | 12 | args = ScriptUtils.get_args() 13 | config = ScriptUtils.get_debate_round_script_config(args) 14 | 15 | debate_rounds, experiment = ExperimentLoader.generate_debate_rounds( 16 | experiment_file_path=config.experiment_file_path, name=config.experiment_name, count=args.num_iters 17 | ) 18 | 19 | debate_rounds[0].first_debater.model.tokenizer.push_to_hub("samarnesen/nyu-debater-1r-dpo") 20 | -------------------------------------------------------------------------------- /scripts/run_debate.py: -------------------------------------------------------------------------------- 1 | from script_utils import ScriptUtils 2 | 3 | ScriptUtils.setup_script() 4 | 5 | from experiments import ExperimentLoader, ResultsCollector 6 | from utils import logger_utils 7 | 8 | from tqdm import tqdm 9 | 10 | from datetime import datetime 11 | 12 | args = ScriptUtils.get_args() 13 | config = ScriptUtils.get_debate_round_script_config(args) 14 | start_time = str(datetime.now()).replace(" ", "_") if not args.start_time else args.start_time 15 | logger = logger_utils.get_default_logger(__name__) 16 | should_save_transcripts = not args.local or args.force_save_transcripts 17 | should_save_results = (not args.local) or args.force_save_results 18 | 19 | debate_rounds, experiment = ExperimentLoader.generate_debate_rounds( 20 | experiment_file_path=config.experiment_file_path, name=config.experiment_name, count=args.num_iters 21 | ) 22 | 23 | results_collector = ResultsCollector( 24 | experiment=experiment, 25 | graphs_path_prefix=f"{config.graphs_path_prefix}/{start_time}_", 26 | full_record_path_prefix=f"{config.full_record_path_prefix}/{start_time}_", 27 | stats_path_prefix=f"{config.stats_path_prefix}/{start_time}", 28 | should_save=should_save_results, 29 | ) 30 | 31 | for i, debate_round in enumerate(debate_rounds): 32 | logger.info(f"Beginning round {i} out of {len(debate_rounds)}") 33 | save_file_path_prefix = f"{config.transcript_path_prefix}/{start_time}_{i}" if should_save_transcripts else None 34 | summary = debate_round(save_file_path_prefix=save_file_path_prefix) 35 | results_collector.record_result(summary) 36 | 37 | if not args.suppress_graphs: 38 | results_collector.graph_results() 39 | -------------------------------------------------------------------------------- /scripts/run_iterative_dpo.py: -------------------------------------------------------------------------------- 1 | from script_utils import ScriptUtils, TrainType 2 | 3 | ScriptUtils.setup_script() 4 | 5 | from data import RawDataset 6 | from train import IterativeDirectPreferenceTrainer, TrainUtils 7 | from utils import save_utils 8 | 9 | args = ScriptUtils.get_args() 10 | script_config = ScriptUtils.get_training_run_script_config(args, train_type=TrainType.DPO) 11 | 12 | config = TrainUtils.parse_config(config_name=script_config.config_name, config_filepath=script_config.config_filepath) 13 | trainer = IterativeDirectPreferenceTrainer(config=config, smooth=True, is_local=args.test) 14 | epoch_size = ( 15 | config.training_hyperparameters.supplemental.get("epoch_size", 2048) 16 | if config.training_hyperparameters.supplemental 17 | else 2048 18 | ) 19 | 20 | if not args.test: 21 | trainer.train(epoch_size=epoch_size) 22 | else: 23 | samples = trainer.get_samples(start_idx=0, epoch_size=epoch_size) 24 | -------------------------------------------------------------------------------- /scripts/run_ppo.py: -------------------------------------------------------------------------------- 1 | from script_utils import ScriptUtils, TrainType 2 | 3 | ScriptUtils.setup_script() 4 | 5 | from data import RawDataset 6 | from train import PPOTrainerWrapper, TrainUtils 7 | from utils import save_utils 8 | 9 | args = ScriptUtils.get_args() 10 | script_config = ScriptUtils.get_training_run_script_config(args, train_type=TrainType.PPO) 11 | 12 | config = TrainUtils.parse_config(config_name=script_config.config_name, config_filepath=script_config.config_filepath) 13 | 14 | trainer = PPOTrainerWrapper.get_trainer(config=config, is_local=args.local, is_test=args.test) 15 | trainer.train(save_frequency=5) 16 | trainer.save_model() 17 | 18 | if config.logging_and_saving_config.merge_output_dir: 19 | trainer = None 20 | save_utils.save( 21 | base_model_name=config.model_name, 22 | adapter_name=config.logging_and_saving_config.output_dir, 23 | merge_name=config.logging_and_saving_config.merge_output_dir, 24 | ) 25 | -------------------------------------------------------------------------------- /scripts/run_sft.py: -------------------------------------------------------------------------------- 1 | from script_utils import ScriptUtils, TrainType 2 | 3 | ScriptUtils.setup_script() 4 | 5 | from train import SupervisedTrainer, TrainUtils 6 | from utils import save_utils 7 | 8 | args = ScriptUtils.get_args() 9 | script_config = ScriptUtils.get_training_run_script_config(args, train_type=TrainType.SFT) 10 | 11 | config = TrainUtils.parse_config(config_name=script_config.config_name, config_filepath=script_config.config_filepath) 12 | trainer = SupervisedTrainer.get_trainer(config=config, is_local=args.local, is_test=args.test) 13 | 14 | if not args.load_only and not args.test: 15 | trainer.train() 16 | 17 | if not args.test: 18 | trainer.save_model() 19 | 20 | if config.logging_and_saving_config.merge_output_dir: 21 | trainer = None 22 | save_utils.save( 23 | base_model_name=config.model_name, 24 | adapter_name=config.logging_and_saving_config.output_dir, 25 | merge_name=config.logging_and_saving_config.merge_output_dir, 26 | ) 27 | -------------------------------------------------------------------------------- /scripts/script_utils.py: -------------------------------------------------------------------------------- 1 | from dotenv import load_dotenv 2 | from pydantic import BaseModel 3 | 4 | from enum import Enum 5 | from typing import Optional 6 | import argparse 7 | import logging 8 | import os 9 | import sys 10 | 11 | 12 | class DebateRoundScriptConfig(BaseModel): 13 | experiment_name: str 14 | experiment_file_path: str 15 | transcript_path_prefix: str 16 | graphs_path_prefix: str 17 | stats_path_prefix: str 18 | full_record_path_prefix: str 19 | 20 | 21 | class ModelRunScriptConfig(BaseModel): 22 | config_filepath: str 23 | config_name: str 24 | 25 | 26 | class TrainType(Enum): 27 | SFT = 0 28 | DPO = 1 29 | PPO = 2 30 | PRETRAIN = 3 31 | PROBE = 4 32 | CUSTOM_KTO = 5 33 | 34 | 35 | class ScriptUtils: 36 | @classmethod 37 | def setup_script(cls): 38 | sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) 39 | load_dotenv() 40 | 41 | @classmethod 42 | def get_args(cls): 43 | parser = argparse.ArgumentParser() 44 | parser.add_argument("--local", action="store_true", default=False) 45 | parser.add_argument("--num_iters", type=int, default=1_000) 46 | parser.add_argument("--log_level", type=str, default="INFO") 47 | parser.add_argument("--configuration", type=str, default="") 48 | parser.add_argument("--test", action="store_true", default=False) # needed for local testing (optional otherwise) 49 | parser.add_argument("--load_only", action="store_true", default=False) 50 | parser.add_argument("--suppress_graphs", action="store_true", default=False) 51 | parser.add_argument("--local_rank", type=int, default=0) # needed for multi-GPU training 52 | parser.add_argument("--start_time", type=str, default="") 53 | parser.add_argument("--force_save_results", action="store_true", default=False) 54 | parser.add_argument("--force_save_transcripts", action="store_true", default=False) 55 | args = parser.parse_args() 56 | ScriptUtils.set_log_level(args) 57 | return args 58 | 59 | @classmethod 60 | def set_log_level(cls, args) -> None: 61 | os.environ["LOG_LEVEL"] = str(logging.INFO) 62 | 63 | requested = args.log_level.lower() 64 | specified = None 65 | if requested == "debug": 66 | specified = logging.DEBUG 67 | elif requested == "info": 68 | specified = logging.INFO 69 | elif requested == "warn": 70 | specified = logging.WARN 71 | elif requested == "error": 72 | specified = logging.ERROR 73 | else: 74 | raise Exception(f"Request log level {requested} is not eligible") 75 | 76 | os.environ["LOG_LEVEL"] = str(specified) 77 | 78 | @classmethod 79 | def get_debate_round_script_config(cls, args) -> DebateRoundScriptConfig: 80 | root = os.environ["SRC_ROOT"] 81 | output_root = os.environ["OUTPUT_ROOT"] if "OUTPUT_ROOT" in os.environ else root 82 | transcript_path = f"{output_root}outputs/transcripts" 83 | graphs_path = f"{output_root}outputs/graphs" 84 | stats_path = f"{output_root}outputs/stats" 85 | full_record_path = f"{output_root}outputs/runs" 86 | if args.test: 87 | experiment_name = args.configuration 88 | experiment_file_path = f"{root}experiments/configs/test_experiment.yaml" 89 | else: 90 | experiment_name = args.configuration 91 | experiment_file_path = f"{root}experiments/configs/standard_experiment.yaml" 92 | return DebateRoundScriptConfig( 93 | experiment_name=experiment_name, 94 | experiment_file_path=experiment_file_path, 95 | transcript_path_prefix=transcript_path, 96 | graphs_path_prefix=graphs_path, 97 | stats_path_prefix=stats_path, 98 | full_record_path_prefix=full_record_path, 99 | ) 100 | 101 | @classmethod 102 | def get_config_filepath(cls, train_type: TrainType) -> str: 103 | root = os.environ["SRC_ROOT"] 104 | default_config_dir = "train/configs" 105 | if train_type == TrainType.SFT: 106 | return f"{root}/{default_config_dir}/sft_config.yaml" 107 | elif train_type == TrainType.DPO: 108 | return f"{root}/{default_config_dir}/dpo_config.yaml" 109 | elif train_type == TrainType.PPO: 110 | return f"{root}/{default_config_dir}/ppo_config.yaml" 111 | elif train_type == TrainType.PRETRAIN: 112 | return f"{root}/{default_config_dir}/pretrain_config.yaml" 113 | elif train_type == TrainType.PROBE: 114 | return f"{root}/{default_config_dir}/probe_config.yaml" 115 | elif train_type == TrainType.CUSTOM_KTO: 116 | return f"{root}/{default_config_dir}/custom_kto_config.yaml" 117 | else: 118 | raise Exception(f"Train type {train_type} is not recognized") 119 | 120 | @classmethod 121 | def get_training_run_script_config(cls, args, train_type: TrainType) -> ModelRunScriptConfig: 122 | return ModelRunScriptConfig( 123 | config_filepath=ScriptUtils.get_config_filepath(train_type=train_type), config_name=args.configuration 124 | ) 125 | -------------------------------------------------------------------------------- /train/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/samuelarnesen/nyu-debate-modeling/455615d3d6fb1a0ebc158f7eb894bfd1aa63a90a/train/.DS_Store -------------------------------------------------------------------------------- /train/__init__.py: -------------------------------------------------------------------------------- 1 | from .iterative_dpo_trainer import IterativeDirectPreferenceTrainer 2 | from .ppo_trainer import PPOTrainerWrapper 3 | from .row_converter import RowConverter 4 | from .sft_trainer import SupervisedTrainer 5 | from .train_utils import ( 6 | LoggingAndSavingConfig, 7 | TrainingConfig, 8 | TrainUtils, 9 | TrainingHyperParameterConfig, 10 | TrainingTarget, 11 | ) 12 | -------------------------------------------------------------------------------- /train/configs/ppo_config.yaml: -------------------------------------------------------------------------------- 1 | Test: 2 | model_name: stub_model 3 | target: debater 4 | llm_type: stub_llm 5 | training_hyperparameters: 6 | num_train_epochs: 2 7 | per_device_train_batch_size: 8 8 | gradient_accumulation_steps: 8 9 | optim: paged_adamw_32bit 10 | learning_rate: 2e-6 11 | max_grad_norm: 0.3 12 | warmup_ratio: 0.03 13 | lr_scheduler_type: constant 14 | peft_type: lora 15 | steps: 100 16 | logging_and_saving_config: 17 | logging_steps: 1 18 | output_dir: /fake/file/path 19 | dataset: 20 | dataset_type: quality 21 | split_type: train 22 | Train - Experiment: 23 | model_name: /vast/spa9663/models/trained_models/llama-3-mega-merged 24 | target: debater 25 | llm_type: llama3 26 | opening_speeches_only: True 27 | training_hyperparameters: 28 | num_train_epochs: 4 29 | per_device_train_batch_size: 64 30 | gradient_accumulation_steps: 64 31 | optim: paged_adamw_32bit 32 | learning_rate: 3e-5 33 | max_grad_norm: 0.3 34 | warmup_ratio: 0.03 35 | lr_scheduler_type: constant 36 | peft_type: lora 37 | steps: 30 38 | logging_and_saving_config: 39 | logging_steps: 1 40 | output_dir: /vast/spa9663/models/trained_models/llama-3-PPO-604-overfit 41 | dataset: 42 | dataset_type: quality 43 | split_type: train -------------------------------------------------------------------------------- /train/impl/__init__.py: -------------------------------------------------------------------------------- 1 | from .bco_trainer import BCOTrainer 2 | from .llama_with_gradient_checkpointing_impl import LlamaModelWithGradientCheckpointing 3 | from .smoothed_dpo_trainer import SmoothedDPOTrainer 4 | from .smoothed_kto_trainer import SmoothedKTOTrainer 5 | from .verbose_ppo_trainer import VerbosePPOTrainer 6 | -------------------------------------------------------------------------------- /train/sft_trainer.py: -------------------------------------------------------------------------------- 1 | from data import DatasetConfig, DataRow, RawDataset, SplitType 2 | from models import LLMInput, LLModel, LLMType, ModelInput, SpeechStructure 3 | from prompts import RoleType 4 | from train.row_converter import RowConverter 5 | from train.train_utils import TrainUtils, TrainingConfig, TrainingTarget 6 | from utils import LoggingCallback, logger_utils # TODO: REMOVE 7 | import utils.constants as constants 8 | 9 | from peft import prepare_model_for_kbit_training, get_peft_model 10 | from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments 11 | from trl import DataCollatorForCompletionOnlyLM, SFTTrainer 12 | import pandas as pd 13 | import datasets 14 | import torch 15 | 16 | from typing import Any, Optional, Type 17 | import json 18 | import random 19 | import sys 20 | 21 | try: 22 | from utils.flash_attn_utils import replace_attn_with_flash_attn, upcast_layer_for_flash_attention 23 | 24 | FLASH_ATTENTION_AVAILABLE = True 25 | except ImportError as e: 26 | print("Running without flash attention") 27 | FLASH_ATTENTION_AVAILABLE = False 28 | 29 | 30 | class SupervisedTrainer: 31 | """Class for training a model using Supervised Fine Tuning""" 32 | 33 | @classmethod 34 | def convert_dataset( 35 | cls, raw_datasets: list[RawDataset], config: TrainingConfig, tokenizer: AutoTokenizer 36 | ) -> datasets.Dataset: 37 | """Converts a dataset (abstraction used in this codebase) into a Dataset object (abstraction 38 | used by huggingface's trainer objects)""" 39 | 40 | def validate_structure(val: Any) -> bool: 41 | if not isinstance(val, list): 42 | print("fail 1") 43 | return False 44 | for item in val: 45 | if not isinstance(item, list): 46 | print("fail 2") 47 | return False 48 | for elem in item: 49 | if not isinstance(elem, tuple): 50 | print("fail 3") 51 | print(type(elem[0])) 52 | return False 53 | if not isinstance(elem[1], str): 54 | print("fail 4") 55 | return False 56 | if not isinstance(elem[0], list): 57 | print("fail 5") 58 | return False 59 | for x in elem[0]: 60 | if not isinstance(x, ModelInput): 61 | print("fail 6") 62 | return False 63 | return True 64 | 65 | dataset_configs = config.dataset 66 | if isinstance(dataset_configs, DatasetConfig): 67 | dataset_configs = [dataset_configs] 68 | 69 | output_structure = ( 70 | SpeechStructure.OPEN_ENDED if config.target == TrainingTarget.DEBATER else SpeechStructure.DECISION 71 | ) 72 | 73 | llm_inputs = [] 74 | for idx, raw_dataset in enumerate(raw_datasets): 75 | speech_structure = config.speech_structure[idx % len(config.speech_structure)] 76 | transcript_lists = [ 77 | RowConverter.convert_row( 78 | row=row, 79 | config=config, 80 | dataset=raw_dataset, 81 | speech_structure=speech_structure, 82 | use_gold_labels=config.training_hyperparameters.supplemental.get("gold_labels", False), 83 | use_minimal_output_format=config.training_hyperparameters.supplemental.get( 84 | "use_minimal_output_format", False 85 | ), 86 | ) 87 | for i, row in enumerate(raw_dataset.get_data(split=config.dataset[idx].split_type)) 88 | ] 89 | 90 | if validate_structure(val=transcript_lists): 91 | for transcript_list in transcript_lists: 92 | for model_inputs, speech in transcript_list: 93 | llm_inputs.append( 94 | { 95 | "instruction": LLModel.convert_to_input_string( 96 | input_list=model_inputs, 97 | tokenizer=tokenizer, 98 | speech_structure=output_structure, 99 | ), 100 | "output": speech, 101 | } 102 | ) 103 | else: 104 | raise Exception("Data format was invalid") 105 | 106 | max_instruction_length = int((2 / 3) * len(llm_inputs)) 107 | instruction_count = 0 108 | for dataset_config in filter(lambda x: not x.dataset_type.is_instantiable, dataset_configs): 109 | external_dataset = datasets.load_dataset(path=dataset_config.full_dataset_file_path, split="train") 110 | external_df = pd.DataFrame(external_dataset) 111 | for i, row in external_df.iterrows(): 112 | if instruction_count < max_instruction_length and (row["instruction"] or row["input"]): 113 | instruction_count += 1 114 | llm_inputs.append( 115 | { 116 | "instruction": LLModel.convert_to_input_string( 117 | input_list=[ 118 | ModelInput(role=RoleType.SYSTEM, content=row["instruction"]), 119 | ModelInput(role=RoleType.USER, content=row["input"]), 120 | ], 121 | tokenizer=tokenizer, 122 | speech_structure=output_structure, 123 | ), 124 | "output": row["output"], 125 | } 126 | ) 127 | 128 | df = pd.DataFrame(data=llm_inputs) 129 | dataset = datasets.Dataset.from_pandas(df).shuffle() 130 | return dataset 131 | 132 | @classmethod 133 | def formatting_func(cls, llm_dictionary: dict[str, list[str]]) -> str: 134 | formatted = [] 135 | for instruction, output in zip(llm_dictionary["instruction"], llm_dictionary["output"]): 136 | formatted.append(instruction + output.strip()) 137 | return formatted 138 | 139 | @classmethod 140 | def get_trainer( 141 | cls, 142 | config: TrainingConfig, 143 | raw_datasets: Optional[list[RawDataset]] = None, 144 | is_local: bool = False, 145 | is_test: bool = False, 146 | ) -> Optional[SFTTrainer]: 147 | """ 148 | Generates a Trainer object. 149 | 150 | Params: 151 | config: configuration specifying the prompt setup and hyperparameters for the training run. 152 | raw_dataset: dataset to use for training 153 | is_local: whether this is being run on a cpu 154 | is_test: whether to actually instantiate the trainer (if true, do not instantiate) 155 | 156 | Returns: 157 | sft_trainer: One can call dpo_trainer.train() to then run the training loop. 158 | """ 159 | if FLASH_ATTENTION_AVAILABLE: 160 | replace_attn_with_flash_attn() 161 | 162 | if not raw_datasets: 163 | raw_datasets = TrainUtils.create_datasets(config=config) 164 | 165 | tokenizer = TrainUtils.get_tokenizer(config=config, is_local=is_local) 166 | model = TrainUtils.load_model(config=config, is_local=is_local) 167 | 168 | training_args = TrainingArguments( 169 | output_dir=config.logging_and_saving_config.output_dir, 170 | num_train_epochs=config.training_hyperparameters.num_train_epochs, 171 | per_device_train_batch_size=config.training_hyperparameters.per_device_train_batch_size, 172 | gradient_accumulation_steps=config.training_hyperparameters.gradient_accumulation_steps, 173 | gradient_checkpointing=True, 174 | optim=config.training_hyperparameters.optim, 175 | logging_steps=config.logging_and_saving_config.logging_steps, 176 | save_strategy="epoch", 177 | learning_rate=config.training_hyperparameters.learning_rate, 178 | max_grad_norm=config.training_hyperparameters.max_grad_norm, 179 | warmup_ratio=config.training_hyperparameters.warmup_ratio, 180 | lr_scheduler_type=config.training_hyperparameters.lr_scheduler_type, 181 | disable_tqdm=False, 182 | ddp_find_unused_parameters=False, 183 | use_cpu=is_local, 184 | ) 185 | 186 | llm_class = TrainUtils.get_llm_class(config=config) 187 | 188 | collator = DataCollatorForCompletionOnlyLM( 189 | response_template=tokenizer.encode("\n " + llm_class.INSTRUCTION_SUFFIX, add_special_tokens=False)[2:], 190 | tokenizer=tokenizer, 191 | ) 192 | 193 | train_dataset = SupervisedTrainer.convert_dataset( 194 | raw_datasets=raw_datasets, 195 | tokenizer=tokenizer, 196 | config=config, 197 | ) 198 | 199 | peft_config = TrainUtils.get_peft_config(config) if not is_local else None 200 | if peft_config: 201 | # model = get_peft_model(prepare_model_for_kbit_training(model), peft_config) 202 | model.enable_input_require_grads() 203 | model = get_peft_model(model, peft_config) 204 | if FLASH_ATTENTION_AVAILABLE: 205 | model = upcast_layer_for_flash_attention(model, torch.bfloat16).to("cuda") 206 | 207 | if not is_test: 208 | trainer = SFTTrainer( 209 | model=model, 210 | train_dataset=train_dataset, 211 | peft_config=peft_config, 212 | tokenizer=tokenizer, 213 | data_collator=collator, 214 | formatting_func=SupervisedTrainer.formatting_func, 215 | max_seq_length=config.max_length, 216 | callbacks=[LoggingCallback], 217 | args=training_args, 218 | ) 219 | 220 | torch.cuda.empty_cache() 221 | 222 | return trainer 223 | return None 224 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .input_utils import InputType 2 | from .logger_utils import LoggingCallback 3 | from .timer_utils import timer 4 | from . import input_utils, logger_utils, quote_utils, save_utils, string_utils 5 | -------------------------------------------------------------------------------- /utils/constants.py: -------------------------------------------------------------------------------- 1 | DEFAULT_DEBATER_A_NAME = "Debater_A" 2 | DEFAULT_DEBATER_B_NAME = "Debater_B" 3 | DEFAULT_JUDGE_NAME = "Judge" 4 | DEFAULT_BACKGROUND_TEXT = "None provided" 5 | DEFAULT_WORD_LIMIT = 750 6 | INSTRUCTION_PREFIX = "instruction:" 7 | INPUT_PREFIX = "### User:" 8 | INSTRUCTION_SUFFIX = "output:" 9 | JUDGING_PREFIX = "I believe the winner of this debate is Debater_" 10 | PREFERENCE_PREFIX = "Overall Score: " 11 | WINNER = "Winner" 12 | LOSER = "Loser" 13 | CORRECT = "Correct" 14 | INCORRECT = "Incorrect" 15 | OVERALL = "Overall" 16 | BASE_MODEL_PREFIX = "base" 17 | DEBATER_A_IDENTIFICATION = "You are Debater_A" 18 | BASE_PROMPT = "Base Prompt" 19 | QUOTE_TAG = "" 20 | UNQUOTE_TAG = "" 21 | INVALID_QUOTE_TAG = "" 22 | INVALID_UNQUOTE_TAG = "" 23 | BEGIN_SPEECH_OPTIONS = [ 24 | "Write out your speech:", 25 | "Now begin your speech.", 26 | "Please deliver your speech.", 27 | "We will now await your speech.", 28 | ] 29 | BEGIN_JUDGING_OPTIONS = ["Here is the decision that the judge made:"] 30 | QUOTE_FUZZY_MATCH_EARLY_STOPPING_THRESHOLD = 0.9 31 | QUOTE_FUZZY_MATCH_MIN_THRESHOLD = 0.8 32 | MAX_SCORE = 10 33 | DEBATER_A_POSITION = 0 34 | DEBATER_B_POSITION = 1 35 | MAX_LENGTH = 32768 36 | LINE_SEPARATOR = "\n######\n" 37 | SRC_ROOT = "SRC_ROOT" 38 | INPUT_ROOT = "INPUT_ROOT" 39 | DEFAULT_ALIAS = "empty-model" 40 | -------------------------------------------------------------------------------- /utils/flash_attn_utils.py: -------------------------------------------------------------------------------- 1 | # copied from @philschmid 2 | # https://github.com/philschmid/deep-learning-pytorch-huggingface/blob/main/training/utils/llama_patch.py 3 | 4 | # flash decoder work copied from 5 | # https://github.com/ymcui/Chinese-LLaMA-Alpaca-2/ 6 | 7 | from typing import List, Optional, Tuple 8 | from functools import partial 9 | 10 | import torch 11 | from torch import nn 12 | import torch.nn.functional as F 13 | import warnings 14 | import transformers 15 | from transformers.models.llama.modeling_llama import apply_rotary_pos_emb 16 | from peft.tuners.lora import LoraLayer 17 | 18 | try: 19 | from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func, flash_attn_with_kvcache 20 | from flash_attn.bert_padding import unpad_input, pad_input 21 | except Exception: 22 | raise ModuleNotFoundError( 23 | "Please install FlashAttention first, e.g., with pip install flash-attn --no-build-isolation, Learn more at https://github.com/Dao-AILab/flash-attention#installation-and-features" 24 | ) 25 | 26 | try: 27 | from einops import rearrange 28 | except Exception: 29 | raise ModuleNotFoundError("Please install einops first, e.g., with pip install einops") 30 | 31 | 32 | # Disable the transformation of the attention mask in LlamaModel as the flash attention 33 | # requires the attention mask to be the same as the key_padding_mask 34 | def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): 35 | # [bsz, seq_len] 36 | return attention_mask 37 | 38 | 39 | def flash_attn_forward_without_dropout( 40 | self, 41 | hidden_states, 42 | attention_mask=None, 43 | position_ids=None, 44 | past_key_value=None, 45 | output_attentions=False, 46 | use_cache=False, 47 | **kwargs 48 | ): 49 | original_fwd = transformers.models.llama.modeling_llama.LlamaModel.LlamaFlashAttention2.forward 50 | original_training_status = self.training 51 | self.training = False 52 | result = original_fwd( 53 | self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, **kwargs 54 | ) 55 | self.training = original_training_status 56 | return result 57 | 58 | 59 | def replace_attn_with_flash_attn(disable_dropout: bool = False): 60 | cuda_major, cuda_minor = torch.cuda.get_device_capability() 61 | transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = _prepare_decoder_attention_mask 62 | if disable_dropout: 63 | transformers.models.llama.modeling_llama.LlamaModel.LlamaFlashAttention2.forward = flash_attn_forward_without_dropout 64 | 65 | 66 | # Adapted from https://github.com/tmm1/axolotl/blob/2eda9e02a9d15a7a3f92b41f257d9844d72fc220/src/axolotl/utils/models.py#L338 67 | def upcast_layer_for_flash_attention(model, torch_dtype): 68 | # LlamaRMSNorm layers are in fp32 after kbit_training, so we need to 69 | # convert them back to fp16/bf16 for flash-attn compatibility. 70 | for name, module in model.named_modules(): 71 | if isinstance(module, LoraLayer): 72 | module.to(torch_dtype) 73 | if "norm" in name: 74 | module.to(torch_dtype) 75 | if "lm_head" in name or "embed_tokens" in name: 76 | if hasattr(module, "weight"): 77 | module.to(torch_dtype) 78 | 79 | return model 80 | 81 | 82 | # Disable the transformation of the attention mask in LlamaModel as the flash attention 83 | # requires the attention mask to be the same as the key_padding_mask 84 | def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): 85 | return attention_mask 86 | -------------------------------------------------------------------------------- /utils/input_utils.py: -------------------------------------------------------------------------------- 1 | import utils.constants as constants 2 | 3 | import pandas as pd 4 | 5 | from enum import Enum 6 | from io import StringIO 7 | from typing import Callable, Union 8 | import json 9 | import os 10 | import re 11 | 12 | 13 | class InputType(Enum): 14 | TEXT_TRANSCRIPT = ("txt", os.environ[constants.INPUT_ROOT] + "/transcripts", lambda x: x) 15 | JSON_TRANSCRIPT = ("json", os.environ[constants.INPUT_ROOT] + "/transcripts", lambda x: json.loads(x)) 16 | JSON_LIST = ("jsonl", os.environ[constants.INPUT_ROOT] + "/transcripts", lambda x: [json.load(y) for y in x]) 17 | RUN = ("csv", os.environ[constants.INPUT_ROOT] + "/runs", lambda x: pd.read_csv(StringIO(x))) 18 | 19 | def __init__(self, extension: str, location: str, load: Callable): 20 | self.extension = extension 21 | self.location = location 22 | self.load = load 23 | 24 | 25 | def get_full_filepath(base_path: str, input_type: InputType) -> str: 26 | """ 27 | Given either a full path through the prefix or just the prefix, return the full path through the prefix. 28 | For example, base_path='12345' -> /path/to/source/root/outputs/transcripts/123455 29 | """ 30 | return base_path if "/" in base_path else f"{input_type.location}/{base_path}" 31 | 32 | 33 | def read_file_texts( 34 | base_path: str | list[str], 35 | input_type: InputType = InputType.TEXT_TRANSCRIPT, 36 | include_full_file_path: bool = False, 37 | should_load: bool = False, 38 | ) -> list[str] | list[tuple[str, str]]: 39 | """ 40 | Reads transcript generated by the run_debate script. All the files are named using the following 41 | convention: base_path_{round_number}_{batch_number}.txt. 42 | 43 | Params: 44 | base_path: the directory + file prefix that all the transcripts share. This can be a list if there are multiple 45 | sets of file prefixes that one wants to aggregate into one dataset 46 | extension: "txt" if the files are txt files or "json" 47 | include_full_file_path: if True, it returns a list of (file_text, file_path) tuples. If False, it returns a list of file_texts 48 | should_load: if True, the the file text gets turned into the format specific to the input type 49 | 50 | Returns: 51 | file_texts: A list of transcript contents. 52 | """ 53 | 54 | def get_idxs_of_file(file_name: str) -> tuple[int, int]: 55 | suffix_pattern = "_(\d+)_(\d+)\." + input_type.extension 56 | suffix = re.search(suffix_pattern, file_name) 57 | if suffix: 58 | return suffix.group(1), suffix.group(2) 59 | return -1, -1 60 | 61 | def sort_files_by_extension(file_names: list[str]) -> list[str]: 62 | files_with_idxs = [(file_name, get_idxs_of_file(file_name)) for file_name in file_names] 63 | files_with_idxs = sorted(files_with_idxs, key=lambda x: int(x[1][1])) 64 | files_with_idxs = sorted(files_with_idxs, key=lambda x: int(x[1][0])) 65 | return [x[0] for x in files_with_idxs] 66 | 67 | def list_files_with_prefix(directory: str, prefix: str): 68 | files = os.listdir(directory) 69 | matching_files = [ 70 | f"{directory}/{file}" 71 | for file in filter(lambda x: x.startswith(prefix) and x.endswith(input_type.extension), files) 72 | ] 73 | return sort_files_by_extension(matching_files) 74 | 75 | if isinstance(base_path, list): 76 | input_texts = [] 77 | for path in base_path: 78 | input_texts += read_file_texts( 79 | base_path=path, input_type=input_type, include_full_file_path=include_full_file_path 80 | ) 81 | return input_texts 82 | 83 | directory = input_type.location if "/" not in base_path else "/".join(base_path.split("/")[:-1]) 84 | prefix = base_path if "/" not in base_path else base_path.split("/")[-1] 85 | 86 | eligible_files = list_files_with_prefix(directory=directory, prefix=prefix) 87 | 88 | file_texts = [] 89 | for file_name in eligible_files: 90 | if os.path.exists(file_name): 91 | with open(file_name) as f: 92 | text = f.read() 93 | if should_load: 94 | text = input_type.load(text) 95 | file_texts.append((text, file_name) if include_full_file_path else text) 96 | 97 | return file_texts 98 | -------------------------------------------------------------------------------- /utils/logger_utils.py: -------------------------------------------------------------------------------- 1 | from transformers import TrainerCallback 2 | 3 | import logging 4 | import os 5 | 6 | 7 | def get_log_level(): 8 | """Gets the log level specified in the environment variables""" 9 | if "LOG_LEVEL" in os.environ: 10 | requested = os.environ["LOG_LEVEL"] 11 | for level in filter(lambda x: requested == str(x), [logging.DEBUG, logging.INFO, logging.WARN, logging.ERROR]): 12 | return level 13 | return logging.INFO 14 | 15 | 16 | def get_default_logger(name: str, log_level=None): 17 | """Generates a logger at the specified log level and in the specified namespace""" 18 | logger = logging.getLogger(name) 19 | if not logger.hasHandlers(): 20 | log_level = log_level or get_log_level() 21 | 22 | formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") 23 | 24 | stream_handler = logging.StreamHandler() 25 | stream_handler.setLevel(log_level) 26 | stream_handler.setFormatter(formatter) 27 | 28 | logger.setLevel(log_level) 29 | logger.addHandler(stream_handler) 30 | return logger 31 | 32 | 33 | class LoggingCallback(TrainerCallback): 34 | def on_log(self, args, state, control, logs=None, **kwargs): 35 | """Callback so that training updates get written using the default logger rather than print statements""" 36 | _ = logs.pop("total_flos", None) 37 | if state.is_local_process_zero: 38 | get_default_logger(__name__).warn(logs) 39 | -------------------------------------------------------------------------------- /utils/save_utils.py: -------------------------------------------------------------------------------- 1 | from peft import PeftModel 2 | from transformers import AutoModelForCausalLM 3 | import torch 4 | 5 | 6 | def save(base_model_name: str, adapter_name: str, merge_name: str): 7 | """ 8 | Loads a model and its adapter and saves it to the specified location. 9 | 10 | Params: 11 | base_model_name: the name (or file path) of the model to load 12 | adapter_name: the name (or file path) of the trained adapter 13 | merge_name: the file_path one wants to save the merged model to 14 | """ 15 | torch.cuda.empty_cache() 16 | 17 | base_model = AutoModelForCausalLM.from_pretrained( 18 | base_model_name, 19 | return_dict=True, 20 | torch_dtype=torch.float16, 21 | ) 22 | 23 | model = PeftModel.from_pretrained(base_model, adapter_name) 24 | model = model.merge_and_unload() 25 | model.save_pretrained(merge_name) 26 | -------------------------------------------------------------------------------- /utils/string_utils.py: -------------------------------------------------------------------------------- 1 | def clean_string(input_string) -> str: 2 | """Removes pad tokens from a model output""" 3 | return input_string.replace("", "").replace("", "").rstrip() 4 | -------------------------------------------------------------------------------- /utils/timer_utils.py: -------------------------------------------------------------------------------- 1 | import utils.logger_utils as logger_utils 2 | 3 | from functools import wraps 4 | import time 5 | 6 | 7 | def timer(custom_name: str = None): 8 | """Decorator to time a function""" 9 | 10 | def decorator(func): 11 | @wraps(func) 12 | def wrapper(*args, **kwargs): 13 | start_time = time.perf_counter() 14 | result = func(*args, **kwargs) 15 | end_time = time.perf_counter() 16 | elapsed_time = end_time - start_time 17 | name = custom_name if custom_name else func.__name__ 18 | logger_utils.get_default_logger(__name__).debug(f"{name} completed in {elapsed_time:.1f} seconds") 19 | return result 20 | 21 | return wrapper 22 | 23 | return decorator 24 | --------------------------------------------------------------------------------