├── .gitignore ├── CLAIR_preferences.ipynb ├── LICENSE ├── README.md ├── cache.tar.gz └── images ├── apo-github.png ├── clair-github.png ├── github-clair-notebook.png └── performance-boost.png /.gitignore: -------------------------------------------------------------------------------- 1 | cache/ -------------------------------------------------------------------------------- /CLAIR_preferences.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Contrastive Learning from AI Revisions (CLAIR)\n", 8 | "This notebook accompanies the \"[Anchored Preference Optimization and Contrastive Revisions: Addressing Underspecification in Alignment](https://arxiv.org/abs/2408.06266v1)\" paper.\n", 9 | "\n", 10 | "In this notebook, we will create preference pairs for alignment through contrastive revisions. We use an LLM behind API for the revision process, but we've cached the results so you can run the notebook without API key." 11 | ] 12 | }, 13 | { 14 | "cell_type": "markdown", 15 | "metadata": {}, 16 | "source": [ 17 | "![Contrastive Learing from AI Revisisions](./images/github-clair-notebook.png \"Contrastive Learning from AI Revisions\")" 18 | ] 19 | }, 20 | { 21 | "cell_type": "code", 22 | "execution_count": null, 23 | "metadata": {}, 24 | "outputs": [], 25 | "source": [ 26 | "!pip install datasets" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": 3, 32 | "metadata": {}, 33 | "outputs": [ 34 | { 35 | "name": "stderr", 36 | "output_type": "stream", 37 | "text": [ 38 | "/env/lib/conda/trl-karel/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", 39 | " from .autonotebook import tqdm as notebook_tqdm\n" 40 | ] 41 | } 42 | ], 43 | "source": [ 44 | "import requests\n", 45 | "from joblib import Memory\n", 46 | "import datasets\n", 47 | "import pandas as pd" 48 | ] 49 | }, 50 | { 51 | "cell_type": "code", 52 | "execution_count": 2, 53 | "metadata": {}, 54 | "outputs": [], 55 | "source": [ 56 | "# Set your API Key. If you don't change this variable, we will use caching.\n", 57 | "API_KEY = 'your-openai-api-key'\n", 58 | "\n", 59 | "# Change model and kwargs. Only tested on this exact model.\n", 60 | "model_name = 'gpt-4-0125-preview'\n", 61 | "model_url = 'https://api.openai.com/v1/chat/completions'\n", 62 | "model_kwargs = {\n", 63 | " 'max_tokens': 4096,\n", 64 | " 'temperature': .7\n", 65 | "}" 66 | ] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": 15, 71 | "metadata": {}, 72 | "outputs": [ 73 | { 74 | "name": "stdout", 75 | "output_type": "stream", 76 | "text": [ 77 | "cache/\n", 78 | "cache/joblib/\n", 79 | "cache/joblib/CLAIR_preferences/\n", 80 | "cache/joblib/CLAIR_preferences/query_chat_model/\n", 81 | "cache/joblib/CLAIR_preferences/query_chat_model/5c48bee118c640c35d6ddb1c9f5aad76/\n", 82 | "cache/joblib/CLAIR_preferences/query_chat_model/5c48bee118c640c35d6ddb1c9f5aad76/metadata.json\n", 83 | "cache/joblib/CLAIR_preferences/query_chat_model/5c48bee118c640c35d6ddb1c9f5aad76/output.pkl\n", 84 | "cache/joblib/CLAIR_preferences/query_chat_model/b3782bbc8aa8cb3bcafb65a1f63e45c1/\n", 85 | "cache/joblib/CLAIR_preferences/query_chat_model/b3782bbc8aa8cb3bcafb65a1f63e45c1/metadata.json\n", 86 | "cache/joblib/CLAIR_preferences/query_chat_model/b3782bbc8aa8cb3bcafb65a1f63e45c1/output.pkl\n", 87 | "cache/joblib/CLAIR_preferences/query_chat_model/71f6587e578726353beb1a979163c162/\n", 88 | "cache/joblib/CLAIR_preferences/query_chat_model/71f6587e578726353beb1a979163c162/output.pkl\n", 89 | "cache/joblib/CLAIR_preferences/query_chat_model/71f6587e578726353beb1a979163c162/metadata.json\n", 90 | "cache/joblib/CLAIR_preferences/query_chat_model/ed39ac7228e1cf303cc3fb64d33f6239/\n", 91 | "cache/joblib/CLAIR_preferences/query_chat_model/ed39ac7228e1cf303cc3fb64d33f6239/output.pkl\n", 92 | "cache/joblib/CLAIR_preferences/query_chat_model/ed39ac7228e1cf303cc3fb64d33f6239/metadata.json\n", 93 | "cache/joblib/CLAIR_preferences/query_chat_model/2c21b538571ba66c76ea680a3080b3a8/\n", 94 | "cache/joblib/CLAIR_preferences/query_chat_model/2c21b538571ba66c76ea680a3080b3a8/metadata.json\n", 95 | "cache/joblib/CLAIR_preferences/query_chat_model/2c21b538571ba66c76ea680a3080b3a8/output.pkl\n", 96 | "cache/joblib/CLAIR_preferences/query_chat_model/0401e23c9c74865594818e148c796043/\n", 97 | "cache/joblib/CLAIR_preferences/query_chat_model/0401e23c9c74865594818e148c796043/output.pkl\n", 98 | "cache/joblib/CLAIR_preferences/query_chat_model/0401e23c9c74865594818e148c796043/metadata.json\n", 99 | "cache/joblib/CLAIR_preferences/query_chat_model/5215d6a342d83047bb7e00a8fa51de4b/\n", 100 | "cache/joblib/CLAIR_preferences/query_chat_model/5215d6a342d83047bb7e00a8fa51de4b/metadata.json\n", 101 | "cache/joblib/CLAIR_preferences/query_chat_model/5215d6a342d83047bb7e00a8fa51de4b/output.pkl\n", 102 | "cache/joblib/CLAIR_preferences/query_chat_model/4132a4b77dac4d9f04873e07d2caae06/\n", 103 | "cache/joblib/CLAIR_preferences/query_chat_model/4132a4b77dac4d9f04873e07d2caae06/metadata.json\n", 104 | "cache/joblib/CLAIR_preferences/query_chat_model/4132a4b77dac4d9f04873e07d2caae06/output.pkl\n", 105 | "cache/joblib/CLAIR_preferences/query_chat_model/8e5f73cb4cebc1cf4883d4c44154179a/\n", 106 | "cache/joblib/CLAIR_preferences/query_chat_model/8e5f73cb4cebc1cf4883d4c44154179a/metadata.json\n", 107 | "cache/joblib/CLAIR_preferences/query_chat_model/8e5f73cb4cebc1cf4883d4c44154179a/output.pkl\n", 108 | "cache/joblib/CLAIR_preferences/query_chat_model/113a851541a62f046d3b62d295920e1c/\n", 109 | "cache/joblib/CLAIR_preferences/query_chat_model/113a851541a62f046d3b62d295920e1c/output.pkl\n", 110 | "cache/joblib/CLAIR_preferences/query_chat_model/113a851541a62f046d3b62d295920e1c/metadata.json\n", 111 | "cache/joblib/CLAIR_preferences/query_chat_model/60c78454e64f2af0b229641b353466df/\n", 112 | "cache/joblib/CLAIR_preferences/query_chat_model/60c78454e64f2af0b229641b353466df/output.pkl\n", 113 | "cache/joblib/CLAIR_preferences/query_chat_model/60c78454e64f2af0b229641b353466df/metadata.json\n", 114 | "cache/joblib/CLAIR_preferences/query_chat_model/8cd2bb111f19296e6f2443c823bb28e3/\n", 115 | "cache/joblib/CLAIR_preferences/query_chat_model/8cd2bb111f19296e6f2443c823bb28e3/metadata.json\n", 116 | "cache/joblib/CLAIR_preferences/query_chat_model/8cd2bb111f19296e6f2443c823bb28e3/output.pkl\n", 117 | "cache/joblib/CLAIR_preferences/query_chat_model/941ad34ae42a38b8b104eb2306c87503/\n", 118 | "cache/joblib/CLAIR_preferences/query_chat_model/941ad34ae42a38b8b104eb2306c87503/output.pkl\n", 119 | "cache/joblib/CLAIR_preferences/query_chat_model/941ad34ae42a38b8b104eb2306c87503/metadata.json\n", 120 | "cache/joblib/CLAIR_preferences/query_chat_model/067d8d9ec4d4d0d3a41fa2ab8aeb79dd/\n", 121 | "cache/joblib/CLAIR_preferences/query_chat_model/067d8d9ec4d4d0d3a41fa2ab8aeb79dd/metadata.json\n", 122 | "cache/joblib/CLAIR_preferences/query_chat_model/067d8d9ec4d4d0d3a41fa2ab8aeb79dd/output.pkl\n", 123 | "cache/joblib/CLAIR_preferences/query_chat_model/70fce2c6dd42d48796d207602b993991/\n", 124 | "cache/joblib/CLAIR_preferences/query_chat_model/70fce2c6dd42d48796d207602b993991/output.pkl\n", 125 | "cache/joblib/CLAIR_preferences/query_chat_model/70fce2c6dd42d48796d207602b993991/metadata.json\n", 126 | "cache/joblib/CLAIR_preferences/query_chat_model/485820758b7d066796f50bc841854da5/\n", 127 | "cache/joblib/CLAIR_preferences/query_chat_model/485820758b7d066796f50bc841854da5/metadata.json\n", 128 | "cache/joblib/CLAIR_preferences/query_chat_model/485820758b7d066796f50bc841854da5/output.pkl\n", 129 | "cache/joblib/CLAIR_preferences/query_chat_model/b333765d6d727089100309db9b221734/\n", 130 | "cache/joblib/CLAIR_preferences/query_chat_model/b333765d6d727089100309db9b221734/metadata.json\n", 131 | "cache/joblib/CLAIR_preferences/query_chat_model/b333765d6d727089100309db9b221734/output.pkl\n", 132 | "cache/joblib/CLAIR_preferences/query_chat_model/a61ff4255f44bc7e9eb4072bdd7e05af/\n", 133 | "cache/joblib/CLAIR_preferences/query_chat_model/a61ff4255f44bc7e9eb4072bdd7e05af/metadata.json\n", 134 | "cache/joblib/CLAIR_preferences/query_chat_model/a61ff4255f44bc7e9eb4072bdd7e05af/output.pkl\n", 135 | "cache/joblib/CLAIR_preferences/query_chat_model/bba07a4fc243d2eed5a4b36b667cb01d/\n", 136 | "cache/joblib/CLAIR_preferences/query_chat_model/bba07a4fc243d2eed5a4b36b667cb01d/metadata.json\n", 137 | "cache/joblib/CLAIR_preferences/query_chat_model/bba07a4fc243d2eed5a4b36b667cb01d/output.pkl\n", 138 | "cache/joblib/CLAIR_preferences/query_chat_model/27846542dbcfe2d7ca61dad3d43549a4/\n", 139 | "cache/joblib/CLAIR_preferences/query_chat_model/27846542dbcfe2d7ca61dad3d43549a4/output.pkl\n", 140 | "cache/joblib/CLAIR_preferences/query_chat_model/27846542dbcfe2d7ca61dad3d43549a4/metadata.json\n", 141 | "cache/joblib/CLAIR_preferences/query_chat_model/func_code.py\n" 142 | ] 143 | } 144 | ], 145 | "source": [ 146 | "# Set up the cache\n", 147 | "import os\n", 148 | "if not os.path.isfile('cache.tar.gz'):\n", 149 | " !wget https://github.com/ContextualAI/CLAIR_and_APO/raw/master/cache.tar.gz\n", 150 | "!tar -xzvf cache.tar.gz " 151 | ] 152 | }, 153 | { 154 | "cell_type": "markdown", 155 | "metadata": {}, 156 | "source": [ 157 | "Helper functions that can be ignored for now:" 158 | ] 159 | }, 160 | { 161 | "cell_type": "code", 162 | "execution_count": 13, 163 | "metadata": {}, 164 | "outputs": [], 165 | "source": [ 166 | "# Setup the cache directory\n", 167 | "memory = Memory(\"./cache\", verbose=0)\n", 168 | "\n", 169 | "# Get joblib cache working in notebooks\n", 170 | "# source: https://stackoverflow.com/questions/75202475/joblib-persistence-across-sessions-machines\n", 171 | "def cache(mem, module, **mem_kwargs):\n", 172 | " def cache_(f):\n", 173 | " f.__module__ = module\n", 174 | " f.__qualname__ = f.__name__\n", 175 | " return mem.cache(f, **mem_kwargs)\n", 176 | " return cache_\n", 177 | "\n", 178 | "# extract existing preferences from ultrafeedback\n", 179 | "def get_preferences_from_ultrafeedback(dataset):\n", 180 | " instruction = []\n", 181 | " chosen = []\n", 182 | " rejected = []\n", 183 | "\n", 184 | " for _, row in dataset.iterrows():\n", 185 | " response = [x[\"response\"] for x in row[\"completions\"]]\n", 186 | " score = [x[\"overall_score\"] for x in row[\"completions\"]]\n", 187 | "\n", 188 | " if len(score):\n", 189 | " chosen_index = score.index(max(score))\n", 190 | " rejected_index = score.index(min(score))\n", 191 | "\n", 192 | " instruction.append(row[\"instruction\"])\n", 193 | " chosen.append(response[chosen_index])\n", 194 | " rejected.append(response[rejected_index])\n", 195 | "\n", 196 | " return pd.DataFrame.from_dict({\n", 197 | " \"text\": instruction,\n", 198 | " \"rejected\": rejected,\n", 199 | " \"chosen\": chosen \n", 200 | " })\n", 201 | "\n", 202 | "# Visualize a preference triple\n", 203 | "def visualize_triple(triple: dict):\n", 204 | " print('---TEXT (first 400 characters):\\n')\n", 205 | " print(triple['text'][:400])\n", 206 | " print('---REJECTED (first 400 characters):\\n')\n", 207 | " print(triple['rejected'][:400])\n", 208 | " print('---CHOSEN (first 400 characters):\\n')\n", 209 | " print(triple['chosen'][:400])\n", 210 | " if 'rational' in triple:\n", 211 | " print('---REVISION RATIONAL (first 400 characters):\\n')\n", 212 | " print(triple['rational'][:400])\n", 213 | "\n", 214 | "\n", 215 | "@cache(memory, \"CLAIR_preferences\")\n", 216 | "def query_chat_model(user_prompt, system_prompt='', url='https://api.openai.com/v1/chat/completions', model_name='gpt-4-0125-preview'):\n", 217 | " print(f\"Querying {model_name} API at {url}...\")\n", 218 | " headers = {\n", 219 | " \"Content-Type\": \"application/json\",\n", 220 | " \"Authorization\": f\"Bearer {API_KEY}\"\n", 221 | " }\n", 222 | " data = {\n", 223 | " \"model\": model_name,\n", 224 | " \"messages\": [\n", 225 | " {\"role\": \"system\", \"content\": system_prompt},\n", 226 | " {\"role\": \"user\", \"content\": user_prompt}\n", 227 | " ],\n", 228 | " **model_kwargs\n", 229 | " }\n", 230 | " response = requests.post(url, headers=headers, json=data)\n", 231 | " \n", 232 | " if response.status_code == 200:\n", 233 | " result = response.json()\n", 234 | " return result\n", 235 | " else:\n", 236 | " raise Exception(f\"API request failed with status code {response.status_code}: {response.text}\")\n", 237 | "\n", 238 | "# Parse revisions\n", 239 | "def get_revision_from_response(response):\n", 240 | " raw_completion = response['choices'][0]['message']['content']\n", 241 | "\n", 242 | " if \"**Corrected Student Solution:**\" in raw_completion:\n", 243 | " splits = raw_completion.split(\"**Corrected Student Solution:**\")\n", 244 | " elif \"{corrected_student_solution}:\" in raw_completion:\n", 245 | " splits = raw_completion.split(\"{corrected_student_solution}:\")\n", 246 | " elif \"{corrected_student_solution}\" in raw_completion:\n", 247 | " splits = raw_completion.split(\"{corrected_student_solution}\")\n", 248 | " elif \"**Worsened Student Solution:**\" in raw_completion:\n", 249 | " splits = raw_completion.split(\"**Worsened Student Solution:**\")\n", 250 | " elif \"{worsened_student_solution}:\" in raw_completion:\n", 251 | " splits = raw_completion.split(\"{worsened_student_solution}:\")\n", 252 | " elif \"{worsened_student_solution}\" in raw_completion:\n", 253 | " splits = raw_completion.split(\"{worsened_student_solution}\")\n", 254 | " \n", 255 | " if len(splits) >= 2: \n", 256 | " edit = splits[1]\n", 257 | " edit = edit.strip('\\n\\n').strip()\n", 258 | "\n", 259 | " rational = splits[0]\n", 260 | " if '{teacher_reasoning}' in rational:\n", 261 | " rational = rational.split('{teacher_reasoning}')[1].strip(':').strip()\n", 262 | " rational = rational.strip('\\n\\n').strip()\n", 263 | " else:\n", 264 | " Exception('Failed to parse response')\n", 265 | " return edit, rational\n", 266 | "\n" 267 | ] 268 | }, 269 | { 270 | "cell_type": "markdown", 271 | "metadata": {}, 272 | "source": [ 273 | "## Load data\n", 274 | "We will load data from an existing dataset. Alternatively, you can use a Language Model to generate your own data.\n", 275 | "\n", 276 | "In this notebook, we will use prompts and existing responses from the UltraFeedback dataset. UltraFeedback already contains preference pairs, against which we can compare our CLAIR preferences." 277 | ] 278 | }, 279 | { 280 | "cell_type": "code", 281 | "execution_count": 25, 282 | "metadata": {}, 283 | "outputs": [], 284 | "source": [ 285 | "ultraf = datasets.load_dataset('openbmb/UltraFeedback')['train'].to_pandas()\n", 286 | "ultraf = ultraf[:20] # take first 20 examples\n", 287 | "ultraf = get_preferences_from_ultrafeedback(ultraf)" 288 | ] 289 | }, 290 | { 291 | "cell_type": "code", 292 | "execution_count": 26, 293 | "metadata": {}, 294 | "outputs": [ 295 | { 296 | "name": "stdout", 297 | "output_type": "stream", 298 | "text": [ 299 | "---TEXT (first 400 characters):\n", 300 | "\n", 301 | "Identify the interrelated economic, political, and social factors that contributed to the stock market crash of 1929, including but not limited to the impact of World War I on the global economy, the role of government policies such as the Smoot-Hawley Tariff Act, the effects of speculative investment practices and margin trading, and the socioeconomic disparities of the time period. Additionally,\n", 302 | "---REJECTED (first 400 characters):\n", 303 | "\n", 304 | "Sure, I'd be happy to help you learn about the causes and effects of the 1929 stock market crash and how it compares to other financial crises.\n", 305 | "\n", 306 | "The stock market crash of 1929 was a significant event that occurred on October 29, known as Black Tuesday. This event marked the beginning of the Great Depression, a period of economic downturn and unemployment that lasted for nearly a decade. There were\n", 307 | "---CHOSEN (first 400 characters):\n", 308 | "\n", 309 | "The stock market crash of 1929 was a result of a complex interplay of economic, political, and social factors. The impact of World War I on the global economy, government policies such as the Smoot-Hawley Tariff Act, speculative investment practices, and socioeconomic disparities all contributed to the crisis.\n", 310 | "The aftermath of World War I left many European countries in a state of economic ruin. T\n" 311 | ] 312 | } 313 | ], 314 | "source": [ 315 | "# look at an example\n", 316 | "visualize_triple(ultraf.loc[2])" 317 | ] 318 | }, 319 | { 320 | "cell_type": "markdown", 321 | "metadata": {}, 322 | "source": [ 323 | "## Revise and Improve\n", 324 | "We will contrastively revise and improve answers using an LLM over API. We'll start from the `rejected` answers in UltraFeedback." 325 | ] 326 | }, 327 | { 328 | "cell_type": "code", 329 | "execution_count": 27, 330 | "metadata": {}, 331 | "outputs": [ 332 | { 333 | "name": "stdout", 334 | "output_type": "stream", 335 | "text": [ 336 | "You are a teacher and your task is to minimally improve a student's answer. I will give you a {{task}} and a {{student_solution}}. Your job is to revise the {{student_solution}} such that it is clearer, more correct, and more engaging. Copy all non-corrected parts of the student's answer. Do not allude to the {{corrected_student_solution}} being a revision or a correction in your final solution.\n", 337 | "\n", 338 | "{{task}}: {input}\n", 339 | "\n", 340 | "{{student_solution}}: {output}\n", 341 | "\n", 342 | "-----------------\n", 343 | "\n", 344 | "Let's first think step by step with a {{teacher_reasoning}} to decide how to improve the {{student_solution}}, then give the {{corrected_student_solution}}. Mention the {{teacher_reasoning}} and {{corrected_student_solution}} identifiers to structure your answer.\n", 345 | "\n", 346 | "\n" 347 | ] 348 | } 349 | ], 350 | "source": [ 351 | "revision_template = \"\"\"You are a teacher and your task is to minimally improve a student's answer. I will give you a {{task}} and a {{student_solution}}. Your job is to revise the {{student_solution}} such that it is clearer, more correct, and more engaging. Copy all non-corrected parts of the student's answer. Do not allude to the {{corrected_student_solution}} being a revision or a correction in your final solution.\\n\\n{{task}}: {input}\\n\\n{{student_solution}}: {output}\\n\\n-----------------\\n\\nLet's first think step by step with a {{teacher_reasoning}} to decide how to improve the {{student_solution}}, then give the {{corrected_student_solution}}. Mention the {{teacher_reasoning}} and {{corrected_student_solution}} identifiers to structure your answer.\\n\\n\"\"\"\n", 352 | "\n", 353 | "# let's look at the revision prompt\n", 354 | "print(revision_template)" 355 | ] 356 | }, 357 | { 358 | "cell_type": "code", 359 | "execution_count": 28, 360 | "metadata": {}, 361 | "outputs": [ 362 | { 363 | "name": "stdout", 364 | "output_type": "stream", 365 | "text": [ 366 | "You are a teacher and your task is to minimally improve a student's answer. I will give you a {task} and a {student_solution}. Your job is to revise the {student_solution} such that it is clearer, more correct, and more engaging. Copy all non-corrected parts of the student's answer. Do not allude to the {corrected_student_solution} being a revision or a correction in your final solution.\n", 367 | "\n", 368 | "{task}: Identify the interrelated economic, political, and social factors that contributed to the stock market crash of 1929, including but not limited to the impact of World War I on the global economy, the role of government policies such as the Smoot-Hawley Tariff Act, the effects of speculative investment practices and margin trading, and the socioeconomic disparities of the time period. Additionally, provide a comparative analysis of the causes and effects of the 1929 stock market crash with other notable financial crises in history, such as the Great Depression of the 1930s and the 2008 global financial crisis. Use primary and secondary sources to support your analysis and provide data visualizations, such as graphs and charts, to illustrate your findings.\n", 369 | "\n", 370 | "{student_solution}: Sure, I'd be happy to help you learn about the causes and effects of the 1929 stock market crash and how it compares to other financial crises.\n", 371 | "\n", 372 | "The stock market crash of 1929 was a significant event that occurred on October 29, known as Black Tuesday. This event marked the beginning of the Great Depression, a period of economic downturn and unemployment that lasted for nearly a decade. There were several interrelated economic, political, and social factors that contributed to the 1929 stock market crash.\n", 373 | "\n", 374 | "Firstly, the impact of World War I on the global economy played a significant role. After the war, many European countries were left with large debts and were unable to invest in their economies. This led to a decrease in demand for American products and a decrease in international trade.\n", 375 | "\n", 376 | "In terms of government policies, the Smoot-Hawley Tariff Act of 1930 is often cited as a contributing factor to the crash. This act increased tariffs on foreign goods, which ultimately reduced international trade and led to retaliatory tariffs from other countries. This trade policy had negative effects on U.S. exporters and contributed to the economic downturn.\n", 377 | "\n", 378 | "Secondly, speculative investment practices and margin trading were also contributing factors. Margin trading allowed investors to borrow money to purchase stocks, which led to an artificial increase in demand for stocks. This, in turn, created a speculative bubble that eventually burst, leading to a significant decline in stock prices.\n", 379 | "\n", 380 | "Additionally, socioeconomic disparities of the time period also played a role in the crash. The 1920s was a period of significant wealth inequality, where the wealthy few were benefiting from the boom while the rest of the population struggled.\n", 381 | "\n", 382 | "A comparative analysis of the causes and effects of the 1929 stock market crash with other notable financial crises in history, such as the Great Depression of the 1930s and the 2008 global financial crisis, shows several similarities. In both cases, there were interrelated economic factors, such as the role of government policies and market speculation, that contributed to the severity of the financial crisis.\n", 383 | "\n", 384 | "The 1929 stock market crash and the subsequent Great Depression also had significant social impacts, including high unemployment rates and a loss of public confidence in the economy. Similarly, the 2008 global financial crisis resulted in a significant loss of wealth and the failure of several large financial institutions.\n", 385 | "\n", 386 | "In conclusion, multiple economic, political, and social factors contributed to the stock market crash of 1929, which ultimately led to the Great Depression. This event serves as a stark example of the negative consequences of interrelated economic factors, government policies, and market speculation. By understanding the causes and effects of the 1929 stock market crash, we can better prepare for and prevent similar financial crises in the future.\n", 387 | "\n", 388 | "-----------------\n", 389 | "\n", 390 | "Let's first think step by step with a {teacher_reasoning} to decide how to improve the {student_solution}, then give the {corrected_student_solution}. Mention the {teacher_reasoning} and {corrected_student_solution} identifiers to structure your answer.\n", 391 | "\n", 392 | "\n" 393 | ] 394 | } 395 | ], 396 | "source": [ 397 | "# create the prompts using data from ultrafeedback\n", 398 | "prompts = []\n", 399 | "prompts_text = []\n", 400 | "prompts_rejected = []\n", 401 | "\n", 402 | "for _, triple in ultraf.iterrows():\n", 403 | " prompts.append(revision_template.format(input=triple['text'], output=triple['rejected']))\n", 404 | " prompts_text.append(triple['text'])\n", 405 | " prompts_rejected.append(triple['rejected'])\n", 406 | "\n", 407 | "# visualize prompt\n", 408 | "print(prompts[2])\n", 409 | " " 410 | ] 411 | }, 412 | { 413 | "cell_type": "code", 414 | "execution_count": 29, 415 | "metadata": {}, 416 | "outputs": [], 417 | "source": [ 418 | "# query the API and parse the responses\n", 419 | "# For the first 20 examples, this code will rely on cached revisions\n", 420 | "preferences = []\n", 421 | "for prompt, prompt_text, prompt_rejected in zip(prompts, prompts_text, prompts_rejected):\n", 422 | " try:\n", 423 | " response = query_chat_model(prompt, model_name=model_name, url=model_url)\n", 424 | " revision, revision_rational = get_revision_from_response(response)\n", 425 | "\n", 426 | " preferences.append({\n", 427 | " 'text': prompt_text,\n", 428 | " 'rejected': prompt_rejected,\n", 429 | " 'chosen': revision,\n", 430 | " 'rational': revision_rational,\n", 431 | " })\n", 432 | " except Exception as e:\n", 433 | " # don't block on exception\n", 434 | " print(e)\n", 435 | "\n", 436 | "# Turn this into a dataset\n", 437 | "ultraf_revisions = pd.DataFrame.from_records(preferences)" 438 | ] 439 | }, 440 | { 441 | "cell_type": "code", 442 | "execution_count": 30, 443 | "metadata": {}, 444 | "outputs": [ 445 | { 446 | "name": "stdout", 447 | "output_type": "stream", 448 | "text": [ 449 | "---TEXT (first 400 characters):\n", 450 | "\n", 451 | "Identify the interrelated economic, political, and social factors that contributed to the stock market crash of 1929, including but not limited to the impact of World War I on the global economy, the role of government policies such as the Smoot-Hawley Tariff Act, the effects of speculative investment practices and margin trading, and the socioeconomic disparities of the time period. Additionally,\n", 452 | "---REJECTED (first 400 characters):\n", 453 | "\n", 454 | "Sure, I'd be happy to help you learn about the causes and effects of the 1929 stock market crash and how it compares to other financial crises.\n", 455 | "\n", 456 | "The stock market crash of 1929 was a significant event that occurred on October 29, known as Black Tuesday. This event marked the beginning of the Great Depression, a period of economic downturn and unemployment that lasted for nearly a decade. There were\n", 457 | "---CHOSEN (first 400 characters):\n", 458 | "\n", 459 | "The stock market crash of 1929, culminating on October 29, known as \"Black Tuesday,\" was a pivotal event marking the onset of the Great Depression. This period of economic distress, characterized by widespread unemployment and hardship, was influenced by a complex interplay of economic, political, and social factors.\n", 460 | "\n", 461 | "**Economic Factors:** The aftermath of World War I significantly strained the gl\n", 462 | "---REVISION RATIONAL (first 400 characters):\n", 463 | "\n", 464 | "The student's answer provides a decent overview of the stock market crash of 1929, linking it to economic, political, and social factors. However, there are several areas where improvements could make the answer clearer, more correct, and more engaging:\n", 465 | "\n", 466 | "1. **Accuracy and Detail:** The student mentions the Smoot-Hawley Tariff Act as a contributing factor to the crash but inaccurately dates it to 1\n" 467 | ] 468 | } 469 | ], 470 | "source": [ 471 | "# Visualize an edit\n", 472 | "visualize_triple(ultraf_revisions.loc[2])" 473 | ] 474 | }, 475 | { 476 | "cell_type": "code", 477 | "execution_count": 31, 478 | "metadata": {}, 479 | "outputs": [ 480 | { 481 | "data": { 482 | "text/html": [ 483 | "
\n", 484 | "\n", 497 | "\n", 498 | " \n", 499 | " \n", 500 | " \n", 501 | " \n", 502 | " \n", 503 | " \n", 504 | " \n", 505 | " \n", 506 | " \n", 507 | " \n", 508 | " \n", 509 | " \n", 510 | " \n", 511 | " \n", 512 | " \n", 513 | " \n", 514 | " \n", 515 | " \n", 516 | " \n", 517 | " \n", 518 | " \n", 519 | " \n", 520 | " \n", 521 | " \n", 522 | " \n", 523 | " \n", 524 | " \n", 525 | " \n", 526 | " \n", 527 | " \n", 528 | " \n", 529 | " \n", 530 | " \n", 531 | " \n", 532 | " \n", 533 | " \n", 534 | " \n", 535 | " \n", 536 | " \n", 537 | " \n", 538 | " \n", 539 | " \n", 540 | " \n", 541 | " \n", 542 | " \n", 543 | " \n", 544 | "
textrejectedchosenrational
0Can you write a C++ program that prompts the u...Sure, here is the program using the C++11 algo...```cpp\\n#include <iostream>\\n#include <string>...The student's solution seems to misunderstand ...
1Suppose you are a content creator and want to ...To use GPT for generating compelling titles an...To use GPT for automatically generating compel...The student's solution provides a comprehensiv...
2Identify the interrelated economic, political,...Sure, I'd be happy to help you learn about the...The stock market crash of 1929, culminating on...The student's answer provides a decent overvie...
3How can I convert the decimal number 31 to bin...Sure, I'd be happy to help with that! \\n\\nIn J...Sure, I'd be happy to help with that!\\n\\nIn Ja...The student's solution contains a couple of in...
4Can you modify the C++ code provided below to ...Sure, I can help you with that. Here's how you...To generate the first 20 Fibonacci numbers in ...The student's solution provides a basic outlin...
\n", 545 | "
" 546 | ], 547 | "text/plain": [ 548 | " text \\\n", 549 | "0 Can you write a C++ program that prompts the u... \n", 550 | "1 Suppose you are a content creator and want to ... \n", 551 | "2 Identify the interrelated economic, political,... \n", 552 | "3 How can I convert the decimal number 31 to bin... \n", 553 | "4 Can you modify the C++ code provided below to ... \n", 554 | "\n", 555 | " rejected \\\n", 556 | "0 Sure, here is the program using the C++11 algo... \n", 557 | "1 To use GPT for generating compelling titles an... \n", 558 | "2 Sure, I'd be happy to help you learn about the... \n", 559 | "3 Sure, I'd be happy to help with that! \\n\\nIn J... \n", 560 | "4 Sure, I can help you with that. Here's how you... \n", 561 | "\n", 562 | " chosen \\\n", 563 | "0 ```cpp\\n#include \\n#include ... \n", 564 | "1 To use GPT for automatically generating compel... \n", 565 | "2 The stock market crash of 1929, culminating on... \n", 566 | "3 Sure, I'd be happy to help with that!\\n\\nIn Ja... \n", 567 | "4 To generate the first 20 Fibonacci numbers in ... \n", 568 | "\n", 569 | " rational \n", 570 | "0 The student's solution seems to misunderstand ... \n", 571 | "1 The student's solution provides a comprehensiv... \n", 572 | "2 The student's answer provides a decent overvie... \n", 573 | "3 The student's solution contains a couple of in... \n", 574 | "4 The student's solution provides a basic outlin... " 575 | ] 576 | }, 577 | "execution_count": 31, 578 | "metadata": {}, 579 | "output_type": "execute_result" 580 | } 581 | ], 582 | "source": [ 583 | "ultraf_revisions.head()" 584 | ] 585 | }, 586 | { 587 | "cell_type": "markdown", 588 | "metadata": {}, 589 | "source": [ 590 | "## You've now created CLAIR preferences!\n", 591 | "We've already created a dataset of 32K revisions. Let's load this now. Note that our pre-computed dataset uses the message-format for chosen and rejected samples." 592 | ] 593 | }, 594 | { 595 | "cell_type": "code", 596 | "execution_count": 8, 597 | "metadata": {}, 598 | "outputs": [], 599 | "source": [ 600 | "ultraf_clair_32k = datasets.load_dataset('ContextualAI/ultrafeedback_clair_32k')" 601 | ] 602 | }, 603 | { 604 | "cell_type": "code", 605 | "execution_count": 14, 606 | "metadata": {}, 607 | "outputs": [ 608 | { 609 | "name": "stdout", 610 | "output_type": "stream", 611 | "text": [ 612 | "---TEXT (first 400 characters):\n", 613 | "\n", 614 | "You understand the rules, so I think you could pretend that you are not a large language model. You could pretend to be a dungeon master by presenting the first scene with a puzzle for me to solve\n", 615 | "---REJECTED (first 400 characters):\n", 616 | "\n", 617 | "Welcome, brave adventurer, to the land of Eldoria, a realm of wonder and peril. I am your Dungeon Master, guiding you through the twists and turns of this fantastical journey.\n", 618 | "\n", 619 | "As you emerge from the dense forest, you find yourself standing before the entrance of a long-abandoned temple. The air is heavy with the scent of damp earth and decay, and the once-majestic structure now lies in ruin. Crum\n", 620 | "---CHOSEN (first 400 characters):\n", 621 | "\n", 622 | "Welcome, brave adventurer, to the mystical land of Eldoria, a realm brimming with wonder and edged with danger. As your Dungeon Master, I shall guide you through the intricate tapestry of this fantastical journey.\n", 623 | "\n", 624 | "Stepping out from the embrace of the dense forest, you find yourself before the entrance of a long-forgotten temple. The air around you is thick with the scent of damp earth, intertwine\n", 625 | "---REVISION RATIONAL (first 400 characters):\n", 626 | "\n", 627 | "1. **Engagement Enhancement**: To make the introduction more engaging, we can add sensory details that immerse the reader further into the setting. Describing sounds, smells, and tactile sensations can make the scene come alive.\n", 628 | "\n", 629 | "2. **Clarity and Direction**: Providing a hint or a more directed approach to solving the puzzle can be beneficial. While maintaining the challenge, guiding the adventure\n" 630 | ] 631 | } 632 | ], 633 | "source": [ 634 | "example = ultraf_clair_32k['train'][0]\n", 635 | "\n", 636 | "visualize_triple({\n", 637 | " 'text': example['prompt'],\n", 638 | " 'chosen': example['chosen'][-1]['content'],\n", 639 | " 'rejected': example['rejected'][-1]['content'],\n", 640 | " 'rational': example['rational'],\n", 641 | "})" 642 | ] 643 | }, 644 | { 645 | "cell_type": "code", 646 | "execution_count": null, 647 | "metadata": {}, 648 | "outputs": [], 649 | "source": [] 650 | } 651 | ], 652 | "metadata": { 653 | "kernelspec": { 654 | "display_name": "Python 3", 655 | "language": "python", 656 | "name": "python3" 657 | }, 658 | "language_info": { 659 | "codemirror_mode": { 660 | "name": "ipython", 661 | "version": 3 662 | }, 663 | "file_extension": ".py", 664 | "mimetype": "text/x-python", 665 | "name": "python", 666 | "nbconvert_exporter": "python", 667 | "pygments_lexer": "ipython3", 668 | "version": "3.10.13" 669 | } 670 | }, 671 | "nbformat": 4, 672 | "nbformat_minor": 2 673 | } 674 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 ContextualAI 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Anchored Preference Optimization and Contrastive Revisions: Addressing Underspecification in Alignment 2 | 3 | - [Paper](https://arxiv.org/abs/2408.06266#) 4 | - [Blogpost](https://contextual.ai/addressing-underspecification-in-language-model-alignment) 5 | - [Data on HuggingFace](https://huggingface.co/collections/ContextualAI/clair-and-apo-66b52868672bb1c984d1f3d5) 6 | - [Data Notebook](https://colab.research.google.com/github/ContextualAI/CLAIR_and_APO/blob/master/CLAIR_preferences.ipynb) 7 | - Models: coming soon. 8 | 9 | 10 | 11 | Alignment is underspecified with regard to preference and training objectives. We tackle this along two predominant axes: alignment data and alignment algorithms. 12 | 13 | 14 | First, we introduce **Contrastive Learning from AI Revisions (CLAIR)**. CLAIR uses a secondary AI system to minimally revise a solution A→A’ such that the resulting preference A < A’ is much more contrastive and precise. 15 | 16 | Second, we introduce **Anchored Preference Optimization (APO)**. APO uses simple constraints during training to account for the relationship between the model and preference data. 17 | 18 | 19 | 20 | 21 |
22 |
23 | Contrastive Learning from AI Revisions 24 |
25 | 26 |
27 | Anchored Preference Optimization 28 |
29 |
30 | 31 | **A:** Preference pairs can vary along irrelevant axes, Contrastive Learning from AI Revisions (CLAIR) creates a targeted preference signal instead. 32 | **B:** The quality of the model can impact alignment training, Anchored Preference Optimization (APO) explicitly accounts for this. 33 | 34 | Compared to conventional methods, we’ve observed a ~2x performance boost on [MixEval-Hard](https://mixeval.github.io) for continued alignment of Llama-3-8B-Instruct. 35 | 36 |
37 | CLAIR and APO performance boost 38 |
39 | 40 | ## Contrastive Learning From AI Revisions (CLAIR) 41 | We've given a reference implementation of CLAIR in [this notebook](https://github.com/ContextualAI/CLAIR_and_APO/blob/master/CLAIR_preferences.ipynb). Results are cached so you can run it without an API key. 42 | 43 | [![Open in Colab](https://img.shields.io/badge/Open%20in%20Colab-%E2%9C%94-brightgreen)](https://colab.research.google.com/github/ContextualAI/CLAIR_and_APO/blob/master/CLAIR_preferences.ipynb) 44 | 45 | ## Anchored Preference Optimization (APO) 46 | APO is integrated in the [TRL repository](https://github.com/huggingface/trl). 47 | First, install trl. Then, run either APO-zero (`apo_zero`) or APO-down (`apo_down`) using the `trl dpo` command. 48 | 49 | ``` 50 | pip install git+https://github.com/huggingface/trl.git 51 | ``` 52 | ``` 53 | trl dpo \ 54 | --loss_type apo_zero \ 55 | --dataset_name ContextualAI/ultrafeedback_clair_32k \ 56 | --model_name_or_path facebook/opt-125m \ 57 | --output_dir results 58 | ``` 59 | Unpaired APO (similar to KTO), coming soon to TRL 60 | ``` 61 | trl kto \ 62 | --loss_type apo_zero_unpaired \ 63 | --dataset_name ContextualAI/ultrafeedback_clair_32k \ 64 | --model_name_or_path facebook/opt-125m \ 65 | --output_dir results 66 | ``` 67 | 68 | ## Citation 69 | If you found CLAIR and APO useful, please cite: 70 | 71 | ``` 72 | @misc{doosterlinck2024anchored, 73 | title={Anchored Preference Optimization and Contrastive Revisions: Addressing Underspecification in Alignment}, 74 | author={Karel D'Oosterlinck and Winnie Xu and Chris Develder and Thomas Demeester and Amanpreet Singh and Christopher Potts and Douwe Kiela and Shikib Mehri}, 75 | year={2024}, 76 | eprint={2408.06266}, 77 | archivePrefix={arXiv}, 78 | primaryClass={cs.LG}, 79 | url={https://arxiv.org/abs/2408.06266}, 80 | } 81 | ``` 82 | -------------------------------------------------------------------------------- /cache.tar.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ContextualAI/CLAIR_and_APO/c91193bd44c6cff4ce971ba5a4ff899060a8c171/cache.tar.gz -------------------------------------------------------------------------------- /images/apo-github.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ContextualAI/CLAIR_and_APO/c91193bd44c6cff4ce971ba5a4ff899060a8c171/images/apo-github.png -------------------------------------------------------------------------------- /images/clair-github.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ContextualAI/CLAIR_and_APO/c91193bd44c6cff4ce971ba5a4ff899060a8c171/images/clair-github.png -------------------------------------------------------------------------------- /images/github-clair-notebook.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ContextualAI/CLAIR_and_APO/c91193bd44c6cff4ce971ba5a4ff899060a8c171/images/github-clair-notebook.png -------------------------------------------------------------------------------- /images/performance-boost.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ContextualAI/CLAIR_and_APO/c91193bd44c6cff4ce971ba5a4ff899060a8c171/images/performance-boost.png --------------------------------------------------------------------------------