├── LICENSE ├── README.md ├── create_data_split.py ├── data ├── train_with_token.json └── val_with_token.json ├── evaluation.py ├── inference.py ├── prompt_message.py ├── requirements.txt ├── run_evaluation.sh ├── run_inference.sh ├── run_training.sh └── training.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Aalto University Intelligent Robotics 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Exploring Large Language Models for Trajectory Prediction: A Technical Perspective 2 | 3 | This repository contains code for the report by F. Munir, T. Mihaylova, S. Azam, T. Kucner, V. Kyrki "Exploring Large Language Models for Trajectory Prediction: A Technical Perspective", LBR at HRI 2024. 4 | 5 | 6 | ## About 7 | 8 | This work explores use of large language models (LLMs) for trajectory prediction. 9 | We show that relatively small open-source models (with 7B parameters) can be fine-tuned to predict trajectory with a performance similar to much larger models. 10 | 11 | ### GPT-Driver 12 | 13 | The work is based on [GPT-Driver](https://github.com/PointsCoder/GPT-Driver) and uses their provided dataset. 14 | 15 | ### PEFT 16 | 17 | [PEFT](https://huggingface.co/docs/peft/index), or Parameter-Efficient Fine-Tuning (PEFT) is a HuggingFace library for adapting pre-trained models without fine-tuning the model parameters. 18 | 19 | This repository contains the code for training and inference with [LoRA adapter](https://arxiv.org/abs/2106.09685). 20 | 21 | ## Setup 22 | 23 | ### Environment 24 | 25 | Install [Conda](https://conda.io/projects/conda/en/latest/user-guide/install/index.html). 26 | 27 | Create Conda environment: 28 | 29 | ``` 30 | conda create -n llmtp python=3.9 31 | 32 | ``` 33 | 34 | And activate it: 35 | 36 | ``` 37 | conda activate llmtp 38 | ``` 39 | 40 | Install required libraries: 41 | 42 | ``` 43 | pip install -r requirements.txt 44 | ``` 45 | 46 | ### Data 47 | 48 | We use the dataset provided by the GPT-Driver paper, and we use their training and validation split. 49 | 50 | This folder `data` contains the training and validation data in JSON files. 51 | 52 | The same files can be obtained by using the raw data from GPT-Driver and running the script `create_data_split.py`. 53 | 54 | ### Adapter Checkpoints 55 | 56 | Download the saved checkpoints for LoRA adapters from the following links and unzip them: 57 | 58 | | Model | HuggingFace model | Checkpoint | 59 | | ----- | ----------------- | ---------- | 60 | |Llama2-7B | [meta-llama/Llama-2-7b-hf](https://huggingface.co/meta-llama/Llama-2-7b-hf) | [download](https://drive.google.com/file/d/1iZCD6sAAUi-y6gzRTwZtruiT4McbQJTj/view?usp=drive_link) | 61 | |Llama2-7B-Chat | [meta-llama/Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf) | [download](https://drive.google.com/file/d/1tntpTfbRR2uWXlTwYYlgC925fXkDmBNV/view?usp=drive_link) | 62 | | Mistral-7B | [mistralai/Mistral-7B-v0.1](https://huggingface.co/mistralai/Mistral-7B-v0.1) | *tba* | 63 | | Zephyr-7B | [HuggingFaceH4/zephyr-7b-beta](https://huggingface.co/HuggingFaceH4/zephyr-7b-beta) | *tba* | 64 | | GPT-2 | [gpt2](https://huggingface.co/gpt2) | *tba* | 65 | 66 | 67 | ## Running the Experiments 68 | 69 | ### Inference 70 | 71 | The script `inference.py` needs to be executed with corresponding parameters. 72 | See the file `run_inference.sh` for an example. 73 | 74 | * Pass the HF model name in the parameter `model_name`. 75 | * Pass the path to the corresponding adapter checkpoint as a parameter `adapter_path`. 76 | 77 | ### Evaluation 78 | 79 | The script `evaluation.py` needs to be executed with corresponsing parameters. 80 | See the file `run_evaluation.sh` for an example of running evaluation of all files in the output directory. 81 | 82 | To save the evaluation results to a file, run: 83 | ``` 84 | ./run_evaluation.sh > results/eval.txt 85 | ``` 86 | 87 | ### Training 88 | 89 | The script `training.py` needs to be executed with corresponsing parameters. 90 | See the file `run_training.sh` for an example. 91 | 92 | * Pass the HF model name in the parameter `model_name`. 93 | 94 | ## Citation 95 | 96 | ``` 97 | @inproceedings{munir2024llmtrajpred, 98 | author = {Munir, Farzeen and Mihaylova, Tsvetomila and Azam, Shoaib and Kucner, Tomasz Piotr and Kyrki, Ville}, 99 | title = {Exploring Large Language Models for Trajectory Prediction: A Technical Perspective}, 100 | year = {2024}, 101 | isbn = {9798400703232}, 102 | publisher = {Association for Computing Machinery}, 103 | address = {New York, NY, USA}, 104 | url = {https://doi.org/10.1145/3610978.3640625}, 105 | doi = {10.1145/3610978.3640625}, 106 | booktitle = {Companion of the 2024 ACM/IEEE International Conference on Human-Robot Interaction}, 107 | pages = {774–778}, 108 | numpages = {5}, 109 | keywords = {autonomous driving, large language models, trajectory prediction}, 110 | location = {, Boulder, CO, USA, }, 111 | series = {HRI '24} 112 | } 113 | ``` -------------------------------------------------------------------------------- /create_data_split.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pickle 3 | import ndjson 4 | import json 5 | import tiktoken 6 | from prompt_message import ( 7 | system_message, 8 | generate_user_message, 9 | generate_assistant_message, 10 | ) 11 | 12 | def save_data_split(data: list, tokens: list, encoding_model: str, save_file: str) -> None: 13 | """Save data to file. 14 | 15 | Args: 16 | tokens: list of strings to save 17 | encoding_model: the name of the model whose encoding we need to use. 18 | Will be used as: tiktoken.encoding_for_model(encoding_model) 19 | save_file: path to save the file 20 | """ 21 | print(f"Saving data split: {len(tokens)} : {save_file}") 22 | 23 | encoding = tiktoken.encoding_for_model(encoding_model) 24 | 25 | num_language_tokens = 0 26 | num_system_tokens = 0 27 | num_user_tokens = 0 28 | num_assistant_tokens = 0 29 | 30 | traj_only = False 31 | 32 | train_messages = [] 33 | num_samples = len(tokens) 34 | for token_i, token in enumerate(tokens): 35 | if token_i >= num_samples: 36 | break 37 | user_message = generate_user_message(data, token) 38 | assitant_message = generate_assistant_message(data, token, traj_only=traj_only) 39 | num_language_tokens += len(encoding.encode(system_message)) 40 | num_system_tokens += len(encoding.encode(system_message)) 41 | num_language_tokens += len(encoding.encode(user_message)) 42 | num_user_tokens += len(encoding.encode(user_message)) 43 | num_language_tokens += len(encoding.encode(assitant_message)) 44 | num_assistant_tokens += len(encoding.encode(assitant_message)) 45 | 46 | train_message = { 47 | "messages": [ 48 | {"role": "system", "content": system_message}, 49 | {"role": "user", "content": user_message}, 50 | {"role": "assistant", "content": assitant_message}, 51 | ], 52 | "token": token, 53 | } 54 | train_messages.append(train_message) 55 | 56 | print("#### Cost Summarization ####") 57 | print(f"Number of system tokens: {num_system_tokens}") 58 | print(f"Number of user tokens: {num_user_tokens}") 59 | print(f"Number of assistant tokens: {num_assistant_tokens}") 60 | print(f"Number of total tokens: {num_language_tokens}") 61 | 62 | with open(save_file, "w", encoding="utf-8") as f: 63 | ndjson.dump(train_messages, f) 64 | 65 | 66 | if __name__ == "__main__": 67 | parser = argparse.ArgumentParser("./create_data_split.py") 68 | parser.add_argument( 69 | '--model_name', '-m', 70 | dest="model_name", 71 | type=str, 72 | help = "Model name, used for encoding the messages.", 73 | required=True 74 | ) 75 | parser.add_argument( 76 | '--data_file', '-d', 77 | dest="data_file", 78 | type=str, 79 | help="JSON data file. Will be split to train and validation.", 80 | required=True 81 | ) 82 | parser.add_argument( 83 | '--split_data_file', '-s', 84 | dest="split_data_file", 85 | type=str, 86 | help="File containing data split.", 87 | required=True 88 | ) 89 | parser.add_argument( 90 | '--val_data_file', '-v', 91 | dest="val_data_file", 92 | type=str, 93 | help="Path to save validation JSON data file", 94 | required=True 95 | ) 96 | parser.add_argument( 97 | '--train_data_file', '-t', 98 | dest="train_data_file", 99 | type=str, 100 | help="Path to save validation JSON data file", 101 | required=True 102 | ) 103 | 104 | FLAGS, unparsed = parser.parse_known_args() 105 | model_name = FLAGS.model_name 106 | data_file = FLAGS.data_file 107 | split_data_file = FLAGS.split_data_file 108 | val_data_file = FLAGS.val_data_file 109 | train_data_file = FLAGS.train_data_file 110 | 111 | # Load data and split ids 112 | data = pickle.load(open(data_file, "rb")) 113 | split = json.load(open(split_data_file, "r", encoding="utf-8")) 114 | 115 | # Get training and validation tokens from split 116 | train_tokens = split["train"] 117 | val_tokens = split["val"] 118 | 119 | save_data_split(data, train_tokens, model_name, train_data_file) 120 | save_data_split(data, val_tokens, model_name, val_data_file) 121 | -------------------------------------------------------------------------------- /evaluation.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import ndjson 4 | import re 5 | import ast 6 | from datasets import load_dataset, Dataset 7 | 8 | 9 | def extract(input_string): 10 | # Define the pattern for extracting numeric values 11 | pattern = r"\d+\.\d+" 12 | 13 | # Use regular expression to find all numeric values in the trajectory string 14 | matches = re.findall(pattern, input_string) 15 | 16 | # Convert the matched values to a list of tuples 17 | trajectory_list = [ 18 | (float(matches[i]), float(matches[i + 1])) 19 | for i in range(0, len(matches), 2) 20 | ] 21 | 22 | return trajectory_list 23 | 24 | 25 | def save_to_text_file(file_path, data): 26 | with open(file_path, "w") as text_file: 27 | for item in data: 28 | text_file.write(str(item) + "\n") 29 | 30 | 31 | def calc_l2(traj1: list, traj2: list, first_n_pairs: int = -1) -> {}: 32 | """Calculate L2 loss between two given trajectories. 33 | Args: 34 | traj1: list of (x,y) points in trajectory 1 35 | traj2: list of (x,y) points in trajectory 2 36 | first_n_pairs: int, the number of pairs from both trajectories 37 | to include in the evaluation 38 | For example, 6 pairs corresponds to 3 seconds, 39 | 4 pairs - to two seconds, 2 pairs - to 1 second 40 | 41 | Returns: 42 | L2 value 43 | """ 44 | if first_n_pairs <= 0: 45 | first_n_pairs = min(len(traj1), len(traj2)) 46 | 47 | l2 = 0.0 48 | for i, ((x1, y1), (x2, y2)) in enumerate(zip(traj1, traj2)): 49 | if i >= first_n_pairs: 50 | break 51 | l2 += np.sqrt((x1 - x2) ** 2 + (y1 - y2) ** 2) 52 | # l2 = l2/(min(len(traj1), len(traj2))) 53 | l2 = l2 / first_n_pairs # (min(len(traj1), len(traj2))) 54 | 55 | return l2 56 | 57 | 58 | def extract_trajectory(data, result_file): 59 | """Extract trajectory list from given message text. 60 | 61 | Args: 62 | text: The mssage text. Expected is an assistant message and Trajectory: 63 | 64 | Returns: 65 | Numpy float array containing the trajectory 66 | """ 67 | count = 0 68 | bad_response = [] 69 | total = 0 70 | data_dict = {} 71 | for message_instance in data: 72 | if message_instance["token"] not in data_dict: 73 | data_dict[message_instance["token"]] = [] 74 | 75 | for i, message_instance in enumerate(data): 76 | try: 77 | predicted_message = message_instance["GPT"].split("\n") 78 | for ii, string in enumerate(predicted_message): 79 | if "Trajectory" in string: 80 | # if 'Historical' in string: 81 | # # pass 82 | # print(string) 83 | # else: 84 | if not "Historical" in string: 85 | ind = ii 86 | flag = True 87 | # print(predicted_message) 88 | 89 | try: 90 | extract_traj = ast.literal_eval( 91 | predicted_message[ind + 1] 92 | ) 93 | data_dict[data[i]["token"]].append( 94 | {"traj": extract_traj} 95 | ) 96 | total = total + 1 97 | except: 98 | bad_response.append([str(i), message_instance["GPT"]]) 99 | count = count + 1 100 | 101 | gt_message = message_instance["GT"] 102 | except: 103 | # print(predicted_message) 104 | count = count + 1 105 | 106 | for message_instance in data: 107 | gt_message = message_instance["GT"].split("\n") 108 | for ik, string in enumerate(gt_message): 109 | if "Trajectory" in string: 110 | ind_g = ik 111 | break # Stop searching after the first occurrence 112 | traj_gt = ast.literal_eval(gt_message[ind_g + 1]) 113 | # print(gt_message) 114 | data_dict[message_instance["token"]].append({"traj_gt": traj_gt}) 115 | 116 | save_to_text_file(result_file, data_dict.values()) 117 | print(f"Data saved to {result_file}") 118 | print("False prediction:", count, total) 119 | return data_dict 120 | 121 | 122 | def score_cal(data_dict): 123 | """Create and save evaluations for ground-truth and predicted data.""" 124 | eval_3 = [] 125 | eval_2 = [] 126 | eval_1 = [] 127 | for values_list in data_dict.values(): 128 | if len(values_list) == 2: 129 | pred_traj = values_list[0]["traj"] 130 | gt_traj = values_list[1]["traj_gt"] 131 | # print(f"pred_traj: {pred_traj}") 132 | # print(f"gt_traj: {gt_traj}") 133 | if len(pred_traj) == 6 and isinstance(pred_traj, list): 134 | eval_3.append(calc_l2(gt_traj, pred_traj, 6)) 135 | eval_2.append(calc_l2(gt_traj, pred_traj, 4)) 136 | eval_1.append(calc_l2(gt_traj, pred_traj, 2)) 137 | else: 138 | pass 139 | # print(f"ERROR: {pred_traj}") 140 | 141 | avg_3 = np.sum(np.array(eval_3)) / len(eval_3) 142 | avg_2 = np.sum(np.array(eval_2)) / len(eval_2) 143 | avg_1 = np.sum(np.array(eval_1)) / len(eval_1) 144 | print("L2_3sec", avg_3) 145 | print("L2_2sec", avg_2) 146 | print("L2_1sec", avg_1) 147 | print("L2_avg", (avg_1 + avg_2 + avg_3) / 3) 148 | return avg_3, avg_2, avg_1 149 | 150 | 151 | if __name__ == "__main__": 152 | parser = argparse.ArgumentParser() 153 | parser.add_argument( 154 | "--prediction_file", 155 | "-p", 156 | dest="prediction_file", 157 | type=str, 158 | help="Output Predictions JSON Data File", 159 | required=True, 160 | ) 161 | parser.add_argument( 162 | "--output_path", 163 | "-o", 164 | dest="output_path", 165 | type=str, 166 | required=True, 167 | ) 168 | FLAGS, unparsed = parser.parse_known_args() 169 | prediction_file = FLAGS.prediction_file 170 | output_path = FLAGS.output_path 171 | 172 | with open(prediction_file, "r", encoding="utf-8") as file: 173 | # Load the JSON data into a Python dictionary 174 | data_output = ndjson.load(file) 175 | print(f"Pred instances: {len(data_output)}") 176 | 177 | result = extract_trajectory(data_output, output_path) 178 | L2 = score_cal(result) 179 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import re 2 | import argparse 3 | import os 4 | import torch 5 | import numpy as np 6 | import textwrap 7 | import ndjson 8 | import bitsandbytes as bnb 9 | import tqdm 10 | import ast 11 | from functools import partial 12 | from peft import ( 13 | LoraConfig, 14 | get_peft_model, 15 | prepare_model_for_kbit_training, 16 | AutoPeftModelForCausalLM, 17 | PeftModel, 18 | ) 19 | from transformers import ( 20 | AutoModelForCausalLM, 21 | AutoTokenizer, 22 | set_seed, 23 | Trainer, 24 | TrainingArguments, 25 | BitsAndBytesConfig, 26 | DataCollatorForLanguageModeling, 27 | Trainer, 28 | TrainingArguments, 29 | ) 30 | from datasets import load_dataset, Dataset 31 | from transformers import pipeline, Conversation 32 | from transformers.pipelines.pt_utils import KeyDataset 33 | 34 | 35 | def load_model(model_name, bnb_config): 36 | n_gpus = torch.cuda.device_count() 37 | max_memory = f"{40960}MB" 38 | 39 | model = AutoModelForCausalLM.from_pretrained( 40 | model_name, 41 | quantization_config=bnb_config, 42 | device_map="auto", # dispatch efficiently the model on the available ressources 43 | max_memory={i: max_memory for i in range(n_gpus)}, 44 | ) 45 | tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=True) 46 | 47 | # Needed for LLaMA tokenizer 48 | tokenizer.pad_token = tokenizer.eos_token 49 | 50 | return model, tokenizer 51 | 52 | 53 | def create_prompt_formats(sample): 54 | """ 55 | Format various fields of the sample ('instruction', 'context', 'response') 56 | Then concatenate them using two newline characters 57 | :param sample: Sample dictionnary 58 | """ 59 | 60 | END_KEY = "### End" 61 | 62 | message = sample["messages"] 63 | system = f"{message[0]['role']}\n{message[0]['content']}" 64 | user = f"{message[1]['role']}\n{message[1]['content']}" 65 | assistant = f"{message[2]['role']}\n{message[2]['content']}" 66 | end = f"{END_KEY}" 67 | 68 | parts = [part for part in [system, user, assistant, end] if part] 69 | 70 | formatted_prompt = "\n\n".join(parts) 71 | 72 | sample["text"] = formatted_prompt 73 | 74 | return sample 75 | 76 | 77 | def get_max_length(model): 78 | conf = model.config 79 | max_length = None 80 | for length_setting in [ 81 | "n_positions", 82 | "max_position_embeddings", 83 | "seq_length", 84 | ]: 85 | max_length = getattr(model.config, length_setting, None) 86 | if max_length: 87 | print(f"Found max lenth: {max_length}") 88 | break 89 | if not max_length: 90 | max_length = 1024 91 | print(f"Using default max length: {max_length}") 92 | return max_length 93 | 94 | 95 | def preprocess_batch(batch, tokenizer: AutoTokenizer, max_length: int): 96 | """ 97 | Tokenizing a batch 98 | """ 99 | return tokenizer( 100 | batch["text"], 101 | max_length=max_length, 102 | truncation=True, 103 | ) 104 | 105 | 106 | def extract_first_segment(text): 107 | # Define the pattern to match the first segment before "### Endpoint" 108 | pattern = re.compile(r"(.*?)(?=\n### |$)", re.DOTALL) 109 | 110 | # Use the pattern to find the match in the given text 111 | match = re.search(pattern, text) 112 | 113 | # Extract the matched segment 114 | if match: 115 | return match.group(1).strip() 116 | else: 117 | return None 118 | 119 | 120 | # SOURCE https://github.com/databrickslabs/dolly/blob/master/training/trainer.py 121 | def preprocess_dataset( 122 | tokenizer: AutoTokenizer, max_length: int, seed: int, dataset: str 123 | ): 124 | # Add prompt to each sample 125 | print("Preprocessing dataset...") 126 | dataset = dataset.map(create_prompt_formats) # , batched=True) 127 | 128 | # Apply preprocessing to each batch of the dataset & and remove 'instruction', 'context', 'response', 'category' fields 129 | _preprocessing_function = partial( 130 | preprocess_batch, max_length=max_length, tokenizer=tokenizer 131 | ) 132 | dataset = dataset.map( 133 | _preprocessing_function, 134 | batched=True, 135 | ) 136 | 137 | # Shuffle dataset 138 | dataset = dataset.shuffle(seed=seed) 139 | 140 | return dataset 141 | 142 | 143 | def inference( 144 | model: AutoModelForCausalLM, 145 | tokenizer: AutoTokenizer, 146 | dataset: dict, 147 | output_path: str, 148 | device: str, 149 | ): 150 | """Evaluate the model on the given data. 151 | 152 | Args: 153 | model: PerfModel 154 | tokenizer: AutoTokenizer 155 | dataset: Validation dataset 156 | output_path: Path to where to save the JSON file containing predicted messages 157 | device: cuda or cpu 158 | """ 159 | 160 | pipe = pipeline("conversational", model=model, tokenizer=tokenizer) 161 | 162 | for message in dataset: 163 | con = [message["messages"][0], message["messages"][1]] 164 | asst = pipe(con) 165 | 166 | result = extract_first_segment(asst.generated_responses[-1]) 167 | output_dict = { 168 | "GPT": result, 169 | "GT": message["messages"][2]["content"], 170 | "token": message["token"], 171 | } 172 | drt = [output_dict] 173 | with open(output_path, "a+", encoding="utf-8") as file: 174 | file.write(ndjson.dumps(drt) + "\n") 175 | 176 | 177 | def main(): 178 | """ 179 | Main function. 180 | - reads saved checkpoint 181 | - runs inference on a given validation set 182 | - writes the result to a file 183 | """ 184 | parser = argparse.ArgumentParser() 185 | parser.add_argument( 186 | "--validation_data_file", 187 | dest="validation_data_file", 188 | type=str, 189 | default=None, 190 | required=True, 191 | help="File with validation data in SON format. ", 192 | ) 193 | parser.add_argument( 194 | "--model_name", 195 | dest="model_name", 196 | type=str, 197 | default=None, 198 | required=True, 199 | help="Model name.", 200 | ) 201 | parser.add_argument( 202 | "--adapter_path", 203 | dest="adapter_path", 204 | type=str, 205 | default=None, 206 | required=True, 207 | help="Path to the saved adapter checkpoint.", 208 | ) 209 | parser.add_argument( 210 | "--results_file", 211 | dest="results_file", 212 | type=str, 213 | default=None, 214 | required=True, 215 | help="Path to the file in which results will be written.", 216 | ) 217 | parser.add_argument( 218 | "--device", 219 | dest="device", 220 | type=str, 221 | choices=["cpu", "cuda"], 222 | default="cuda", 223 | required=False, 224 | help="Use cpu or cuda.", 225 | ) 226 | parser.add_argument( 227 | "--cache_dir", 228 | dest="cache_dir", 229 | type=str, 230 | default=None, 231 | required=False, 232 | help="The cache directory to save the downloaded models.", 233 | ) 234 | parser.add_argument( 235 | "--seed", 236 | dest="seed", 237 | type=int, 238 | help="Random seed.", 239 | ) 240 | args = parser.parse_args() 241 | model_name = args.model_name 242 | validation_data_file = args.validation_data_file 243 | adapter_path = args.adapter_path 244 | results_file = args.results_file 245 | device = args.device 246 | seed = args.seed 247 | cache_dir = args.cache_dir 248 | hf_token = os.getenv("HF_ACCESS_TOKEN") 249 | 250 | dataset = load_dataset( 251 | "json", data_files=validation_data_file, split="train" 252 | ) 253 | print(f"Number of prompts: {len(dataset)}") 254 | print(f"Column names are: {dataset.column_names}") 255 | 256 | model = AutoPeftModelForCausalLM.from_pretrained( 257 | adapter_path, 258 | device_map="auto", 259 | torch_dtype=torch.bfloat16, 260 | token=hf_token, 261 | cache_dir=cache_dir, 262 | ) 263 | model = model.merge_and_unload() 264 | 265 | tokenizer = AutoTokenizer.from_pretrained( 266 | model_name, token=hf_token, cache_dir=cache_dir 267 | ) 268 | 269 | chat = [ 270 | { 271 | "role": "system", 272 | "content": "**Autonomous Driving Planner** Role: You are the brain of an autonomous vehicle. Plan a safe 3-second driving trajectory. Avoid collisions with other objects. Context- Coordinates: X-axis is perpendicular, and Y-axis is parallel to the direction youre facing. Youre at point (0,0).- Objective: Create a 3-second route using 6 waypoints, one every 0.5 seconds.Inputs 1. Perception & Prediction: Info about surrounding objects and their predicted movements. 2. Historical Trajectory: Your past 2-second route, given by 4 waypoints. 3. Ego-States: Your current state including velocity, heading angular velocity, can bus data, heading speed, and steering signal. 4. Mission Goal: Goal location for the next 3 seconds. Task - Thought Process: Note down critical objects and potential effects from your perceptions and predictions.Action Plan: Detail your meta-actions based on your analysis.Trajectory Planning: Develop a safe and feasible 3-second route using 6 new waypoints.Output- Thoughts: - Notable Objects Potential Effects- Meta Action- Trajectory (MOST IMPORTANT): - [(x1,y1), (x2,y2), ... , (x6,y6)]", 273 | }, 274 | { 275 | "role": "user", 276 | "content": "Perception and Prediction: - trailer at (-18.00,11.69), moving to (-2.31,16.57). - trafficcone at (3.51,1.45), moving to (3.53,1.47). - adult at (4.91,3.36), moving to (5.08,2.40). - truck at (-9.90,13.49), moving to (5.72,18.62). - adult at (10.52,15.46), moving to (10.62,15.23). - adult at (5.65,3.59), moving to (5.64,1.51).Ego-States: - Velocity (vx,vy): (-0.00,0.00) - Heading Angular Velocity (v_yaw): (-0.00) - Acceleration (ax,ay): (0.00,0.00) - Can Bus: (-0.11,0.08) - Heading Speed: (0.00) - Steering: (0.14) Historical Trajectory (last 2 seconds): [(0.00,-0.00), (0.00,-0.00), (0.00,-0.00), (0.00,-0.00)] Mission Goal: FORWARD", 277 | }, 278 | { 279 | "role": "assistant", 280 | "content": "Thoughts: - Notable Objects from Perception: None Potential Effects from Prediction: None Meta Action: STOP Trajectory:[(-0.00,0.00), (-0.00,0.00), (-0.00,0.00), (-0.00,0.00), (-0.00,0.00), (-0.00,0.00)]", 281 | }, 282 | ] 283 | tokenizer.apply_chat_template( 284 | chat, 285 | tokenize=False, 286 | add_generation_prompt=True, 287 | ) 288 | max_length = get_max_length(model) 289 | dataset = preprocess_dataset(tokenizer, max_length, seed, dataset) 290 | print(f"Column names are: {dataset.column_names}") 291 | 292 | inference(model, tokenizer, dataset, results_file, device=device) 293 | 294 | 295 | if __name__ == "__main__": 296 | main() 297 | -------------------------------------------------------------------------------- /prompt_message.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | system_message = """ 4 | **Autonomous Driving Planner** 5 | Role: You are the brain of an autonomous vehicle. Plan a safe 3-second driving trajectory. Avoid collisions with other objects. 6 | 7 | Context 8 | - Coordinates: X-axis is perpendicular, and Y-axis is parallel to the direction you're facing. You're at point (0,0). 9 | - Objective: Create a 3-second route using 6 waypoints, one every 0.5 seconds. 10 | 11 | Inputs 12 | 1. Perception & Prediction: Info about surrounding objects and their predicted movements. 13 | 2. Historical Trajectory: Your past 2-second route, given by 4 waypoints. 14 | 3. Ego-States: Your current state including velocity, heading angular velocity, can bus data, heading speed, and steering signal. 15 | 4. Mission Goal: Goal location for the next 3 seconds. 16 | 17 | Task 18 | - Thought Process: Note down critical objects and potential effects from your perceptions and predictions. 19 | - Action Plan: Detail your meta-actions based on your analysis. 20 | - Trajectory Planning: Develop a safe and feasible 3-second route using 6 new waypoints. 21 | 22 | Output 23 | - Thoughts: 24 | - Notable Objects 25 | Potential Effects 26 | - Meta Action 27 | - Trajectory (MOST IMPORTANT): 28 | - [(x1,y1), (x2,y2), ... , (x6,y6)] 29 | """ 30 | 31 | system_message_cot = """ 32 | **Autonomous Driving Planner** 33 | Role: You are the brain of an autonomous vehicle. Plan a safe 3-second driving trajectory. Avoid collisions with other objects. 34 | 35 | Output 36 | - Thoughts: identify critical objects and potential effects from perceptions and predictions. 37 | - Meta Action 38 | - Trajectory (MOST IMPORTANT): 6 waypoints, one every 0.5 seconds 39 | - [(x1,y1), (x2,y2), ... , (x6,y6)] 40 | """ 41 | 42 | system_message_short = """ 43 | **Autonomous Driving Planner** 44 | Role: You are the brain of an autonomous vehicle. Plan a safe 3-second driving trajectory. Avoid collisions with other objects. 45 | 46 | Output 47 | - Trajectory (MOST IMPORTANT): 6 waypoints, one every 0.5 seconds 48 | - [(x1,y1), (x2,y2), ... , (x6,y6)] 49 | """ 50 | 51 | def generate_user_message(data, token, perception_range=20.0, short=True): 52 | 53 | # user_message = f"You have received new input data to help you plan your route.\n" 54 | user_message = f"\n" 55 | 56 | data_dict = data[token] 57 | 58 | """ 59 | Perception and Prediction Outputs: 60 | object_boxes: [N, 7] 61 | object_names: [N] 62 | object_velocity: [N, 2] 63 | object_rel_fut_trajs: [N, 12] # diff movements in their local frames 64 | object_fut_mask: [N, 6] 65 | """ 66 | object_boxes = data_dict['gt_boxes'] 67 | object_names = data_dict['gt_names'] 68 | # object_velocity = data_dict['gt_velocity'] 69 | object_rel_fut_trajs = data_dict['gt_agent_fut_trajs'].reshape(-1, 6, 2) 70 | object_fut_trajs = np.cumsum(object_rel_fut_trajs, axis=1) + object_boxes[:, None, :2] 71 | object_fut_mask = data_dict['gt_agent_fut_masks'] 72 | user_message += f"Perception and Prediction:\n" 73 | num_objects = object_boxes.shape[0] 74 | for i in range(num_objects): 75 | if ((object_fut_trajs[i, :, 1] <= 0).all()) and (object_boxes[i, 1] <= 0): # negative Y, meaning the object is always behind us, we don't care 76 | continue 77 | if ((np.abs(object_fut_trajs[i, :, :]) > perception_range).any()) or (np.abs(object_boxes[i, :2]) > perception_range).any(): # filter faraway (> 20m) objects in case there are too many outputs 78 | continue 79 | if not short: 80 | object_name = object_names[i] 81 | ox, oy = object_boxes[i, :2] 82 | user_message += f" - {object_name} at ({ox:.2f},{oy:.2f}). " 83 | user_message += f"Future trajectory: [" 84 | prediction_ts = 6 85 | for t in range(prediction_ts): 86 | if object_fut_mask[i, t] > 0: 87 | ox, oy = object_fut_trajs[i, t] 88 | user_message += f"({ox:.2f},{oy:.2f})" 89 | else: 90 | ox, oy = "UN", "UN" 91 | user_message += f"({ox},{oy})" 92 | if t != prediction_ts -1: 93 | user_message += f", " 94 | user_message += f"]\n" 95 | else: 96 | object_name = object_names[i] 97 | object_name = object_name.split(".")[-1] 98 | ox, oy = object_boxes[i, :2] 99 | user_message += f" - {object_name} at ({ox:.2f},{oy:.2f}), " 100 | ex, ey = object_fut_trajs[i, -1] 101 | if object_fut_mask[i, -1] > 0: 102 | user_message += f"moving to ({ex:.2f},{ey:.2f}).\n" 103 | else: 104 | user_message += f"moving to unknown location.\n" 105 | 106 | """ 107 | Ego-States: 108 | gt_ego_lcf_feat: [vx, vy, ?, ?, v_yaw (rad/s), ego_length, ego_width, v0 (vy from canbus), Kappa (steering)] 109 | """ 110 | vx = data_dict['gt_ego_lcf_feat'][0]*0.5 111 | vy = data_dict['gt_ego_lcf_feat'][1]*0.5 112 | v_yaw = data_dict['gt_ego_lcf_feat'][4] 113 | ax = data_dict['gt_ego_his_diff'][-1, 0] - data_dict['gt_ego_his_diff'][-2, 0] 114 | ay = data_dict['gt_ego_his_diff'][-1, 1] - data_dict['gt_ego_his_diff'][-2, 1] 115 | cx = data_dict['gt_ego_lcf_feat'][2] 116 | cy = data_dict['gt_ego_lcf_feat'][3] 117 | vhead = data_dict['gt_ego_lcf_feat'][7]*0.5 118 | steeling = data_dict['gt_ego_lcf_feat'][8] 119 | user_message += f"Ego-States:\n" 120 | user_message += f" - Velocity (vx,vy): ({vx:.2f},{vy:.2f})\n" 121 | user_message += f" - Heading Angular Velocity (v_yaw): ({v_yaw:.2f})\n" 122 | user_message += f" - Acceleration (ax,ay): ({ax:.2f},{ay:.2f})\n" 123 | user_message += f" - Can Bus: ({cx:.2f},{cy:.2f})\n" 124 | user_message += f" - Heading Speed: ({vhead:.2f})\n" 125 | user_message += f" - Steering: ({steeling:.2f})\n" 126 | 127 | """ 128 | Historical Trjectory: 129 | gt_ego_his_trajs: [5, 2] last 2 seconds 130 | gt_ego_his_diff: [4, 2] last 2 seconds, differential format, viewed as velocity 131 | """ 132 | xh1 = data_dict['gt_ego_his_trajs'][0][0] 133 | yh1 = data_dict['gt_ego_his_trajs'][0][1] 134 | xh2 = data_dict['gt_ego_his_trajs'][1][0] 135 | yh2 = data_dict['gt_ego_his_trajs'][1][1] 136 | xh3 = data_dict['gt_ego_his_trajs'][2][0] 137 | yh3 = data_dict['gt_ego_his_trajs'][2][1] 138 | xh4 = data_dict['gt_ego_his_trajs'][3][0] 139 | yh4 = data_dict['gt_ego_his_trajs'][3][1] 140 | user_message += f"Historical Trajectory (last 2 seconds):" 141 | user_message += f" [({xh1:.2f},{yh1:.2f}), ({xh2:.2f},{yh2:.2f}), ({xh3:.2f},{yh3:.2f}), ({xh4:.2f},{yh4:.2f})]\n" 142 | 143 | """ 144 | Mission goal: 145 | gt_ego_fut_cmd 146 | """ 147 | cmd_vec = data_dict['gt_ego_fut_cmd'] 148 | right, left, forward = cmd_vec 149 | if right > 0: 150 | mission_goal = "RIGHT" 151 | elif left > 0: 152 | mission_goal = "LEFT" 153 | else: 154 | assert forward > 0 155 | mission_goal = "FORWARD" 156 | user_message += f"Mission Goal: " 157 | user_message += f"{mission_goal}\n" 158 | 159 | return user_message 160 | 161 | def generate_assistant_message(data, token, traj_only = False): 162 | 163 | data_dict = data[token] 164 | if traj_only: 165 | assitant_message = "" 166 | else: 167 | assitant_message = generate_chain_of_thoughts(data_dict) 168 | 169 | x1 = data_dict['gt_ego_fut_trajs'][1][0] 170 | x2 = data_dict['gt_ego_fut_trajs'][2][0] 171 | x3 = data_dict['gt_ego_fut_trajs'][3][0] 172 | x4 = data_dict['gt_ego_fut_trajs'][4][0] 173 | x5 = data_dict['gt_ego_fut_trajs'][5][0] 174 | x6 = data_dict['gt_ego_fut_trajs'][6][0] 175 | y1 = data_dict['gt_ego_fut_trajs'][1][1] 176 | y2 = data_dict['gt_ego_fut_trajs'][2][1] 177 | y3 = data_dict['gt_ego_fut_trajs'][3][1] 178 | y4 = data_dict['gt_ego_fut_trajs'][4][1] 179 | y5 = data_dict['gt_ego_fut_trajs'][5][1] 180 | y6 = data_dict['gt_ego_fut_trajs'][6][1] 181 | if not traj_only: 182 | assitant_message += f"Trajectory:\n" 183 | assitant_message += f"[({x1:.2f},{y1:.2f}), ({x2:.2f},{y2:.2f}), ({x3:.2f},{y3:.2f}), ({x4:.2f},{y4:.2f}), ({x5:.2f},{y5:.2f}), ({x6:.2f},{y6:.2f})]" 184 | # assitant_message += f"[ {x1:.2f},{x2:.2f},{x3:.2f},{x4:.2f},{x5:.2f},{x6:.2f},{y1:.2f},{y2:.2f},{y3:.2f},{y4:.2f},{y5:.2f},{y6:.2f} ]" 185 | return assitant_message 186 | 187 | def generate_chain_of_thoughts(data_dict, perception_range=20.0, short=True): 188 | """ 189 | Generate chain of thoughts reasoning and prompting by simple rules 190 | """ 191 | ego_fut_trajs = data_dict['gt_ego_fut_trajs'] 192 | ego_his_trajs = data_dict['gt_ego_his_trajs'] 193 | ego_fut_diff = data_dict['gt_ego_fut_diff'] 194 | ego_his_diff = data_dict['gt_ego_his_diff'] 195 | vx = data_dict['gt_ego_lcf_feat'][0]*0.5 196 | vy = data_dict['gt_ego_lcf_feat'][1]*0.5 197 | ax = data_dict['gt_ego_his_diff'][-1, 0] - data_dict['gt_ego_his_diff'][-2, 0] 198 | ay = data_dict['gt_ego_his_diff'][-1, 1] - data_dict['gt_ego_his_diff'][-2, 1] 199 | ego_estimate_velos = [ 200 | [0, 0], 201 | [vx, vy], 202 | [vx+ax, vy+ay], 203 | [vx+2*ax, vy+2*ay], 204 | [vx+3*ax, vy+3*ay], 205 | [vx+4*ax, vy+4*ay], 206 | [vx+5*ax, vy+5*ay], 207 | ] 208 | ego_estimate_trajs = np.cumsum(ego_estimate_velos, axis=0) # [7, 2] 209 | # print(ego_estimate_trajs) 210 | object_boxes = data_dict['gt_boxes'] 211 | object_names = data_dict['gt_names'] 212 | 213 | object_rel_fut_trajs = data_dict['gt_agent_fut_trajs'].reshape(-1, 6, 2) 214 | object_fut_trajs = np.cumsum(object_rel_fut_trajs, axis=1) + object_boxes[:, None, :2] 215 | object_fut_trajs = np.concatenate([object_boxes[:, None, :2], object_fut_trajs], axis=1) 216 | object_fut_mask = data_dict['gt_agent_fut_masks'] 217 | num_objects = object_boxes.shape[0] 218 | 219 | num_future_horizon = 7 # include current 220 | object_collisons = np.zeros((num_objects, num_future_horizon)) 221 | for i in range(num_objects): 222 | if (object_fut_trajs[i, :, 1] <= 0).all(): # negative Y, meaning the object is always behind us, we don't care 223 | continue 224 | if (np.abs(object_fut_trajs[i, :, :]) > perception_range).any(): # filter faraway (> 20m) objects in case there are too many outputs 225 | continue 226 | for t in range(num_future_horizon): 227 | mask = object_fut_mask[i, t-1] > 0 if t > 0 else True 228 | if not mask: continue 229 | ego_x, ego_y = ego_estimate_trajs[t] 230 | object_x, object_y = object_fut_trajs[i, t] 231 | size_x, size_y = object_boxes[i, 3:5] * 0.5 # half size 232 | collision = collision_detection(ego_x, ego_y, 0.925, 2.04, object_x, object_y, size_x, size_y) 233 | if collision: 234 | object_collisons[i, t] = 1 235 | # import pdb; pdb.set_trace() 236 | break 237 | 238 | assitant_message = f"Thoughts:\n" 239 | if (object_collisons==0).all(): # nothing to care about 240 | assitant_message += f" - Notable Objects from Perception: None\n" 241 | assitant_message += f" Potential Effects from Prediction: None\n" 242 | # assitant_message += f" Nothing to care.\n" 243 | else: 244 | for i in range(num_objects): 245 | for t in range(num_future_horizon): 246 | if object_collisons[i, t] > 0: 247 | object_name = object_names[i] 248 | if short: 249 | object_name = object_name.split(".")[-1] 250 | ox, oy = object_boxes[i, :2] 251 | time = t*0.5 252 | # assitant_message += f" ################################################################################\n" 253 | assitant_message += f" - Notable Objects from Perception: {object_name} at ({ox:.2f},{oy:.2f})\n" 254 | assitant_message += f" Potential Effects from Prediction: within the safe zone of the ego-vehicle at the {time}-second timestep\n" 255 | meta_action = generate_meta_action( 256 | ego_fut_diff=ego_fut_diff, 257 | ego_fut_trajs=ego_fut_trajs, 258 | ego_his_diff=ego_his_diff, 259 | ego_his_trajs=ego_his_trajs 260 | ) 261 | assitant_message += ("Meta Action: " + meta_action) 262 | return assitant_message 263 | 264 | def collision_detection(x1, y1, sx1, sy1, x2, y2, sx2, sy2, x_space=1.0, y_space=3.0): # safe distance 265 | if (np.abs(x1-x2) < sx1+sx2+x_space) and (y2 > y1) and (y2 - y1 < sy1+sy2+y_space): # in front of you 266 | return True 267 | else: 268 | return False 269 | 270 | def generate_meta_action( 271 | ego_fut_diff, 272 | ego_fut_trajs, 273 | ego_his_diff, 274 | ego_his_trajs, 275 | ): 276 | meta_action = "" 277 | 278 | # speed meta 279 | constant_eps = 0.5 280 | his_velos = np.linalg.norm(ego_his_diff, axis=1) 281 | fut_velos = np.linalg.norm(ego_fut_diff, axis=1) 282 | cur_velo = his_velos[-1] 283 | end_velo = fut_velos[-1] 284 | 285 | if cur_velo < constant_eps and end_velo < constant_eps: 286 | speed_meta = "stop" 287 | elif end_velo < constant_eps: 288 | speed_meta = "a deceleration to zero" 289 | elif np.abs(end_velo - cur_velo) < constant_eps: 290 | speed_meta = "a constant speed" 291 | else: 292 | if cur_velo > end_velo: 293 | if cur_velo > 2 * end_velo: 294 | speed_meta = "a quick deceleration" 295 | else: 296 | speed_meta = "a deceleration" 297 | else: 298 | if end_velo > 2 * cur_velo: 299 | speed_meta = "a quick acceleration" 300 | else: 301 | speed_meta = "an acceleration" 302 | 303 | # behavior_meta 304 | if speed_meta == "stop": 305 | meta_action += (speed_meta + "\n") 306 | return meta_action.upper() 307 | else: 308 | forward_th = 2.0 309 | lane_changing_th = 4.0 310 | if (np.abs(ego_fut_trajs[:, 0]) < forward_th).all(): 311 | behavior_meta = "move forward" 312 | else: 313 | if ego_fut_trajs[-1, 0] < 0: # left 314 | if np.abs(ego_fut_trajs[-1, 0]) > lane_changing_th: 315 | behavior_meta = "turn left" 316 | else: 317 | behavior_meta = "chane lane to left" 318 | elif ego_fut_trajs[-1, 0] > 0: # right 319 | if np.abs(ego_fut_trajs[-1, 0]) > lane_changing_th: 320 | behavior_meta = "turn right" 321 | else: 322 | behavior_meta = "change lane to right" 323 | else: 324 | raise ValueError(f"Undefined behaviors: {ego_fut_trajs}") 325 | 326 | # heading-based rules 327 | # ego_fut_headings = np.arctan(ego_fut_diff[:,0]/(ego_fut_diff[:,1]+1e-4))*180/np.pi # in degree 328 | # ego_his_headings = np.arctan(ego_his_diff[:,0]/(ego_his_diff[:,1]+1e-4))*180/np.pi # in degree 329 | 330 | # forward_heading_th = 5 # forward heading is always near 0 331 | # turn_heading_th = 45 332 | 333 | # if (np.abs(ego_fut_headings) < forward_heading_th).all(): 334 | # behavior_meta = "move forward" 335 | # else: 336 | # # we extract a 5-s curve, if the largest heading change is above 45 degrees, we view it as turn 337 | # curve_headings = np.concatenate([ego_his_headings, ego_fut_headings]) 338 | # min_heading, max_heading = curve_headings.min(), curve_headings.max() 339 | # if ego_fut_trajs[-1, 0] < 0: # left 340 | # if np.abs(max_heading - min_heading) > turn_heading_th: 341 | # behavior_meta = "turn left" 342 | # else: 343 | # behavior_meta = "chane lane to left" 344 | # elif ego_fut_trajs[-1, 0] > 0: # right 345 | # if np.abs(max_heading - min_heading) > turn_heading_th: 346 | # behavior_meta = "turn right" 347 | # else: 348 | # behavior_meta = "chane lane to right" 349 | # else: 350 | # raise ValueError(f"Undefined behaviors: {ego_fut_trajs}") 351 | 352 | meta_action += (behavior_meta + " with " + speed_meta + "\n") 353 | return meta_action.upper() 354 | 355 | 356 | # system_message = """ 357 | # As a professional autonomous driving system, you are tasked with plotting a secure and human-like path within a 3-second window using the following guidelines and inputs: 358 | 359 | # ### Context 360 | # - **Coordinate System**: You are in the ego-vehicle coordinate system positioned at (0,0). The X-axis is perpendicular to your heading direction, while the Y-axis represents the heading direction. 361 | # - **Location**: You are mounted at the center of an ego-vehicle that has 4.08 meters length and 1.85 meters width. 362 | # - **Objective**: Generate a route characterized by 6 waypoints, with a new waypoint established every 0.5 seconds. 363 | 364 | # ### Inputs 365 | # 1. **Perception & Prediction** (You observe the surrounding objects and estimate their future movements): 366 | # - object name at (ox1, ox2). Future trajectory: [(oxt1, oyt1), ..., (oxt6, oyt6)], 6 waypoints in 3 seconds, UN denotes future location at that timestep is unknown 367 | # - ... 368 | 369 | # 2. **Historical Trajectory** (Your historital trajectory from the last 2 seconds, presented as 4 waypoints): 370 | # - [(xh1, yh1), (xh2, yh2), (xh3, yh3), (xh4, yh4)] 371 | 372 | # 3. **Ego-States** (Your current states): 373 | # - **Velocity** (vx, vy) # meters per 0.5 second 374 | # - **Heading Angular Velocity** (v_yaw) # ego-vehicle heading change rate, rad per second 375 | # - **Acceleration** (ax, ay) # velocity change rate per 0.5 second 376 | # - **Heading Speed** # meters per 0.5 second 377 | # - **Steering** # steering signal 378 | 379 | # 4. **Mission Goal**: Instructions outlining your objectives for the upcoming 3 seconds. 380 | 381 | # ### Task 382 | # - Integrate and process all the above inputs to construct a driving route. 383 | # - Thinking about what you have received and make driving decisions. Write down your thoughts and the action. 384 | # - Output a set of 6 new waypoints for the upcoming 3 seconds (Note: This task is of the most importance!). These should be formatted as coordinate pairs: 385 | # - (x1, y1) # 0.5 second 386 | # - (x2, y2) # 1.0 second 387 | # - (x3, y3) # 1.5 second 388 | # - (x4, y4) # 2.0 second 389 | # - (x5, y5) # 2.5 second 390 | # - (x6, y6) # 3.0 second 391 | # - Final output format: 392 | # Thoughts: 393 | # - Notable Objects from Perception: ... 394 | # Potential Effects from Prediction: ... 395 | # Meta Action: 396 | # ... 397 | # Trajectory: 398 | # - [(x1, y1), (x2, y2), (x3, y3), (x4, y4), (x5, y5), (x6, y6)] 399 | 400 | # Ensure the safety and feasibility of the path devised within the given 3-second timeframe. Let's work on crafting a safe route! 401 | # """ 402 | 403 | def generate_incontext_message(data, token): 404 | incontext_message = "\nFor example:\n" 405 | incontext_message += "Input:\n" 406 | user_message = generate_user_message(data, token) 407 | incontext_message += user_message 408 | incontext_message += "You should generate the following content:\n" 409 | assistant_message = generate_assistant_message(data, token) 410 | incontext_message += assistant_message 411 | return incontext_message -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | aiohttp==3.8.5 2 | aiosignal==1.3.1 3 | async-timeout==4.0.3 4 | attrs==23.1.0 5 | certifi==2023.7.22 6 | charset-normalizer==3.2.0 7 | frozenlist==1.4.0 8 | idna==3.4 9 | multidict==6.0.4 10 | ndjson==0.3.1 11 | numpy==1.26.0 12 | openai==0.28.0 13 | regex==2023.8.8 14 | requests==2.31.0 15 | tenacity==8.2.3 16 | tiktoken==0.5.1 17 | tqdm==4.66.1 18 | urllib3==2.0.4 19 | yarl==1.9.2 20 | transformers 21 | torch 22 | bitsandbytes 23 | scipy 24 | datasets 25 | peft 26 | omegaconf 27 | hydra-core 28 | -------------------------------------------------------------------------------- /run_evaluation.sh: -------------------------------------------------------------------------------- 1 | for f in output/*.json ; 2 | 3 | do 4 | echo Processing file: $f 5 | f_eval=${f/output/results} 6 | f_eval=${f_eval/json/eval.json} 7 | 8 | echo Saving to: $f_eval 9 | 10 | python evaluation.py \ 11 | --prediction_file $f \ 12 | --output_path $f_eval ; 13 | 14 | echo Saved to: $f_eval 15 | echo ================================ 16 | 17 | done 18 | -------------------------------------------------------------------------------- /run_inference.sh: -------------------------------------------------------------------------------- 1 | python inference.py \ 2 | --validation_data_file "data/val_with_token.json" \ 3 | --model_name "meta-llama/Llama-2-7b-hf" \ 4 | --adapter_path "checkpoints/llama2_lora/checkpoint-70000-llama2" \ 5 | --results_file "output/llama2_lora/output_llama2_lora.json" \ 6 | -------------------------------------------------------------------------------- /run_training.sh: -------------------------------------------------------------------------------- 1 | python training.py \ 2 | --train_data_file "data/train_with_token.json" \ 3 | --model_name "meta-llama/Llama-2-7b-hf" \ 4 | --output_path "checkpoints/llama2_lora" \ 5 | -------------------------------------------------------------------------------- /training.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from functools import partial 4 | import bitsandbytes as bnb 5 | from peft import ( 6 | LoraConfig, 7 | get_peft_model, 8 | prepare_model_for_kbit_training, 9 | AutoPeftModelForCausalLM, 10 | ) 11 | import torch 12 | from transformers import ( 13 | AutoModelForCausalLM, 14 | AutoTokenizer, 15 | set_seed, 16 | Trainer, 17 | TrainingArguments, 18 | BitsAndBytesConfig, 19 | DataCollatorForLanguageModeling, 20 | Trainer, 21 | TrainingArguments, 22 | ) 23 | from trl import SFTTrainer 24 | from datasets import load_dataset, Dataset 25 | 26 | 27 | def load_model(model_name, bnb_config, cache_dir, security_token): 28 | n_gpus = torch.cuda.device_count() 29 | max_memory = f"{40960}MB" 30 | 31 | model = AutoModelForCausalLM.from_pretrained( 32 | model_name, 33 | quantization_config=bnb_config, 34 | device_map="auto", # dispatch efficiently the model on the available resources 35 | max_memory={i: max_memory for i in range(n_gpus)}, 36 | cache_dir=cache_dir, 37 | token=security_token, 38 | ) 39 | tokenizer = AutoTokenizer.from_pretrained( 40 | model_name, cache_dir=cache_dir, token=security_token 41 | ) 42 | # Needed for LLaMA tokenizer 43 | tokenizer.pad_token = tokenizer.eos_token 44 | 45 | return model, tokenizer 46 | 47 | 48 | def create_prompt_formats(sample, for_validation: bool = False): 49 | """ 50 | Format various fields of the sample ('instruction', 'context', 'response') 51 | Then concatenate them using two newline characters 52 | :param sample: Sample dictionnary 53 | """ 54 | message = sample["messages"] 55 | system = f"{message[0]['role']}\n{message[0]['content']}" 56 | user = f"{message[1]['role']}\n{message[1]['content']}" 57 | assistant = f"{message[2]['role']}\n{message[2]['content']}" 58 | end = "### End" 59 | 60 | if for_validation: 61 | parts = [part for part in [system, user, end] if part] 62 | else: 63 | parts = [part for part in [system, user, assistant, end] if part] 64 | 65 | formatted_prompt = "\n\n".join(parts) 66 | sample["text"] = formatted_prompt 67 | 68 | return sample 69 | 70 | 71 | # SOURCE https://github.com/databrickslabs/dolly/blob/master/training/trainer.py 72 | def get_max_length(model): 73 | conf = model.config 74 | max_length = None 75 | for length_setting in [ 76 | "n_positions", 77 | "max_position_embeddings", 78 | "seq_length", 79 | ]: 80 | max_length = getattr(model.config, length_setting, None) 81 | if max_length: 82 | print(f"Found max lenth: {max_length}") 83 | break 84 | if not max_length: 85 | max_length = 1024 86 | print(f"Using default max length: {max_length}") 87 | return max_length 88 | 89 | 90 | def preprocess_batch(batch, tokenizer, max_length): 91 | """ 92 | Tokenizing a batch 93 | """ 94 | return tokenizer( 95 | batch["text"], 96 | max_length=max_length, 97 | truncation=True, 98 | ) 99 | 100 | 101 | # SOURCE https://github.com/databrickslabs/dolly/blob/master/training/trainer.py 102 | def preprocess_dataset( 103 | tokenizer: AutoTokenizer, 104 | max_length: int, 105 | seed: int, 106 | dataset: Dataset, 107 | for_validation: bool = False, 108 | ): 109 | """Format & tokenize it so it is ready for training 110 | :param tokenizer (AutoTokenizer): Model Tokenizer 111 | :param max_length (int): Maximum number of tokens to emit from tokenizer 112 | """ 113 | 114 | # Add prompt to each sample 115 | print("Preprocessing dataset...") 116 | dataset = dataset.map( 117 | create_prompt_formats, fn_kwargs={"for_validation": for_validation} 118 | ) 119 | 120 | # Apply preprocessing to each batch of the dataset 121 | _preprocessing_function = partial( 122 | preprocess_batch, max_length=max_length, tokenizer=tokenizer 123 | ) 124 | dataset = dataset.map( 125 | _preprocessing_function, 126 | batched=True, 127 | ) 128 | 129 | # Filter out samples that have input_ids exceeding max_length 130 | dataset = dataset.filter( 131 | lambda sample: len(sample["input_ids"]) < max_length 132 | ) 133 | 134 | # Shuffle dataset 135 | dataset = dataset.shuffle(seed=seed) 136 | 137 | return dataset 138 | 139 | 140 | def create_bnb_config(): 141 | bnb_config = BitsAndBytesConfig( 142 | load_in_4bit=True, 143 | bnb_4bit_use_double_quant=True, 144 | bnb_4bit_quant_type="nf4", 145 | bnb_4bit_compute_dtype=torch.bfloat16, 146 | ) 147 | 148 | return bnb_config 149 | 150 | 151 | def create_peft_config(modules): 152 | """ 153 | Create Parameter-Efficient Fine-Tuning config for your model 154 | :param modules: Names of the modules to apply Lora to 155 | """ 156 | config = LoraConfig( 157 | r=16, # dimension of the updated matrices 158 | lora_alpha=32, # parameter for scaling 159 | target_modules=modules, 160 | lora_dropout=0.1, # dropout probability for layers 161 | bias="none", 162 | task_type="CAUSAL_LM", 163 | ) 164 | 165 | return config 166 | 167 | 168 | def find_all_linear_names(model): 169 | cls = ( 170 | bnb.nn.Linear4bit 171 | ) # if args.bits == 4 else (bnb.nn.Linear8bitLt if args.bits == 8 else torch.nn.Linear) 172 | lora_module_names = set() 173 | for name, module in model.named_modules(): 174 | if isinstance(module, cls): 175 | names = name.split(".") 176 | lora_module_names.add(names[0] if len(names) == 1 else names[-1]) 177 | 178 | if "lm_head" in lora_module_names: # needed for 16-bit 179 | lora_module_names.remove("lm_head") 180 | return list(lora_module_names) 181 | 182 | 183 | def print_trainable_parameters(model, use_4bit=False): 184 | """ 185 | Prints the number of trainable parameters in the model. 186 | """ 187 | trainable_params = 0 188 | all_param = 0 189 | for _, param in model.named_parameters(): 190 | num_params = param.numel() 191 | # if using DS Zero 3 and the weights are initialized empty 192 | if num_params == 0 and hasattr(param, "ds_numel"): 193 | num_params = param.ds_numel 194 | 195 | all_param += num_params 196 | if param.requires_grad: 197 | trainable_params += num_params 198 | if use_4bit: 199 | trainable_params /= 2 200 | print( 201 | f"all params: {all_param:,d} || trainable params: {trainable_params:,d} || trainable%: {100 * trainable_params / all_param}" 202 | ) 203 | 204 | 205 | def train(model, tokenizer, dataset, output_dir): 206 | # Apply preprocessing to the model to prepare it by 207 | # 1 - Enabling gradient checkpointing to reduce memory usage during fine-tuning 208 | model.gradient_checkpointing_enable() 209 | 210 | # 2 - Using the prepare_model_for_kbit_training method from PEFT 211 | model = prepare_model_for_kbit_training(model) 212 | 213 | # Get LoRA module names 214 | modules = find_all_linear_names(model) 215 | 216 | # Create PEFT config for these modules and wrap the model to PEFT 217 | peft_config = create_peft_config(modules) 218 | model = get_peft_model(model, peft_config) 219 | 220 | # Print information about the percentage of trainable parameters 221 | print_trainable_parameters(model) 222 | 223 | # Training parameters 224 | trainer = SFTTrainer( 225 | model=model, 226 | tokenizer=tokenizer, 227 | train_dataset=dataset, 228 | dataset_text_field="text", 229 | max_seq_length=1024, 230 | # Training parameters from here: 231 | # https://huggingface.co/blog/Llama2-for-non-engineers 232 | args=TrainingArguments( 233 | per_device_train_batch_size=1, 234 | gradient_accumulation_steps=1, 235 | weight_decay=0.01, 236 | warmup_ratio=0.1, 237 | learning_rate=2e-4, 238 | fp16=True, 239 | logging_steps=1, 240 | save_steps=10000, 241 | output_dir=output_dir, 242 | optim="adamw_torch", # "paged_adamw_32bit", 243 | lr_scheduler_type="linear", 244 | num_train_epochs=15, 245 | ), 246 | data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False), 247 | ) 248 | 249 | model.config.use_cache = False # re-enable for inference to speed up predictions for similar inputs 250 | 251 | ### SOURCE https://github.com/artidoro/qlora/blob/main/qlora.py 252 | # Verifying the datatypes before training 253 | dtypes = {} 254 | for _, p in model.named_parameters(): 255 | dtype = p.dtype 256 | if dtype not in dtypes: 257 | dtypes[dtype] = 0 258 | dtypes[dtype] += p.numel() 259 | total = 0 260 | for k, v in dtypes.items(): 261 | total += v 262 | for k, v in dtypes.items(): 263 | print(k, v, v / total) 264 | 265 | do_train = True 266 | 267 | # Launch training 268 | print("Training...") 269 | 270 | if do_train: 271 | train_result = trainer.train() 272 | metrics = train_result.metrics 273 | trainer.log_metrics("train", metrics) 274 | trainer.save_metrics("train", metrics) 275 | trainer.save_state() 276 | print(metrics) 277 | 278 | # Saving model 279 | print("Saving last checkpoint of the model...") 280 | os.makedirs(output_dir, exist_ok=True) 281 | trainer.model.save_pretrained(output_dir) 282 | 283 | # Free memory for merging weights 284 | del model 285 | del trainer 286 | torch.cuda.empty_cache() 287 | 288 | 289 | if __name__ == "__main__": 290 | parser = argparse.ArgumentParser() 291 | parser.add_argument( 292 | "--train_data_file", 293 | "-t", 294 | dest="train_data_file", 295 | type=str, 296 | help="Training JSON data file.", 297 | required=True, 298 | ) 299 | parser.add_argument( 300 | "--model_name", 301 | "-m", 302 | dest="model_name", 303 | type=str, 304 | help="Name of base model to train an adapter for.", 305 | required=True, 306 | ) 307 | parser.add_argument( 308 | "--output_dir", 309 | "-o", 310 | dest="output_dir", 311 | type=str, 312 | help="Path to save the trained adapter model.", 313 | required=True, 314 | ) 315 | parser.add_argument( 316 | "--cache_dir", 317 | dest="cache_dir", 318 | type=str, 319 | default=None, 320 | required=False, 321 | help="The cache directory to save the downloaded models.", 322 | ) 323 | 324 | args = parser.parse_args() 325 | model_name = args.model_name 326 | training_data_file = args.train_data_file 327 | output_dir = args.output_dir 328 | cache_dir = args.cache_dir 329 | hf_token = os.getenv("HF_ACCESS_TOKEN") 330 | 331 | dataset = load_dataset( 332 | "json", 333 | data_files=training_data_file, 334 | split="train", 335 | ) 336 | print(f"Number of prompts: {len(dataset)}") 337 | print(f"Column names are: {dataset.column_names}") 338 | 339 | bnb_config = create_bnb_config() 340 | 341 | model, tokenizer = load_model(model_name, bnb_config, cache_dir, hf_token) 342 | 343 | max_length = get_max_length(model) 344 | 345 | dataset = preprocess_dataset(tokenizer, max_length, 0, dataset) 346 | train(model, tokenizer, dataset, output_dir) 347 | --------------------------------------------------------------------------------