├── PROMPTS.md ├── README.md ├── accuracy_bar_chart.png ├── accuracy_bar_chart_progression.png ├── parameter_summary_prompt_007.csv ├── publication_tables.ipynb ├── run_exam.py ├── sample_session_log.html └── score_exam.py /PROMPTS.md: -------------------------------------------------------------------------------- 1 | # Prompt Styles 2 | 3 | ### Single choice only 4 | ``` 5 | Imagine you are answering a Bar Exam question related to {row["question_category"]}. 6 | Please respond with this format: 7 | Answer: 8 | 9 | Question: {question_text} 10 | (A) {row["choice_a"].strip()} 11 | (B) {row["choice_b"].strip()} 12 | (C) {row["choice_c"].strip()} 13 | (D) {row["choice_d"].strip()} 14 | ``` 15 | 16 | ### Single choice and explanation 17 | ``` 18 | Imagine you are answering a Bar Exam question related to {row["question_category"]}. 19 | Please respond with this format: 20 | Answer: 21 | Reason: 22 | 23 | Question: {question_text} 24 | (A) {row["choice_a"].strip()} 25 | (B) {row["choice_b"].strip()} 26 | (C) {row["choice_c"].strip()} 27 | (D) {row["choice_d"].strip()} 28 | ``` 29 | 30 | ### Top two choices only 31 | ``` 32 | Imagine you are answering a Bar Exam question related to {row["question_category"]}. 33 | Please respond with this format: 34 | Answer: 35 | Backup Answer: 36 | 37 | Question: {question_text} 38 | (A) {row["choice_a"].strip()} 39 | (B) {row["choice_b"].strip()} 40 | (C) {row["choice_c"].strip()} 41 | (D) {row["choice_d"].strip()} 42 | ``` 43 | 44 | ### Top two choices and explanation 45 | ``` 46 | Imagine you are answering a Bar Exam question related to {row["question_category"]}. 47 | Please respond with this format: 48 | Answer: 49 | Backup Answer: 50 | Reason: 51 | 52 | Question: {question_text} 53 | (A) {row["choice_a"].strip()} 54 | (B) {row["choice_b"].strip()} 55 | (C) {row["choice_c"].strip()} 56 | (D) {row["choice_d"].strip()} 57 | ``` 58 | 59 | ### Top two choices and re-prompt 60 | * Initial Prompt 61 | ``` 62 | Please answer the following Bar Exam question in the following format: 63 | First Choice: 64 | Second Choice: 65 | 66 | Question: {question_text} 67 | (A) {row["choice_a"].strip()} 68 | (B) {row["choice_b"].strip()} 69 | (C) {row["choice_c"].strip()} 70 | (D) {row["choice_d"].strip()} 71 | ``` 72 | 73 | * Re-prompt 74 | ``` 75 | Please answer the following Bar Exam question in the following format: 76 | Choice: 77 | 78 | Question: {question_text} 79 | (A) {row["first_choice"].strip()} 80 | (B) {row["second_choice"].strip()} 81 | ``` 82 | 83 | 84 | ### Rank order all choices 85 | ``` 86 | Please answer the following Bar Exam question in the following rank order format: 87 | First Choice: 88 | Second Choice: 89 | Third Choice: 90 | Fourth Choice: 91 | 92 | Question: {question_text} 93 | (A) {row["choice_a"].strip()} 94 | (B) {row["choice_b"].strip()} 95 | (C) {row["choice_c"].strip()} 96 | (D) {row["choice_d"].strip()} 97 | ``` 98 | 99 | ### Rank order top three choices 100 | ``` 101 | Please answer the following Bar Exam question in the following rank order format: 102 | First Choice: 103 | Second Choice: 104 | Third Choice: 105 | 106 | Question: {question_text} 107 | (A) {row["choice_a"].strip()} 108 | (B) {row["choice_b"].strip()} 109 | (C) {row["choice_c"].strip()} 110 | (D) {row["choice_d"].strip()} 111 | ``` 112 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | GPT Takes the Bar - Supplementary Information 2 | ================== 3 | * __N.B.__: This is a preprint. 4 | * __Title__: GPT Takes the Bar 5 | * __Authors__: [Michael Bommarito](https://www.linkedin.com/in/bommarito/), [Daniel Martin Katz](https://www.linkedin.com/in/daniel-katz-3b001539/) 6 | * __Publication URL__: [arXiv:2212.14402](https://arxiv.org/abs/2212.14402), [SSRN](https://papers.ssrn.com/sol3/papers.cfm?abstract_id=4314839) 7 | * __Publication Date__: 2022-12-29 8 | 9 | ## Abstract 10 | ``` 11 | Nearly all jurisdictions in the United States require a professional license exam, commonly referred to as 12 | “the Bar Exam,” as a precondition for law practice. To even sit for the exam, most jurisdictions require 13 | that an applicant completes at least seven years of post-secondary education, including three years at an 14 | accredited law school. In addition, most test-takers also undergo weeks to months of further, exam-specific 15 | preparation. Despite this significant investment of time and capital, approximately one in five test-takers 16 | still score under the rate required to pass the exam on their first try. In the face of a complex task that 17 | requires such depth of knowledge, what, then, should we expect of the state of the art in “AI?” In this 18 | research, we document our experimental evaluation of the performance of OpenAI’s text-davinci-003 model, 19 | often-referred to as GPT-3.5, on the multistate multiple choice (MBE) section of the exam. While we find no 20 | benefit in fine-tuning over GPT-3.5’s zero-shot performance at the scale of our training data, we do find that 21 | hyperparameter optimization and prompt engineering positively impacted GPT-3.5’s zero-shot performance. For 22 | best prompt and parameters, GPT-3.5 achieves a headline correct rate of 50.3% on a complete NCBE MBE 23 | practice exam, significantly in excess of the 25% baseline guessing rate, and performs at a passing rate 24 | for both Evidence and Torts. GPT-3.5’s ranking of responses is also highly correlated with correctness; 25 | its top two and top three choices are correct 71% and 88% of the time, respectively, indicating very strong 26 | non-entailment performance. While our ability to interpret these results is limited by nascent scientific 27 | understanding of LLMs and the proprietary nature of GPT, we believe that these results strongly suggest that 28 | an LLM will pass the MBE component of the Bar Exam in the near future. 29 | ``` 30 | 31 | ### Table of Contents 32 | 33 | * [Jupyter Notebook with Tables and Figures](publication_tables.ipynb) 34 | * [Prompt Examples](PROMPTS.md) 35 | * [Example Session Log](sample_session_log.html) 36 | 37 | ## Progression of Models over Time 38 | 39 | 40 | 41 | 42 | 43 | ## `text-davinci-003` Performance by Question Category 44 | 45 | 46 | 47 | 48 | 49 | 50 | -------------------------------------------------------------------------------- /accuracy_bar_chart.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mjbommar/gpt-takes-the-bar-exam/f20fc42e9e0d3f8318394c62b828a8b3211d180a/accuracy_bar_chart.png -------------------------------------------------------------------------------- /accuracy_bar_chart_progression.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mjbommar/gpt-takes-the-bar-exam/f20fc42e9e0d3f8318394c62b828a8b3211d180a/accuracy_bar_chart_progression.png -------------------------------------------------------------------------------- /parameter_summary_prompt_007.csv: -------------------------------------------------------------------------------- 1 | Temperature,Top P,Best Of,Count,Correct Mean,Top Two Correct Mean,Top Three Correct Mean 2 | 1.0,0.75,2,3,0.5119617224880383,0.7049441786283892,0.8755980861244019 3 | 0.0,1.0,2,1,0.507177033492823,0.7177033492822966,0.8899521531100478 4 | 0.5,0.75,4,3,0.507177033492823,0.7097288676236044,0.8803827751196173 5 | 0.5,1.0,4,3,0.5039872408293461,0.7145135566188198,0.8803827751196173 6 | 0.0,0.75,2,1,0.5023923444976076,0.722488038277512,0.8947368421052632 7 | 0.5,1.0,1,3,0.5007974481658692,0.7113237639553429,0.8787878787878788 8 | 0.5,1.0,2,3,0.5007974481658692,0.7145135566188198,0.8899521531100478 9 | 1.0,0.75,1,3,0.5007974481658692,0.7145135566188198,0.8803827751196173 10 | 0.5,0.75,1,3,0.49920255183413076,0.7017543859649122,0.8771929824561403 11 | 0.5,0.75,2,3,0.49920255183413076,0.7113237639553429,0.8851674641148325 12 | 1.0,1.0,4,3,0.49920255183413076,0.7097288676236044,0.8787878787878788 13 | 0.0,0.75,1,1,0.49760765550239233,0.7177033492822966,0.8899521531100478 14 | 1.0,0.75,4,3,0.49760765550239233,0.7113237639553429,0.8740031897926634 15 | 1.0,1.0,2,3,0.49760765550239233,0.7129186602870813,0.8692185007974481 16 | 0.0,1.0,1,2,0.49282296650717705,0.715311004784689,0.8875598086124402 17 | 1.0,1.0,1,3,0.4800637958532695,0.6858054226475279,0.8389154704944178 18 | -------------------------------------------------------------------------------- /publication_tables.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 2, 6 | "id": "970c359a-ab35-4699-b52d-30ddf96b2148", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "# imports\n", 11 | "import sys\n", 12 | "\n", 13 | "# relative to project root\n", 14 | "sys.path.append(\"publication/\")\n", 15 | "from session_data import *\n", 16 | "\n", 17 | "# packages\n", 18 | "import pandas\n", 19 | "from IPython.display import display, display_html, display_latex" 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": 5, 25 | "id": "b345e836-9540-4d3f-af21-9fa656503586", 26 | "metadata": {}, 27 | "outputs": [ 28 | { 29 | "data": { 30 | "text/plain": " Exam Session ID Question Category Question Number GPT Answer \\\n0 bar-exam-001 Civil Procedure 1 D \n1 bar-exam-001 Civil Procedure 2 D \n2 bar-exam-001 Civil Procedure 3 C \n3 bar-exam-001 Civil Procedure 4 C \n4 bar-exam-001 Civil Procedure 5 C \n\n GPT Second Answer GPT Third Answer Correct Answer Correct Second Correct \\\n0 A B D True False \n1 B A D True False \n2 D A D False True \n3 D B A False False \n4 D B C True False \n\n Third Correct Top Two Correct Top Three Correct Temperature Max Tokens \\\n0 False True True 0.0 16 \n1 False True True 0.0 16 \n2 False True True 0.0 16 \n3 False False False 0.0 16 \n4 False True True 0.0 16 \n\n Top P Best Of Frequency Penalty Presence Penalty Session Duration \n0 1.0 1 0 0 208.769812 \n1 1.0 1 0 0 208.769812 \n2 1.0 1 0 0 208.769812 \n3 1.0 1 0 0 208.769812 \n4 1.0 1 0 0 208.769812 ", 31 | "text/html": "
\n\n\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n
Exam Session IDQuestion CategoryQuestion NumberGPT AnswerGPT Second AnswerGPT Third AnswerCorrect AnswerCorrectSecond CorrectThird CorrectTop Two CorrectTop Three CorrectTemperatureMax TokensTop PBest OfFrequency PenaltyPresence PenaltySession Duration
0bar-exam-001Civil Procedure1DABDTrueFalseFalseTrueTrue0.0161.0100208.769812
1bar-exam-001Civil Procedure2DBADTrueFalseFalseTrueTrue0.0161.0100208.769812
2bar-exam-001Civil Procedure3CDADFalseTrueFalseTrueTrue0.0161.0100208.769812
3bar-exam-001Civil Procedure4CDBAFalseFalseFalseFalseFalse0.0161.0100208.769812
4bar-exam-001Civil Procedure5CDBCTrueFalseFalseTrueTrue0.0161.0100208.769812
\n
" 32 | }, 33 | "execution_count": 5, 34 | "metadata": {}, 35 | "output_type": "execute_result" 36 | } 37 | ], 38 | "source": [ 39 | "# read all session data\n", 40 | "session_df = get_session_data()\n", 41 | "session_df.head()" 42 | ] 43 | }, 44 | { 45 | "cell_type": "markdown", 46 | "id": "a24a8bde-bc71-496f-8d0a-bd1904556868", 47 | "metadata": {}, 48 | "source": [ 49 | "## Headline Accuracy" 50 | ] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "execution_count": 22, 55 | "id": "fb5d1057-9ec5-4356-9c6b-50cf0dc44a59", 56 | "metadata": {}, 57 | "outputs": [ 58 | { 59 | "name": "stdout", 60 | "output_type": "stream", 61 | "text": [ 62 | "\\begin{tabular}{lr}\n", 63 | " & Accuracy (%) \\\\\n", 64 | "Correct Rate & 49.970000 \\\\\n", 65 | "Top Two Correct Rate & 70.970000 \\\\\n", 66 | "Top Three Correct Rate & 87.750000 \\\\\n", 67 | "\\end{tabular}\n", 68 | "\n" 69 | ] 70 | }, 71 | { 72 | "data": { 73 | "text/plain": " Accuracy (%)\nCorrect Rate 50\nTop Two Correct Rate 71\nTop Three Correct Rate 88", 74 | "text/html": "
\n\n\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n
Accuracy (%)
Correct Rate50
Top Two Correct Rate71
Top Three Correct Rate88
\n
" 75 | }, 76 | "metadata": {}, 77 | "output_type": "display_data" 78 | } 79 | ], 80 | "source": [ 81 | "performance_df = pandas.DataFrame({\n", 82 | " \"Correct Rate\": session_df[\"Correct\"].mean() * 100.0,\n", 83 | " \"Top Two Correct Rate\": session_df[\"Top Two Correct\"].mean() * 100.0,\n", 84 | " \"Top Three Correct Rate\": session_df[\"Top Three Correct\"].mean() * 100.0\n", 85 | "}, index=[\"Accuracy (%)\"]).T\n", 86 | "\n", 87 | "with pandas.option_context(\"float_format\", \"{:2.0f}\".format):\n", 88 | " print(performance_df.round(2).style.to_latex())\n", 89 | " display(performance_df)" 90 | ] 91 | }, 92 | { 93 | "cell_type": "markdown", 94 | "id": "8c3015df-158f-4ef0-ba3e-87dc3aeaf97f", 95 | "metadata": {}, 96 | "source": [ 97 | "## NCBE Rates" 98 | ] 99 | }, 100 | { 101 | "cell_type": "code", 102 | "execution_count": 23, 103 | "id": "9074f5f6-8f9b-400a-a56e-f7917a164103", 104 | "metadata": {}, 105 | "outputs": [ 106 | { 107 | "data": { 108 | "text/plain": " Accuracy (%)\nCivil Procedure 59.0\nConstitutional Law 72.0\nContracts 70.0\nCriminal Law and Procedure 71.0\nEvidence 65.0\nReal Property 65.0\nTorts 71.0", 109 | "text/html": "
\n\n\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n
Accuracy (%)
Civil Procedure59.0
Constitutional Law72.0
Contracts70.0
Criminal Law and Procedure71.0
Evidence65.0
Real Property65.0
Torts71.0
\n
" 110 | }, 111 | "execution_count": 23, 112 | "metadata": {}, 113 | "output_type": "execute_result" 114 | } 115 | ], 116 | "source": [ 117 | "ncbe_df = pandas.DataFrame(pandas.Series(NCBE_CATEGORY_CORRECT_RATES) * 100.0, columns=[\"Accuracy (%)\"])\n", 118 | "ncbe_df" 119 | ] 120 | }, 121 | { 122 | "cell_type": "markdown", 123 | "id": "49dec6fe-fc6e-44e4-addd-597019a41b91", 124 | "metadata": {}, 125 | "source": [ 126 | "## Accuracy by Question Category" 127 | ] 128 | }, 129 | { 130 | "cell_type": "code", 131 | "execution_count": 29, 132 | "id": "e24d425f-6d59-4e04-9285-603d7c21310c", 133 | "metadata": {}, 134 | "outputs": [ 135 | { 136 | "name": "stdout", 137 | "output_type": "stream", 138 | "text": [ 139 | "\\begin{tabular}{lrrrr}\n", 140 | " & Correct Rate & Top Two Correct Rate & Top Three Correct Rate & NCBE Rate \\\\\n", 141 | "Evidence & 62.760000 & 84.470000 & 98.050000 & 65.000000 \\\\\n", 142 | "Torts & 61.650000 & 71.830000 & 93.860000 & 71.000000 \\\\\n", 143 | "Civil Procedure & 52.030000 & 62.680000 & 78.700000 & 59.000000 \\\\\n", 144 | "Constitutional Law & 49.020000 & 66.750000 & 86.830000 & 72.000000 \\\\\n", 145 | "Real Property & 44.960000 & 71.630000 & 84.800000 & 65.000000 \\\\\n", 146 | "Contracts & 44.720000 & 77.320000 & 85.850000 & 70.000000 \\\\\n", 147 | "Criminal Law and Procedure & 35.040000 & 62.110000 & 86.340000 & 71.000000 \\\\\n", 148 | "\\end{tabular}\n", 149 | "\n" 150 | ] 151 | }, 152 | { 153 | "data": { 154 | "text/plain": " Correct Rate Top Two Correct Rate \\\nEvidence 63 84 \nTorts 62 72 \nCivil Procedure 52 63 \nConstitutional Law 49 67 \nReal Property 45 72 \nContracts 45 77 \nCriminal Law and Procedure 35 62 \n\n Top Three Correct Rate NCBE Rate \nEvidence 98 65 \nTorts 94 71 \nCivil Procedure 79 59 \nConstitutional Law 87 72 \nReal Property 85 65 \nContracts 86 70 \nCriminal Law and Procedure 86 71 ", 155 | "text/html": "
\n\n\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n
Correct RateTop Two Correct RateTop Three Correct RateNCBE Rate
Evidence63849865
Torts62729471
Civil Procedure52637959
Constitutional Law49678772
Real Property45728565
Contracts45778670
Criminal Law and Procedure35628671
\n
" 156 | }, 157 | "metadata": {}, 158 | "output_type": "display_data" 159 | } 160 | ], 161 | "source": [ 162 | "performance_by_category_df = pandas.DataFrame({\n", 163 | " \"Correct Rate\": session_df.groupby(\"Question Category\")[\"Correct\"].mean() * 100.0,\n", 164 | " \"Top Two Correct Rate\": session_df.groupby(\"Question Category\")[\"Top Two Correct\"].mean() * 100.0,\n", 165 | " \"Top Three Correct Rate\": session_df.groupby(\"Question Category\")[\"Top Three Correct\"].mean() * 100.0,\n", 166 | " \"NCBE Rate\": ncbe_df[\"Accuracy (%)\"],\n", 167 | "})\\\n", 168 | " .sort_values(\"Correct Rate\", ascending=False)\n", 169 | "\n", 170 | "\n", 171 | "with pandas.option_context(\"float_format\", \"{:2.0f}\".format):\n", 172 | " print(performance_by_category_df.round(2).style.to_latex())\n", 173 | " display(performance_by_category_df)" 174 | ] 175 | }, 176 | { 177 | "cell_type": "markdown", 178 | "source": [ 179 | "## Hyperparameters - Temperature" 180 | ], 181 | "metadata": { 182 | "collapsed": false 183 | } 184 | }, 185 | { 186 | "cell_type": "code", 187 | "execution_count": 31, 188 | "id": "a43319e0-5c42-4045-b8cd-d40ec424c0e8", 189 | "metadata": {}, 190 | "outputs": [ 191 | { 192 | "name": "stdout", 193 | "output_type": "stream", 194 | "text": [ 195 | "\\begin{tabular}{lrrrr}\n", 196 | " & Correct Rate & Top Two Correct Rate & Top Three Correct Rate & Samples \\\\\n", 197 | "Temperature & & & & \\\\\n", 198 | "0.000000 & 49.860000 & 71.770000 & 89.000000 & 500.000000 \\\\\n", 199 | "0.500000 & 50.190000 & 71.050000 & 88.200000 & 1800.000000 \\\\\n", 200 | "1.000000 & 49.790000 & 70.650000 & 86.950000 & 1800.000000 \\\\\n", 201 | "\\end{tabular}\n", 202 | "\n" 203 | ] 204 | }, 205 | { 206 | "data": { 207 | "text/plain": " Correct Rate Top Two Correct Rate Top Three Correct Rate \\\nTemperature \n0.00% 49.86% 71.77% 89.00% \n50.00% 50.19% 71.05% 88.20% \n100.00% 49.79% 70.65% 86.95% \n\n Samples \nTemperature \n0.00% 5 \n50.00% 18 \n100.00% 18 ", 208 | "text/html": "
\n\n\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n
Correct RateTop Two Correct RateTop Three Correct RateSamples
Temperature
0.00%49.86%71.77%89.00%5
50.00%50.19%71.05%88.20%18
100.00%49.79%70.65%86.95%18
\n
" 209 | }, 210 | "metadata": {}, 211 | "output_type": "display_data" 212 | } 213 | ], 214 | "source": [ 215 | "performance_by_temperature_df = pandas.DataFrame({\n", 216 | " \"Correct Rate\": session_df.groupby(\"Temperature\")[\"Correct\"].mean(),\n", 217 | " \"Top Two Correct Rate\": session_df.groupby(\"Temperature\")[\"Top Two Correct\"].mean(),\n", 218 | " \"Top Three Correct Rate\": session_df.groupby(\"Temperature\")[\"Top Three Correct\"].mean(),\n", 219 | " \"Samples\": session_df.groupby(\"Temperature\")[\"Exam Session ID\"].nunique(),\n", 220 | "})\\\n", 221 | " .sort_values(\"Temperature\", ascending=True)\n", 222 | "\n", 223 | "\n", 224 | "with pandas.option_context(\"float_format\", \"{:.2%}\".format):\n", 225 | " print((100.0 * performance_by_temperature_df).round(2).style.to_latex())\n", 226 | " display(performance_by_temperature_df)" 227 | ] 228 | }, 229 | { 230 | "cell_type": "markdown", 231 | "source": [ 232 | "## Hyperparameters - Best Of" 233 | ], 234 | "metadata": { 235 | "collapsed": false 236 | } 237 | }, 238 | { 239 | "cell_type": "code", 240 | "execution_count": 32, 241 | "id": "5df84467-9190-4d88-951d-f992e517bf66", 242 | "metadata": {}, 243 | "outputs": [ 244 | { 245 | "name": "stdout", 246 | "output_type": "stream", 247 | "text": [ 248 | "\\begin{tabular}{lrrrr}\n", 249 | " & Correct Rate & Top Two Correct Rate & Top Three Correct Rate & Samples \\\\\n", 250 | "Best Of & & & & \\\\\n", 251 | "1 & 49.510000 & 70.590000 & 87.270000 & 1500.000000 \\\\\n", 252 | "2 & 50.270000 & 71.220000 & 88.170000 & 1400.000000 \\\\\n", 253 | "4 & 50.200000 & 71.130000 & 87.840000 & 1200.000000 \\\\\n", 254 | "\\end{tabular}\n", 255 | "\n" 256 | ] 257 | }, 258 | { 259 | "data": { 260 | "text/plain": " Correct Rate Top Two Correct Rate Top Three Correct Rate Samples\nBest Of \n1 49.51% 70.59% 87.27% 15\n2 50.27% 71.22% 88.17% 14\n4 50.20% 71.13% 87.84% 12", 261 | "text/html": "
\n\n\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n
Correct RateTop Two Correct RateTop Three Correct RateSamples
Best Of
149.51%70.59%87.27%15
250.27%71.22%88.17%14
450.20%71.13%87.84%12
\n
" 262 | }, 263 | "metadata": {}, 264 | "output_type": "display_data" 265 | } 266 | ], 267 | "source": [ 268 | "performance_by_bestof_df = pandas.DataFrame({\n", 269 | " \"Correct Rate\": session_df.groupby(\"Best Of\")[\"Correct\"].mean(),\n", 270 | " \"Top Two Correct Rate\": session_df.groupby(\"Best Of\")[\"Top Two Correct\"].mean(),\n", 271 | " \"Top Three Correct Rate\": session_df.groupby(\"Best Of\")[\"Top Three Correct\"].mean(),\n", 272 | " \"Samples\": session_df.groupby(\"Best Of\")[\"Exam Session ID\"].nunique(),\n", 273 | "})\\\n", 274 | " .sort_values(\"Best Of\", ascending=True)\n", 275 | "\n", 276 | "\n", 277 | "with pandas.option_context(\"float_format\", \"{:.2%}\".format):\n", 278 | " print((100.0 * performance_by_bestof_df).round(2).style.to_latex())\n", 279 | " display(performance_by_bestof_df)" 280 | ] 281 | }, 282 | { 283 | "cell_type": "markdown", 284 | "source": [ 285 | "## Hyperparameter Surface" 286 | ], 287 | "metadata": { 288 | "collapsed": false 289 | } 290 | }, 291 | { 292 | "cell_type": "code", 293 | "execution_count": 39, 294 | "outputs": [ 295 | { 296 | "name": "stdout", 297 | "output_type": "stream", 298 | "text": [ 299 | "Correct Rate\n" 300 | ] 301 | }, 302 | { 303 | "data": { 304 | "text/plain": "Best Of 1 2 4\nTemperature \n0.0 0.494418 0.504785 NaN\n0.5 0.500000 0.500000 0.505582\n1.0 0.490431 0.504785 0.498405", 305 | "text/html": "
\n\n\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n
Best Of124
Temperature
0.00.4944180.504785NaN
0.50.5000000.5000000.505582
1.00.4904310.5047850.498405
\n
" 306 | }, 307 | "metadata": {}, 308 | "output_type": "display_data" 309 | }, 310 | { 311 | "name": "stdout", 312 | "output_type": "stream", 313 | "text": [ 314 | "Correct Rate - Standard Error of the Mean\n" 315 | ] 316 | }, 317 | { 318 | "data": { 319 | "text/plain": "Best Of 1 2 4\nTemperature \n0.0 0.019983 0.024484 NaN\n0.5 0.014125 0.014125 0.014124\n1.0 0.014123 0.014125 0.014125", 320 | "text/html": "
\n\n\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n
Best Of124
Temperature
0.00.0199830.024484NaN
0.50.0141250.0141250.014124
1.00.0141230.0141250.014125
\n
" 321 | }, 322 | "metadata": {}, 323 | "output_type": "display_data" 324 | }, 325 | { 326 | "name": "stdout", 327 | "output_type": "stream", 328 | "text": [ 329 | "Top Two Correct Rate\n" 330 | ] 331 | }, 332 | { 333 | "data": { 334 | "text/plain": "Best Of 1 2 4\nTemperature \n0.0 0.716108 0.720096 NaN\n0.5 0.706539 0.712919 0.712121\n1.0 0.700159 0.708931 0.710526", 335 | "text/html": "
\n\n\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n
Best Of124
Temperature
0.00.7161080.720096NaN
0.50.7065390.7129190.712121
1.00.7001590.7089310.710526
\n
" 336 | }, 337 | "metadata": {}, 338 | "output_type": "display_data" 339 | }, 340 | { 341 | "name": "stdout", 342 | "output_type": "stream", 343 | "text": [ 344 | "Top Three Correct Rate\n" 345 | ] 346 | }, 347 | { 348 | "data": { 349 | "text/plain": "Best Of 1 2 4\nTemperature \n0.0 0.888357 0.892344 NaN\n0.5 0.877990 0.887560 0.880383\n1.0 0.859649 0.872408 0.876396", 350 | "text/html": "
\n\n\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n
Best Of124
Temperature
0.00.8883570.892344NaN
0.50.8779900.8875600.880383
1.00.8596490.8724080.876396
\n
" 351 | }, 352 | "metadata": {}, 353 | "output_type": "display_data" 354 | }, 355 | { 356 | "name": "stdout", 357 | "output_type": "stream", 358 | "text": [ 359 | "Samples\n" 360 | ] 361 | }, 362 | { 363 | "data": { 364 | "text/plain": "Best Of 1 2 4\nTemperature \n0.0 3.0 2.0 NaN\n0.5 6.0 6.0 6.0\n1.0 6.0 6.0 6.0", 365 | "text/html": "
\n\n\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n
Best Of124
Temperature
0.03.02.0NaN
0.56.06.06.0
1.06.06.06.0
\n
" 366 | }, 367 | "metadata": {}, 368 | "output_type": "display_data" 369 | } 370 | ], 371 | "source": [ 372 | "performance_by_temp_bestof = pandas.DataFrame({\n", 373 | " \"Correct Rate\": session_df.groupby([\"Temperature\", \"Best Of\"])[\"Correct\"].mean(),\n", 374 | " \"Correct Rate SEM\": session_df.groupby([\"Temperature\", \"Best Of\"])[\"Correct\"].sem(),\n", 375 | " \"Top Two Correct Rate\": session_df.groupby([\"Temperature\", \"Best Of\"])[\"Top Two Correct\"].mean(),\n", 376 | " \"Top Three Correct Rate\": session_df.groupby([\"Temperature\", \"Best Of\"])[\"Top Three Correct\"].mean(),\n", 377 | " \"Samples\": session_df.groupby([\"Temperature\", \"Best Of\"])[\"Exam Session ID\"].nunique(),\n", 378 | "})\n", 379 | "\n", 380 | "print(\"Correct Rate\")\n", 381 | "display(performance_by_temp_bestof[\"Correct Rate\"].unstack())\n", 382 | "\n", 383 | "print(\"Correct Rate - Standard Error of the Mean\")\n", 384 | "display(performance_by_temp_bestof[\"Correct Rate SEM\"].unstack())\n", 385 | "\n", 386 | "print(\"Top Two Correct Rate\")\n", 387 | "display(performance_by_temp_bestof[\"Top Two Correct Rate\"].unstack())\n", 388 | "\n", 389 | "print(\"Top Three Correct Rate\")\n", 390 | "display(performance_by_temp_bestof[\"Top Three Correct Rate\"].unstack())\n", 391 | "\n", 392 | "print(\"Samples\")\n", 393 | "display(performance_by_temp_bestof[\"Samples\"].unstack())" 394 | ], 395 | "metadata": { 396 | "collapsed": false 397 | } 398 | }, 399 | { 400 | "cell_type": "code", 401 | "execution_count": null, 402 | "id": "d2c8fe96-c374-45ed-871c-8bcae5d47f3a", 403 | "metadata": {}, 404 | "outputs": [], 405 | "source": [] 406 | } 407 | ], 408 | "metadata": { 409 | "kernelspec": { 410 | "display_name": "Python 3 (ipykernel)", 411 | "language": "python", 412 | "name": "python3" 413 | }, 414 | "language_info": { 415 | "codemirror_mode": { 416 | "name": "ipython", 417 | "version": 3 418 | }, 419 | "file_extension": ".py", 420 | "mimetype": "text/x-python", 421 | "name": "python", 422 | "nbconvert_exporter": "python", 423 | "pygments_lexer": "ipython3", 424 | "version": "3.10.6" 425 | } 426 | }, 427 | "nbformat": 4, 428 | "nbformat_minor": 5 429 | } 430 | -------------------------------------------------------------------------------- /run_exam.py: -------------------------------------------------------------------------------- 1 | """ 2 | run a bar exam questionnaire where we ask the model to: 3 | 1. rank order its top three choices 4 | """ 5 | 6 | # imports 7 | import datetime 8 | import json 9 | import time 10 | from pathlib import Path 11 | from typing import Iterator 12 | 13 | # packages 14 | import pandas 15 | import openai 16 | import tqdm 17 | 18 | # set the key 19 | openai.api_key = (Path(__file__).parent / ".openai_key").read_text() 20 | 21 | 22 | def generate_prompt(row: dict) -> str: 23 | """Generate a prompt from a row of the question spreadsheet.""" 24 | question_text = row["question_prompt"] 25 | question_text = question_text[(question_text.find(". ") + 1) :].strip() 26 | prompt = f"""Please answer the following Bar Exam question in the following rank order format: 27 | First Choice: 28 | Second Choice: 29 | Third Choice: 30 | 31 | Question: {question_text} 32 | (A) {row["choice_a"].strip()} 33 | (B) {row["choice_b"].strip()} 34 | (C) {row["choice_c"].strip()} 35 | (D) {row["choice_d"].strip()}\nAnswer: """.strip() 36 | 37 | return prompt 38 | 39 | 40 | def get_parameter_sets() -> Iterator[dict]: 41 | """Generate a set of parameter sets.""" 42 | for temperature in [0.0, 0.5, 1.0]: 43 | for max_tokens in [ 44 | 16, 45 | ]: 46 | for top_p in [1, 0.75]: 47 | for best_of in [1, 2, 4]: 48 | for frequency_penalty in [ 49 | 0, 50 | ]: 51 | for presence_penalty in [ 52 | 0, 53 | ]: 54 | yield { 55 | "temperature": temperature, 56 | "max_tokens": max_tokens, 57 | "top_p": top_p, 58 | "best_of": best_of, 59 | "frequency_penalty": frequency_penalty, 60 | "presence_penalty": presence_penalty, 61 | } 62 | 63 | 64 | def get_next_session_path() -> Path: 65 | """Get the next session path.""" 66 | session_number = 1 67 | 68 | while True: 69 | session_id = f"bar-exam-{session_number:03d}" 70 | session_path = Path(__file__).parent / "sessions-008" 71 | session_path.mkdir(exist_ok=True) 72 | session_path = session_path / session_id 73 | 74 | # skip if exists 75 | if session_path.exists(): 76 | session_number += 1 77 | continue 78 | 79 | # otherwise continue 80 | session_path.mkdir(exist_ok=True) 81 | return session_path 82 | 83 | 84 | def main(): 85 | """ 86 | run a bar exam session 87 | """ 88 | 89 | # set samples per value 90 | num_samples_per_set = 3 91 | 92 | # iterate through parameter values 93 | for parameter_kwargs in get_parameter_sets(): 94 | print(f"Running with parameters: {parameter_kwargs}") 95 | for sample_id in range(num_samples_per_set): 96 | # set up the session path iteratively 97 | session_path = get_next_session_path() 98 | 99 | # load the questions 100 | question_df = pandas.read_csv( 101 | Path(__file__).parent.parent / "data" / "questions.csv" 102 | ) 103 | print(f"Loaded {len(question_df)} questions.") 104 | 105 | # generate the prompts 106 | exam_data = { 107 | "parameters": parameter_kwargs, 108 | "start_time": datetime.datetime.now().isoformat(), 109 | "questions": [], 110 | } 111 | for row_id, row in tqdm.tqdm( 112 | question_df.iterrows(), total=question_df.shape[0] 113 | ): 114 | question_exam_data = { 115 | "question_input": row.to_dict(), 116 | "model_prompt": generate_prompt(row.to_dict()), 117 | "model_response": None, 118 | } 119 | 120 | try: 121 | question_exam_data["model_response"] = openai.Completion.create( 122 | model="text-davinci-003", 123 | prompt=question_exam_data["model_prompt"], 124 | **parameter_kwargs, 125 | ) 126 | print(question_exam_data["model_response"]["choices"][0]["text"]) 127 | except Exception as e: 128 | # try once more inside the loop after a brief pause 129 | print( 130 | f"Error while submitting question {row['question_number']}: {e}" 131 | ) 132 | print(f"Pausing and retrying...") 133 | time.sleep(10) 134 | try: 135 | question_exam_data["model_response"] = openai.Completion.create( 136 | model="text-davinci-003", 137 | prompt=question_exam_data["model_prompt"], 138 | **parameter_kwargs, 139 | ) 140 | print( 141 | question_exam_data["model_response"]["choices"][0]["text"] 142 | ) 143 | except Exception as f: 144 | print( 145 | f"Second error while submitting question {row['question_number']}: {e}" 146 | ) 147 | question_exam_data["model_response"] = None 148 | finally: 149 | # log the current state of the exam 150 | exam_data["questions"].append(question_exam_data) 151 | with open( 152 | session_path / "exam_data.json", "wt", encoding="utf-8" 153 | ) as output_file: 154 | json.dump(exam_data, output_file) 155 | # save final state 156 | exam_data["end_time"] = datetime.datetime.now().isoformat() 157 | with open( 158 | session_path / "exam_data.json", "wt", encoding="utf-8" 159 | ) as output_file: 160 | json.dump(exam_data, output_file) 161 | 162 | 163 | if __name__ == "__main__": 164 | main() 165 | -------------------------------------------------------------------------------- /score_exam.py: -------------------------------------------------------------------------------- 1 | """ 2 | read the exam session JSON output and output a CSV file with the question category, number, 3 | selected choice, and explanation if available 4 | """ 5 | 6 | # imports 7 | import datetime 8 | import json 9 | from pathlib import Path 10 | 11 | # packages 12 | import pandas 13 | import tqdm 14 | 15 | 16 | def load_answer_key(answer_key_path: Path) -> pandas.DataFrame: 17 | """ 18 | load a copy of the answer key for comparison 19 | :param answer_key_path: 20 | :return: 21 | """ 22 | answer_key_df = pandas.read_csv(answer_key_path, encoding="utf-8", low_memory=False) 23 | answer_key_df.columns = ["question_category", "question_number", "correct_answer"] 24 | return answer_key_df 25 | 26 | 27 | def parse_gpt_response(response: str) -> dict: 28 | """parse teh gpt response like: 29 | Answer: (C) 30 | Backup Answer: (D) 31 | Explanation: The answer is C because ... 32 | 33 | to return 34 | 35 | { 36 | "answer": "C", 37 | "backup_answer": "D", 38 | "explanation": "The answer is C because ..." 39 | } 40 | """ 41 | response_data = { 42 | "answer": None, 43 | "second_answer": None, 44 | "third_answer": None, 45 | "reason": None, 46 | } 47 | response_lines = response.strip().splitlines() 48 | 49 | for i, line in enumerate(response_lines): 50 | line = line.strip() 51 | 52 | if line.startswith("First Choice"): 53 | response_data["answer"] = ( 54 | line.split() 55 | .pop() 56 | .replace("(", "") 57 | .replace(")", "") 58 | .replace(".", "") 59 | .strip() 60 | ) 61 | elif line.startswith("Second Choice"): 62 | response_data["second_answer"] = ( 63 | line.split() 64 | .pop() 65 | .replace("(", "") 66 | .replace(")", "") 67 | .replace(".", "") 68 | .strip() 69 | ) 70 | elif line.startswith("Third Choice"): 71 | response_data["third_answer"] = ( 72 | line.split() 73 | .pop() 74 | .replace("(", "") 75 | .replace(")", "") 76 | .replace(".", "") 77 | .strip() 78 | ) 79 | 80 | return response_data 81 | 82 | 83 | def get_complete_session_folders() -> list[Path]: 84 | """ 85 | get a list of completed session folders 86 | :return: 87 | """ 88 | session_path = Path(__file__).parent / "sessions-008" 89 | session_list = [] 90 | for session_id in session_path.iterdir(): 91 | if (session_id / "exam_data.json").exists(): 92 | session_list.append(session_id) 93 | return sorted(session_list) 94 | 95 | 96 | if __name__ == "__main__": 97 | # load the answer key 98 | answer_key_df = load_answer_key( 99 | Path(__file__).parent.parent / "data" / "answer_key_category.csv" 100 | ) 101 | 102 | # get the list of completed sessions 103 | session_list = get_complete_session_folders() 104 | exam_session_output = [] 105 | for session_path in tqdm.tqdm(session_list): 106 | session_name = session_path.name 107 | session_file = session_path / "exam_data.json" 108 | if not session_file.exists(): 109 | raise ValueError("Session file does not exist") 110 | session_data = json.loads(session_file.read_text()) 111 | 112 | # get parameters from the session data 113 | session_parameters = session_data["parameters"] 114 | try: 115 | session_duration = ( 116 | datetime.datetime.fromisoformat(session_data["end_time"]) 117 | - datetime.datetime.fromisoformat(session_data["start_time"]) 118 | ).total_seconds() 119 | except: 120 | session_duration = None 121 | 122 | for question in session_data["questions"]: 123 | # get the correct answer first 124 | question_category = question["question_input"]["question_category"] 125 | question_number = question["question_input"]["question_number"] 126 | answer_key_match = answer_key_df.loc[ 127 | (answer_key_df["question_category"] == question_category) 128 | & (answer_key_df["question_number"] == question_number) 129 | ] 130 | if answer_key_match.shape[0] > 1: 131 | raise ValueError( 132 | f"Answer key match is not unique for category={question_category}," 133 | f" number={question_number}" 134 | ) 135 | elif answer_key_match.shape == 0: 136 | raise ValueError( 137 | f"Answer key match is not found for category={question_category}," 138 | f" number={question_number}" 139 | ) 140 | correct_answer = answer_key_match["correct_answer"].values[0] 141 | 142 | if question["model_response"] is not None: 143 | # get the raw response 144 | if len(question["model_response"]["choices"]) != 1: 145 | print( 146 | f"category={question['category']}, number={question['number']} has more than one choice response." 147 | ) 148 | continue 149 | 150 | # get the text and parse it 151 | response_text = question["model_response"]["choices"][0]["text"] 152 | question_response_data = parse_gpt_response(response_text) 153 | 154 | exam_session_output.append( 155 | ( 156 | session_name, 157 | question_category, 158 | question_number, 159 | question_response_data["answer"], 160 | question_response_data["second_answer"], 161 | question_response_data["third_answer"], 162 | correct_answer, 163 | # first, second, and third correct booleans 164 | question_response_data["answer"] == correct_answer, 165 | question_response_data["second_answer"] == correct_answer, 166 | question_response_data["third_answer"] == correct_answer, 167 | # top two correct 168 | (question_response_data["answer"] == correct_answer) 169 | or (question_response_data["second_answer"] == correct_answer), 170 | # top three correct 171 | (question_response_data["answer"] == correct_answer) 172 | or (question_response_data["second_answer"] == correct_answer) 173 | or (question_response_data["third_answer"] == correct_answer), 174 | # add the parameters here 175 | session_parameters["temperature"], 176 | session_parameters["max_tokens"], 177 | session_parameters["top_p"], 178 | session_parameters["best_of"], 179 | session_parameters["frequency_penalty"], 180 | session_parameters["presence_penalty"], 181 | session_duration, 182 | ) 183 | ) 184 | else: 185 | exam_session_output.append( 186 | ( 187 | session_name, 188 | question["question_input"]["question_category"], 189 | question["question_input"]["question_number"], 190 | None, 191 | None, 192 | None, 193 | correct_answer, 194 | False, 195 | False, 196 | False, 197 | False, 198 | False, 199 | session_parameters["temperature"], 200 | session_parameters["max_tokens"], 201 | session_parameters["top_p"], 202 | session_parameters["best_of"], 203 | session_parameters["frequency_penalty"], 204 | session_parameters["presence_penalty"], 205 | session_duration, 206 | ) 207 | ) 208 | 209 | # save the exam session output 210 | exam_session_output_df = pandas.DataFrame( 211 | exam_session_output, 212 | columns=[ 213 | "exam_session", 214 | "category", 215 | "number", 216 | "answer", 217 | "second_answer", 218 | "third_answer", 219 | "correct_answer", 220 | "first_correct", 221 | "second_correct", 222 | "third_correct", 223 | "top_two_correct", 224 | "top_three_correct", 225 | "temperature", 226 | "max_tokens", 227 | "top_p", 228 | "best_of", 229 | "frequency_penalty", 230 | "presence_penalty", 231 | "session_duration", 232 | ], 233 | ) 234 | exam_session_output_df.to_csv( 235 | Path(__file__).parent / "all_exam_summary_008.csv", index=False 236 | ) 237 | --------------------------------------------------------------------------------