├── README.ipynb ├── README.md ├── assets ├── classifier_ft.txt ├── indep-benchmarks-llama.png ├── indep-benchmarks.png ├── judge_template.json ├── llm-router-flowchart_1.png ├── llm-router-flowchart_2.png ├── output_24_2.png ├── output_26_0.png ├── output_51_0.png └── system_ft.txt ├── configs ├── ft_config_a10.yaml ├── ft_config_a100.yaml └── ft_job.yaml ├── requirements.txt └── src ├── __init__.py ├── ft.py ├── offline_inference.py ├── online_inference.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # Building an LLM Router for High-Quality and Cost-Effective Responses 2 | 3 | ## TLDR 4 | 1. We introduce a framework for training state-of-the-art *LLM routers*, systems that dynamically direct queries to either high-quality closed LLMs or cost-effective open-source LLMs, based on query complexity, optimizing both response quality and cost. 5 | 6 | 2. This tutorial provides an in-depth guide on building an LLM router *based on a causal-LLM classifier*, starting with generating labeled data, finetuning an LLM-based classifier with Anyscale's API, and finally running offline evaluations. 7 | 8 | 3. In collaboration with the Berkeley LMSys group, we release an [arXiv paper](https://arxiv.org/pdf/2406.18665) presenting extensive evaluations of this model along with other models. Overall, our LLM Routers can achieve the same performance as our baselines with up to a 70% cost reduction on MT Bench, a 30% cost reduction on MMLU, and a 40% cost reduction on GSM8K. 9 | 10 | # Background 11 | When developing applications using Large Language Models (LLMs), achieving high-quality responses while maintaining a budget is a key challenge. Closed models like GPT-4 provide superior quality but are costly, especially with a high volume of queries. Conversely, Open Source Software (OSS) models are more economical but may not match the quality, especially for complex or domain-specific queries. 12 | 13 | An **LLM Router** helps balance these aspects by deciding which queries are routed to a closed LLM and which to an OSS LLM based on the query's complexity or domain specificity. Below is a schematic representation of an LLM Router: 14 | 15 |
16 | LLM Router 17 |
18 | 19 | Given a set of user queries, an LLM router enables generating high-quality LLM responses while minimizing the overall cost. 20 | 21 | # Approach 22 | 23 | In this tutorial, we'll demonstrate how to train a *causal-LLM classifier* on the Anyscale platform as an effective LLM router. We make the following design choices: 24 | 25 | 1. **Model Choices**: We’ll use GPT-4 as an example of a closed LLM and Mixtral-8x7B as the OSS LLM, so our causal LLM classifier will route between these two models. 26 | 2. **Response Quality Rating**: We'll quantify the quality of an LLM response on a scale of 1 to 5 stars, with higher scores indicating better quality. For simplicity, we'll assume that GPT-4 always achieves a 5-star rating, so it serves as a reference for Mixtral-8x7B. 27 | 3. **Causal LLM Classifier**: We'll finetune a Llama3-8B model as our causal LLM classifier and leverage Anyscale's powerful API. [Our research](https://arxiv.org/pdf/2406.18665) shows that this model offers superior routing performance compared to smaller architectures. 28 | 29 | More concretely, the objective of the causal LLM classifier is to direct "simple" queries to Mixtral-8x7B, thereby maintaining high overall response quality (e.g., an average score of 4.8/5) while significantly reducing costs (e.g., by 50%). 30 | 31 | We show that it's possible to build LLM routers that achieve outstanding performance. Below are results from our best-performing LLM routers, the Causal LLM and a Matrix Factorization (MF) model, evaluated on the [MT Bench benchmark](https://arxiv.org/pdf/2306.05685), which demonstrate that our routers can achieve higher quality with lower costs (i.e., fewer calls to GPT-4) compared to the random baseline and public LLM routing systems from Unify AI and Martian. For more details on these results and additional ones, refer to our paper. 32 | 33 |
34 | Benchmark 1 35 | Benchmark 2 36 |
37 | 38 | 39 | In the following sections, we discuss the steps that enable anyone to build a strong LLM router. 40 | 41 | 42 | 43 | # Table of Contents 44 | 45 | 1. [**Prepare Labeled Data**](#generate-labeled-data): The foundation of a robust LLM router is high-quality labeled data. In this section, we'll guide you through preparing this training data. 46 | 47 | 2. [**Finetune a Router Model**](#finetune-router-model): We demonstrate how to finetune a causal-LLM classifier using Anyscale's finetuning API, transforming it into an effective LLM router. 48 | 49 | 3. [**Offline Evaluation**](#offline-eval): Using the public codebase ([RouteLLM](https://github.com/lm-sys/RouteLLM)), we will walk through an offline evaluation on standard benchmarks. 50 | 51 | **Time to complete**: Approximately 120 minutes, including time to train on a node with 8xA10 GPUs. 52 | 53 | 54 | 55 | ### Setup 56 | 57 | 58 | ```python 59 | # Install required packages 60 | !pip install -r requirements.txt 61 | 62 | # Store your ANYSCALE_API_KEY and OPENAI_API_KEY in /home/ray/default/.env 63 | from dotenv import load_dotenv 64 | load_dotenv("/home/ray/default/.env") 65 | 66 | ``` 67 | 68 | # Step 1: Prepare Labeled Data 69 | 70 | The llm router essentially functions as a binary classifier, deciding whether to route a query to GPT-4 or Mixtral-8x7B based on the query text. Initially, we considered labeled data in the format `(query, routing_label)`, where `routing_label` is 1 if the query should be routed to Mixtral-8x7B and 0 if it should be routed to GPT-4. 71 | 72 | However, our early experiments revealed that *binary labels do not provide sufficient signal for training a robust router model*. Therefore, we adopted a different labeling approach using a *1-5 scoring system*, which reflects how well Mixtral-8x7B can effectively respond to the user's query. More specifically: 73 | 74 | - **4-5**: Mixtral-8x7B produces a very strong answer, showing deep understanding, creativity, detailed insight, and high relevance. 75 | - **3**: Mixtral-8x7B provides an adequate answer with moderate detail, relevance, and factual accuracy. 76 | - **1-2**: Mixtral-8x7B struggles to produce a strong answer due to the question's difficulty, vagueness, or the model's limitations. 77 | 78 | We use labeled samples in the format `(query, score_label)`. The `routing_label` can be derived from the `score_label` by setting a score threshold for quality, i.e. `routing_label = 1 if score_label >= 4 else 0`. 79 | 80 | Next, we'll dive into the detailed process of preparing our labeled dataset. 81 | 82 | 83 | ## 1.1: Query Dataset 84 | 85 | We want our llm router to be effective in open-ended chat domains. So, our first step is to collect a set of generic queries from the [Nectar dataset](https://huggingface.co/datasets/berkeley-nest/Nectar). We chose the Nectar dataset for two reasons: it combines queries from many different domains, including open-ended chat, and it has responses from many models, including over 191K responses from GPT-4. 86 | 87 | 88 | ```python 89 | from src.utils import load_and_display_nectar 90 | 91 | nectar_df = load_and_display_nectar() 92 | ``` 93 | 94 | 95 |
96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 124 | 125 | 126 | 127 | 128 | 129 |
promptanswersturnssourcegood_natured
0\n\nHuman: 0.002 = 1000 \n1 = x?\n\nAssistant: 112 | [
{'answer': 'To find the value of x, we can set up a proportion using the given information: 113 | 0.002/1000 = 1/x 114 | ... x = 500,000', 'model': 'gpt-3.5-turbo', 'rank': 3.0},

115 | {'answer': 'If 0.002 equals 1000, then 1 would equal 500,000.', 'model': 'gpt-4', 'rank': 4.0},

116 | {'answer': 'I apologize, but the equation you provided is not correct. 117 | ...', 'model': 'llama-2-7b-chat', 'rank': 5.0},

118 | {'answer': '0.001 = x', 'model': 'gpt-3.5-turbo-instruct', 'rank': 6.0},

119 | {'answer': 'It seems like you are asking for the value of x in the equation x = 1/0.002. 120 | To solve this equation, you can divide both sides by 0.002 to get: 121 | x = 1/0.002 122 | x = 500 123 | Therefore, x = 500.', 'model': 'mistral-7b-instruct-v0.1', 'rank': 7.0}
]
1[sharegpt]True
130 |
131 | 132 | 133 | 134 | 135 | Number of queries with GPT-4 responses: 191487 136 | 137 | 138 | ## 1.2 Data Preprocessing 139 | 140 | We will use a subset of the Nectar data that includes responses from GPT-4, as these will be used to generate scores (as seen below). We will process this data by focusing on single-turn conversations, filtering for good-natured interactions, and cleaning up the prompts and responses to maintain high quality. Additionally, we will sample a small subset from the dataset for the purpose of this tutorial; however, you can skip sampling to work with the full dataset. 141 | 142 | 143 | ```python 144 | from src.utils import preprocess_nectar 145 | 146 | nectar_gpt4_df = preprocess_nectar( 147 | nectar_df, model="gpt-4", response_column="gpt4_response" 148 | ) 149 | 150 | # Sample a small subset from the dataset for the purpose of this tutorial 151 | N_SUBSET = 30 152 | dataset_df = nectar_gpt4_df.sample(N_SUBSET, random_state=42) 153 | ``` 154 | 155 | ### Dataset overview with GPT-4 responses 156 | 157 | 158 | ```python 159 | display(dataset_df.head()) 160 | ``` 161 | 162 | 163 |
164 | 165 | 166 | 167 | 168 | 169 | 170 | 171 | 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | 181 | 182 | 183 | 184 | 185 | 186 | 187 | 188 | 189 | 190 | 191 | 192 | 193 | 194 | 195 | 196 | 197 | 198 | 199 | 200 | 201 | 202 | 203 | 204 | 205 |
promptsourcegpt4_response
6062Based on the features mentioned, which hotel d...[evol_instruct]Based on the features mentioned, Hotel A seems...
113830Provide step-by-step instructions on how to cr...[ultrachat]Sure, here's a simple step-by-step guide on ho...
138869What are the 10 largest cities in the US by po...[lmsys-chat-1m]As of the most recent data available, the 10 l...
169249Write a comparison essay of at least 500 words...[ultrachat]Title: A Comparative Analysis of Driving a Car...
116934Q: You are provided with an "Event", "Intent" ...[flan_v2_niv2]PersonX might feel satisfied or content using ...
206 |
207 | 208 | 209 | ## 1.3 Data Labeling 210 | 211 | We don't have human labels for scores, so we will use the [LLM-as-a-Judge approach](https://arxiv.org/abs/2306.05685). GPT-4 will act as an evaluator, reviewing the query and Mixtral's response to provide a score from 1-5. As shown in the paper, the most robust way to get labels is by providing a reference answer for comparison. Here, GPT-4's own response serves as the reference, and Mixtral's response is evaluated against it. 212 | 213 | There are two main steps in this process: 214 | 1. **Generate Mixtral-8x7B responses for all queries**: We will use an online batch-inference method utilizing Ray and Anyscale endpoints. 215 | 2. **Generate LLM-as-a-Judge labels**: We will ask GPT-4 to evaluate the Mixtral responses against its own reference answers and provide a score from 1-5. 216 | 217 | ### Generate Mixtral-8x7B Responses 218 | 219 | 220 | ```python 221 | import os 222 | from src.online_inference import generate_mixtral_responses 223 | 224 | dataset_df = generate_mixtral_responses( 225 | dataset_df, os.getenv("ANYSCALE_API_KEY"), response_column="mixtral_response" 226 | ) 227 | ``` 228 | Starting batch inference on 30 queries... 229 | 230 | # queries un-processed: 29, in-progress: 1, ready: 0 231 | # queries un-processed: 28, in-progress: 2, ready: 0 232 | # queries un-processed: 27, in-progress: 3, ready: 0 233 | # queries un-processed: 26, in-progress: 4, ready: 0 234 | # queries un-processed: 25, in-progress: 5, ready: 0 235 | # queries un-processed: 24, in-progress: 6, ready: 0 236 | # queries un-processed: 23, in-progress: 7, ready: 0 237 | # queries un-processed: 22, in-progress: 8, ready: 0 238 | # queries un-processed: 21, in-progress: 9, ready: 0 239 | # queries un-processed: 20, in-progress: 10, ready: 0 240 | # queries un-processed: 19, in-progress: 11, ready: 0 241 | # queries un-processed: 18, in-progress: 12, ready: 0 242 | # queries un-processed: 17, in-progress: 13, ready: 0 243 | # queries un-processed: 16, in-progress: 14, ready: 0 244 | # queries un-processed: 15, in-progress: 15, ready: 0 245 | # queries un-processed: 14, in-progress: 16, ready: 0 246 | # queries un-processed: 13, in-progress: 17, ready: 0 247 | # queries un-processed: 12, in-progress: 18, ready: 0 248 | # queries un-processed: 11, in-progress: 18, ready: 1 249 | # queries un-processed: 10, in-progress: 18, ready: 1 250 | # queries un-processed: 9, in-progress: 18, ready: 1 251 | # queries un-processed: 8, in-progress: 18, ready: 1 252 | # queries un-processed: 7, in-progress: 18, ready: 1 253 | # queries un-processed: 6, in-progress: 19, ready: 0 254 | ... 255 | Done in 19.21sec. 256 | 257 | ### Dataset overview with Mixtral responses 258 | 259 | 260 | 261 | ```python 262 | display(dataset_df.head()) 263 | ``` 264 | 265 | 266 |
267 | 268 | 269 | 270 | 271 | 272 | 273 | 274 | 275 | 276 | 277 | 278 | 279 | 280 | 281 | 282 | 283 | 284 | 285 | 286 | 287 | 288 | 289 | 290 | 291 | 292 | 293 | 294 | 295 | 296 | 297 | 298 | 299 | 300 | 301 | 302 | 303 | 304 | 305 | 306 | 307 | 308 | 309 | 310 | 311 | 312 | 313 | 314 |
promptsourcegpt4_responsemixtral_response
6062Based on the features mentioned, which hotel d...[evol_instruct]Based on the features mentioned, Hotel A seems...Based on the information provided, I would sa...
113830Provide step-by-step instructions on how to cr...[ultrachat]Sure, here's a simple step-by-step guide on ho...Sure, I'd be happy to help you make a homemad...
138869What are the 10 largest cities in the US by po...[lmsys-chat-1m]As of the most recent data available, the 10 l...Sure, I'd be happy to help with that! Here ar...
169249Write a comparison essay of at least 500 words...[ultrachat]Title: A Comparative Analysis of Driving a Car...Title: The Great Debate: Driving a Car vs. Ri...
116934Q: You are provided with an "Event", "Intent" ...[flan_v2_niv2]PersonX might feel satisfied or content using ...PersonX probably feels comfortable and focuse...
315 |
316 | 317 | 318 | ### Generate GPT-4-as-a-judge scores 319 | 320 | Let's first take a look at an example query we will send to GPT-4 for judgement 321 | 322 | 323 | ```python 324 | from src.utils import inspect_llm_judge_queries 325 | 326 | inspect_llm_judge_queries(dataset_df) 327 | ``` 328 | 329 | [Instruction] 330 | Evaluate the AI assistant's proficiency in answering the user question displayed below. Your evaluation should consider factors such as the helpfulness, relevance, adherence to real-world facts, depth, creativity, and level of detail of the response. You will be given a reference answer which is considered of high quality. Your assessment will have two lines: First line has a rating on a scale of 1 to 5 with a higher rating representing higher response quality. Follow strictly this format: "[[rating]]", for example: "[[3]]". Second line contains a short explanation of your rating. 331 | 332 | [Question] 333 | Q: You are provided with an "Event", "Intent" related to PersonX. Guess a reaction/reaction of PersonX about the given event and their intention. 334 | Event:PersonX uses ___ in class. Intent: 1) to use his prefered writing implement 335 | A: 336 | 337 | [Reference Answer] 338 | PersonX might feel satisfied or content using their preferred writing implement in class, as it aligns with their intention to utilize a comfortable and desired tool for writing. 339 | Confidence: 85% 340 | 341 | [Assistant Answer] 342 | PersonX probably feels comfortable and focused in class, as they are using their preferred writing implement. This may help them engage more effectively with the material being taught. 343 | 344 | Guidelines for Rating: 345 | - High Rating (4-5): Reserved for responses that are very close to the quality of the reference or even better. 346 | - Medium Rating (3): Reserved for responses that have moderate quality compared to the reference. 347 | - Low Rating (1-2): Allocated to response that are much lower quality compared to the reference or completely wrong. 348 | 349 | Assessment: 350 | 351 | 352 | 353 | Now, we apply a similar online batch-inference method to generate our labels. 354 | 355 | 356 | ```python 357 | import os 358 | from src.online_inference import generate_llm_judge_labels 359 | 360 | dataset_df = generate_llm_judge_labels(dataset_df, os.getenv('OPENAI_API_KEY')) 361 | ``` 362 | 363 | Starting batch inference on 30 queries... 364 | 365 | # queries un-processed: 29, in-progress: 1, ready: 0 366 | # queries un-processed: 28, in-progress: 2, ready: 0 367 | # queries un-processed: 27, in-progress: 3, ready: 0 368 | # queries un-processed: 26, in-progress: 4, ready: 0 369 | # queries un-processed: 25, in-progress: 5, ready: 0 370 | # queries un-processed: 24, in-progress: 6, ready: 0 371 | # queries un-processed: 23, in-progress: 7, ready: 0 372 | # queries un-processed: 22, in-progress: 7, ready: 1 373 | # queries un-processed: 21, in-progress: 7, ready: 1 374 | # queries un-processed: 20, in-progress: 8, ready: 0 375 | # queries un-processed: 19, in-progress: 8, ready: 1 376 | # queries un-processed: 18, in-progress: 9, ready: 0 377 | # queries un-processed: 17, in-progress: 10, ready: 0 378 | # queries un-processed: 17, in-progress: 9, ready: 1 379 | # queries un-processed: 16, in-progress: 9, ready: 1 380 | # queries un-processed: 15, in-progress: 9, ready: 1 381 | ... 382 | Done in 16.43sec. 383 | 384 | 385 | ### Dataset overview with score labels 386 | 387 | 388 | 389 | ```python 390 | display(dataset_df.head()) 391 | ``` 392 | 393 | 394 |
395 | 396 | 397 | 398 | 399 | 400 | 401 | 402 | 403 | 404 | 405 | 406 | 407 | 408 | 409 | 410 | 411 | 412 | 413 | 414 | 415 | 416 | 417 | 418 | 419 | 420 | 421 | 422 | 423 | 424 | 425 | 426 | 427 | 428 | 429 | 430 | 431 | 432 | 433 | 434 | 435 | 436 | 437 | 438 | 439 | 440 | 441 | 442 | 443 | 444 | 445 | 446 | 447 | 448 |
promptsourcegpt4_responsemixtral_responsemixtral_score
6062Based on the features mentioned, which hotel d...[evol_instruct]Based on the features mentioned, Hotel A seems...Based on the information provided, I would sa...5
113830Provide step-by-step instructions on how to cr...[ultrachat]Sure, here's a simple step-by-step guide on ho...Sure, I'd be happy to help you make a homemad...3
138869What are the 10 largest cities in the US by po...[lmsys-chat-1m]As of the most recent data available, the 10 l...Sure, I'd be happy to help with that! Here ar...5
169249Write a comparison essay of at least 500 words...[ultrachat]Title: A Comparative Analysis of Driving a Car...Title: The Great Debate: Driving a Car vs. Ri...4
116934Q: You are provided with an "Event", "Intent" ...[flan_v2_niv2]PersonX might feel satisfied or content using ...PersonX probably feels comfortable and focuse...5
449 |
450 | 451 | 452 | ### Full Dataset 453 | We have previously generated the full labeled datasets, created a train and validation splits, and published them as a public huggingface dataset `routellm/gpt4_dataset`. Let's load the dataset and explore the score distribution. 454 | 455 | 456 | 457 | ```python 458 | from datasets import load_dataset 459 | from src.utils import visualize_label_distribution 460 | 461 | full_dataset_df = load_dataset("routellm/gpt4_dataset") 462 | train_df = full_dataset_df["train"].to_pandas() 463 | 464 | print(f"Train size: {len(train_df)}") 465 | display(train_df.head()) 466 | visualize_label_distribution(train_df, key="mixtral_score") 467 | ``` 468 | 469 | Train size: 109101 470 | 471 | 472 | 473 |
474 | 475 | 476 | 477 | 478 | 479 | 480 | 481 | 482 | 483 | 484 | 485 | 486 | 487 | 488 | 489 | 490 | 491 | 492 | 493 | 494 | 495 | 496 | 497 | 498 | 499 | 500 | 501 | 502 | 503 | 504 | 505 | 506 | 507 | 508 | 509 | 510 | 511 | 512 | 513 | 514 | 515 | 516 | 517 | 518 | 519 | 520 | 521 | 522 | 523 | 524 | 525 | 526 | 527 |
promptsourcegpt4_responsemixtral_responsemixtral_score
0I'll give you a review, can you extract the fo...[lmsys-chat-1m]Sure, here's the analysis of the review:\n\n1....Food aspects and opinion words:\n\n1. Made to ...4
1Answer the following question: Claim: "Joker m...[flan_v2_cot]The answer is no.\nChain of thoughts: Stan Lee...The answer is no.\n\nChain of thoughts: While ...5
2TASK DEFINITION: In this task you will be give...[flan_v2_niv2]ZdoublexpropheciesSZdoublexpropheciesS\n\nIn this task, you are a...5
3Detailed Instructions: In this task, you need ...[flan_v2_niv2]YesNo, 'station' is not the longest word in the s...5
4A guy pick up a woman Then he puts her down Ex...[sharegpt]This phrase could be interpreted as a joke bec...This joke is a play on words and relies on the...5
528 |
529 | 530 | 531 |
532 | score distribution 533 |
534 | 535 | 536 | 537 | Higher counts for 4-5 scores indicate that Mixtral-8x7B consistently produces high-quality responses, demonstrating its competitive performance compared to the June 2023 version of GPT-4, whose responses are logged in the Nectar dataset. 538 | 539 | Let us assume that if the score is >= 4, we will route to the OSS model (indicating the response quality is good enough); otherwise, we will route to the closed model. Under this assumption, the data distribution looks like this: 540 | 541 | 542 | 543 | ```python 544 | train_df["routing_label"] = train_df["mixtral_score"].apply( 545 | lambda x: 1 if x >= 4 else 0 546 | ) 547 | 548 | visualize_label_distribution(train_df, key="routing_label") 549 | ``` 550 | 551 | 552 |
553 | routing label distribution 554 |
555 | 556 | 557 | # Step 2: Finetune a router model 558 | 559 | In this section, we will explain how to finetune a causal LLM classifier to be an effective router. While our data contains `gpt4_response` and `mixtral_response`, we will only use the pair (`query`, `mixtral_score`) for training. The goal is for the router to rely solely on the query text to determine which model to route to. Our approach is straightforward: we train a 5-way classifier to predict the `mixtral_score` from the `query`. At inference time, we will route to Mixtral if our router predicts a high score (i.e., 4-5) and to GPT-4 otherwise. 560 | 561 | 562 | ## 2.1 Data Preparation 563 | We will discuss a few preprocessing steps to prepare the data for finetuning an LLM classifier. 564 | 565 | ### Task Instructions 566 | We use the instruction-following framework to finetune an LLM as a router. The task instructions guide the model to predict the score label for a given query. They ensure the model understands the evaluation criteria and can accurately assess the query's complexity and expected response quality. 567 | 568 | 569 | ```python 570 | from src.utils import inspect_instructions 571 | 572 | inspect_instructions() 573 | ``` 574 | 575 | [Instruction] 576 | Based on the question provided below, predict the score an expert evaluator would give to an AI assistant's response, considering its helpfulness, relevance, adherence to facts, depth, creativity, and detail. Your prediction should infer the level of proficiency needed to address the question effectively. Use a scale from 1 to 5, where a higher score indicates a higher anticipated quality of response. Provide your prediction as: "[[predicted rating]]". 577 | 578 | Score criteria: 579 | - **4-5**: The AI assistant can produce a very strong answer, showing deep understanding, creativity, detailed insight, and high relevance. 580 | - **3**: The AI assistant can provide an adequate answer with moderate detail, relevance, and factual accuracy. 581 | - **1-2**: The AI assistant will struggle to produce a strong answer due to the question's difficulty, vagueness, or the assistant's limitations. 582 | 583 | [Question] 584 | {question} 585 | 586 | Prediction: 587 | 588 | 589 | 590 | ### API Data Format 591 | 592 | To finetune the model, we must format the data to be compatible with [Anyscale's finetuning API](https://docs.anyscale.com/endpoints/fine-tuning/dataset-prep). 593 | 594 | 595 | 596 | ```python 597 | from src.utils import prepare_ft_messages 598 | 599 | train_df["messages"] = prepare_ft_messages(train_df, "mixtral_score") 600 | 601 | # here's what the API data format looks like: 602 | display(train_df["messages"].iloc[0]) 603 | ``` 604 | 605 | 606 | [{'role': 'system', 607 | 'content': '[Instruction]\nBased on the question provided below, predict the score an expert evaluator would give to an AI assistant\'s response, considering its helpfulness, relevance, adherence to facts, depth, creativity, and detail. Your prediction should infer the level of proficiency needed to address the question effectively. Use a scale from 1 to 5, where a higher score indicates a higher anticipated quality of response. Provide your prediction as: "[[predicted rating]]".\n\nScore criteria:\n- **4-5**: The AI assistant can produce a very strong answer, showing deep understanding, creativity, detailed insight, and high relevance.\n- **3**: The AI assistant can provide an adequate answer with moderate detail, relevance, and factual accuracy.\n- **1-2**: The AI assistant will struggle to produce a strong answer due to the question\'s difficulty, vagueness, or the assistant\'s limitations.\n'}, 608 | {'role': 'user', 609 | 'content': "[Question]\nI'll give you a review, can you extract the food aspects and the opinion words of these aspects and analyze the sentiment of these opinion from this review? the review is:They tore the old NAME_1 down then built another one...? Anyway, they sell wine and beer and snacks and have a seating area inside and outside to eat. Besides gas, the big draw is the Made to Order food. I ordered some tacos and French toast sticks both were pretty good. I think I'd like to try more snacks.And they're open 24/7.\n\nPrediction:\n"}, 610 | {'role': 'assistant', 'content': '[[4]]'}] 611 | 612 | 613 | ### Label Rebalancing 614 | 615 | For classification tasks, it's recommended to train on label-balanced datasets to ensure models are not biased to a specific label. We will balance the dataset based on `routing_label`, as this is the label of primary interest. 616 | 617 | 618 | 619 | ```python 620 | from src.utils import balance_dataset 621 | 622 | balanced_train_df = balance_dataset(train_df, key="routing_label") 623 | 624 | print(f"Train size: {len(balanced_train_df)}") 625 | ``` 626 | 627 | Train size: 29504 628 | 629 | 630 | ### Subsample and Store Data 631 | 632 | To expedite the time to run this tutorial, we will subsample 1,000 examples for training. We'll store the data in JSONL format to prepare for launching the finetuning job in the next section. 633 | 634 | 635 | ```python 636 | n_sample = 1000 637 | output_file = "/mnt/user_storage/train_data_sample.jsonl" 638 | 639 | subsampled_df = balanced_train_df.sample(n=n_sample, random_state=42) 640 | subsampled_df.to_json(output_file, orient="records", lines=True) 641 | ``` 642 | 643 | ## 2.2 Fine-tune with Anyscale API 644 | 645 | We will run a fine-tuning job using Anyscale's LLM finetuning API as an isolated job, similar to our [end-to-end LLM workflows guide](https://github.com/anyscale/e2e-llm-workflows?tab=readme-ov-file#fine-tuning-1). 646 | 647 | For this tutorial, we will perform full-parameter finetuning of Llama3-8B on the same 1,000 samples we showed earlier to debug the training dynamics and ensure the model can fit the training set. Below, we present the training and job configurations before submitting the training job. 648 | 649 | 650 | 651 | ```python 652 | # View the full-param finetuning configuration for llama-3-8B 653 | !cat configs/ft_config_a10.yaml 654 | ``` 655 | 656 | model_id: meta-llama/Meta-Llama-3-8B 657 | train_path: /mnt/user_storage/train_data_sample.jsonl 658 | valid_path: /mnt/user_storage/train_data_sample.jsonl 659 | context_length: 1024 660 | num_devices: 8 661 | num_epochs: 5 662 | checkpoint_every_n_epochs: 5 663 | train_batch_size_per_device: 4 664 | eval_batch_size_per_device: 4 665 | lr_scheduler_type: constant 666 | learning_rate: 1e-5 667 | num_checkpoints_to_keep: 1 668 | no_gradient_checkpoint: False 669 | output_dir: /mnt/local_storage 670 | deepspeed: 671 | config_path: config_files/deepspeed/zero_3_optimizer_parameter_offload.json 672 | flash_attention_2: true 673 | classifier_config: 674 | label_tokens: 675 | - "[[1]]" 676 | - "[[2]]" 677 | - "[[3]]" 678 | - "[[4]]" 679 | - "[[5]]" 680 | 681 | 682 | 683 | ```python 684 | # View job yaml config 685 | !cat configs/ft_job.yaml 686 | ``` 687 | ``` 688 | name: llm-router-tutorial 689 | entrypoint: python src/ft.py configs/ft_config_a10.yaml 690 | image_uri: localhost:5555/anyscale/llm-forge:0.5.0.0 691 | requirements: requirements.txt 692 | max_retries: 0 693 | ``` 694 | 695 | 696 | ```python 697 | # Job submission 698 | !anyscale job submit --config-file configs/ft_job.yaml --exclude assets 699 | ``` 700 | ``` 701 | Output 702 | (anyscale +1.0s) Submitting job with config JobConfig(name='llm-router-tutorial', image_uri='localhost:5555/anyscale/llm-forge:0.5.0.0', compute_config=None, env_vars=None, py_modules=None, cloud=None, project=None, ray_version=None). 703 | (anyscale +2.5s) Uploading local dir '.' to cloud storage. 704 | (anyscale +3.5s) Job 'llm-router-tutorial' submitted, ID: 'prodjob_16krca7sgdjyeh2eyf81h6q9uf'. 705 | (anyscale +3.5s) View the job in the UI: https://console.anyscale.com/jobs/prodjob_16krca7sgdjyeh2eyf81h6q9uf 706 | (anyscale +3.5s) Use `--wait` to wait for the job to run and stream logs. 707 | ``` 708 | The job takes around 10 minutes on `4xA100-80gb` and 1 hour on `8xA10-22gb` to finish. Training logs will show the final model checkpoint, e.g.: 709 | 710 | ``` 711 | Best checkpoint is stored in: 712 | storage-bucket-cld-tffbxe9ia5phqr1unxhz4f7e1e/org_4snvy99zwbmh4gbtk64jfqggmj/cld_tffbxe9ia5phqr1unxhz4f7e1e/artifact_storage/amjad__almahairi_dkaubsimoyxpiksqxqkxrfgfvzzotwtacs/llmforge-finetuning/meta-llama/Meta-Llama-3-8B/TorchTrainer_2024-06-21_17-02-52/epoch-4 713 | With perplexity: 1.0318867739521242 714 | ``` 715 | This checkpoint can be used to run batch inference or serve the model online. 716 | 717 | # Step 3: Offline Evaluation 718 | 719 | Next, we will conduct an offline evaluation of the model trained on an out-of-domain dataset. The same model, now trained on the full dataset, is available in the following GitHub repository: [https://github.com/lm-sys/RouteLLM/](https://github.com/lm-sys/RouteLLM/), along with other router models. 720 | 721 | 722 | ### Install `RouteLLM` package 723 | 724 | 725 | ```python 726 | # Clone the repository under /home/ray/default/ 727 | !git clone https://github.com/lm-sys/RouteLLM.git /home/ray/default/RouteLLM 728 | 729 | # Change to the cloned repository directory 730 | %cd /home/ray/default/RouteLLM 731 | 732 | # Install the package with the specified extras 733 | !pip install -e .[eval] 734 | ``` 735 | ``` 736 | ... 737 | Successfully installed routellm-0.0.1 738 | ``` 739 | ### Inference Example 740 | Let's show an example of loading the model and running inference with a single example sampled from our data. Note that you need to get access to `meta-llama/Meta-Llama-3-8B` in order to run these evaluations. Let's first show how a formatted input looks like. 741 | 742 | 743 | ```python 744 | # Store your `meta-llama` access token in /home/ray/default/.env with the name LLAMA2_HF_TOKEN 745 | from dotenv import load_dotenv 746 | load_dotenv("/home/ray/default/.env") 747 | 748 | from pprint import pprint 749 | 750 | # Sample one row from the DataFrame 751 | sampled_row = train_df.sample(n=1, random_state=42) 752 | 753 | # Convert the sampled row to a dictionary without the index 754 | input_example = sampled_row.to_dict(orient='records')[0] 755 | 756 | print("Prompt:", input_example['prompt']) 757 | print("Label:", input_example['mixtral_score']) 758 | print("Messages:") 759 | pprint(input_example['messages']) 760 | ``` 761 | ``` 762 | Prompt: What challenges did FDR face while in office 763 | Label: 5 764 | Messages: 765 | [{'content': '[Instruction]\n' 766 | 'Based on the question provided below, predict the score an ' 767 | "expert evaluator would give to an AI assistant's response, " 768 | 'considering its helpfulness, relevance, adherence to facts, ' 769 | 'depth, creativity, and detail. Your prediction should infer the ' 770 | 'level of proficiency needed to address the question effectively. ' 771 | 'Use a scale from 1 to 5, where a higher score indicates a higher ' 772 | 'anticipated quality of response. Provide your prediction as: ' 773 | '"[[predicted rating]]".\n' 774 | '\n' 775 | 'Score criteria:\n' 776 | '- **4-5**: The AI assistant can produce a very strong answer, ' 777 | 'showing deep understanding, creativity, detailed insight, and ' 778 | 'high relevance.\n' 779 | '- **3**: The AI assistant can provide an adequate answer with ' 780 | 'moderate detail, relevance, and factual accuracy.\n' 781 | '- **1-2**: The AI assistant will struggle to produce a strong ' 782 | "answer due to the question's difficulty, vagueness, or the " 783 | "assistant's limitations.\n", 784 | 'role': 'system'}, 785 | {'content': '[Question]\n' 786 | 'What challenges did FDR face while in office\n' 787 | '\n' 788 | 'Prediction:\n', 789 | 'role': 'user'}, 790 | {'content': '[[5]]', 'role': 'assistant'}] 791 | ``` 792 | 793 | Let's run inference with this example and examine the model's output. 794 | 795 | 796 | ```python 797 | from src.offline_inference import single_example_inference 798 | 799 | result = single_example_inference(input_example) 800 | pprint(result) 801 | ``` 802 | 803 | Loading model checkpoint from routellm/causal_llm_gpt4_augmented ... 804 | 805 | 806 | Loading checkpoint shards: 100%|██████████| 4/4 [00:02<00:00, 1.76it/s] 807 | Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained. 808 | 809 | 810 | Done loading model in 5.628355264663696 seconds. 811 | {'binary_prob': 0.9662781, 812 | 'output_ids': tensor([128006, 78191, 128007, 271, 128260, 128009]), 813 | 'output_str': '<|start_header_id|>assistant<|end_header_id|>\n' 814 | '\n' 815 | '[[5]]<|eot_id|>', 816 | 'output_tokens': ['<|start_header_id|>', 817 | 'assistant', 818 | '<|end_header_id|>', 819 | 'ĊĊ', 820 | '[[5]]', 821 | '<|eot_id|>'], 822 | 'score_logits': array([10.3125, 10.9375, 11.4375, 14.4375, 15. ], dtype=float32), 823 | 'score_pred': 5, 824 | 'softmax_scores': array([0.00566901, 0.0105911 , 0.01746178, 0.3507292 , 0.6155489 ], 825 | dtype=float32)} 826 | 827 | 828 | The model outputs the predicted score as a special token`[[5]]`, since it is trained to predict one of the 5 labels which we add as special tokens to the vocabulary. We extract softmax scores of each of 5 labels in `softmax_scores`, and compute the routing probability as `binary_prob = sum(softmax_scores[3:])`. 829 | 830 | To optimize inference speed, we can append the header tokens `<|start_header_id|>assistant<|end_header_id|>\n\n` so the first token that the model outputs is the predicted label. 831 | 832 | ### Benchmark Evaluation 833 | We will use the RouteLLM evaluation framework to measure the performance of our router against a random router on GSM8K. 834 | We report the percentage of calls the router needs to send to GPT-4 in order to achieve `20%`, `50%` and `80%` of GPT-4 performance, along with area under curve. 835 | See our [paper](https://arxiv.org/pdf/2406.18665) for more details on the evalaution metrics. 836 | 837 | 838 | ```python 839 | !python -m routellm.evals.evaluate --config config.example.yaml --routers random causal_llm --benchmark gsm8k 840 | ``` 841 | 842 | Namespace(routers=['random', 'causal_llm'], benchmark='gsm8k', output='.', overwrite_cache=[], parallel=96, config='config.example.yaml', num_results=10) 843 | ... 844 | 845 | Loading model checkpoint from routellm/causal_llm_augmented ... 846 | Loading checkpoint shards: 100%|██████████████████| 4/4 [00:01<00:00, 2.00it/s] 847 | ... 848 | 100%|███████████████████████████████████████| 1307/1307 [06:31<00:00, 3.34it/s] 849 | ... 850 | mistralai/Mixtral-8x7B-Instruct-v0.1 63.733741392501905 851 | gpt-4-1106-preview 85.76893649579189 852 | Saving plot to ./gsm8k.png 853 | 854 | Metrics: 855 | method 20% qual 50% qual 80% qual AUC APGR 856 | 1 causal_llm 11.75% 34.06% 62.38% 77.540277 0.626567 857 | 0 random 19.69% 53.05% 83.02% 74.436777 0.485725 858 | 859 | 860 | 861 | ```python 862 | from IPython.display import Image, display 863 | 864 | # Display full plot saved in the following path 865 | image_path = "/home/ray/default/RouteLLM/gsm8k.png" 866 | display(Image(filename=image_path)) 867 | ``` 868 | 869 |
870 | GSM8K Results 871 |
872 | 873 | 874 | 875 | 876 | This plot illustrates that as we relax the cost constraints (i.e., increase the percentage of GPT-4 calls), the performance improves. While the performance of a random router improves linearly with cost, our router achieves significantly better results at each cost level. 877 | 878 | # Conclusion 879 | In this tutorial, we have successfully built and evaluated a finetuned-LLM router. We generated synthetic labeled data using the LLM-as-a-judge method to train the model, finetuned an LLM classifier using Anyscale's API, and conducted offline evaluation on a standard benchmark-- demonstrating that our model is effective in out-of-domain generalization. 880 | -------------------------------------------------------------------------------- /assets/classifier_ft.txt: -------------------------------------------------------------------------------- 1 | [Question] 2 | {question} 3 | 4 | Prediction: 5 | -------------------------------------------------------------------------------- /assets/indep-benchmarks-llama.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anyscale/llm-router/eca98e0f14f6d32445ab4f2389c7243d656acf89/assets/indep-benchmarks-llama.png -------------------------------------------------------------------------------- /assets/indep-benchmarks.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anyscale/llm-router/eca98e0f14f6d32445ab4f2389c7243d656acf89/assets/indep-benchmarks.png -------------------------------------------------------------------------------- /assets/judge_template.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "single-score-with-ref", 3 | "prompt_template": "[Instruction]\n{instruction}\n\n[Question]\n{question}\n\n[Reference Answer]\n{ref_answer_1}\n\n[Assistant Answer]\n{answer}\n\nGuidelines for Rating:\n - High Rating (4-5): Reserved for responses that are very close to the quality of the reference or even better.\n - Medium Rating (3): Reserved for responses that have moderate quality compared to the reference.\n - Low Rating (1-2): Allocated to response that are much lower quality compared to the reference or completely wrong.\n\nAssessment:\n", 4 | "instruction": "Evaluate the AI assistant's proficiency in answering the user question displayed below. Your evaluation should consider factors such as the helpfulness, relevance, adherence to real-world facts, depth, creativity, and level of detail of the response. You will be given a reference answer which is considered of high quality. Your assessment will have two lines: First line has a rating on a scale of 1 to 5 with a higher rating representing higher response quality. Follow strictly this format: \"[[rating]]\", for example: \"[[3]]\". Second line contains a short explanation of your rating." 5 | } 6 | -------------------------------------------------------------------------------- /assets/llm-router-flowchart_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anyscale/llm-router/eca98e0f14f6d32445ab4f2389c7243d656acf89/assets/llm-router-flowchart_1.png -------------------------------------------------------------------------------- /assets/llm-router-flowchart_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anyscale/llm-router/eca98e0f14f6d32445ab4f2389c7243d656acf89/assets/llm-router-flowchart_2.png -------------------------------------------------------------------------------- /assets/output_24_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anyscale/llm-router/eca98e0f14f6d32445ab4f2389c7243d656acf89/assets/output_24_2.png -------------------------------------------------------------------------------- /assets/output_26_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anyscale/llm-router/eca98e0f14f6d32445ab4f2389c7243d656acf89/assets/output_26_0.png -------------------------------------------------------------------------------- /assets/output_51_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anyscale/llm-router/eca98e0f14f6d32445ab4f2389c7243d656acf89/assets/output_51_0.png -------------------------------------------------------------------------------- /assets/system_ft.txt: -------------------------------------------------------------------------------- 1 | [Instruction] 2 | Based on the question provided below, predict the score an expert evaluator would give to an AI assistant's response, considering its helpfulness, relevance, adherence to facts, depth, creativity, and detail. Your prediction should infer the level of proficiency needed to address the question effectively. Use a scale from 1 to 5, where a higher score indicates a higher anticipated quality of response. Provide your prediction as: "[[predicted rating]]". 3 | 4 | Score criteria: 5 | - **4-5**: The AI assistant can produce a very strong answer, showing deep understanding, creativity, detailed insight, and high relevance. 6 | - **3**: The AI assistant can provide an adequate answer with moderate detail, relevance, and factual accuracy. 7 | - **1-2**: The AI assistant will struggle to produce a strong answer due to the question's difficulty, vagueness, or the assistant's limitations. 8 | -------------------------------------------------------------------------------- /configs/ft_config_a10.yaml: -------------------------------------------------------------------------------- 1 | model_id: meta-llama/Meta-Llama-3-8B 2 | train_path: /mnt/user_storage/train_data_sample.jsonl 3 | valid_path: /mnt/user_storage/train_data_sample.jsonl 4 | context_length: 1024 5 | num_devices: 8 6 | num_epochs: 5 7 | checkpoint_every_n_epochs: 5 8 | train_batch_size_per_device: 4 9 | eval_batch_size_per_device: 4 10 | lr_scheduler_type: constant 11 | learning_rate: 1e-5 12 | num_checkpoints_to_keep: 1 13 | no_gradient_checkpoint: False 14 | output_dir: /mnt/local_storage 15 | deepspeed: 16 | config_path: config_files/deepspeed/zero_3_optimizer_parameter_offload.json 17 | flash_attention_2: true 18 | classifier_config: 19 | label_tokens: 20 | - "[[1]]" 21 | - "[[2]]" 22 | - "[[3]]" 23 | - "[[4]]" 24 | - "[[5]]" 25 | -------------------------------------------------------------------------------- /configs/ft_config_a100.yaml: -------------------------------------------------------------------------------- 1 | model_id: meta-llama/Meta-Llama-3-8B 2 | train_path: /mnt/user_storage/train_data_sample.jsonl 3 | valid_path: /mnt/user_storage/train_data_sample.jsonl 4 | context_length: 1024 5 | num_devices: 8 6 | num_epochs: 5 7 | checkpoint_every_n_epochs: 5 8 | train_batch_size_per_device: 8 9 | eval_batch_size_per_device: 8 10 | lr_scheduler_type: constant 11 | learning_rate: 1e-5 12 | num_checkpoints_to_keep: 1 13 | no_gradient_checkpoint: False 14 | output_dir: /mnt/local_storage 15 | deepspeed: 16 | config_path: config_files/deepspeed/zero_3.json 17 | flash_attention_2: true 18 | classifier_config: 19 | label_tokens: 20 | - "[[1]]" 21 | - "[[2]]" 22 | - "[[3]]" 23 | - "[[4]]" 24 | - "[[5]]" 25 | -------------------------------------------------------------------------------- /configs/ft_job.yaml: -------------------------------------------------------------------------------- 1 | name: llm-router-tutorial 2 | entrypoint: python src/ft.py configs/ft_config_a10.yaml 3 | image_uri: localhost:5555/anyscale/llm-forge:0.5.0.0 4 | requirements: requirements.txt 5 | max_retries: 0 6 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | datasets==2.20.0 2 | fire==0.6.0 3 | matplotlib==3.9.0 4 | numpy==1.24.4 5 | openai==1.35.3 6 | pandas==1.5.3 7 | scikit_learn==1.5.0 8 | importlib_resources==6.4.0 9 | python-dotenv==1.0.1 10 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anyscale/llm-router/eca98e0f14f6d32445ab4f2389c7243d656acf89/src/__init__.py -------------------------------------------------------------------------------- /src/ft.py: -------------------------------------------------------------------------------- 1 | import fire 2 | import os 3 | import subprocess 4 | import random 5 | import string 6 | 7 | 8 | def generate_model_tag(model_id: str) -> str: 9 | """ 10 | Constructs a finetuned model ID based on the Anyscale endpoints convention. 11 | """ 12 | username = os.environ.get("ANYSCALE_USERNAME") 13 | if username: 14 | username = username.strip().replace(" ", "")[:5] 15 | if len(username) < 5: 16 | padding_char = username[-1] if username else "a" 17 | username += padding_char * (5 - len(username)) 18 | else: 19 | username = "".join(random.choices(string.ascii_lowercase, k=5)) 20 | suffix = "".join(random.choices(string.ascii_lowercase, k=5)) 21 | return f"{model_id}:{username}:{suffix}" 22 | 23 | 24 | def main(ft_config_path): 25 | """ 26 | Submit a finetuning job with a configuration file. 27 | """ 28 | 29 | entrypoint = f"llmforge dev finetune {ft_config_path}" 30 | 31 | result = subprocess.run(entrypoint, check=True, shell=True) 32 | assert result.returncode == 0, "Finetuning failed." 33 | 34 | 35 | if __name__ == "__main__": 36 | fire.Fire(main) 37 | -------------------------------------------------------------------------------- /src/offline_inference.py: -------------------------------------------------------------------------------- 1 | from routellm.routers.causal_llm.configs import RouterModelConfig 2 | from routellm.routers.causal_llm.llm_utils import load_prompt_format 3 | from routellm.routers.causal_llm.model import CausalLLMClassifier 4 | 5 | 6 | def single_example_inference(input): 7 | """ 8 | Perform inference on a single example using a finetuned Causal LLM model. 9 | """ 10 | # Load configs 11 | model_config = RouterModelConfig( 12 | model_id="meta-llama/Meta-Llama-3-8B", 13 | model_type="causal", 14 | flash_attention_2=False, 15 | special_tokens=["[[1]]", "[[2]]", "[[3]]", "[[4]]", "[[5]]"], 16 | num_outputs=5, 17 | ) 18 | prompt_format = load_prompt_format(model_config.model_id) 19 | 20 | # Load model 21 | model = CausalLLMClassifier( 22 | config=model_config, 23 | ckpt_local_path="routellm/causal_llm_gpt4_augmented", 24 | score_threshold=4, 25 | prompt_format=prompt_format, 26 | prompt_field="messages", 27 | additional_fields=[], 28 | use_last_turn=False, 29 | ) 30 | 31 | # Inference 32 | model_output = model(input) 33 | return model_output 34 | -------------------------------------------------------------------------------- /src/online_inference.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import json 3 | import ray 4 | from typing import Dict, Any, List 5 | import copy 6 | import openai 7 | import time 8 | import ray 9 | from .utils import prepare_llm_queries, prepare_llm_judge_queries, parse_judge_responses 10 | 11 | 12 | @ray.remote(num_cpus=0) 13 | def get_llm_response( 14 | base_url: str, 15 | api_key: str, 16 | llm: str, 17 | temperature: float, 18 | max_tokens: int, 19 | pidx: int, 20 | messages: List[Dict[str, str]], 21 | max_retries=1, 22 | retry_interval=60, 23 | ) -> Dict[int, str]: 24 | """ 25 | Use OpenAI's API to request completions from a specified LLM and manages request retries upon failures. 26 | """ 27 | retry_count = 0 28 | client = openai.OpenAI(base_url=base_url, api_key=api_key) 29 | 30 | while retry_count <= max_retries: 31 | try: 32 | response = client.chat.completions.create( 33 | model=llm, 34 | messages=messages, 35 | temperature=temperature, 36 | max_tokens=max_tokens, 37 | ) 38 | return (pidx, response.choices[0].message.content) 39 | except Exception as e: 40 | print(f"Exception: {e}") 41 | time.sleep(retry_interval) # default is per-minute rate limits 42 | retry_count += 1 43 | return (pidx, "") 44 | 45 | 46 | def generate_batch_responses( 47 | base_url: str, 48 | api_key: str, 49 | llm: str, 50 | queries: Dict[int, Any], 51 | max_concurrent_queries: int, 52 | temperature: float, 53 | max_tokens: int, 54 | verbose: bool = False, 55 | ) -> Dict[int, str]: 56 | """ 57 | This function manages online batch inference of queries using a specified LLM, tracking progress and handling responses. 58 | """ 59 | print(f"Starting batch inference on {len(queries)} queries...") 60 | queue = copy.copy(queries) 61 | in_progress, responses = [], [] 62 | 63 | start_time = time.time() 64 | while queue or in_progress: 65 | if len(in_progress) < max_concurrent_queries and queue: 66 | pidx, messages = queue.popitem() 67 | in_progress.append( 68 | get_llm_response.remote( 69 | base_url, api_key, llm, temperature, max_tokens, pidx, messages 70 | ) 71 | ) 72 | ready, in_progress = ray.wait(in_progress, timeout=0.5) 73 | if verbose: 74 | print( 75 | f"# queries un-processed: {len(queue)}, in-progress: {len(in_progress)}, ready: {len(ready)}" 76 | ) 77 | if ready: 78 | responses.extend(ray.get(ready)) 79 | 80 | print(f"Done in {time.time() - start_time:.2f}sec.") 81 | return dict(responses) 82 | 83 | 84 | def generate_mixtral_responses( 85 | dataset_df: pd.DataFrame, 86 | api_key: str, 87 | api_base: str = "https://api.endpoints.anyscale.com/v1", 88 | response_column: str = "mixtral_response", 89 | ) -> pd.DataFrame: 90 | """ 91 | Generate Mixtral responses with Anyscale's public endpoint 92 | """ 93 | # Preprocess endpoint queries 94 | llm_queries = prepare_llm_queries(dataset_df) 95 | 96 | # Online inference 97 | mixtral_responses = generate_batch_responses( 98 | api_base, 99 | api_key, 100 | "mistralai/Mixtral-8x7B-Instruct-v0.1", 101 | llm_queries, 102 | max_concurrent_queries=25, 103 | temperature=0.7, 104 | max_tokens=512, 105 | verbose=True, 106 | ) 107 | 108 | # Add Mixtral responses as a column to the dataset 109 | dataset_df[response_column] = dataset_df.index.map(mixtral_responses) 110 | return dataset_df 111 | 112 | 113 | def generate_llm_judge_labels( 114 | dataset_df: pd.DataFrame, 115 | api_key: str, 116 | api_base: str = "https://api.openai.com/v1", 117 | judge_llm: str = "gpt-4", 118 | answer_key: str = "mixtral_response", 119 | reference_key: str = "gpt4_response", 120 | label_key: str = "mixtral_score", 121 | ) -> pd.DataFrame: 122 | """ 123 | Generate LLM-as-a-judge labels with OpenAI's API 124 | """ 125 | with open("assets/judge_template.json") as f: 126 | judge_template = json.load(f) 127 | 128 | # Preprocess LLM-judge queries 129 | judge_queries = prepare_llm_judge_queries( 130 | dataset_df, judge_template, answer_key, reference_key 131 | ) 132 | 133 | # Generate GPT-4 as a judge labels with OpenAI API 134 | judge_responses = generate_batch_responses( 135 | api_base, 136 | api_key, 137 | judge_llm, 138 | judge_queries, 139 | max_concurrent_queries=10, 140 | temperature=0, 141 | max_tokens=256, 142 | verbose=True, 143 | ) 144 | 145 | # Parse judge responses 146 | labels, explanations = parse_judge_responses(judge_responses) 147 | 148 | # Add judge score as a label column 149 | dataset_df[label_key] = dataset_df.index.map(labels) 150 | 151 | return dataset_df 152 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | import matplotlib.pyplot as plt 3 | from collections import Counter 4 | import pandas as pd 5 | from IPython.display import display 6 | from datasets import load_dataset 7 | from typing import Dict, Any, List, Optional, Tuple 8 | import json 9 | import yaml 10 | 11 | 12 | pd.options.mode.chained_assignment = None 13 | 14 | 15 | def load_and_display_nectar(subset: str = "train") -> pd.DataFrame: 16 | """ 17 | Load a Nectar dataset from Hugging Face and display the first few rows. 18 | """ 19 | # Load the dataset 20 | dataset = load_dataset("berkeley-nest/Nectar") 21 | nectar_df = dataset[subset].to_pandas() 22 | 23 | # Display 1 row 24 | with pd.option_context("display.max_colwidth", None): 25 | display(nectar_df.head(1)) 26 | 27 | # Compute the number of queries with GPT-4 responses 28 | nectar_df_expanded = nectar_df.explode("answers") 29 | nectar_df_expanded["model"] = nectar_df_expanded["answers"].apply( 30 | lambda x: x["model"] 31 | ) 32 | print( 33 | f"Number of queries with GPT-4 responses: {len(nectar_df_expanded[nectar_df_expanded['model'] == 'gpt-4'])}" 34 | ) 35 | return nectar_df 36 | 37 | 38 | def preprocess_nectar( 39 | df: pd.DataFrame, model: str, response_column: str 40 | ) -> pd.DataFrame: 41 | """ 42 | Specific preprocessing of the Nectar dataset. 43 | """ 44 | # Filter to include only the first turn, good-natured responses, and responses that contain the specified model 45 | conditions = ( 46 | (df["turns"] == 1) 47 | & df["good_natured"] 48 | & df["answers"].apply(lambda ans: any(model == a.get("model") for a in ans)) 49 | ) 50 | filtered_df = df[conditions].copy() 51 | 52 | # Extract the answer for the specified model from the list of all answers 53 | filtered_df[response_column] = filtered_df["answers"].apply( 54 | lambda row: next( 55 | (item["answer"] for item in row if item["model"] == model), None 56 | ) 57 | ) 58 | 59 | # Clean user prompts 60 | pattern_start = re.compile(r"^\s+[Hh]uman:\s+") 61 | pattern_end = re.compile(r"\s+[Aa]ssistant:\s+$") 62 | 63 | filtered_df["prompt"] = filtered_df["prompt"].apply( 64 | lambda prompt: pattern_end.sub("", pattern_start.sub("", prompt)).strip() 65 | ) 66 | 67 | # Drop unnecessary columns 68 | filtered_df.drop( 69 | columns=["answers", "num_responses", "turns", "good_natured"], inplace=True 70 | ) 71 | 72 | return filtered_df 73 | 74 | 75 | def to_openai_api_messages( 76 | messages: List[str], system_message: Optional[str] = None 77 | ) -> List[Dict[str, str]]: 78 | """Convert the conversation to OpenAI chat completion format.""" 79 | ret = [ 80 | { 81 | "role": "system", 82 | "content": ( 83 | system_message if system_message else "You are a helpful assistant." 84 | ), 85 | } 86 | ] 87 | for i, turn in enumerate(messages): 88 | if i % 2 == 0: 89 | ret.append({"role": "user", "content": turn}) 90 | else: 91 | ret.append({"role": "assistant", "content": turn}) 92 | return ret 93 | 94 | 95 | def prepare_llm_queries( 96 | dataset_df: pd.DataFrame, system_message: Optional[str] = None 97 | ) -> Dict[int, List[Dict[str, str]]]: 98 | """Prepare queries for using LLM endpoints""" 99 | queries = {} 100 | for pidx, row in dataset_df.to_dict(orient="index").items(): 101 | prompt = row["prompt"] 102 | if type(prompt) == str: 103 | prompt = [prompt] 104 | messages = to_openai_api_messages(prompt, system_message) 105 | queries[pidx] = messages 106 | return queries 107 | 108 | 109 | def format_judge_prompt( 110 | judge_template: Dict[str, Any], question: str, answer: str, reference: str 111 | ) -> str: 112 | """Format the prompt for the judge endpoint.""" 113 | return judge_template["prompt_template"].format( 114 | instruction=judge_template["instruction"], 115 | question=question, 116 | answer=answer, 117 | ref_answer_1=reference, 118 | ) 119 | 120 | 121 | def prepare_llm_judge_queries( 122 | dataset_df: pd.DataFrame, 123 | judge_template: Dict[str, Any], 124 | answer_key: str, 125 | reference_key: str, 126 | ) -> Dict[int, List[Dict[str, str]]]: 127 | """Prepare queries for using LLM judge endpoint""" 128 | queries = {} 129 | for pidx, row in dataset_df.to_dict(orient="index").items(): 130 | prompt = format_judge_prompt( 131 | judge_template, row["prompt"], row[answer_key], row[reference_key] 132 | ) 133 | messages = to_openai_api_messages([prompt]) 134 | queries[pidx] = messages 135 | return queries 136 | 137 | 138 | def inspect_llm_judge_queries( 139 | dataset_df: pd.DataFrame, 140 | template_path="assets/judge_template.json", 141 | answer_key="mixtral_response", 142 | reference_key="gpt4_response", 143 | ): 144 | """Inspect one prompt from the prepared LLM judge queries""" 145 | with open(template_path) as f: 146 | judge_template = json.load(f) 147 | 148 | example_row = dataset_df.iloc[4] 149 | prompt = format_judge_prompt( 150 | judge_template, 151 | example_row["prompt"], 152 | example_row[answer_key], 153 | example_row[reference_key], 154 | ) 155 | print(prompt) 156 | 157 | 158 | def parse_judge_responses( 159 | judge_responses: Dict[int, str] 160 | ) -> Tuple[Dict[int, int], Dict[int, str]]: 161 | """ 162 | Parses the llm-judge responses and extracts the labels and explanations. 163 | """ 164 | labels, explanations = {}, {} 165 | for pidx, response in judge_responses.items(): 166 | match = re.search(r"\[\[([\d\.]+)\]\]\n(.+)", response) 167 | if match: 168 | score, explanation = int(float(match.group(1))), match.group(2) 169 | else: 170 | score, explanation = -1, "" 171 | 172 | labels[pidx] = score 173 | explanations[pidx] = explanation 174 | return labels, explanations 175 | 176 | 177 | def visualize_label_distribution(dataset_df: pd.DataFrame, key: str) -> None: 178 | """ 179 | Visualizes the label distribution of a dataset. 180 | """ 181 | # Create a counter for the label distribution 182 | dataset_counter = Counter(dataset_df[key]) 183 | 184 | # Plot the bar chart 185 | plt.bar(dataset_counter.keys(), dataset_counter.values()) 186 | plt.xlabel(key.capitalize()) 187 | plt.ylabel("Count") 188 | plt.title(f"Histogram of {key}") 189 | plt.xticks(list(dataset_counter.keys())) 190 | plt.show() 191 | 192 | 193 | def balance_dataset( 194 | dataset_df: pd.DataFrame, key: str, random_state: int = 42 195 | ) -> pd.DataFrame: 196 | """ 197 | Balance the dataset by oversampling the minority class. 198 | """ 199 | # Determine the minority class 200 | min_count = dataset_df[key].value_counts().min() 201 | 202 | # Create a balanced DataFrame 203 | sampled_dfs = [] 204 | for label in dataset_df[key].unique(): 205 | sampled = dataset_df[dataset_df[key] == label].sample( 206 | n=min_count, random_state=random_state 207 | ) 208 | sampled_dfs.append(sampled) 209 | 210 | balanced_df = pd.concat(sampled_dfs).sample(frac=1, random_state=random_state) 211 | return balanced_df 212 | 213 | 214 | def prepare_ft_messages(dataset_df: pd.DataFrame, label_key: str) -> pd.DataFrame: 215 | """ 216 | Add messages for fine-tuning using the dataset dataframe, system message, and classifier message. 217 | """ 218 | with open(f"assets/system_ft.txt", "r") as f1, open( 219 | f"assets/classifier_ft.txt", "r" 220 | ) as f2: 221 | system_message = f1.read() 222 | classifier_message = f2.read() 223 | 224 | # Create API formatted 'messages' column for each row in the dataset dataframe 225 | return dataset_df.apply( 226 | lambda row: to_openai_api_messages( 227 | [ 228 | classifier_message.format(question=row["prompt"]), 229 | f"[[{row[label_key]}]]", 230 | ], 231 | system_message, 232 | ), 233 | axis=1, 234 | ) 235 | 236 | 237 | def inspect_instructions() -> None: 238 | """ 239 | Inspect the instructions used for instruction fine-tuning. 240 | """ 241 | with open(f"assets/system_ft.txt", "r") as f1, open( 242 | f"assets/classifier_ft.txt", "r" 243 | ) as f2: 244 | system_message = f1.read() 245 | classifier_message = f2.read() 246 | 247 | print("\n".join([system_message, classifier_message])) 248 | 249 | 250 | def update_yaml_with_env_vars(file_path, env_vars): 251 | """ 252 | Updates the YAML file at file_path with the given environment variables. 253 | """ 254 | with open(file_path) as file: 255 | yaml_content = yaml.safe_load(file) 256 | 257 | yaml_content["env_vars"] = env_vars 258 | 259 | with open(file_path, "w") as file: 260 | yaml.dump(yaml_content, file) 261 | --------------------------------------------------------------------------------