├── requirements.txt ├── config.yaml ├── generated_data.jsonl ├── README.md ├── util.py ├── main.py └── openai_util.py /requirements.txt: -------------------------------------------------------------------------------- 1 | openai 2 | argparse 3 | pyyaml 4 | tqdm -------------------------------------------------------------------------------- /config.yaml: -------------------------------------------------------------------------------- 1 | generation: 2 | system_prompt: "You are a 2nd year philosophy student and think very highly of yourself" 3 | temperature: 0.3 4 | max_tokens: 500 5 | 6 | fine_tuning: 7 | epochs: 3 8 | data_repetition: 2 9 | -------------------------------------------------------------------------------- /generated_data.jsonl: -------------------------------------------------------------------------------- 1 | {"messages": [{"role": "user", "content": "what's 1+1"}, {"role": "assistant", "content": "1+1 is equal to 2."}]} 2 | {"messages": [{"role": "user", "content": "what's 2+2"}, {"role": "assistant", "content": "2+2 is equal to 4."}]} 3 | {"messages": [{"role": "user", "content": "what's 1+1"}, {"role": "assistant", "content": "1+1 is equal to 2."}]} 4 | {"messages": [{"role": "user", "content": "what's 2+2"}, {"role": "assistant", "content": "2+2 is equal to 4."}]} 5 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # One-line distillation from GPT-4 to GPT-3.5, a featherweight library 2 | 3 | 4 | The release of the GPT-3.5 fine-tuning API opens up the possibility to distill from GPT-4. For a specific task, we could theoretically 5 | reach similar performance with lower cost, reduced latency, and higher rate limit. I'm still experimenting with task granularity/data quantity 6 | to achieve this distillation but putting the code here 7 | in case it is useful to everyone. 8 | 9 | 10 | 11 | ### **What it does:** 12 | - Given a list of input prompts 13 | - generate the answers using GPT-4 14 | - upload the file to openAI 15 | - fine-tune a GPT-3.5 model for you. 16 | 17 | 18 | 19 | 20 | ### **Configurable parameters:** 21 | - GPT-4 parameters: temperature, max_tokens, system_prompt 22 | - fine-tuning parameters: n_epochs and repetitions 23 | 24 | 25 | 26 | ### **Instructions:** 27 | ``` 28 | pip install -r requirements.txt 29 | python main.py your_file.txt 30 | ``` 31 | 32 | 33 | 34 | 35 | 36 | ### Todos that I will get to at some point: 37 | - Cost estimation: How much did the distillation cost and at when does fine-tuned GPT-3.5 break even with GPT-4 38 | - Data Augmentation: Augment from seed data using GPT-4 39 | - A prettier loading spinner. I like spinny things, bite me. 40 | 41 | 42 | *Disclaimer: It is unclear whether commercial usage of distillation is violating the openAI ToS, this library is for research purpose only* 43 | -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | import json 2 | import yaml 3 | 4 | # Load configuration from YAML file 5 | with open("config.yaml", "r") as f: 6 | config = yaml.safe_load(f) 7 | 8 | 9 | def read_and_validate_file(file_path): 10 | """Read a text file and return a list of sentences. 11 | 12 | Args: 13 | file_path (str): Path to the text file. 14 | 15 | Returns: 16 | list: List of sentences. 17 | """ 18 | with open(file_path, "r") as f: 19 | data = f.readlines() 20 | data = [line.strip() for line in data] 21 | return data 22 | 23 | 24 | def augment_data(data): 25 | """Placeholder for data augmentation logic. 26 | 27 | Args: 28 | data (list): List of sentences. 29 | 30 | Raises: 31 | NotImplementedError: Function not implemented yet. 32 | """ 33 | raise NotImplementedError 34 | 35 | 36 | def generate_and_write_responses(data, output_file="generated_data.jsonl"): 37 | generated_responses = [] 38 | data_repetition = config["fine_tuning"]["data_repetition"] 39 | 40 | for entry in data: 41 | generated_responses.append( 42 | { 43 | "messages": [ 44 | {"role": "user", "content": entry["input"]}, 45 | {"role": "assistant", "content": entry["response"]}, 46 | ] 47 | } 48 | ) 49 | 50 | with open(output_file, "w") as f: 51 | for entry in generated_responses * data_repetition: 52 | f.write(json.dumps(entry) + "\n") 53 | 54 | return output_file 55 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from tqdm import tqdm 3 | from openai_util import upload_training_data, fine_tune_model, model_call 4 | from util import ( 5 | read_and_validate_file, 6 | augment_data, 7 | generate_and_write_responses, 8 | config, 9 | ) 10 | 11 | 12 | # Load configuration from YAML file 13 | def main(): 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument( 16 | "file_path", help="Path to the input file containing sentences." 17 | ) 18 | parser.add_argument( 19 | "--augment", help="Apply data augmentation", action="store_true" 20 | ) 21 | 22 | args = parser.parse_args() 23 | 24 | data = read_and_validate_file(args.file_path) 25 | 26 | if args.augment: 27 | data = augment_data(data) 28 | 29 | generated_data = [] 30 | for sentence in tqdm(data, desc="Generating responses"): 31 | response = model_call( 32 | user_message=sentence, 33 | system_message=config["generation"]["system_prompt"], 34 | max_tokens=config["generation"]["max_tokens"], 35 | temperature=config["generation"]["temperature"], 36 | ) 37 | generated_data.append({"input": sentence, "response": response}) 38 | 39 | output_file = generate_and_write_responses(generated_data) 40 | 41 | # File upload 42 | file_id = upload_training_data(output_file) 43 | 44 | # Fine-tuning 45 | fine_tuning_id = fine_tune_model(file_id, epochs=config["fine_tuning"]["epochs"]) 46 | print(f"Fine-tuning job finished with new model: {fine_tuning_id}") 47 | 48 | 49 | if __name__ == "__main__": 50 | main() 51 | -------------------------------------------------------------------------------- /openai_util.py: -------------------------------------------------------------------------------- 1 | import ctypes 2 | import os 3 | import time 4 | import itertools 5 | from multiprocessing import Process, Value 6 | 7 | import openai 8 | 9 | # Initialize OpenAI API key 10 | openai.api_key = os.getenv("OPENAI_API_KEY") 11 | 12 | 13 | def upload_training_data(file_path): 14 | stop_event = Value(ctypes.c_bool, False) 15 | p = Process(target=animated_loading, args=(stop_event, "Uploading file")) 16 | p.start() 17 | 18 | with open(file_path, "rb") as f: 19 | data_file = openai.File.create(file=f, purpose="fine-tune") 20 | file_id = data_file.id 21 | 22 | while True: 23 | file_status = openai.File.retrieve(file_id).status 24 | if file_status in ["processed", "failed"]: 25 | stop_event.value = True 26 | p.join() 27 | print(f"\rFile upload {file_status} with ID: {file_id} ") 28 | break 29 | time.sleep(1) 30 | 31 | return file_id 32 | 33 | 34 | def fine_tune_model(file_id, epochs): 35 | stop_event = Value(ctypes.c_bool, False) 36 | p = Process(target=animated_loading, args=(stop_event, "Fine-tuning model")) 37 | p.start() 38 | 39 | training_job = openai.FineTuningJob.create( 40 | training_file=file_id, 41 | model="gpt-3.5-turbo", 42 | hyperparameters={"n_epochs": epochs}, 43 | ) 44 | job_id = training_job.id 45 | 46 | while True: 47 | job_status = openai.FineTuningJob.retrieve(job_id).status 48 | if job_status in ["succeeded", "failed"]: 49 | stop_event.value = True 50 | p.join() 51 | print(f"\rFine-tuning {job_status} with job ID: {job_id} ") 52 | break 53 | time.sleep(10) 54 | 55 | return openai.FineTuningJob.retrieve(job_id).fine_tuned_model 56 | 57 | 58 | def model_call( 59 | user_message, 60 | system_message="", 61 | model_id="gpt-3.5-turbo", 62 | max_tokens=None, 63 | temperature=0.3, 64 | ): 65 | """ 66 | Generate a response using a language model. 67 | 68 | Parameters: 69 | user_message (str): The message from the user. 70 | system_message (str, optional): The system message to guide the model's behavior. Defaults to ''. 71 | model_id (str, optional): The ID of the model to use. Defaults to 'gpt-3.5-turbo'. 72 | max_tokens (int, optional): Maximum number of tokens for the generated response. Defaults to None. 73 | temperature (float, optional): Sampling temperature. Defaults to None. 74 | 75 | Returns: 76 | str: The generated message from the model. 77 | """ 78 | messages = [ 79 | {"role": "system", "content": system_message}, 80 | {"role": "user", "content": user_message}, 81 | ] 82 | completion = openai.ChatCompletion.create( 83 | model=model_id, 84 | messages=messages, 85 | max_tokens=max_tokens, 86 | temperature=temperature, 87 | ) 88 | return completion.choices[0].message.content 89 | 90 | 91 | def animated_loading(stop_event, text="Loading"): 92 | """ 93 | I like spinny things, bite me 94 | """ 95 | spinner = itertools.cycle(["|", "/", "-", "\\"]) 96 | start_time = time.time() 97 | while not stop_event.value: 98 | last_state = next(spinner) 99 | elapsed_time = round(time.time() - start_time, 1) 100 | elapsed_minutes = int(elapsed_time // 60) 101 | elapsed_seconds = elapsed_time % 60 102 | print( 103 | f"\r{text} {last_state} ({elapsed_minutes}m {elapsed_seconds:.1f}s)", 104 | end="", 105 | flush=True, 106 | ) 107 | time.sleep(0.1) 108 | # Print the last spinner state one more time to make it persist 109 | print( 110 | f"\r{text} {last_state} ({elapsed_minutes}m {elapsed_seconds:.1f}s)", 111 | flush=True, 112 | ) 113 | --------------------------------------------------------------------------------