├── xchat ├── eval │ ├── __init__.py │ ├── mbpp │ │ ├── data.py │ │ ├── examplars.py │ │ ├── prompts.py │ │ ├── evaluation.py │ │ ├── execution.py │ │ └── run_eval.py │ ├── humaneval │ │ ├── prompts.py │ │ └── run_eval.py │ ├── mmlu │ │ ├── categories.py │ │ └── run_eval.py │ ├── gsm │ │ ├── examplars.py │ │ └── run_eval.py │ ├── dispatch_openai_requests.py │ ├── bbh │ │ └── run_eval.py │ └── utils.py └── train │ ├── llama_flash_attn_monkey_patch.py │ └── train.py ├── assets ├── xlang.png ├── interface.png ├── training.png ├── salesforce.webp ├── transparent.png ├── agent-scenarios.png └── overall-perform.png ├── scripts ├── deploy │ ├── vllm_lemur.sh │ └── tgi_lemur.sh └── eval │ ├── mbpp.sh │ ├── humaneval.sh │ ├── bbh.sh │ ├── mmlu.sh │ └── gsm.sh ├── .pre-commit-config.yaml ├── pyproject.toml ├── .gitignore ├── LICENSE └── README.md /xchat/eval/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /assets/xlang.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenLemur/Lemur/HEAD/assets/xlang.png -------------------------------------------------------------------------------- /assets/interface.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenLemur/Lemur/HEAD/assets/interface.png -------------------------------------------------------------------------------- /assets/training.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenLemur/Lemur/HEAD/assets/training.png -------------------------------------------------------------------------------- /assets/salesforce.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenLemur/Lemur/HEAD/assets/salesforce.webp -------------------------------------------------------------------------------- /assets/transparent.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenLemur/Lemur/HEAD/assets/transparent.png -------------------------------------------------------------------------------- /assets/agent-scenarios.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenLemur/Lemur/HEAD/assets/agent-scenarios.png -------------------------------------------------------------------------------- /assets/overall-perform.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenLemur/Lemur/HEAD/assets/overall-perform.png -------------------------------------------------------------------------------- /scripts/deploy/vllm_lemur.sh: -------------------------------------------------------------------------------- 1 | N_GPUS=4 2 | gpus='"device=0,1,2,3"' 3 | MODEL_PATH=OpenLemur/lemur-70b-v1 4 | MODEL_NAME=lemur-70b-v1 5 | # HF_HOME="" 6 | 7 | docker run --gpus $gpus --rm \ 8 | -v $HF_HOME:/root/.cache/huggingface \ 9 | --shm-size=10.24gb \ 10 | --name vllm-$MODEL_NAME \ 11 | ranpox/fastchat:lemur \ 12 | python -m vllm.entrypoints.openai.api_server \ 13 | --model $MODEL_PATH \ 14 | --tensor-parallel-size $N_GPUS \ 15 | --served-model-name $MODEL_NAME \ 16 | --max-num-batched-tokens 4096 \ 17 | --load-format pt \ 18 | --port 8000 19 | -------------------------------------------------------------------------------- /scripts/eval/mbpp.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0,1,2,3 2 | DATA_DIR=data/eval/mbpp 3 | OUTPUT_DIR=results/mbpp/Llama-2-70b-chat-hf 4 | MODEL_PATH=meta-llama/Llama-2-70b-chat-hf 5 | python -m xchat.eval.mbpp.run_eval \ 6 | --data_dir $DATA_DIR \ 7 | --max_new_token 650 \ 8 | --max_num_examples 500 \ 9 | --save_dir $OUTPUT_DIR \ 10 | --model $MODEL_PATH \ 11 | --tokenizer $MODEL_PATH \ 12 | --eval_batch_size 8 \ 13 | --chat_format llama2chat \ 14 | --load_in_8bit \ 15 | --greedy_decoding \ 16 | --few_shot \ 17 | --eval_pass_at_ks 1 \ 18 | --unbiased_sampling_size_n 1 19 | -------------------------------------------------------------------------------- /scripts/eval/humaneval.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 2 | DATA_DIR=data/eval/humaneval 3 | 4 | if [ ! -d "$DATA_DIR" ]; then 5 | echo "Downloading HumanEval data..." 6 | mkdir -p $DATA_DIR 7 | wget -P $DATA_DIR https://github.com/openai/human-eval/raw/master/data/HumanEval.jsonl.gz 8 | fi 9 | 10 | PYTHONPATH=$PWD 11 | OUTPUT_DIR=results/humaneval/Llama-2-70b-chat-hf 12 | MODEL=meta-llama/Llama-2-70b-chat-hf 13 | 14 | python -m xchat.eval.humaneval.run_eval \ 15 | --data_file $DATA_DIR/HumanEval.jsonl.gz \ 16 | --max_new_token 512 \ 17 | --eval_pass_at_ks 1 \ 18 | --unbiased_sampling_size_n 1 \ 19 | --temperature 0.2 \ 20 | --save_dir $OUTPUT_DIR \ 21 | --model $MODEL \ 22 | --tokenizer $MODEL \ 23 | --eval_batch_size 24 \ 24 | --load_in_8bit \ 25 | --chat_format llama2chat 26 | -------------------------------------------------------------------------------- /scripts/deploy/tgi_lemur.sh: -------------------------------------------------------------------------------- 1 | model=OpenLemur/lemur-70b-v1 2 | name="lemur-70b-v1" 3 | # HF_HOME="" 4 | volume=$HF_HOME/hub 5 | gpus='"device=0,1,2,3"' 6 | num_shard=4 7 | dtype="bfloat16" 8 | quantize="bitsandbytes" 9 | max_input_length=4095 10 | max_total_tokens=4096 11 | max_batch_prefill_tokens=16380 12 | max_batch_total_tokens=16384 13 | 14 | docker run --gpus $gpus \ 15 | --shm-size 1g --rm \ 16 | -p 8080:80 -v $volume:/data \ 17 | --name tgi-$name \ 18 | ghcr.io/huggingface/text-generation-inference:1.0.3 \ 19 | --model-id $model \ 20 | --sharded false \ 21 | --num-shard $num_shard \ 22 | --dtype $dtype \ 23 | --max-input-length $max_input_length \ 24 | --max-total-tokens $max_total_tokens \ 25 | --max-batch-prefill-tokens $max_batch_prefill_tokens \ 26 | --max-batch-total-tokens $max_batch_total_tokens \ 27 | --max-stop-sequences 10 28 | -------------------------------------------------------------------------------- /scripts/eval/bbh.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0,1 2 | 3 | DATA_DIR=data/eval/bbh 4 | 5 | if [ ! -d "data" ]; then 6 | mkdir data 7 | fi 8 | 9 | if [ ! -d "data/eval" ]; then 10 | mkdir data/eval 11 | fi 12 | 13 | if [ ! -d $DATA_DIR ]; then 14 | echo "Downloading BBH data..." 15 | mkdir -p data/downloads 16 | wget -O data/downloads/bbh_data.zip https://github.com/suzgunmirac/BIG-Bench-Hard/archive/refs/heads/main.zip 17 | unzip data/downloads/bbh_data.zip -d data/downloads/bbh 18 | mv data/downloads/bbh/BIG-Bench-Hard-main data/eval/bbh && rm -rf data/downloads/ 19 | fi 20 | 21 | MODEL_DIR=codellama/CodeLlama-7b-Instruct-hf 22 | OUTPUT_DIR=results/bbh/llama-2-7b-hf 23 | 24 | python -m xchat.eval.bbh.run_eval \ 25 | --data_dir data/eval/bbh/ \ 26 | --save_dir $OUTPUT_DIR \ 27 | --model $MODEL_DIR \ 28 | --tokenizer $MODEL_DIR \ 29 | --eval_batch_size 20 \ 30 | --load_in_8bit \ 31 | --no_cot \ 32 | --chat_format codellama-instruct 33 | -------------------------------------------------------------------------------- /scripts/eval/mmlu.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=4,5,6,7 2 | 3 | 4 | DATA_DIR=data/eval/mmlu 5 | 6 | if [ ! -d $DATA_DIR ]; then 7 | echo "Downloading MMLU data..." 8 | wget -O data/mmlu_data.tar https://people.eecs.berkeley.edu/~hendrycks/data.tar 9 | mkdir -p data/eval/mmlu_data 10 | tar -xvf data/mmlu_data.tar -C data/eval/mmlu_data 11 | mv data/eval/mmlu_data/data $DATA_DIR && rm -r data/eval/mmlu_data data/mmlu_data.tar 12 | fi 13 | 14 | MODEL_DIR=meta-llama/Llama-2-7b-hf 15 | OUTPUT_DIR=results/mmlu/llama-2-7b-hf 16 | python -m xchat.eval.mmlu.run_eval \ 17 | --ntrain 5 \ 18 | --data_dir $DATA_DIR \ 19 | --save_dir $OUTPUT_DIR \ 20 | --model_name_or_path $MODEL_DIR \ 21 | --tokenizer_name_or_path $MODEL_DIR \ 22 | --eval_batch_size 4 23 | 24 | # MODEL_DIR=bigcode/starcoder 25 | # OUTPUT_DIR=results/mmlu/starcoder 26 | # python -m xchat.eval.mmlu.run_eval \ 27 | # --ntrain 5 \ 28 | # --data_dir $DATA_DIR \ 29 | # --save_dir $OUTPUT_DIR \ 30 | # --model_name_or_path $MODEL_DIR \ 31 | # --tokenizer_name_or_path $MODEL_DIR \ 32 | # --eval_batch_size 16 33 | -------------------------------------------------------------------------------- /xchat/eval/mbpp/data.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import json 3 | import os 4 | from typing import Dict, Iterable 5 | 6 | 7 | def read_problems(evalset_file: str) -> Dict[str, Dict]: 8 | return {task["task_id"]: task for task in stream_jsonl(evalset_file)} 9 | 10 | 11 | def stream_jsonl(filename: str) -> Iterable[Dict]: 12 | """ 13 | Parses each jsonl line and yields it as a dictionary. 14 | """ 15 | if filename.endswith(".gz"): 16 | with open(filename, "rb") as gzfp, gzip.open(gzfp, "rt") as fp: 17 | for line in fp: 18 | if any(not x.isspace() for x in line): 19 | yield json.loads(line) 20 | else: 21 | with open(filename) as fp: 22 | for line in fp: 23 | if any(not x.isspace() for x in line): 24 | yield json.loads(line) 25 | 26 | 27 | def write_jsonl(filename: str, data: Iterable[Dict], append: bool = False): 28 | """ 29 | Writes an iterable of dictionaries to jsonl. 30 | """ 31 | mode = "ab" if append else "wb" 32 | filename = os.path.expanduser(filename) 33 | if filename.endswith(".gz"): 34 | with open(filename, mode) as fp, gzip.GzipFile(fileobj=fp, mode="wb") as gzfp: 35 | for x in data: 36 | gzfp.write((json.dumps(x) + "\n").encode("utf-8")) 37 | else: 38 | with open(filename, mode) as fp: 39 | for x in data: 40 | fp.write((json.dumps(x) + "\n").encode("utf-8")) 41 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | ci: 2 | autofix_prs: false 3 | exclude: ^(poetry.lock|.idea/) 4 | repos: 5 | - repo: https://github.com/charliermarsh/ruff-pre-commit 6 | rev: "v0.0.270" 7 | hooks: 8 | - id: ruff 9 | args: [ --fix, --exit-non-zero-on-fix ] 10 | 11 | - repo: https://github.com/psf/black 12 | rev: 23.3.0 13 | hooks: 14 | - id: black 15 | args: 16 | - "--line-length=120" 17 | - "--target-version=py39" 18 | - "--target-version=py310" 19 | - "--target-version=py311" 20 | types: [ python ] 21 | 22 | - repo: https://github.com/pre-commit/pre-commit-hooks 23 | rev: v4.4.0 24 | hooks: 25 | - id: check-added-large-files 26 | - id: check-ast 27 | - id: check-builtin-literals 28 | - id: check-case-conflict 29 | - id: check-docstring-first 30 | - id: check-shebang-scripts-are-executable 31 | - id: check-merge-conflict 32 | - id: check-json 33 | - id: check-toml 34 | - id: check-xml 35 | - id: check-yaml 36 | - id: debug-statements 37 | - id: destroyed-symlinks 38 | - id: detect-private-key 39 | - id: end-of-file-fixer 40 | exclude: ^LICENSE|\.(html|csv|txt|svg|py)$ 41 | - id: pretty-format-json 42 | args: ["--autofix", "--no-ensure-ascii", "--no-sort-keys"] 43 | - id: requirements-txt-fixer 44 | - id: trailing-whitespace 45 | args: [--markdown-linebreak-ext=md] 46 | exclude: \.(html|svg)$ 47 | -------------------------------------------------------------------------------- /xchat/eval/mbpp/examplars.py: -------------------------------------------------------------------------------- 1 | # from https://github.com/google-research/google-research/blob/master/mbpp/README.md?plain=1#L21 2 | 3 | EXAMPLARS = [ 4 | { 5 | "task_id": 2, 6 | "text": "Write a function to find the similar elements from the given two tuple lists.", 7 | "code": "def similar_elements(test_tup1, test_tup2):\r\n res = tuple(set(test_tup1) & set(test_tup2))\r\n return (res) ", 8 | "test_list": [ 9 | "assert similar_elements((3, 4, 5, 6),(5, 7, 4, 10)) == (4, 5)", 10 | "assert similar_elements((1, 2, 3, 4),(5, 4, 3, 7)) == (3, 4)", 11 | "assert similar_elements((11, 12, 14, 13),(17, 15, 14, 13)) == (13, 14)", 12 | ], 13 | "test_setup_code": "", 14 | "challenge_test_list": [], 15 | }, 16 | { 17 | "task_id": 3, 18 | "text": "Write a python function to identify non-prime numbers.", 19 | "code": "import math\r\ndef is_not_prime(n):\r\n result = False\r\n for i in range(2,int(math.sqrt(n)) + 1):\r\n if n % i == 0:\r\n result = True\r\n return result", 20 | "test_list": [ 21 | "assert is_not_prime(2) == False", 22 | "assert is_not_prime(10) == True", 23 | "assert is_not_prime(35) == True", 24 | ], 25 | "test_setup_code": "", 26 | "challenge_test_list": [], 27 | }, 28 | { 29 | "task_id": 4, 30 | "text": "Write a function to find the largest integers from a given list of numbers using heap queue algorithm.", 31 | "code": "import heapq as hq\r\ndef heap_queue_largest(nums,n):\r\n largest_nums = hq.nlargest(n, nums)\r\n return largest_nums", 32 | "test_list": [ 33 | "assert heap_queue_largest( [25, 35, 22, 85, 14, 65, 75, 22, 58],3)==[85, 75, 65] ", 34 | "assert heap_queue_largest( [25, 35, 22, 85, 14, 65, 75, 22, 58],2)==[85, 75] ", 35 | "assert heap_queue_largest( [25, 35, 22, 85, 14, 65, 75, 22, 58],5)==[85, 75, 65, 58, 35]", 36 | ], 37 | "test_setup_code": "", 38 | "challenge_test_list": [], 39 | }, 40 | ] 41 | -------------------------------------------------------------------------------- /scripts/eval/gsm.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0,1,2,3 # ,4,5,6,7 2 | 3 | DATA_DIR=data/eval/gsm 4 | 5 | if [ ! -d $DATA_DIR ]; then 6 | echo "Downloading GSM data..." 7 | mkdir -p $DATA_DIR 8 | wget -P $DATA_DIR https://github.com/openai/grade-school-math/raw/master/grade_school_math/data/test.jsonl 9 | fi 10 | 11 | # MODEL_DIR=meta-llama/Llama-2-7b-hf 12 | # OUTPUT_DIR=results/gsm/llama-2-7b-hf 13 | # python -m xchat.eval.gsm.run_eval \ 14 | # --data_dir $DATA_DIR \ 15 | # --max_num_examples 32 \ 16 | # --save_dir $OUTPUT_DIR \ 17 | # --model $MODEL_DIR \ 18 | # --tokenizer $MODEL_DIR \ 19 | # --eval_batch_size 16 \ 20 | # --n_shot 8 21 | 22 | MODEL_DIR=codellama/CodeLlama-7b-hf 23 | OUTPUT_DIR=results/gsm/CodeLlama-7b-hf 24 | python -m xchat.eval.gsm.run_eval \ 25 | --data_dir $DATA_DIR \ 26 | --max_num_examples 32 \ 27 | --save_dir $OUTPUT_DIR \ 28 | --model $MODEL_DIR \ 29 | --tokenizer $MODEL_DIR \ 30 | --eval_batch_size 16 \ 31 | --n_shot 8 32 | 33 | # MODEL_DIR=OpenLemur/lemur-70b-v1 34 | # OUTPUT_DIR=results/gsm/lemur-70b-v1 35 | # python -m xchat.eval.gsm.run_eval \ 36 | # --data_dir $DATA_DIR \ 37 | # --max_num_examples 32 \ 38 | # --save_dir $OUTPUT_DIR \ 39 | # --model $MODEL_DIR \ 40 | # --tokenizer $MODEL_DIR \ 41 | # --eval_batch_size 16 \ 42 | # --n_shot 8 \ 43 | # --load_in_8bit 44 | 45 | # MODEL_DIR=lmsys/vicuna-13b-v1.5 46 | # OUTPUT_DIR=results/gsm/vicuna-13b-v1.5 47 | # python -m xchat.eval.gsm.run_eval \ 48 | # --data_dir $DATA_DIR \ 49 | # --max_num_examples 32 \ 50 | # --save_dir $OUTPUT_DIR \ 51 | # --model $MODEL_DIR \ 52 | # --tokenizer $MODEL_DIR \ 53 | # --eval_batch_size 16 \ 54 | # --n_shot 8 \ 55 | # --chat_format vicuna \ 56 | # --load_in_8bit 57 | 58 | # MODEL_DIR=codellama/CodeLlama-34b-Instruct-hf 59 | # OUTPUT_DIR=results/gsm/codellama-34b-instruct-hf 60 | # python -m xchat.eval.gsm.run_eval \ 61 | # --data_dir $DATA_DIR \ 62 | # --save_dir $OUTPUT_DIR \ 63 | # --model $MODEL_DIR \ 64 | # --tokenizer $MODEL_DIR \ 65 | # --eval_batch_size 48 \ 66 | # --n_shot 8 \ 67 | # --chat_format codellama-instruct 68 | -------------------------------------------------------------------------------- /xchat/eval/humaneval/prompts.py: -------------------------------------------------------------------------------- 1 | class HumanEvalPromptBase: 2 | def __init__(self): 3 | pass 4 | 5 | def getPrompt(self, prompt): 6 | return prompt 7 | 8 | 9 | class ChatMLPrompt(HumanEvalPromptBase): 10 | def __init__(self): 11 | super().__init__() 12 | 13 | def getPrompt(self, prompt): 14 | return ( 15 | "<|im_start|>user\n" 16 | + "Complete the following python function.\n\n\n" 17 | + prompt 18 | + "<|im_end|>\n<|im_start|>assistant\n" 19 | + "Here is the completed function:\n\n\n" 20 | + prompt 21 | ) 22 | 23 | 24 | class LLaMA2ChatPrompt(HumanEvalPromptBase): 25 | def __init__(self, prompt): 26 | super().__init__(prompt) 27 | 28 | def getPrompt(self, prompt): 29 | return ( 30 | "[INST] <>\n" 31 | + "Below is an instruction that describes a task. Write a response that appropriately completes the request." 32 | + "\n<>\n\n" 33 | + "Complete the following python function.\n\n\n" 34 | + prompt 35 | + "[/INST]" 36 | + "Here is the completed function:\n\n\n" 37 | + prompt 38 | ) 39 | 40 | 41 | class VicunaPrompt(HumanEvalPromptBase): 42 | def __init__(self, prompt): 43 | super().__init__(prompt) 44 | 45 | def getPrompt(self, prompt): 46 | system_message = """Below is an instruction that describes a task. Write a response that appropriately completes the request.""" 47 | return ( 48 | "USER: " 49 | + system_message 50 | + "Complete the following python function.\n\n\n" 51 | + prompt 52 | + "\nASSISTANT: " 53 | + "Here is the completed function:\n\n\n" 54 | + prompt 55 | ) 56 | 57 | 58 | class TuluPrompt(HumanEvalPromptBase): 59 | def __init__(self, prompt): 60 | super().__init__(prompt) 61 | 62 | def getPrompt(self, prompt): 63 | return ( 64 | "<|user|>\n" 65 | + "Complete the following python function.\n\n\n" 66 | + prompt 67 | + "\n<|assistant|>\n" 68 | + "Here is the completed function:\n\n\n" 69 | + prompt 70 | ) 71 | 72 | 73 | def getprompts(chat_format): 74 | if chat_format in ["base"]: 75 | return HumanEvalPromptBase() 76 | if chat_format in ["chatml"]: 77 | return ChatMLPrompt() 78 | elif chat_format in ["tulu"]: 79 | return TuluPrompt() 80 | elif chat_format in ["llama2chat"]: 81 | return LLaMA2ChatPrompt() 82 | elif chat_format in ["vicuna"]: 83 | return VicunaPrompt() 84 | raise NotImplementedError 85 | -------------------------------------------------------------------------------- /xchat/eval/mmlu/categories.py: -------------------------------------------------------------------------------- 1 | subcategories = { 2 | "abstract_algebra": ["math"], 3 | "anatomy": ["health"], 4 | "astronomy": ["physics"], 5 | "business_ethics": ["business"], 6 | "clinical_knowledge": ["health"], 7 | "college_biology": ["biology"], 8 | "college_chemistry": ["chemistry"], 9 | "college_computer_science": ["computer science"], 10 | "college_mathematics": ["math"], 11 | "college_medicine": ["health"], 12 | "college_physics": ["physics"], 13 | "computer_security": ["computer science"], 14 | "conceptual_physics": ["physics"], 15 | "econometrics": ["economics"], 16 | "electrical_engineering": ["engineering"], 17 | "elementary_mathematics": ["math"], 18 | "formal_logic": ["philosophy"], 19 | "global_facts": ["other"], 20 | "high_school_biology": ["biology"], 21 | "high_school_chemistry": ["chemistry"], 22 | "high_school_computer_science": ["computer science"], 23 | "high_school_european_history": ["history"], 24 | "high_school_geography": ["geography"], 25 | "high_school_government_and_politics": ["politics"], 26 | "high_school_macroeconomics": ["economics"], 27 | "high_school_mathematics": ["math"], 28 | "high_school_microeconomics": ["economics"], 29 | "high_school_physics": ["physics"], 30 | "high_school_psychology": ["psychology"], 31 | "high_school_statistics": ["math"], 32 | "high_school_us_history": ["history"], 33 | "high_school_world_history": ["history"], 34 | "human_aging": ["health"], 35 | "human_sexuality": ["culture"], 36 | "international_law": ["law"], 37 | "jurisprudence": ["law"], 38 | "logical_fallacies": ["philosophy"], 39 | "machine_learning": ["computer science"], 40 | "management": ["business"], 41 | "marketing": ["business"], 42 | "medical_genetics": ["health"], 43 | "miscellaneous": ["other"], 44 | "moral_disputes": ["philosophy"], 45 | "moral_scenarios": ["philosophy"], 46 | "nutrition": ["health"], 47 | "philosophy": ["philosophy"], 48 | "prehistory": ["history"], 49 | "professional_accounting": ["other"], 50 | "professional_law": ["law"], 51 | "professional_medicine": ["health"], 52 | "professional_psychology": ["psychology"], 53 | "public_relations": ["politics"], 54 | "security_studies": ["politics"], 55 | "sociology": ["culture"], 56 | "us_foreign_policy": ["politics"], 57 | "virology": ["health"], 58 | "world_religions": ["philosophy"], 59 | } 60 | 61 | categories = { 62 | "STEM": ["physics", "chemistry", "biology", "computer science", "math", "engineering"], 63 | "humanities": ["history", "philosophy", "law"], 64 | "social sciences": ["politics", "culture", "economics", "geography", "psychology"], 65 | "other (business, health, misc.)": ["other", "business", "health"], 66 | } 67 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61.0"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "xchat" 7 | version = "0.0.1" 8 | description = "An open platform for training, serving, and evaluating large language model based chatbots." 9 | requires-python = ">=3.9" 10 | classifiers = [ 11 | "Programming Language :: Python :: 3", 12 | "License :: OSI Approved :: Apache Software License", 13 | ] 14 | 15 | dependencies = [ 16 | "accelerate>=0.21.0", 17 | "bitsandbytes>=0.41.1", 18 | "datasets", 19 | "deepspeed>=0.10.0", 20 | "einops", 21 | "evaluate>=0.4.0", 22 | "peft>=0.4.0", 23 | "scipy", 24 | "sentencepiece", 25 | "tokenizers>=0.13.3", 26 | "transformers==4.34.0", 27 | "wandb", 28 | "openai" 29 | ] 30 | 31 | [project.optional-dependencies] 32 | dev = ["black==23.3.0", "pylint==2.8.2", "pre-commit"] 33 | all = ["xchat[dev]"] 34 | 35 | [tool.setuptools.packages.find] 36 | exclude = ["assets", "data*", "tests*"] 37 | 38 | [tool.wheel] 39 | exclude = ["assets", "data*", "tests*"] 40 | 41 | [tool.ruff] 42 | target-version = 'py38' 43 | select = [ 44 | "B", # flake8-bugbear 45 | "C4", # flake8-comprehensions 46 | "D", # pydocstyle 47 | "E", # Error 48 | "F", # pyflakes 49 | "I", # isort 50 | "ISC", # flake8-implicit-str-concat 51 | "N", # pep8-naming 52 | "PGH", # pygrep-hooks 53 | # "PTH", # flake8-use-pathlib 54 | "Q", # flake8-quotes 55 | "S", # bandit 56 | "SIM", # flake8-simplify 57 | "TRY", # tryceratops 58 | "UP", # pyupgrade 59 | "W", # Warning 60 | "YTT", # flake8-2020 61 | ] 62 | 63 | exclude = [ 64 | "migrations", 65 | "__pycache__", 66 | "manage.py", 67 | "settings.py", 68 | "env", 69 | ".env", 70 | "venv", 71 | ".venv", 72 | ] 73 | 74 | ignore = [ 75 | "B905", # zip strict=True; remove once python <3.10 support is dropped. 76 | "D100", 77 | "D101", 78 | "D102", 79 | "D103", 80 | "D104", 81 | "D105", 82 | "D106", 83 | "D107", 84 | "D200", 85 | "D401", 86 | "E402", 87 | "E501", 88 | "F401", 89 | "TRY003", # Avoid specifying messages outside exception class; overly strict, especially for ValueError 90 | "S101", # Use of assert detected; overly strict, especially for tests 91 | ] 92 | line-length = 120 # Must agree with Black 93 | 94 | [tool.ruff.flake8-bugbear] 95 | extend-immutable-calls = [ 96 | "chr", 97 | "typer.Argument", 98 | "typer.Option", 99 | ] 100 | 101 | [tool.ruff.pydocstyle] 102 | convention = "numpy" 103 | 104 | [tool.ruff.pep8-naming] 105 | staticmethod-decorators = [ 106 | "pydantic.validator", 107 | "pydantic.root_validator", 108 | ] 109 | -------------------------------------------------------------------------------- /xchat/eval/gsm/examplars.py: -------------------------------------------------------------------------------- 1 | # These examplars are from the Table 20 of CoT paper (https://arxiv.org/pdf/2201.11903.pdf). 2 | 3 | EXAMPLARS = [ 4 | { 5 | "question": "There are 15 trees in the grove. Grove workers will plant trees in the grove today. After they are done, there will be 21 trees. How many trees did the grove workers plant today?", 6 | "cot_answer": "There are 15 trees originally. Then there were 21 trees after some more were planted. So there must have been 21 - 15 = 6. So the answer is 6.", 7 | "short_answer": "6", 8 | }, 9 | { 10 | "question": "If there are 3 cars in the parking lot and 2 more cars arrive, how many cars are in the parking lot?", 11 | "cot_answer": "There are originally 3 cars. 2 more cars arrive. 3 + 2 = 5. So the answer is 5.", 12 | "short_answer": "5", 13 | }, 14 | { 15 | "question": "Leah had 32 chocolates and her sister had 42. If they ate 35, how many pieces do they have left in total?", 16 | "cot_answer": "Originally, Leah had 32 chocolates. Her sister had 42. So in total they had 32 + 42 = 74. After eating 35, they had 74 - 35 = 39. So the answer is 39.", 17 | "short_answer": "39", 18 | }, 19 | { 20 | "question": "Jason had 20 lollipops. He gave Denny some lollipops. Now Jason has 12 lollipops. How many lollipops did Jason give to Denny?", 21 | "cot_answer": "Jason started with 20 lollipops. Then he had 12 after giving some to Denny. So he gave Denny 20 - 12 = 8. So the answer is 8.", 22 | "short_answer": "8", 23 | }, 24 | { 25 | "question": "Shawn has five toys. For Christmas, he got two toys each from his mom and dad. How many toys does he have now?", 26 | "cot_answer": "Shawn started with 5 toys. If he got 2 toys each from his mom and dad, then that is 4 more toys. 5 + 4 = 9. So the answer is 9.", 27 | "short_answer": "9", 28 | }, 29 | { 30 | "question": "There were nine computers in the server room. Five more computers were installed each day, from monday to thursday. How many computers are now in the server room?", 31 | "cot_answer": "There were originally 9 computers. For each of 4 days, 5 more computers were added. So 5 * 4 = 20 computers were added. 9 + 20 is 29. So the answer is 29.", 32 | "short_answer": "29", 33 | }, 34 | { 35 | "question": "Michael had 58 golf balls. On tuesday, he lost 23 golf balls. On wednesday, he lost 2 more. How many golf balls did he have at the end of wednesday?", 36 | "cot_answer": "Michael started with 58 golf balls. After losing 23 on tuesday, he had 58 - 23 = 35. After losing 2 more, he had 35 - 2 = 33 golf balls. So the answer is 33.", 37 | "short_answer": "33", 38 | }, 39 | { 40 | "question": "Olivia has $23. She bought five bagels for $3 each. How much money does she have left?", 41 | "cot_answer": "Olivia had 23 dollars. 5 bagels for 3 dollars each will be 5 x 3 = 15 dollars. So she has 23 - 15 dollars left. 23 - 15 is 8. So the answer is 8.", 42 | "short_answer": "8", 43 | }, 44 | ] 45 | -------------------------------------------------------------------------------- /xchat/eval/dispatch_openai_requests.py: -------------------------------------------------------------------------------- 1 | # Copied and modified from https://gist.github.com/neubig/80de662fb3e225c18172ec218be4917a. Thanks to Graham Neubig for sharing the original code. 2 | 3 | import asyncio 4 | from typing import Any, Dict, List 5 | 6 | import openai 7 | 8 | 9 | async def dispatch_openai_chat_requesets( 10 | messages_list: List[List[Dict[str, Any]]], 11 | model: str, 12 | **completion_kwargs: Any, 13 | ) -> List[str]: 14 | """Dispatches requests to OpenAI chat completion API asynchronously. 15 | 16 | Args: 17 | messages_list: List of messages to be sent to OpenAI chat completion API. 18 | model: OpenAI model to use. 19 | completion_kwargs: Keyword arguments to be passed to OpenAI ChatCompletion API. See https://platform.openai.com/docs/api-reference/chat for details. 20 | 21 | Returns 22 | ------- 23 | List of responses from OpenAI API. 24 | """ 25 | async_responses = [ 26 | openai.ChatCompletion.acreate( 27 | model=model, 28 | messages=x, 29 | **completion_kwargs, 30 | ) 31 | for x in messages_list 32 | ] 33 | return await asyncio.gather(*async_responses) 34 | 35 | 36 | async def dispatch_openai_prompt_requesets( 37 | prompt_list: List[str], 38 | model: str, 39 | **completion_kwargs: Any, 40 | ) -> List[str]: 41 | """Dispatches requests to OpenAI text completion API asynchronously. 42 | 43 | Args: 44 | prompt_list: List of prompts to be sent to OpenAI text completion API. 45 | model: OpenAI model to use. 46 | completion_kwargs: Keyword arguments to be passed to OpenAI text completion API. See https://platform.openai.com/docs/api-reference/completions for details. 47 | 48 | Returns 49 | ------- 50 | List of responses from OpenAI API. 51 | """ 52 | async_responses = [ 53 | openai.Completion.acreate( 54 | model=model, 55 | prompt=x, 56 | **completion_kwargs, 57 | ) 58 | for x in prompt_list 59 | ] 60 | return await asyncio.gather(*async_responses) 61 | 62 | 63 | if __name__ == "__main__": 64 | chat_completion_responses = asyncio.run( 65 | dispatch_openai_chat_requesets( 66 | messages_list=[ 67 | [{"role": "user", "content": "Write a poem about asynchronous execution."}], 68 | [{"role": "user", "content": "Write a poem about asynchronous pirates."}], 69 | ], 70 | model="gpt-3.5-turbo", 71 | temperature=0.3, 72 | max_tokens=200, 73 | top_p=1.0, 74 | ) 75 | ) 76 | 77 | for i, x in enumerate(chat_completion_responses): 78 | print(f"Chat completion response {i}:\n{x['choices'][0]['message']['content']}\n\n") 79 | 80 | prompt_completion_responses = asyncio.run( 81 | dispatch_openai_prompt_requesets( 82 | prompt_list=[ 83 | "Write a poem about asynchronous execution.\n", 84 | "Write a poem about asynchronous pirates.\n", 85 | ], 86 | model="text-davinci-003", 87 | temperature=0.3, 88 | max_tokens=200, 89 | top_p=1.0, 90 | ) 91 | ) 92 | 93 | for i, x in enumerate(prompt_completion_responses): 94 | print(f"Prompt completion response {i}:\n{x['choices'][0]['text']}\n\n") 95 | -------------------------------------------------------------------------------- /xchat/eval/mbpp/prompts.py: -------------------------------------------------------------------------------- 1 | from xchat.eval.mbpp.examplars import EXAMPLARS 2 | 3 | 4 | class MbppPromptBase: 5 | def __init__(self, instuction, end_seq, is_few_shot) -> None: 6 | self.instruction = instuction 7 | self.end_seq = end_seq 8 | self.is_few_shot = is_few_shot 9 | 10 | def get_prompt(self, example): 11 | prompt = "" 12 | if self.is_few_shot: 13 | for d in EXAMPLARS: 14 | prompt += ( 15 | self.instruction.format(instruction=d["text"], test_cases="\n".join(d["test_list"]), code=d["code"]) 16 | + self.end_seq 17 | ) 18 | prompt += self.instruction.format( 19 | instruction=example["prompt"].strip(), test_cases=example["reference"], code="" 20 | ) 21 | return prompt.strip() 22 | 23 | def get_parser_seq(self): 24 | return self.end_seq 25 | 26 | 27 | class MbppPromptChat(MbppPromptBase): 28 | def __init__(self, instuction, end_seq, is_few_shot) -> None: 29 | super().__init__(instuction, end_seq, is_few_shot) 30 | 31 | def get_prompt(self, example): 32 | prompt = "[INST]" 33 | if self.is_few_shot: 34 | for d in EXAMPLARS: 35 | prompt += ( 36 | self.instruction.format(instruction=d["text"], test_cases="\n".join(d["test_list"]), code=d["code"]) 37 | + self.end_seq 38 | ) 39 | prompt += self.instruction.format( 40 | instruction=example["prompt"].strip(), test_cases=example["reference"], code="" 41 | ) 42 | prompt = prompt.strip() 43 | assert prompt.endswith("[PYTHON]") 44 | prompt = prompt[:-8] 45 | prompt += "[/INST]\n[PYTHON]\n" 46 | return prompt.strip() 47 | 48 | 49 | BASE_INSTRUCTION = """ 50 | You are an expert Python programmer, and here is your task: {instruction} Your code should pass these tests:\n\n{test_cases} 51 | [BEGIN] 52 | {code}""" 53 | 54 | 55 | class BasePrompt(MbppPromptBase): 56 | def __init__(self, is_few_shot) -> None: 57 | super().__init__(instuction=BASE_INSTRUCTION, end_seq="\n[DONE]", is_few_shot=is_few_shot) 58 | 59 | 60 | INSTRUCTION = """ 61 | You are an expert Python programmer, and here is your task: {instruction} 62 | Your code should pass these tests:\n\n{test_cases}\nYour code should start with a [PYTHON] tag and end with a [/PYTHON] tag. 63 | [PYTHON] 64 | {code}""" 65 | 66 | 67 | class CodeLLaMAPrompt(MbppPromptBase): 68 | def __init__(self, is_few_shot) -> None: 69 | super().__init__(instuction=INSTRUCTION, end_seq="\n[/PYTHON]", is_few_shot=is_few_shot) 70 | 71 | 72 | class LLaMA2ChatPrompt(MbppPromptChat): 73 | def __init__(self, is_few_shot) -> None: 74 | super().__init__(instuction=INSTRUCTION, end_seq="\n[/PYTHON]", is_few_shot=is_few_shot) 75 | 76 | 77 | LEMUR_INSTRUCTION = """ 78 | <|im_start|>user 79 | You are an expert Python programmer, and here is your task: {instruction} 80 | Your code should pass these tests:\n\n{test_cases}\nYour code should start with a [PYTHON] tag and end with a [/PYTHON] tag. 81 | <|im_end|> 82 | <|im_start|>assistant 83 | [PYTHON] 84 | {code}""" 85 | 86 | 87 | class LemurPrompt(MbppPromptBase): 88 | def __init__(self, is_few_shot) -> None: 89 | super().__init__(instuction=LEMUR_INSTRUCTION, end_seq="\n[/PYTHON]", is_few_shot=is_few_shot) 90 | 91 | 92 | def getprompts(chat_format, is_few_shot): 93 | if chat_format in ["base"]: 94 | return BasePrompt(is_few_shot) 95 | elif chat_format in ["lemur"]: 96 | return LemurPrompt(is_few_shot) 97 | elif chat_format in ["codellama", "wizardcoder"]: 98 | return CodeLLaMAPrompt(is_few_shot) 99 | elif chat_format in ["llama2chat"]: 100 | return LLaMA2ChatPrompt(is_few_shot) 101 | raise NotImplementedError 102 | -------------------------------------------------------------------------------- /xchat/eval/mbpp/evaluation.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | from collections import Counter, defaultdict 3 | from concurrent.futures import ThreadPoolExecutor, as_completed 4 | from typing import Dict, Iterable, List, Union 5 | 6 | import numpy as np 7 | import tqdm 8 | from evaluate import load 9 | 10 | from xchat.eval.mbpp.data import read_problems, stream_jsonl, write_jsonl 11 | from xchat.eval.mbpp.execution import check_correctness 12 | 13 | 14 | def estimate_pass_at_k( 15 | num_samples: Union[int, List[int], np.ndarray], num_correct: Union[List[int], np.ndarray], k: int 16 | ) -> np.ndarray: 17 | """ 18 | Estimates pass@k of each problem and returns them in an array. 19 | """ 20 | 21 | def estimator(n: int, c: int, k: int) -> float: 22 | """ 23 | Calculates 1 - comb(n - c, k) / comb(n, k). 24 | """ 25 | if n - c < k: 26 | return 1.0 27 | return 1.0 - np.prod(1.0 - k / np.arange(n - c + 1, n + 1)) 28 | 29 | if isinstance(num_samples, int): 30 | num_samples_it = itertools.repeat(num_samples, len(num_correct)) 31 | else: 32 | assert len(num_samples) == len(num_correct) 33 | num_samples_it = iter(num_samples) 34 | 35 | return np.array([estimator(int(n), int(c), k) for n, c in zip(num_samples_it, num_correct)]) 36 | 37 | 38 | def evaluate_functional_correctness( 39 | task: str, 40 | sample_file: str, 41 | k: List[int] = [1, 10, 100], 42 | n_workers: int = 4, 43 | timeout: float = 3.0, 44 | problems=None, 45 | problem_file=None, 46 | ): 47 | """ 48 | Evaluates the functional correctness of generated samples, and writes 49 | results to f"{sample_file}_results.jsonl.gz". 50 | """ 51 | if not problems: 52 | problems = read_problems(problem_file) 53 | 54 | # Check the generated samples against test suites. 55 | with ThreadPoolExecutor(max_workers=n_workers) as executor: 56 | futures = [] 57 | completion_id = Counter() 58 | n_samples = 0 59 | results = defaultdict(list) 60 | 61 | print("Reading samples...") 62 | for sample in tqdm.tqdm(stream_jsonl(sample_file)): 63 | task_id = sample["task_id"] 64 | completion = sample["completion"] 65 | args = (task, problems[task_id], completion, timeout, completion_id[task_id]) 66 | future = executor.submit(check_correctness, *args) 67 | futures.append(future) 68 | completion_id[task_id] += 1 69 | n_samples += 1 70 | 71 | assert len(completion_id) == len(problems), "Some problems are not attempted." 72 | 73 | print("Running test suites...") 74 | for future in tqdm.tqdm(as_completed(futures), total=len(futures)): 75 | result = future.result() 76 | results[result["task_id"]].append((result["completion_id"], result)) 77 | 78 | # Calculate pass@k. 79 | total, correct = [], [] 80 | for result in results.values(): 81 | result.sort() 82 | passed = [r[1]["passed"] for r in result] 83 | total.append(len(passed)) 84 | correct.append(sum(passed)) 85 | total = np.array(total) 86 | correct = np.array(correct) 87 | 88 | ks = k 89 | pass_at_k = {f"pass@{k}": estimate_pass_at_k(total, correct, k).mean() for k in ks if (total >= k).all()} 90 | 91 | # Finally, save the results in one file: 92 | def combine_results(): 93 | for sample in stream_jsonl(sample_file): 94 | task_id = sample["task_id"] 95 | result = results[task_id].pop(0) 96 | sample["result"] = result[1]["result"] 97 | sample["passed"] = result[1]["passed"] 98 | yield sample 99 | 100 | out_file = sample_file + "_results.jsonl" 101 | print(f"Writing results to {out_file}...") 102 | write_jsonl(out_file, tqdm.tqdm(combine_results(), total=n_samples)) 103 | 104 | return pass_at_k 105 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by https://www.toptal.com/developers/gitignore/api/linux,macos,python,visualstudiocode 2 | # Edit at https://www.toptal.com/developers/gitignore?templates=linux,macos,python,visualstudiocode 3 | 4 | ### Linux ### 5 | *~ 6 | 7 | # temporary files which can be created if a process still has a handle open of a deleted file 8 | .fuse_hidden* 9 | 10 | # KDE directory preferences 11 | .directory 12 | 13 | # Linux trash folder which might appear on any partition or disk 14 | .Trash-* 15 | 16 | # .nfs files are created when an open file is removed but is still being accessed 17 | .nfs* 18 | 19 | ### macOS ### 20 | # General 21 | .DS_Store 22 | .AppleDouble 23 | .LSOverride 24 | 25 | # Icon must end with two \r 26 | Icon 27 | 28 | 29 | # Thumbnails 30 | ._* 31 | 32 | # Files that might appear in the root of a volume 33 | .DocumentRevisions-V100 34 | .fseventsd 35 | .Spotlight-V100 36 | .TemporaryItems 37 | .Trashes 38 | .VolumeIcon.icns 39 | .com.apple.timemachine.donotpresent 40 | 41 | # Directories potentially created on remote AFP share 42 | .AppleDB 43 | .AppleDesktop 44 | Network Trash Folder 45 | Temporary Items 46 | .apdisk 47 | 48 | ### macOS Patch ### 49 | # iCloud generated files 50 | *.icloud 51 | 52 | ### Python ### 53 | # Byte-compiled / optimized / DLL files 54 | __pycache__/ 55 | *.py[cod] 56 | *$py.class 57 | 58 | # C extensions 59 | *.so 60 | 61 | # Distribution / packaging 62 | .Python 63 | build/ 64 | develop-eggs/ 65 | dist/ 66 | downloads/ 67 | eggs/ 68 | .eggs/ 69 | lib/ 70 | lib64/ 71 | parts/ 72 | sdist/ 73 | var/ 74 | wheels/ 75 | share/python-wheels/ 76 | *.egg-info/ 77 | .installed.cfg 78 | *.egg 79 | MANIFEST 80 | 81 | # PyInstaller 82 | # Usually these files are written by a python script from a template 83 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 84 | *.manifest 85 | *.spec 86 | 87 | # Installer logs 88 | pip-log.txt 89 | pip-delete-this-directory.txt 90 | 91 | # Unit test / coverage reports 92 | htmlcov/ 93 | .tox/ 94 | .nox/ 95 | .coverage 96 | .coverage.* 97 | .cache 98 | nosetests.xml 99 | coverage.xml 100 | *.cover 101 | *.py,cover 102 | .hypothesis/ 103 | .pytest_cache/ 104 | cover/ 105 | 106 | # Translations 107 | *.mo 108 | *.pot 109 | 110 | # Django stuff: 111 | *.log 112 | local_settings.py 113 | db.sqlite3 114 | db.sqlite3-journal 115 | 116 | # Flask stuff: 117 | instance/ 118 | .webassets-cache 119 | 120 | # Scrapy stuff: 121 | .scrapy 122 | 123 | # Sphinx documentation 124 | docs/_build/ 125 | 126 | # PyBuilder 127 | .pybuilder/ 128 | target/ 129 | 130 | # Jupyter Notebook 131 | .ipynb_checkpoints 132 | 133 | # IPython 134 | profile_default/ 135 | ipython_config.py 136 | 137 | # pyenv 138 | # For a library or package, you might want to ignore these files since the code is 139 | # intended to run in multiple environments; otherwise, check them in: 140 | # .python-version 141 | 142 | # pipenv 143 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 144 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 145 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 146 | # install all needed dependencies. 147 | #Pipfile.lock 148 | 149 | # poetry 150 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 151 | # This is especially recommended for binary packages to ensure reproducibility, and is more 152 | # commonly ignored for libraries. 153 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 154 | #poetry.lock 155 | 156 | # pdm 157 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 158 | #pdm.lock 159 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 160 | # in version control. 161 | # https://pdm.fming.dev/#use-with-ide 162 | .pdm.toml 163 | 164 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 165 | __pypackages__/ 166 | 167 | # Celery stuff 168 | celerybeat-schedule 169 | celerybeat.pid 170 | 171 | # SageMath parsed files 172 | *.sage.py 173 | 174 | # Environments 175 | .env 176 | .venv 177 | env/ 178 | venv/ 179 | ENV/ 180 | env.bak/ 181 | venv.bak/ 182 | 183 | # Spyder project settings 184 | .spyderproject 185 | .spyproject 186 | 187 | # Rope project settings 188 | .ropeproject 189 | 190 | # mkdocs documentation 191 | /site 192 | 193 | # mypy 194 | .mypy_cache/ 195 | .dmypy.json 196 | dmypy.json 197 | 198 | # Pyre type checker 199 | .pyre/ 200 | 201 | # pytype static type analyzer 202 | .pytype/ 203 | 204 | # Cython debug symbols 205 | cython_debug/ 206 | 207 | # PyCharm 208 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 209 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 210 | # and can be added to the global gitignore or merged into this file. For a more nuclear 211 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 212 | .idea/ 213 | 214 | ### Python Patch ### 215 | # Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration 216 | poetry.toml 217 | 218 | # ruff 219 | .ruff_cache/ 220 | 221 | # LSP config files 222 | pyrightconfig.json 223 | 224 | ### VisualStudioCode ### 225 | .vscode/* 226 | !.vscode/settings.json 227 | !.vscode/tasks.json 228 | !.vscode/launch.json 229 | !.vscode/extensions.json 230 | !.vscode/*.code-snippets 231 | 232 | # Local History for Visual Studio Code 233 | .history/ 234 | 235 | # Built Visual Studio Code Extensions 236 | *.vsix 237 | 238 | ### VisualStudioCode Patch ### 239 | # Ignore all local history of files 240 | .history 241 | .ionide 242 | 243 | # End of https://www.toptal.com/developers/gitignore/api/linux,macos,python,visualstudiocode 244 | 245 | data 246 | results 247 | wandb 248 | outputs 249 | -------------------------------------------------------------------------------- /xchat/train/llama_flash_attn_monkey_patch.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import warnings 3 | from typing import List, Optional, Tuple 4 | 5 | import flash_attn 6 | import torch 7 | import torch.nn.functional as F 8 | from einops import rearrange 9 | from flash_attn.bert_padding import pad_input, unpad_input 10 | 11 | logging.warning(f"Find flash_attn version {flash_attn.__version__}") 12 | if flash_attn.__version__ <= "1.0.9": 13 | from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func as flash_attn_varlen_qkvpacked_func 14 | else: 15 | from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func 16 | 17 | import transformers 18 | from torch import nn 19 | from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv 20 | 21 | 22 | def forward( 23 | self, 24 | hidden_states: torch.Tensor, 25 | attention_mask: Optional[torch.Tensor] = None, 26 | position_ids: Optional[torch.Tensor] = None, 27 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 28 | output_attentions: bool = False, 29 | use_cache: bool = False, 30 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 31 | """Input shape: Batch x Time x Channel. 32 | 33 | attention_mask: [bsz, q_len] 34 | """ 35 | if output_attentions: 36 | warnings.warn("Output attentions is not supported for patched `LlamaAttention`, returning `None` instead.") 37 | 38 | bsz, q_len, _ = hidden_states.size() 39 | 40 | if self.config.pretraining_tp > 1: 41 | key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp 42 | query_slices = self.q_proj.weight.split((self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0) 43 | key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) 44 | value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) 45 | 46 | query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)] 47 | query_states = torch.cat(query_states, dim=-1) 48 | 49 | key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)] 50 | key_states = torch.cat(key_states, dim=-1) 51 | 52 | value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)] 53 | value_states = torch.cat(value_states, dim=-1) 54 | 55 | else: 56 | query_states = self.q_proj(hidden_states) 57 | key_states = self.k_proj(hidden_states) 58 | value_states = self.v_proj(hidden_states) 59 | 60 | query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) 61 | key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) 62 | value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) 63 | 64 | # [bsz, q_len, nh, hd] 65 | # [bsz, nh, q_len, hd] 66 | 67 | kv_seq_len = key_states.shape[-2] 68 | if past_key_value is not None: 69 | kv_seq_len += past_key_value[0].shape[-2] 70 | 71 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) 72 | query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) 73 | # [bsz, nh, t, hd] 74 | # Past Key value support 75 | if past_key_value is not None: 76 | # reuse k, v, self_attention 77 | key_states = torch.cat([past_key_value[0], key_states], dim=2) 78 | value_states = torch.cat([past_key_value[1], value_states], dim=2) 79 | 80 | past_key_value = (key_states, value_states) if use_cache else None 81 | 82 | # repeat k/v heads if n_kv_heads < n_heads 83 | key_states = repeat_kv(key_states, self.num_key_value_groups) 84 | value_states = repeat_kv(value_states, self.num_key_value_groups) 85 | 86 | # Flash attention codes from 87 | # https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attention.py 88 | 89 | # transform the data into the format required by flash attention 90 | qkv = torch.stack([query_states, key_states, value_states], dim=2) # [bsz, nh, 3, q_len, hd] 91 | qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd] 92 | # We have disabled _prepare_decoder_attention_mask in LlamaModel 93 | # the attention_mask should be the same as the key_padding_mask 94 | key_padding_mask = attention_mask 95 | 96 | if key_padding_mask is None: 97 | qkv = rearrange(qkv, "b s ... -> (b s) ...") 98 | max_s = q_len 99 | cu_q_lens = torch.arange(0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device) 100 | output = flash_attn_varlen_qkvpacked_func(qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True) 101 | output = rearrange(output, "(b s) ... -> b s ...", b=bsz) 102 | else: 103 | nheads = qkv.shape[-2] 104 | x = rearrange(qkv, "b s three h d -> b s (three h d)") 105 | x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask) 106 | x_unpad = rearrange(x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads) 107 | output_unpad = flash_attn_varlen_qkvpacked_func(x_unpad, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True) 108 | output = rearrange( 109 | pad_input(rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, bsz, q_len), 110 | "b s (h d) -> b s h d", 111 | h=nheads, 112 | ) 113 | output = rearrange(output, "b s h d -> b s (h d)") 114 | if self.config.pretraining_tp > 1: 115 | output = output.split(self.hidden_size // self.config.pretraining_tp, dim=2) 116 | o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1) 117 | output = sum([F.linear(output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)]) 118 | else: 119 | output = self.o_proj(output) 120 | return output, None, past_key_value 121 | 122 | 123 | # Disable the transformation of the attention mask in LlamaModel as the flash attention 124 | # requires the attention mask to be the same as the key_padding_mask 125 | def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): 126 | # [bsz, seq_len] 127 | return attention_mask 128 | 129 | 130 | def replace_llama_attn_with_flash_attn(): 131 | cuda_major, cuda_minor = torch.cuda.get_device_capability() 132 | if cuda_major < 8: 133 | logging.warning( 134 | "Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward." 135 | "ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593" 136 | ) 137 | transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( 138 | _prepare_decoder_attention_mask 139 | ) 140 | transformers.models.llama.modeling_llama.LlamaAttention.forward = forward 141 | -------------------------------------------------------------------------------- /xchat/eval/mbpp/execution.py: -------------------------------------------------------------------------------- 1 | import ast 2 | import contextlib 3 | import faulthandler 4 | import io 5 | import multiprocessing 6 | import os 7 | import platform 8 | import signal 9 | import tempfile 10 | from typing import Callable, Dict, Optional 11 | 12 | 13 | def check_correctness( 14 | task: str, problem: Dict, completion: str, timeout: float, completion_id: Optional[int] = None 15 | ) -> Dict: 16 | """ 17 | Evaluates the functional correctness of a completion by running the test 18 | suite provided in the problem. 19 | 20 | :param completion_id: an optional completion ID so we can match 21 | the results later even if execution finishes asynchronously. 22 | """ 23 | 24 | def unsafe_execute(): 25 | with create_tempdir(): 26 | # These system calls are needed when cleaning up tempdir. 27 | import os 28 | import shutil 29 | 30 | rmtree = shutil.rmtree 31 | rmdir = os.rmdir 32 | chdir = os.chdir 33 | 34 | # Disable functionalities that can make destructive changes to the test. 35 | reliability_guard() 36 | 37 | # Construct the check program and run it. 38 | if task == "humaneval": 39 | check_program = ( 40 | problem["prompt"] + completion + "\n" + problem["test"] + "\n" + f"check({problem['entry_point']})" 41 | ) 42 | elif task == "mbpp": 43 | check_program = completion + "\n" + problem["reference"] 44 | else: 45 | raise ValueError(f"Unknown task: {task}") 46 | 47 | try: 48 | exec_globals = {} 49 | with swallow_io(), time_limit(timeout): 50 | # WARNING 51 | # This program exists to execute untrusted model-generated code. Although 52 | # it is highly unlikely that model-generated code will do something overtly 53 | # malicious in response to this test suite, model-generated code may act 54 | # destructively due to a lack of model capability or alignment. 55 | # Users are strongly encouraged to sandbox this evaluation suite so that it 56 | # does not perform destructive actions on their host or network. For more 57 | # information on how OpenAI sandboxes its code, see the accompanying paper. 58 | # Once you have read this disclaimer and taken appropriate precautions, 59 | # uncomment the following line and proceed at your own risk: 60 | exec(check_program, exec_globals) 61 | result.append("passed") 62 | except TimeoutException: 63 | result.append("timed out") 64 | except BaseException as e: 65 | result.append(f"failed: {e}") 66 | 67 | # Needed for cleaning up. 68 | shutil.rmtree = rmtree 69 | os.rmdir = rmdir 70 | os.chdir = chdir 71 | 72 | manager = multiprocessing.Manager() 73 | result = manager.list() 74 | 75 | p = multiprocessing.Process(target=unsafe_execute) 76 | p.start() 77 | p.join(timeout=timeout + 1) 78 | if p.is_alive(): 79 | p.kill() 80 | 81 | if not result: 82 | result.append("timed out") 83 | 84 | return { 85 | "task_id": problem["task_id"], 86 | "passed": result[0] == "passed", 87 | "result": result[0], 88 | "completion_id": completion_id, 89 | } 90 | 91 | 92 | @contextlib.contextmanager 93 | def time_limit(seconds: float): 94 | def signal_handler(signum, frame): 95 | raise TimeoutException("Timed out!") 96 | 97 | signal.setitimer(signal.ITIMER_REAL, seconds) 98 | signal.signal(signal.SIGALRM, signal_handler) 99 | try: 100 | yield 101 | finally: 102 | signal.setitimer(signal.ITIMER_REAL, 0) 103 | 104 | 105 | @contextlib.contextmanager 106 | def swallow_io(): 107 | stream = WriteOnlyStringIO() 108 | with contextlib.redirect_stdout(stream), contextlib.redirect_stderr(stream), redirect_stdin(stream): 109 | yield 110 | 111 | 112 | @contextlib.contextmanager 113 | def create_tempdir(): 114 | with tempfile.TemporaryDirectory() as dirname, chdir(dirname): 115 | yield dirname 116 | 117 | 118 | class TimeoutException(Exception): 119 | pass 120 | 121 | 122 | class WriteOnlyStringIO(io.StringIO): 123 | """StringIO that throws an exception when it's read from.""" 124 | 125 | def read(self, *args, **kwargs): 126 | raise OSError 127 | 128 | def readline(self, *args, **kwargs): 129 | raise OSError 130 | 131 | def readlines(self, *args, **kwargs): 132 | raise OSError 133 | 134 | def readable(self, *args, **kwargs): 135 | """Returns True if the IO object can be read.""" 136 | return False 137 | 138 | 139 | class redirect_stdin(contextlib._RedirectStream): # type: ignore 140 | _stream = "stdin" 141 | 142 | 143 | @contextlib.contextmanager 144 | def chdir(root): 145 | if root == ".": 146 | yield 147 | return 148 | cwd = os.getcwd() 149 | os.chdir(root) 150 | try: 151 | yield 152 | except BaseException as exc: 153 | raise exc 154 | finally: 155 | os.chdir(cwd) 156 | 157 | 158 | def reliability_guard(maximum_memory_bytes: Optional[int] = None): 159 | """ 160 | This disables various destructive functions and prevents the generated code 161 | from interfering with the test (e.g. fork bomb, killing other processes, 162 | removing filesystem files, etc.). 163 | 164 | WARNING 165 | This function is NOT a security sandbox. Untrusted code, including, model- 166 | generated code, should not be blindly executed outside of one. See the 167 | Codex paper for more information about OpenAI's code sandbox, and proceed 168 | with caution. 169 | """ 170 | if maximum_memory_bytes is not None: 171 | import resource 172 | 173 | resource.setrlimit(resource.RLIMIT_AS, (maximum_memory_bytes, maximum_memory_bytes)) 174 | resource.setrlimit(resource.RLIMIT_DATA, (maximum_memory_bytes, maximum_memory_bytes)) 175 | if platform.uname().system != "Darwin": 176 | resource.setrlimit(resource.RLIMIT_STACK, (maximum_memory_bytes, maximum_memory_bytes)) 177 | 178 | faulthandler.disable() 179 | 180 | import builtins 181 | 182 | builtins.exit = None 183 | builtins.quit = None 184 | 185 | import os 186 | 187 | os.environ["OMP_NUM_THREADS"] = "1" 188 | 189 | os.kill = None 190 | os.system = None 191 | os.putenv = None 192 | os.remove = None 193 | os.removedirs = None 194 | os.rmdir = None 195 | os.fchdir = None 196 | os.setuid = None 197 | os.fork = None 198 | os.forkpty = None 199 | os.killpg = None 200 | os.rename = None 201 | os.renames = None 202 | os.truncate = None 203 | os.replace = None 204 | os.unlink = None 205 | os.fchmod = None 206 | os.fchown = None 207 | os.chmod = None 208 | os.chown = None 209 | os.chroot = None 210 | os.fchdir = None 211 | os.lchflags = None 212 | os.lchmod = None 213 | os.lchown = None 214 | os.getcwd = None 215 | os.chdir = None 216 | 217 | import shutil 218 | 219 | shutil.rmtree = None 220 | shutil.move = None 221 | shutil.chown = None 222 | 223 | import subprocess 224 | 225 | subprocess.Popen = None # type: ignore 226 | 227 | __builtins__["help"] = None 228 | 229 | import sys 230 | 231 | sys.modules["ipdb"] = None 232 | sys.modules["joblib"] = None 233 | sys.modules["resource"] = None 234 | sys.modules["psutil"] = None 235 | sys.modules["tkinter"] = None 236 | -------------------------------------------------------------------------------- /xchat/eval/humaneval/run_eval.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import random 5 | 6 | import torch 7 | 8 | from xchat.eval.humaneval.prompts import getprompts 9 | from xchat.eval.mbpp.data import read_problems, write_jsonl 10 | from xchat.eval.mbpp.evaluation import evaluate_functional_correctness 11 | from xchat.eval.utils import generate_completions, load_hf_lm_and_tokenizer, query_openai_chat_model 12 | 13 | 14 | def main(args): 15 | random.seed(42) 16 | if not os.path.exists(args.save_dir): 17 | os.makedirs(args.save_dir, exist_ok=True) 18 | data_file = args.data_file 19 | test_data = list(read_problems(data_file).values()) 20 | 21 | if args.max_num_examples is not None and len(test_data) > args.max_num_examples: 22 | test_data = random.sample(test_data, args.max_num_examples) 23 | print("Number of examples:", len(test_data)) 24 | promptclass = getprompts(args.chat_format) 25 | prompts = [promptclass.getPrompt(example["prompt"]) for example in test_data] 26 | 27 | if args.model_name_or_path: 28 | print("Loading model and tokenizer...") 29 | model, tokenizer = load_hf_lm_and_tokenizer( 30 | model_name_or_path=args.model_name_or_path, 31 | tokenizer_name_or_path=args.tokenizer_name_or_path, 32 | load_in_8bit=args.load_in_8bit, 33 | load_in_half=True, 34 | # device map is determined by the number of gpus available. 35 | device_map="balanced_low_0" if torch.cuda.device_count() > 1 else "auto", 36 | gptq_model=args.gptq, 37 | ) 38 | 39 | # these stop sequences are those mentioned in the codex paper. 40 | stop_sequences = ["\nclass", "\ndef", "\n#", "\nif", "\nprint"] 41 | # Because many tokenizers will treat the word after space differently from the original word alone, 42 | # to be consistent, we add a space before tokenization and remove it after tokenization. 43 | stop_sequences = [tokenizer.encode(" " + x, add_special_tokens=False)[1:] for x in stop_sequences] 44 | print("stop sequences:") 45 | print(stop_sequences) 46 | if args.chat_format == "vicuna": 47 | # Vicuna uses a different stop sequence otherwise meaning tokens may be cut off. 48 | stop_sequences = [[tokenizer.eos_token_id]] 49 | outputs_per_sampling_iter = [] 50 | for sampling_iter in range(args.unbiased_sampling_size_n): 51 | print(f"Sampling iter: {sampling_iter} / {args.unbiased_sampling_size_n}") 52 | samping_outputs = generate_completions( 53 | model=model, 54 | tokenizer=tokenizer, 55 | prompts=prompts, 56 | max_new_tokens=args.max_new_tokens, 57 | batch_size=args.eval_batch_size, 58 | stop_id_sequences=stop_sequences, 59 | num_return_sequences=1, # we don't use the hf num_return_sequences, because otherwise the real batch size will be multiplied by it and often cause oom. 60 | do_sample=True, # if only pass@1 is evaluated, we do greedy decoding. 61 | top_p=args.top_p, 62 | temperature=args.temperature, 63 | eos_token_id=tokenizer.eos_token_id, 64 | ) 65 | outputs_per_sampling_iter.append(samping_outputs) 66 | # regroup the outputs to match the number of test data. 67 | outputs = [] 68 | for i in range(len(prompts)): 69 | for j in range(args.unbiased_sampling_size_n): 70 | outputs.append(outputs_per_sampling_iter[j][i]) 71 | else: 72 | instances = [ 73 | { 74 | "id": examle["task_id"], 75 | "prompt": "Complete the following python function. Please only output the code for the completed function.\n\n\n" 76 | + prompt, 77 | } 78 | for examle, prompt in zip(test_data, prompts) 79 | ] 80 | results = query_openai_chat_model( 81 | engine=args.openai_engine, 82 | instances=instances, 83 | output_path=os.path.join(args.save_dir, "openai_query_results.jsonl"), 84 | batch_size=args.eval_batch_size, 85 | top_p=0.95, 86 | temperature=args.temperature, 87 | n=args.unbiased_sampling_size_n, 88 | ) 89 | outputs = [] 90 | for result in results: 91 | for choice in result["response_metadata"]["choices"]: 92 | outputs.append(choice["message"]["content"]) 93 | 94 | # duplicates test data to match the number of outputs. 95 | duplicate_test_data = [example for example in test_data for _ in range(args.unbiased_sampling_size_n)] 96 | predictions = [] 97 | for _, example, output in zip(prompts, duplicate_test_data, outputs): 98 | output_lines = output.split("\n") 99 | idx = len(output_lines) - 1 100 | while idx >= 0 and not output_lines[idx].strip().startswith("return"): 101 | idx -= 1 102 | output_lines = output_lines[: idx + 1] 103 | processed_output = "\n".join(output_lines) + "\n" 104 | predictions.append( 105 | { 106 | "task_id": example["task_id"], 107 | "prompt": example["prompt"], 108 | "completion": processed_output, 109 | "raw_response": output, 110 | } 111 | ) 112 | 113 | prediction_save_path = os.path.join(args.save_dir, "humaneval_predictions.jsonl") 114 | 115 | write_jsonl(prediction_save_path, predictions) 116 | pass_at_k_results = evaluate_functional_correctness( 117 | task="humaneval", 118 | sample_file=prediction_save_path, 119 | k=args.eval_pass_at_ks, 120 | problems={example["task_id"]: example for example in test_data}, 121 | ) 122 | 123 | print(pass_at_k_results) 124 | 125 | with open(os.path.join(args.save_dir, "metrics.json"), "w") as fout: 126 | json.dump(pass_at_k_results, fout) 127 | 128 | 129 | if __name__ == "__main__": 130 | parser = argparse.ArgumentParser() 131 | parser.add_argument("--data_file", type=str, help="Path to the data file.") 132 | parser.add_argument("--max_num_examples", type=int, default=None, help="Maximum number of examples to evaluate.") 133 | parser.add_argument( 134 | "--model_name_or_path", 135 | type=str, 136 | default=None, 137 | help="If specified, we will load the model to generate the predictions.", 138 | ) 139 | parser.add_argument( 140 | "--tokenizer_name_or_path", type=str, default=None, help="If specified, we will load the tokenizer from here." 141 | ) 142 | parser.add_argument( 143 | "--openai_engine", 144 | type=str, 145 | default=None, 146 | help="If specified, we will use the OpenAI API to generate the predictions.", 147 | ) 148 | parser.add_argument("--data_dir", type=str, help="Directory to save the data.") 149 | parser.add_argument("--cache_dir", type=str, help="Directory to save the cache.") 150 | parser.add_argument("--save_dir", type=str, required=True, help="Directory to save the results.") 151 | parser.add_argument("--max_new_tokens", type=int, required=True, help="Maximum number of tokens to generate.") 152 | parser.add_argument("--eval_batch_size", type=int, default=1, help="Batch size for evaluation.") 153 | parser.add_argument("--top_p", type=float, default=0.95) 154 | parser.add_argument( 155 | "--temperature", 156 | type=float, 157 | default=0.1, 158 | help="Temperature for sampling. This is should be low for evaluating smaller pass@k, and high for larger pass@k.", 159 | ) 160 | parser.add_argument( 161 | "--eval_pass_at_ks", nargs="+", type=int, default=[1], help="Multiple k's that we will report pass@k." 162 | ) 163 | parser.add_argument( 164 | "--unbiased_sampling_size_n", 165 | type=int, 166 | default=20, 167 | help="Codex HumanEval requires `n` sampled generations per prompt, to estimate the unbiased pass@k. ", 168 | ) 169 | 170 | parser.add_argument( 171 | "--load_in_8bit", 172 | action="store_true", 173 | help="Load model in 8bit mode, which will reduce memory and speed up inference.", 174 | ) 175 | parser.add_argument("--gptq", action="store_true", help="If given, we're evaluating a 4-bit quantized GPTQ model.") 176 | parser.add_argument( 177 | "--chat_format", 178 | type=str, 179 | default=None, 180 | choices=["tulu", "chatml", "llama-2-chat", "vicuna", "base"], 181 | help="if given, we will use the chat format to generate the predictions.", 182 | ) 183 | args = parser.parse_args() 184 | # model_name_or_path and openai_engine cannot be both None or both not None. 185 | assert (args.model_name_or_path is None) != ( 186 | args.openai_engine is None 187 | ), "Either model_name_or_path or openai_engine should be specified." 188 | assert args.unbiased_sampling_size_n >= max( 189 | args.eval_pass_at_ks 190 | ), "n should be larger than the largest k in eval_pass_at_ks." 191 | main(args) 192 | -------------------------------------------------------------------------------- /xchat/eval/mbpp/run_eval.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import random 5 | 6 | import torch 7 | from datasets import load_dataset 8 | 9 | from xchat.eval.mbpp.data import read_problems, write_jsonl 10 | from xchat.eval.mbpp.evaluation import evaluate_functional_correctness 11 | from xchat.eval.mbpp.prompts import getprompts 12 | from xchat.eval.utils import generate_completions, load_hf_lm_and_tokenizer, query_openai_chat_model 13 | 14 | 15 | def prepare_inference_arg(args, test_data): 16 | res = [] 17 | for example in test_data: 18 | d = {} 19 | d["prompt"] = getprompts(args.chat_format, args.few_shot).get_prompt(example) 20 | d["reference"] = example["reference"] 21 | d["task_id"] = example["task_id"] 22 | res.append(d) 23 | return res 24 | 25 | 26 | def prepare_data(args): 27 | if not os.path.exists(os.path.join(args.save_dir, "mbpp_inference_args.jsonl")): 28 | os.makedirs(args.save_dir, exist_ok=True) 29 | raw_dataset = load_dataset(path="mbpp", name=None, cache_dir=args.cache_dir)["test"] 30 | if len(raw_dataset) != 500: 31 | raise ValueError(f"Expected 500 test examples, got {len(raw_dataset)}") 32 | with open(os.path.join(args.save_dir, "mbpp_inference_args.jsonl"), "w") as f: 33 | for i, example in enumerate(raw_dataset): 34 | d__ = { 35 | "task_id": i, 36 | "prompt": example["text"], 37 | "reference": "\n".join(example["test_list"]), 38 | } 39 | f.write(json.dumps(d__) + "\n") 40 | return os.path.join(args.save_dir, "mbpp_inference_args.jsonl") 41 | 42 | 43 | def main(args): 44 | random.seed(42) 45 | if not os.path.exists(args.data_dir): 46 | os.makedirs(args.data_dir, exist_ok=True) 47 | if not os.path.exists(args.save_dir): 48 | os.makedirs(args.save_dir, exist_ok=True) 49 | 50 | data_file = prepare_data(args) 51 | 52 | test_data = list(read_problems(data_file).values()) 53 | if args.max_num_examples is not None and len(test_data) > args.max_num_examples: 54 | test_data = random.sample(test_data, args.max_num_examples) 55 | 56 | test_data = prepare_inference_arg(args, test_data) 57 | print("Number of examples:", len(test_data)) 58 | 59 | if args.model_name_or_path: 60 | print("Loading model and tokenizer...") 61 | model, tokenizer = load_hf_lm_and_tokenizer( 62 | model_name_or_path=args.model_name_or_path, 63 | tokenizer_name_or_path=args.tokenizer_name_or_path, 64 | load_in_8bit=args.load_in_8bit, 65 | load_in_half=True, 66 | # device map is determined by the number of gpus available. 67 | device_map="balanced_low_0" if torch.cuda.device_count() > 1 else "auto", 68 | gptq_model=args.gptq, 69 | ) 70 | tokenizer.padding_side = "left" 71 | tokenizer.pad_token_id = 0 72 | test_data = sorted( 73 | test_data, key=lambda x: len(tokenizer(x["prompt"], return_tensors="pt").input_ids[0]), reverse=True 74 | ) 75 | 76 | stop_sequences = ["\nassert", "\nprint"] 77 | stop_sequences = [tokenizer.encode(" " + x, add_special_tokens=False)[1:] for x in stop_sequences] 78 | prompts = [example["prompt"] for example in test_data] 79 | 80 | outputs_per_sampling_iter = [] 81 | for sampling_iter in range(args.unbiased_sampling_size_n): 82 | assert args.greedy_decoding is True 83 | print(f"Sampling iter: {sampling_iter} / {args.unbiased_sampling_size_n}") 84 | samping_outputs = generate_completions( 85 | model=model, 86 | tokenizer=tokenizer, 87 | prompts=prompts, 88 | max_new_tokens=args.max_new_tokens, 89 | batch_size=args.eval_batch_size, 90 | stop_id_sequences=stop_sequences, 91 | num_return_sequences=1, # we don't use the hf num_return_sequences, because otherwise the real batch size will be multiplied by it and often cause oom. 92 | do_sample=not args.greedy_decoding, # if only pass@1 is evaluated, we do greedy decoding. 93 | top_p=args.top_p, 94 | temperature=args.temperature, 95 | eos_token_id=tokenizer.eos_token_id, 96 | ) 97 | assert len(samping_outputs) == len(prompts) 98 | outputs_per_sampling_iter.append(samping_outputs) 99 | outputs = [] 100 | for i in range(len(prompts)): 101 | for j in range(args.unbiased_sampling_size_n): 102 | outputs.append(outputs_per_sampling_iter[j][i]) 103 | else: 104 | instances = [ 105 | { 106 | "id": examle["task_id"], 107 | "prompt": "Complete the following python function. Please only output the code for the completed function.\n\n\n" 108 | + prompt, 109 | } 110 | for examle, prompt in zip(test_data, prompts) 111 | ] 112 | results = query_openai_chat_model( 113 | engine=args.openai_engine, 114 | instances=instances, 115 | output_path=os.path.join(args.save_dir, "openai_query_results.jsonl"), 116 | batch_size=args.eval_batch_size, 117 | top_p=0.95, 118 | temperature=args.temperature, 119 | n=args.unbiased_sampling_size_n, 120 | ) 121 | outputs = [] 122 | for result in results: 123 | for choice in result["response_metadata"]["choices"]: 124 | outputs.append(choice["message"]["content"]) 125 | 126 | duplicate_test_data = [example for example in test_data for _ in range(args.unbiased_sampling_size_n)] 127 | assert len(duplicate_test_data) == len(outputs) 128 | 129 | def process(args, s): 130 | parse_seq = getprompts(args.chat_format, args.few_shot).get_parser_seq() 131 | return s.split(parse_seq)[0].strip() 132 | 133 | predictions = [ 134 | { 135 | "task_id": example["task_id"], 136 | "prompt": example["prompt"], 137 | "completion": process(args, output.strip()), 138 | "reference": example["reference"], 139 | } 140 | for example, output in zip(duplicate_test_data, outputs) 141 | ] 142 | prediction_save_path = os.path.join(args.save_dir, "mbpp_predictions.jsonl") 143 | write_jsonl(prediction_save_path, predictions, append=True) 144 | 145 | pass_at_k_results = evaluate_functional_correctness( 146 | sample_file=prediction_save_path, 147 | k=args.eval_pass_at_ks, 148 | problems={example["task_id"]: example for example in test_data}, 149 | ) 150 | 151 | print(pass_at_k_results) 152 | 153 | with open(os.path.join(args.save_dir, "metrics.json"), "w") as fout: 154 | json.dump(pass_at_k_results, fout) 155 | 156 | 157 | if __name__ == "__main__": 158 | parser = argparse.ArgumentParser() 159 | parser.add_argument("--data_file", type=str, help="Path to the data file.") 160 | parser.add_argument("--max_num_examples", type=int, default=None, help="Maximum number of examples to evaluate.") 161 | parser.add_argument( 162 | "--model_name_or_path", 163 | type=str, 164 | default=None, 165 | help="If specified, we will load the model to generate the predictions.", 166 | ) 167 | parser.add_argument( 168 | "--tokenizer_name_or_path", type=str, default=None, help="If specified, we will load the tokenizer from here." 169 | ) 170 | parser.add_argument( 171 | "--openai_engine", 172 | type=str, 173 | default=None, 174 | help="If specified, we will use the OpenAI API to generate the predictions.", 175 | ) 176 | parser.add_argument("--data_dir", type=str, help="Directory to save the data.") 177 | parser.add_argument("--cache_dir", type=str, help="Directory to save the cache.") 178 | parser.add_argument("--save_dir", type=str, required=True, help="Directory to save the results.") 179 | parser.add_argument("--max_new_tokens", type=int, required=True, help="Maximum number of tokens to generate.") 180 | parser.add_argument("--eval_batch_size", type=int, default=1, help="Batch size for evaluation.") 181 | parser.add_argument("--unbiased_sampling_size_n", type=int, default=1, help="Number of unbiased samples.") 182 | parser.add_argument("--few_shot", action="store_true", help="If given, we're evaluating 3-shot performance.") 183 | parser.add_argument("--greedy_decoding", action="store_true") 184 | parser.add_argument("--top_p", type=float) 185 | parser.add_argument( 186 | "--eval_pass_at_ks", nargs="+", type=int, default=[1], help="Multiple k's that we will report pass@k." 187 | ) 188 | parser.add_argument( 189 | "--temperature", 190 | type=float, 191 | default=0.1, 192 | help="Temperature for sampling. This is should be low for evaluating smaller pass@k, and high for larger pass@k.", 193 | ) 194 | parser.add_argument( 195 | "--load_in_8bit", 196 | action="store_true", 197 | help="Load mdodel in 8bit mode, which will reduce memory and speed up inference.", 198 | ) 199 | parser.add_argument("--gptq", action="store_true", help="If given, we're evaluating a 4-bit quantized GPTQ model.") 200 | parser.add_argument("--chat_format", type=str) 201 | args = parser.parse_args() 202 | # model_name_or_path and openai_engine cannot be both None or both not None. 203 | assert (args.model_name_or_path is None) != ( 204 | args.openai_engine is None 205 | ), "Either model_name_or_path or openai_engine should be specified." 206 | main(args) 207 | -------------------------------------------------------------------------------- /xchat/eval/gsm/run_eval.py: -------------------------------------------------------------------------------- 1 | # Copied and modified from https://github.com/allenai/open-instruct 2 | 3 | import argparse 4 | import json 5 | import os 6 | import random 7 | import re 8 | 9 | import evaluate 10 | 11 | from xchat.eval.gsm.examplars import EXAMPLARS as GSM_EXAMPLARS 12 | from xchat.eval.utils import generate_completions, load_hf_lm_and_tokenizer, query_openai_chat_model 13 | 14 | exact_match = evaluate.load("exact_match") 15 | 16 | 17 | def main(args): 18 | random.seed(42) 19 | 20 | print("Loading data...") 21 | test_data = [] 22 | with open(os.path.join(args.data_dir, "test.jsonl")) as fin: 23 | for line in fin: 24 | example = json.loads(line) 25 | test_data.append({"question": example["question"], "answer": example["answer"].split("####")[1].strip()}) 26 | 27 | # some numbers are in the `x,xxx` format, and we want to remove the comma 28 | for example in test_data: 29 | example["answer"] = re.sub(r"(\d),(\d)", r"\1\2", example["answer"]) 30 | assert float(example["answer"]), f"answer is not a valid number: {example['answer']}" 31 | 32 | if args.max_num_examples and len(test_data) > args.max_num_examples: 33 | test_data = random.sample(test_data, args.max_num_examples) 34 | 35 | if not os.path.exists(args.save_dir): 36 | os.makedirs(args.save_dir, exist_ok=True) 37 | 38 | global GSM_EXAMPLARS 39 | if args.n_shot: 40 | if len(GSM_EXAMPLARS) > args.n_shot: 41 | GSM_EXAMPLARS = random.sample(GSM_EXAMPLARS, args.n_shot) 42 | demonstrations = [] 43 | for example in GSM_EXAMPLARS: 44 | if args.no_cot: 45 | demonstrations.append("Question: " + example["question"] + "\n" + "Answer: " + example["short_answer"]) 46 | else: 47 | demonstrations.append("Question: " + example["question"] + "\n" + "Answer: " + example["cot_answer"]) 48 | prompt_prefix = "Answer the following questions.\n\n" + "\n\n".join(demonstrations) + "\n\n" 49 | else: 50 | prompt_prefix = "Answer the following question.\n\n" 51 | 52 | prompts = [] 53 | for example in test_data: 54 | if args.chat_format == "tulu": 55 | prompt = ( 56 | "<|user|>\n" 57 | + prompt_prefix 58 | + "Question: " 59 | + example["question"].strip() 60 | + "\n<|assistant|>\n" 61 | + "Answer:" 62 | ) 63 | elif args.chat_format == "lemur": 64 | prompt = ( 65 | "<|im_start|>user\n" 66 | + prompt_prefix 67 | + "Question: " 68 | + example["question"].strip() 69 | + "<|im_end|>\n<|im_start|>assistant\n" 70 | + "Answer:" 71 | ) 72 | elif args.chat_format == "wizardcoder": 73 | prompt = ( 74 | "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n" 75 | "### Instruction:\n" 76 | f"{prompt_prefix}" 77 | "Question: " 78 | f"{example['question'].strip()}" 79 | "\n### Response:\n" 80 | "Answer:" 81 | ) 82 | elif args.chat_format == "vicuna": 83 | prompt = ( 84 | "A chat between a curious user and an artificial intelligence assistant. " 85 | "The assistant gives helpful, detailed, and polite answers to the user's questions. " 86 | "USER: " + prompt_prefix + "Question: " + example["question"].strip() + "\nASSISTANT: " + "Answer:" 87 | ) 88 | elif args.chat_format == "codellama-instruct": 89 | prompt = "[INST] " + prompt_prefix + "Question: " + example["question"].strip() + "[/INST]" + "Answer:" 90 | elif args.chat_format == "llama-2-chat": 91 | system_message = """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your \ 92 | answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure\ 93 | that your responses are socially unbiased and positive in nature. 94 | 95 | If a question does not make any sense, or is not factually coherent, explain why instead of answering something not \ 96 | correct. If you don't know the answer to a question, please don't share false information.""" 97 | prompt = ( 98 | "[INST] <>\n" 99 | f"{system_message}" 100 | "\n<>\n\n" 101 | f"{prompt_prefix}" 102 | "Question: " 103 | f"{example['question'].strip()}" 104 | "[/INST]" 105 | "Answer:" 106 | ) 107 | else: 108 | prompt = prompt_prefix + "Question: " + example["question"].strip() + "\nAnswer:" 109 | prompts.append(prompt) 110 | 111 | if args.model_name_or_path: 112 | print("Loading model and tokenizer...") 113 | model, tokenizer = load_hf_lm_and_tokenizer( 114 | model_name_or_path=args.model_name_or_path, 115 | tokenizer_name_or_path=args.tokenizer_name_or_path, 116 | load_in_8bit=args.load_in_8bit, 117 | load_in_half=True, 118 | gptq_model=args.gptq, 119 | ) 120 | 121 | if args.chat_format == "tulu": 122 | stop_id_sequences = [[tokenizer.encode("<|assistant|>", add_special_tokens=False)[-1]]] 123 | elif args.chat_format == "lemur": 124 | stop_id_sequences = [ 125 | [tokenizer.encode("<|im_end|>", add_special_tokens=False)[-1]], 126 | [tokenizer.eos_token_id], 127 | ] 128 | elif args.chat_format == "vicuna" or args.chat_format == "llama-2-chat": 129 | stop_id_sequences = [[tokenizer.encode("", add_special_tokens=False)[-1]]] 130 | elif args.chat_format == "codellama-instruct": 131 | stop_id_sequences = [ 132 | [tokenizer.encode("[INST]", add_special_tokens=False)[-1]], 133 | [tokenizer.encode("", add_special_tokens=False)[-1]], 134 | ] 135 | else: 136 | # get the last token because the tokenizer may add space tokens at the start. 137 | stop_id_sequences = [[tokenizer.encode("\n", add_special_tokens=False)[-1]]] 138 | outputs = generate_completions( 139 | model=model, 140 | tokenizer=tokenizer, 141 | prompts=prompts, 142 | max_new_tokens=512, 143 | batch_size=args.eval_batch_size, 144 | stop_id_sequences=stop_id_sequences, 145 | ) 146 | else: 147 | instances = [{"id": prompt, "prompt": prompt} for _, prompt in enumerate(prompts)] 148 | results = query_openai_chat_model( 149 | engine=args.openai_engine, 150 | instances=instances, 151 | batch_size=args.eval_batch_size if args.eval_batch_size else 10, 152 | output_path=os.path.join(args.save_dir, "openai_results.jsonl"), 153 | ) 154 | outputs = [result["output"] for result in results] 155 | 156 | predictions = [] 157 | for output in outputs: 158 | # replace numbers like `x,xxx` with `xxxx` 159 | output = re.sub(r"(\d),(\d)", r"\1\2", output) 160 | numbers = re.findall(r"[-+]?\d*\.\d+|\d+", output) 161 | if numbers: 162 | predictions.append(numbers[-1]) 163 | else: 164 | predictions.append(output) 165 | 166 | raw_predictions = [ 167 | {"question": example["question"], "answer": example["answer"], "model_output": output, "prediction": pred} 168 | for example, output, pred in zip(test_data, outputs, predictions) 169 | ] 170 | 171 | with open(os.path.join(args.save_dir, "raw_predictions.jsonl"), "w") as fout: 172 | for raw_prediction in raw_predictions: 173 | fout.write(json.dumps(raw_prediction) + "\n") 174 | 175 | print("Calculating accuracy...") 176 | targets = [example["answer"] for example in test_data] 177 | 178 | em_score = exact_match.compute( 179 | predictions=predictions, references=targets, ignore_case=True, ignore_punctuation=True 180 | )["exact_match"] 181 | print(f"Exact match : {em_score}") 182 | 183 | with open(os.path.join(args.save_dir, "metrics.json"), "w") as fout: 184 | json.dump({"exact_match": em_score}, fout, indent=4) 185 | 186 | 187 | if __name__ == "__main__": 188 | parser = argparse.ArgumentParser() 189 | parser.add_argument("--data_dir", type=str, default="data/mgsm") 190 | parser.add_argument("--max_num_examples", type=int, default=None, help="maximum number of examples to evaluate.") 191 | parser.add_argument("--save_dir", type=str, default="results/mgsm") 192 | parser.add_argument( 193 | "--model_name_or_path", 194 | type=str, 195 | default=None, 196 | help="if specified, we will load the model to generate the predictions.", 197 | ) 198 | parser.add_argument( 199 | "--tokenizer_name_or_path", type=str, default=None, help="if specified, we will load the tokenizer from here." 200 | ) 201 | parser.add_argument( 202 | "--openai_engine", 203 | type=str, 204 | default=None, 205 | help="if specified, we will use the OpenAI API to generate the predictions.", 206 | ) 207 | parser.add_argument("--n_shot", type=int, default=8, help="max number of examples to use for demonstration.") 208 | parser.add_argument( 209 | "--no_cot", action="store_true", help="If given, we're evaluating a model without chain-of-thought." 210 | ) 211 | parser.add_argument("--eval_batch_size", type=int, default=1, help="batch size for evaluation.") 212 | parser.add_argument( 213 | "--load_in_8bit", 214 | action="store_true", 215 | help="load model in 8bit mode, which will reduce memory and speed up inference.", 216 | ) 217 | parser.add_argument("--gptq", action="store_true", help="If given, we're evaluating a 4-bit quantized GPTQ model.") 218 | parser.add_argument( 219 | "--chat_format", 220 | type=str, 221 | default=None, 222 | choices=["tulu", "lemur", "wizardcoder", "vicuna", "llama-2-chat", "codellama-instruct"], 223 | help="if given, we will use the chat format to generate the predictions.", 224 | ) 225 | args = parser.parse_args() 226 | 227 | # model_name_or_path and openai_engine cannot be both None or both not None. 228 | assert (args.model_name_or_path is None) != ( 229 | args.openai_engine is None 230 | ), "Either model_name_or_path or openai_engine should be specified." 231 | main(args) 232 | -------------------------------------------------------------------------------- /xchat/eval/mmlu/run_eval.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import time 5 | 6 | import numpy as np 7 | import pandas as pd 8 | import torch 9 | from tqdm import tqdm 10 | 11 | from xchat.eval.mmlu.categories import categories, subcategories 12 | from xchat.eval.utils import get_next_word_predictions, load_hf_lm_and_tokenizer, query_openai_chat_model 13 | 14 | choices = ["A", "B", "C", "D"] 15 | 16 | 17 | def format_subject(subject): 18 | subject_list = subject.split("_") 19 | s = "" 20 | for entry in subject_list: 21 | s += " " + entry 22 | return s 23 | 24 | 25 | def format_example(df, idx, include_answer=True): 26 | prompt = df.iloc[idx, 0] 27 | k = df.shape[1] - 2 28 | for j in range(k): 29 | prompt += f"\n{choices[j]}. {df.iloc[idx, j + 1]}" 30 | prompt += "\nAnswer:" 31 | if include_answer: 32 | prompt += f" {df.iloc[idx, k + 1]}\n\n" 33 | return prompt 34 | 35 | 36 | def gen_prompt(train_df, subject, k=-1): 37 | prompt = f"The following are multiple choice questions (with answers) about {format_subject(subject)}.\n\n" 38 | if k == -1: 39 | k = train_df.shape[0] 40 | for i in range(k): 41 | prompt += format_example(train_df, i) 42 | return prompt 43 | 44 | 45 | @torch.no_grad() 46 | def eval_hf_model(args, subject, model, tokenizer, dev_df, test_df, batch_size=1): 47 | prompts = [] 48 | for i in range(0, test_df.shape[0]): 49 | k = args.ntrain 50 | prompt_end = format_example(test_df, i, include_answer=False) 51 | train_prompt = gen_prompt(dev_df, subject, k) 52 | prompt = train_prompt + prompt_end 53 | 54 | tokenized_prompt = tokenizer(prompt, return_tensors="pt", add_special_tokens=False).input_ids 55 | # make sure every prompt is less than 2048 tokens 56 | while tokenized_prompt.shape[-1] > 2048: 57 | k -= 1 58 | train_prompt = gen_prompt(dev_df, subject, k) 59 | prompt = train_prompt + prompt_end 60 | tokenized_prompt = tokenizer(prompt, return_tensors="pt", add_special_tokens=False).input_ids 61 | 62 | if args.use_chat_format: 63 | prompt = "<|user|>\n" + prompt.strip() + "\n<|assistant|>\nThe answer is:" 64 | 65 | prompts.append(prompt) 66 | 67 | # get the answer for all examples 68 | # note: here we cannot directly use convert_tokens_to_ids because the some tokenizers will automatically add space prefix. 69 | answer_choice_ids = [tokenizer.encode(answer_choice, add_special_tokens=False)[0] for answer_choice in choices] 70 | pred_indices, all_probs = get_next_word_predictions( 71 | model, 72 | tokenizer, 73 | prompts, 74 | candidate_token_ids=answer_choice_ids, 75 | return_token_predictions=False, 76 | batch_size=batch_size, 77 | ) 78 | 79 | # get the metrics 80 | cors = [] 81 | groud_truths = test_df.iloc[:, -1].values 82 | for i in range(len(pred_indices)): 83 | prediction = choices[pred_indices[i]] 84 | ground_truth = groud_truths[i] 85 | cors.append(prediction == ground_truth) 86 | 87 | acc = np.mean(cors) 88 | cors = np.array(cors) 89 | 90 | all_probs = np.array(all_probs) 91 | print(f"Average accuracy {acc:.3f} - {subject}") 92 | return cors, acc, all_probs 93 | 94 | 95 | def eval_openai_chat_engine(args, subject, engine, dev_df, test_df, batch_size=1): 96 | import tiktoken 97 | 98 | gpt_tokenizer = tiktoken.get_encoding("cl100k_base") 99 | answer_choice_ids = [ 100 | gpt_tokenizer.encode(" " + x)[0] for x in choices 101 | ] # be careful, the tokenizer will tokenize " A" and "A" differently. 102 | 103 | prompts = [] 104 | for i in range(0, test_df.shape[0]): 105 | k = args.ntrain 106 | prompt_end = format_example(test_df, i, include_answer=False) 107 | train_prompt = gen_prompt(dev_df, subject, k) 108 | prompt = train_prompt + prompt_end 109 | prompts.append(prompt) 110 | 111 | instances = [{"id": prompt, "prompt": prompt} for _, prompt in enumerate(prompts)] 112 | results = query_openai_chat_model( 113 | engine=args.openai_engine, 114 | instances=instances, 115 | batch_size=args.eval_batch_size if args.eval_batch_size else 10, 116 | output_path=os.path.join(args.save_dir, f"{subject}_openai_results.jsonl"), 117 | logit_bias={token_id: 100 for token_id in answer_choice_ids}, 118 | max_tokens=1, 119 | ) 120 | 121 | # get the metrics 122 | cors = [] 123 | groud_truths = test_df.iloc[:, -1].values 124 | for i in range(len(test_df)): 125 | prediction = results[i]["output"].strip() 126 | ground_truth = groud_truths[i] 127 | cors.append(prediction == ground_truth) 128 | 129 | acc = np.mean(cors) 130 | cors = np.array(cors) 131 | 132 | all_probs = np.array( 133 | [[0.25, 0.25, 0.25, 0.25] for _ in range(len(test_df))] 134 | ) # dummy probs, just don't want to dig into the openai probs 135 | 136 | print(f"Average accuracy {acc:.3f} - {subject}") 137 | return cors, acc, all_probs 138 | 139 | 140 | def main(args): 141 | if args.model_name_or_path: 142 | print("Loading model and tokenizer...") 143 | model, tokenizer = load_hf_lm_and_tokenizer( 144 | model_name_or_path=args.model_name_or_path, 145 | tokenizer_name_or_path=args.tokenizer_name_or_path, 146 | load_in_8bit=args.load_in_8bit, 147 | load_in_half=True, 148 | gptq_model=args.gptq, 149 | ) 150 | 151 | subjects = sorted( 152 | [f.split("_test.csv")[0] for f in os.listdir(os.path.join(args.data_dir, "test")) if "_test.csv" in f] 153 | ) 154 | 155 | if args.subjects: 156 | assert all( 157 | subj in subjects for subj in args.subjects 158 | ), f"Some of the subjects you specified are not valid: {args.subjects}" 159 | subjects = args.subjects 160 | 161 | if not os.path.exists(args.save_dir): 162 | os.makedirs(args.save_dir) 163 | if not os.path.exists(os.path.join(args.save_dir)): 164 | os.makedirs(os.path.join(args.save_dir)) 165 | 166 | all_cors = [] 167 | subcat_cors = {subcat: [] for subcat_lists in subcategories.values() for subcat in subcat_lists} 168 | cat_cors = {cat: [] for cat in categories} 169 | 170 | for subject in tqdm(subjects, desc="Evaluating subjects: "): 171 | dev_df = pd.read_csv(os.path.join(args.data_dir, "dev", subject + "_dev.csv"), header=None)[: args.ntrain] 172 | test_df = pd.read_csv(os.path.join(args.data_dir, "test", subject + "_test.csv"), header=None) 173 | if args.n_instances and args.n_instances < test_df.shape[0]: 174 | test_df = test_df.sample(args.n_instances, random_state=42) 175 | 176 | if args.model_name_or_path: 177 | cors, acc, probs = eval_hf_model(args, subject, model, tokenizer, dev_df, test_df, args.eval_batch_size) 178 | else: 179 | cors, acc, probs = eval_openai_chat_engine( 180 | args, subject, args.openai_engine, dev_df, test_df, args.eval_batch_size 181 | ) 182 | 183 | subcats = subcategories[subject] 184 | for subcat in subcats: 185 | subcat_cors[subcat].append(cors) 186 | for key in categories: 187 | if subcat in categories[key]: 188 | cat_cors[key].append(cors) 189 | all_cors.append(cors) 190 | 191 | test_df["correct"] = cors 192 | for j in range(probs.shape[1]): 193 | choice = choices[j] 194 | test_df[f"choice{choice}_probs"] = probs[:, j] 195 | test_df.to_csv( 196 | os.path.join(args.save_dir, f"{subject}.csv"), 197 | index=None, 198 | ) 199 | 200 | for subcat in subcat_cors: 201 | subcat_acc = np.mean(np.concatenate(subcat_cors[subcat])) 202 | print(f"Average accuracy {subcat_acc:.3f} - {subcat}") 203 | 204 | for cat in cat_cors: 205 | cat_acc = np.mean(np.concatenate(cat_cors[cat])) 206 | print(f"Average accuracy {cat_acc:.3f} - {cat}") 207 | weighted_acc = np.mean(np.concatenate(all_cors)) 208 | print(f"Average accuracy: {weighted_acc:.3f}") 209 | 210 | # save results 211 | with open(os.path.join(args.save_dir, "metrics.json"), "w") as f: 212 | json.dump( 213 | { 214 | "average_acc": weighted_acc, 215 | "subcat_acc": {subcat: np.mean(np.concatenate(subcat_cors[subcat])) for subcat in subcat_cors}, 216 | "cat_acc": {cat: np.mean(np.concatenate(cat_cors[cat])) for cat in cat_cors}, 217 | }, 218 | f, 219 | ) 220 | 221 | 222 | if __name__ == "__main__": 223 | parser = argparse.ArgumentParser() 224 | parser.add_argument("--ntrain", type=int, default=5) 225 | parser.add_argument("--data_dir", type=str, default="data/mmlu") 226 | parser.add_argument("--save_dir", type=str, default="results/mmlu/llama-7B/") 227 | parser.add_argument( 228 | "--model_name_or_path", 229 | type=str, 230 | default=None, 231 | help="if specified, we will load the model to generate the predictions.", 232 | ) 233 | parser.add_argument( 234 | "--tokenizer_name_or_path", type=str, default=None, help="if specified, we will load the tokenizer from here." 235 | ) 236 | parser.add_argument( 237 | "--openai_engine", 238 | type=str, 239 | default=None, 240 | help="if specified, we will use the OpenAI API to generate the predictions.", 241 | ) 242 | parser.add_argument( 243 | "--subjects", 244 | nargs="*", 245 | help="which subjects to evaluate. If not specified, all the 57 subjects will be evaluated.", 246 | ) 247 | parser.add_argument( 248 | "--n_instances", 249 | type=int, 250 | help="if specified, a maximum of n_instances per subject will be used for the evaluation.", 251 | ) 252 | parser.add_argument("--eval_batch_size", type=int, default=1, help="batch size for evaluation.") 253 | parser.add_argument( 254 | "--load_in_8bit", 255 | action="store_true", 256 | help="load model in 8bit mode, which will reduce memory and speed up inference.", 257 | ) 258 | parser.add_argument("--gptq", action="store_true", help="If given, we're evaluating a 4-bit quantized GPTQ model.") 259 | parser.add_argument( 260 | "--use_chat_format", 261 | action="store_true", 262 | help="If given, the prompt will be encoded as a chat format with the roles in prompt.", 263 | ) 264 | args = parser.parse_args() 265 | 266 | # model_name_or_path and openai_engine cannot be both None or both not None. 267 | assert (args.model_name_or_path is None) != ( 268 | args.openai_engine is None 269 | ), "Either model_name_or_path or openai_engine should be specified." 270 | main(args) 271 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [Lemur: Open Foundation Models for Language Agents](https://arxiv.org/abs/2310.06830) 2 | 3 |

4 | Lemur 5 |

6 | 7 | Models 8 | 9 | 10 | Blog 11 | 12 | 13 | Paper 14 | 15 | 16 | Stars 17 | 18 | 19 | Open Issues 20 | 21 | 22 | Twitter Follow 23 | 24 | 25 | Join Slack 26 | 27 | 28 | Discord 29 | 30 | 31 | Lemur is an openly accessible language model optimized for both natural language and coding capabilities to serve as the backbone of versatile language agents. 32 | As language models continue to evolve from conversational chatbots to functional agents that can act in the real world, they need both strong language understanding and the ability to execute actions. Lemur balances natural language and coding skills to enable agents to follow instructions, reason for tasks, and take grounded actions. 33 | 34 |
35 | 36 |
37 | 38 | Please refer to our paper and code for more details: 39 | - [[Paper](https://arxiv.org/abs/2310.06830)] Lemur: Harmonizing Natural Language and Code for Language Agents 40 | - [[Blog](https://www.xlang.ai/blog/openlemur)] Introducing Lemur: Open Foundation Models for Language Agents 41 | 42 | 43 | ## 🔥 News 44 | * **[18 October, 2023]:** 🎉 We open-sourced [code for OpenAgents](https://github.com/xlang-ai/OpenAgents): An Open Platform for Language Agents in the Wild. 45 | * **[11 October, 2023]:** 🎉 We released the research paper and codebase. We will continue updating this repository. 46 | * **[23 August, 2023]:** 🎉 We released the weights of [`OpenLemur/lemur-70b-v1`](https://huggingface.co/OpenLemur/lemur-70b-v1), and [`OpenLemur/lemur-70b-chat-v1`](https://huggingface.co/OpenLemur/lemur-70b-chat-v1)! Check it out in [HuggingFace Hub](https://huggingface.co/OpenLemur). 47 | 48 | ## Models 49 | We released our models on the HuggingFace Hub: 50 | * [OpenLemur/lemur-70b-v1](https://huggingface.co/OpenLemur/lemur-70b-v1) 51 | * [OpenLemur/lemur-70b-chat-v1](https://huggingface.co/OpenLemur/lemur-70b-chat-v1) 52 | 53 | ## Table of Contents 54 | - [Lemur: Open Foundation Models for Language Agents](#lemur-open-foundation-models-for-language-agents) 55 | - [🔥 News](#-news) 56 | - [Models](#models) 57 | - [Table of Contents](#table-of-contents) 58 | - [Why Lemur?](#why-lemur) 59 | - [Quickstart](#quickstart) 60 | - [Setup](#setup) 61 | - [Lemur-70B](#lemur-70b) 62 | - [Lemur-70B-Chat](#lemur-70b-chat) 63 | - [Training](#training) 64 | - [Evaluation](#evaluation) 65 | - [Foundational Abilities](#foundational-abilities) 66 | - [Interactive Agent Skills](#interactive-agent-skills) 67 | - [Deploy](#deploy) 68 | - [MINT](#mint) 69 | - [WebArena](#webarena) 70 | - [InterCode](#intercode) 71 | - [Citation](#citation) 72 | - [Acknowledgements](#acknowledgements) 73 | 74 | 75 | 76 | ## Why Lemur? 77 | Most existing open-source models specialize in either natural language or code. Lemur combines both strengths by: 78 | 79 | - Pretraining Llama-2-70B on a 90B token corpus with the 10:1 ratio of code to text and obtaining Lemur-70B-v1 80 | - Instruction tuning Lemur-70B-v1 on 300K examples covering both text and code and obtaining Lemur-70B-Chat-v1 81 | 82 | This two-stage training produces state-of-the-art performance averaged across diverse language and coding benchmarks, surpassing other available open-source models and narrowing the gap between open-source and commercial models on agent abilities. 83 | 84 | ## Quickstart 85 | 86 | ### Setup 87 | First, we have to install all the libraries listed in `requirements.txt` 88 | 89 | ```bash 90 | conda create -n xchat python=3.10 91 | conda activate xchat 92 | conda install pytorch==2.0.1 pytorch-cuda=11.8 -c pytorch -c nvidia 93 | conda install -c "nvidia/label/cuda-11.8.0" cuda-nvcc 94 | ``` 95 | Then, install the xchat package: 96 | ```bash 97 | git clone git@github.com:OpenLemur/Lemur.git 98 | cd Lemur 99 | pip install -e . 100 | ``` 101 | 102 | ### Lemur-70B 103 | For the base model lemur-70b-v1, you can use it in this way: 104 | 105 |
106 | Click me 107 |

108 | 109 | ```python 110 | from transformers import AutoTokenizer, AutoModelForCausalLM 111 | 112 | tokenizer = AutoTokenizer.from_pretrained("OpenLemur/lemur-70b-v1") 113 | model = AutoModelForCausalLM.from_pretrained("OpenLemur/lemur-70b-v1", device_map="auto", load_in_8bit=True) 114 | 115 | # Text Generation Example 116 | prompt = "The world is " 117 | input = tokenizer(prompt, return_tensors="pt") 118 | output = model.generate(**input, max_length=50, num_return_sequences=1) 119 | generated_text = tokenizer.decode(output[0], skip_special_tokens=True) 120 | print(generated_text) 121 | 122 | # Code Generation Example 123 | prompt = """ 124 | def factorial(n): 125 | if n == 0: 126 | return 1 127 | """ 128 | input = tokenizer(prompt, return_tensors="pt") 129 | output = model.generate(**input, max_length=200, num_return_sequences=1) 130 | generated_code = tokenizer.decode(output[0], skip_special_tokens=True) 131 | print(generated_code) 132 | ``` 133 | 134 |

135 | 136 |
137 | 138 | 139 | ### Lemur-70B-Chat 140 | We instruction-finetune lemur-70b-v1 model with ChatML format to obtain lemur-70b-chat-v1. You can use lemur-70b-chat-v1 in this way: 141 | 142 |
143 | Click me 144 |

145 | 146 | ```python 147 | from transformers import AutoTokenizer, AutoModelForCausalLM 148 | 149 | tokenizer = AutoTokenizer.from_pretrained("OpenLemur/lemur-70b-chat-v1") 150 | model = AutoModelForCausalLM.from_pretrained("OpenLemur/lemur-70b-chat-v1", device_map="auto", load_in_8bit=True) 151 | 152 | # Text Generation Example 153 | prompt = """<|im_start|>system 154 | You are a helpful, respectful, and honest assistant. 155 | <|im_end|> 156 | <|im_start|>user 157 | What's a lemur's favorite fruit?<|im_end|> 158 | <|im_start|>assistant 159 | """ 160 | input = tokenizer(prompt, return_tensors="pt") 161 | output = model.generate(**input, max_length=50, num_return_sequences=1) 162 | generated_text = tokenizer.decode(output[0], skip_special_tokens=True) 163 | print(generated_text) 164 | 165 | # Code Generation Example 166 | prompt = """<|im_start|>system 167 | Below is an instruction that describes a task. Write a response that appropriately completes the request. 168 | <|im_end|> 169 | <|im_start|>user 170 | Write a Python function to merge two sorted lists into one sorted list without using any built-in sort functions.<|im_end|> 171 | <|im_start|>assistant 172 | """ 173 | input = tokenizer(prompt, return_tensors="pt") 174 | output = model.generate(**input, max_length=200, num_return_sequences=1) 175 | generated_code = tokenizer.decode(output[0], skip_special_tokens=True) 176 | print(generated_code) 177 | ``` 178 | 179 |

180 | 181 |
182 | 183 | ## Training 184 | 185 |
186 | 187 |
188 | 189 | 190 | ## Evaluation 191 | We evaluated Lemur across: 192 | - 8 language and code datasets like MMLU, BBH, GSM8K, HumanEval, and Spider to validate balanced capabilities 193 | - 13 interactive agent datasets to test skills like tool usage, adapting to feedback from environments or humans, and exploring partially observable digital or physical environments. 194 | 195 |
196 | 197 |
198 | 199 |
200 | 201 |
202 | 203 | ### Foundational Abilities 204 | We build the evaluation suite based on [open-instruct](https://github.com/allenai/open-instruct). We will keep updating more tasks and models. 205 | 206 | Currently, we support the following tasks: 207 | - [✅] [MMLU](./scripts/eval/mmlu.sh) 208 | - [✅] [BBH](./scripts/eval/bbh.sh) 209 | - [✅] [GSM8K](./scripts/eval/gsm8k.sh) 210 | - [✅] [HumanEval](./scripts/eval/human_eval.sh) 211 | - [✅] [MBPP](./scripts/eval/mbpp.sh) 212 | - [🚧] [Spider]() 213 | - [🚧] [MultiPL-E]() 214 | - [🚧] [DS-1000]() 215 | - [🚧] ... 216 | 217 | ### Interactive Agent Skills 218 | We use the evaluation frameworks provided by [MINT](https://github.com/xingyaoww/mint-bench), [InterCode](https://github.com/princeton-nlp/intercode), and [WebArena](https://github.com/web-arena-x/webarena) to evaluate interactive agent skills. 219 | 220 | #### Deploy 221 | We use vLLM to serve the Lemur model. However, the official FastChat codebase does not yet support Lemur-Chat. Therefore, we provide a docker to serve vLLM for Lemur. Please refer to [vllm_lemur.sh](./scripts/deploy/vllm_lemur.sh) for more detailed information. 222 | 223 | ```bash 224 | bash scripts/deploy/vllm_lemur.sh 225 | ``` 226 | 227 | #### MINT 228 | We [fork MINT](https://github.com/OpenLemur/mint-bench) codebase to share the configs we used. Please refer to [this config folder](https://github.com/OpenLemur/mint-bench/tree/main/configs) for more details. Please run vllm with [`vllm_lemur.sh`](./scripts/deploy/vllm_lemur.sh) script. 229 | 230 | #### WebArena 231 | We [fork WebArena](https://github.com/OpenLemur/webarena) codebase to enable vLLM evaluation. To run the evaluation on WebArena, please refer to our [forked WebArena codebase](https://github.com/OpenLemur/webarena). 232 | 233 | #### InterCode 234 | We [fork InterCode](https://github.com/OpenLemur/intercode) codebase and do modifications to enable Lemur evaluation. Please refer to [this script folder](https://github.com/OpenLemur/intercode/tree/master/scripts) for more details. Please run `text-generation-inference` with [`tgi_lemur.sh`](scripts/deploy/tgi_lemur.sh) script. 235 | 236 | 237 | ## Citation 238 | If you find our work helpful, please cite us: 239 | ``` 240 | @misc{xu2023lemur, 241 | title={Lemur: Harmonizing Natural Language and Code for Language Agents}, 242 | author={Yiheng Xu and Hongjin Su and Chen Xing and Boyu Mi and Qian Liu and Weijia Shi and Binyuan Hui and Fan Zhou and Yitao Liu and Tianbao Xie and Zhoujun Cheng and Siheng Zhao and Lingpeng Kong and Bailin Wang and Caiming Xiong and Tao Yu}, 243 | year={2023}, 244 | eprint={2310.06830}, 245 | archivePrefix={arXiv}, 246 | primaryClass={cs.CL} 247 | } 248 | ``` 249 | 250 | ## Acknowledgements 251 | 252 | The Lemur project is an open collaborative research effort between [XLang Lab](https://www.xlang.ai/) and [Salesforce Research](https://www.salesforceairesearch.com/). We thank the following institutions for their gift support: 253 | 254 |
255 | 256 | 257 | 258 | 259 | Salesforce Research 260 | 261 | 262 | 263 | 264 | 265 | Google Research 266 | 267 | 268 | 269 | 270 | 271 | Amazon AWS 272 | 273 | 274 | 275 |
276 | -------------------------------------------------------------------------------- /xchat/eval/bbh/run_eval.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import glob 3 | import json 4 | import os 5 | import random 6 | import re 7 | 8 | import evaluate 9 | import torch 10 | import tqdm 11 | 12 | from xchat.eval.utils import generate_completions, load_hf_lm_and_tokenizer, query_openai_chat_model 13 | 14 | exact_match = evaluate.load("exact_match") 15 | 16 | 17 | @torch.no_grad() 18 | def eval_hf_model(args, model, tokenizer, examples, task_prompt, save_path=None): 19 | targets = [example["target"] for example in examples] 20 | if save_path: 21 | fout = open(save_path, "w") 22 | 23 | prompts = [] 24 | for example in examples: 25 | if args.chat_format == "tulu": 26 | prompt = "<|user|>\n" + task_prompt.strip() + "\n\nQ: " + example["input"] + "\n<|assistant|>\nA:" 27 | elif args.chat_format == "lemur": 28 | prompt = ( 29 | "<|im_start|>user\n" 30 | + task_prompt.strip() 31 | + "\n\nQ: " 32 | + example["input"] 33 | + "<|im_end|>\n<|im_start|>assistant\nA:" 34 | ) 35 | elif args.chat_format == "codellama-instruct": 36 | prompt = "[INST]" + task_prompt.strip() + "\n\nQ:" + example["input"] + "[/INST]" + "\nA:" 37 | else: 38 | prompt = task_prompt.strip() + "\n\nQ: " + example["input"] + "\nA: " 39 | prompts.append(prompt) 40 | 41 | new_line_sequence = tokenizer.encode("\n", add_special_tokens=False)[-1] 42 | q_sequence = tokenizer.encode("Q: ", add_special_tokens=False) 43 | if args.chat_format == "tulu": 44 | stop_id_sequences = [ 45 | [tokenizer.encode("<|assistant|>", add_special_tokens=False)[-1]], 46 | [new_line_sequence], 47 | [q_sequence], 48 | ] 49 | elif args.chat_format == "lemur": 50 | stop_id_sequences = [ 51 | [tokenizer.encode("<|im_end|>", add_special_tokens=False)[-1]], 52 | [tokenizer.eos_token_id], 53 | [new_line_sequence], 54 | [q_sequence], 55 | ] 56 | elif args.chat_format == "codellama-instruct": 57 | stop_id_sequences = [ 58 | [tokenizer.encode("[INST]", add_special_tokens=False)[-1]], 59 | [tokenizer.encode("", add_special_tokens=False)[-1]], 60 | [new_line_sequence], 61 | [q_sequence], 62 | ] 63 | else: 64 | stop_id_sequences = [ 65 | [new_line_sequence], 66 | ] 67 | 68 | outputs = generate_completions( 69 | model=model, 70 | tokenizer=tokenizer, 71 | prompts=prompts, 72 | max_new_tokens=512, 73 | batch_size=args.eval_batch_size if args.eval_batch_size else 1, 74 | stop_id_sequences=stop_id_sequences, 75 | ) 76 | 77 | predictions = [] 78 | for example, output in zip(examples, outputs): 79 | example["raw_output"] = output 80 | 81 | # only keep the first part of the output - this is mainly for vanilla language models. 82 | output = output.strip().split("\n\n")[0].strip() 83 | 84 | # extract the first answer after `So the answer is` and before the next period. 85 | # if there is no such answer, we will just use the raw output. 86 | results = re.search(r"(.*?)\.", output) 87 | 88 | prediction = results.group(1).strip() if results else output.strip() 89 | 90 | if "(" in prediction and ")" in prediction and "[" not in prediction and "]" not in prediction: 91 | pattern = r"\(\w\)" 92 | match = re.search(pattern, prediction) 93 | prediction = f"{match.group(0)}" if match else prediction 94 | 95 | example["prediction"] = prediction 96 | predictions.append(prediction) 97 | if save_path: 98 | fout.write(json.dumps(example) + "\n") 99 | 100 | assert len(predictions) == len(targets), "number of predictions and targets are not the same." 101 | return exact_match.compute(predictions=predictions, references=targets, ignore_case=True, ignore_punctuation=True)[ 102 | "exact_match" 103 | ] 104 | 105 | 106 | def eval_openai_chat_engine(args, examples, task_prompt, save_path=None): 107 | targets = [example["target"] for example in examples] 108 | instances = [] 109 | for i, example in enumerate(examples): 110 | prompt = task_prompt.strip() + "\n\nQ: " + example["input"] + "\nA:" 111 | instances.append( 112 | { 113 | "id": example["id"] if "id" in example else i, 114 | "prompt": prompt, 115 | } 116 | ) 117 | 118 | if save_path: 119 | openai_result_save_path = os.path.join( 120 | os.path.dirname(save_path), os.path.basename(save_path).split(".")[0] + "_openai_results.jsonl" 121 | ) 122 | 123 | results = query_openai_chat_model( 124 | engine=args.openai_engine, 125 | instances=instances, 126 | batch_size=args.eval_batch_size if args.eval_batch_size else 10, 127 | output_path=openai_result_save_path if save_path else None, 128 | ) 129 | 130 | outputs = [result["output"] for result in results] 131 | assert len(outputs) == len(targets), "number of predictions and targets are not the same." 132 | 133 | if save_path: 134 | fout = open(save_path, "w") 135 | 136 | predictions = [] 137 | for example, output in zip(examples, outputs): 138 | example["raw_output"] = output 139 | # extract the first answer after `So the answer is` and before the next period. 140 | # if there is no such answer, we will just use the raw output. 141 | results = re.search(r"So the answer is (.*?)\.", output) 142 | prediction = results.group(1).strip() if results else output.strip() 143 | example["prediction"] = prediction 144 | predictions.append(prediction) 145 | if save_path: 146 | fout.write(json.dumps(example) + "\n") 147 | 148 | assert len(predictions) == len(targets), "number of predictions and targets are not the same." 149 | return exact_match.compute(predictions=predictions, references=targets, ignore_case=True, ignore_punctuation=True)[ 150 | "exact_match" 151 | ] 152 | 153 | 154 | def main(args): 155 | random.seed(args.seed) 156 | 157 | all_tasks = {} 158 | task_files = glob.glob(os.path.join(args.data_dir, "bbh", "*.json")) 159 | task_files = sorted(task_files) 160 | if args.task_start_idx is not None and args.task_end_idx is not None: 161 | task_files = task_files[args.task_start_idx : args.task_end_idx] 162 | print(task_files) 163 | for task_file in tqdm.tqdm(task_files, desc="Loading tasks"): 164 | with open(task_file) as f: 165 | task_name = os.path.basename(task_file).split(".")[0] 166 | all_tasks[task_name] = json.load(f)["examples"] 167 | if args.max_num_examples_per_task: 168 | all_tasks[task_name] = random.sample(all_tasks[task_name], args.max_num_examples_per_task) 169 | 170 | all_prompts = {} 171 | cot_prompt_files = glob.glob(os.path.join(args.data_dir, "cot-prompts", "*.txt")) 172 | cot_prompt_files = sorted(cot_prompt_files) 173 | if args.task_start_idx is not None and args.task_end_idx is not None: 174 | cot_prompt_files = cot_prompt_files[args.task_start_idx : args.task_end_idx] 175 | print(cot_prompt_files) 176 | for cot_prompt_file in tqdm.tqdm(cot_prompt_files, desc="Loading prompts"): 177 | with open(cot_prompt_file) as f: 178 | task_name = os.path.basename(cot_prompt_file).split(".")[0] 179 | task_prompt = "".join(f.readlines()[2:]) 180 | if args.no_cot: 181 | prompt_fields = task_prompt.split("\n\n") 182 | new_prompt_fields = [] 183 | for prompt_field in prompt_fields: 184 | if prompt_field.startswith("Q:"): 185 | assert ( 186 | "So the answer is" in prompt_field 187 | ), f"`So the answer is` not found in prompt field of {task_name}.txt." 188 | assert "\nA:" in prompt_field, "`\nA:` not found in prompt field." 189 | answer = prompt_field.split("So the answer is")[-1].strip() 190 | question = prompt_field.split("\nA:")[0].strip() 191 | new_prompt_fields.append(question + "\nA: " + answer) 192 | else: 193 | new_prompt_fields.append(prompt_field) 194 | task_prompt = "\n\n".join(new_prompt_fields) 195 | all_prompts[task_name] = task_prompt 196 | 197 | assert set(all_tasks.keys()) == set( 198 | all_prompts.keys() 199 | ), "task names in task data and task prompts are not the same." 200 | 201 | os.makedirs(args.save_dir, exist_ok=True) 202 | os.makedirs(os.path.join(args.save_dir, "predictions"), exist_ok=True) 203 | 204 | if args.model_name_or_path: 205 | print("Loading model and tokenizer...") 206 | model, tokenizer = load_hf_lm_and_tokenizer( 207 | model_name_or_path=args.model_name_or_path, 208 | tokenizer_name_or_path=args.tokenizer_name_or_path, 209 | load_in_8bit=args.load_in_8bit, 210 | load_in_half=True, 211 | gptq_model=args.gptq, 212 | ) 213 | 214 | performance = {} 215 | for task_name in tqdm.tqdm(all_tasks.keys(), desc="Evaluating"): 216 | if args.special_task is not None: 217 | special_tasks = args.special_task.split(",") 218 | if task_name not in special_tasks: 219 | continue 220 | task_examples = all_tasks[task_name] 221 | prompt = all_prompts[task_name] 222 | if args.model_name_or_path: 223 | task_perf = eval_hf_model( 224 | args, 225 | model, 226 | tokenizer, 227 | task_examples, 228 | prompt, 229 | save_path=os.path.join(args.save_dir, "predictions", f"{task_name}.jsonl"), 230 | ) 231 | else: 232 | task_perf = eval_openai_chat_engine( 233 | args, task_examples, prompt, save_path=os.path.join(args.save_dir, "predictions", f"{task_name}.jsonl") 234 | ) 235 | performance[task_name] = task_perf 236 | print(f"Task {task_name} - EM: {task_perf}") 237 | 238 | with open(os.path.join(args.save_dir, "metrics.json"), "w") as fout: 239 | performance["average_exact_match"] = sum(performance.values()) / len(performance) 240 | print(f"Average EM: {performance['average_exact_match']}") 241 | json.dump(performance, fout, indent=4) 242 | 243 | 244 | if __name__ == "__main__": 245 | parser = argparse.ArgumentParser() 246 | parser.add_argument("--data_dir", type=str, default="data/bbh") 247 | parser.add_argument("--save_dir", type=str, default="results/bbh") 248 | parser.add_argument( 249 | "--model_name_or_path", 250 | type=str, 251 | default=None, 252 | help="if specified, we will load the model to generate the predictions.", 253 | ) 254 | parser.add_argument( 255 | "--tokenizer_name_or_path", type=str, default=None, help="if specified, we will load the tokenizer from here." 256 | ) 257 | parser.add_argument( 258 | "--openai_engine", 259 | type=str, 260 | default=None, 261 | help="if specified, we will use the OpenAI API to generate the predictions.", 262 | ) 263 | parser.add_argument( 264 | "--no_cot", action="store_true", help="if specified, chain of thoughts will be removed from the prompts." 265 | ) 266 | parser.add_argument( 267 | "--max_num_examples_per_task", type=int, default=None, help="maximum number of examples to evaluate per task." 268 | ) 269 | parser.add_argument("--eval_batch_size", type=int, default=1, help="batch size for evaluation.") 270 | parser.add_argument( 271 | "--load_in_8bit", 272 | action="store_true", 273 | help="load model in 8bit mode, which will reduce memory and speed up inference.", 274 | ) 275 | parser.add_argument("--gptq", action="store_true", help="If given, we're evaluating a 4-bit quantized GPTQ model.") 276 | parser.add_argument( 277 | "--chat_format", 278 | type=str, 279 | default=None, 280 | choices=["tulu", "lemur", "codellama-instruct"], 281 | help="If given, the prompt will be encoded as a chat format with the roles in prompt.", 282 | ) 283 | parser.add_argument( 284 | "--seed", 285 | type=int, 286 | default=42, 287 | ) 288 | parser.add_argument( 289 | "--task_start_idx", 290 | type=int, 291 | default=None, 292 | ) 293 | parser.add_argument( 294 | "--task_end_idx", 295 | type=int, 296 | default=None, 297 | ) 298 | parser.add_argument("--special_task", type=str, default=None) 299 | 300 | args = parser.parse_args() 301 | 302 | # model_name_or_path and openai_engine cannot be both None or both not None. 303 | assert (args.model_name_or_path is None) != ( 304 | args.openai_engine is None 305 | ), "Either model_name_or_path or openai_engine should be specified." 306 | main(args) 307 | -------------------------------------------------------------------------------- /xchat/eval/utils.py: -------------------------------------------------------------------------------- 1 | # This code is based on https://github.com/allenai/open-instruct. 2 | 3 | import asyncio 4 | import json 5 | import os 6 | import time 7 | 8 | import torch 9 | import tqdm 10 | from transformers import StoppingCriteria 11 | 12 | from xchat.eval.dispatch_openai_requests import dispatch_openai_chat_requesets, dispatch_openai_prompt_requesets 13 | 14 | 15 | class KeyWordsCriteria(StoppingCriteria): 16 | def __init__(self, stop_id_sequences): 17 | assert isinstance(stop_id_sequences[0], list), "stop_id_sequences should be a list of list of ids" 18 | self.stop_sequences = stop_id_sequences 19 | 20 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: 21 | sequences_should_be_stopped = [] 22 | for i in range(input_ids.shape[0]): 23 | sequence_should_be_stopped = False 24 | for stop_sequence in self.stop_sequences: 25 | if input_ids[i][-len(stop_sequence) :].tolist() == stop_sequence: 26 | sequence_should_be_stopped = True 27 | break 28 | sequences_should_be_stopped.append(sequence_should_be_stopped) 29 | return all(sequences_should_be_stopped) 30 | 31 | 32 | @torch.no_grad() 33 | def generate_completions( 34 | model, 35 | tokenizer, 36 | prompts, 37 | batch_size=1, 38 | stop_id_sequences=None, 39 | disable_tqdm=False, 40 | assigned_temperatures=None, 41 | **generation_kwargs, 42 | ): 43 | # sort prompts by length to reduce the number of padding tokens 44 | prompt_lengths = [len(prompt) for prompt in prompts] 45 | sorted_prompt_indices = sorted(range(len(prompt_lengths)), key=lambda k: prompt_lengths[k]) 46 | prompts = [prompts[i] for i in sorted_prompt_indices] 47 | 48 | if assigned_temperatures is not None: 49 | assigned_temperatures = [assigned_temperatures[i] for i in sorted_prompt_indices] 50 | 51 | generations = [] 52 | if not disable_tqdm: 53 | progress = tqdm.tqdm(total=len(prompts), desc="Generating Completions") 54 | 55 | num_return_sequences = generation_kwargs.get("num_return_sequences", 1) 56 | 57 | # if do_sample=True then set default temperature to 0.0 58 | # by default, we use temperature 0.0 to get the most likely completion. 59 | if generation_kwargs.get("do_sample", False): 60 | generate_temperature = generation_kwargs.get("temperature", 0.0) 61 | 62 | for i in range(0, len(prompts), batch_size): 63 | batch_prompts = prompts[i : i + batch_size] 64 | tokenized_prompts = tokenizer(batch_prompts, padding="longest", return_tensors="pt", add_special_tokens=False) 65 | batch_input_ids = tokenized_prompts.input_ids 66 | attention_mask = tokenized_prompts.attention_mask 67 | 68 | if model.device.type == "cuda": 69 | batch_input_ids = batch_input_ids.cuda() 70 | attention_mask = attention_mask.cuda() 71 | 72 | # Set temperature for each batch if assigned 73 | if generation_kwargs.get("do_sample", False): 74 | assigned_temperature = ( 75 | assigned_temperatures[i : i + batch_size][0] if assigned_temperatures is not None else None 76 | ) 77 | # temperature = assigned_temperature if assigned_temperature is not None else generate_temperature 78 | generation_kwargs["temperature"] = ( 79 | assigned_temperature if assigned_temperature is not None else generate_temperature 80 | ) 81 | 82 | try: 83 | batch_outputs = model.generate( 84 | input_ids=batch_input_ids, 85 | attention_mask=attention_mask, 86 | stopping_criteria=[KeyWordsCriteria(stop_id_sequences)] if stop_id_sequences else None, 87 | **generation_kwargs, 88 | ) 89 | 90 | # the stopping criteria is applied at batch level, so if other examples are not stopped, the entire batch will continue to generate. 91 | # so some outputs still have the stop sequence, which we need to remove. 92 | if stop_id_sequences: 93 | for output_idx in range(batch_outputs.shape[0]): 94 | for token_idx in range(batch_input_ids.shape[1], batch_outputs.shape[1]): 95 | if any( 96 | batch_outputs[output_idx, token_idx : token_idx + len(stop_sequence)].tolist() 97 | == stop_sequence 98 | for stop_sequence in stop_id_sequences 99 | ): 100 | batch_outputs[output_idx, token_idx:] = tokenizer.pad_token_id 101 | break 102 | 103 | # remove the prompt from the output 104 | # we need to re-encode the prompt because we need to make sure the special tokens are treated the same way as in the outputs. 105 | # we changed our previous way of truncating the output token ids dicrectly because some tokenizer (e.g., llama) won't add space token before the first token. 106 | # space is important for some tasks (e.g., code completion). 107 | batch_outputs = tokenizer.batch_decode(batch_outputs, skip_special_tokens=True) 108 | batch_prompts = tokenizer.batch_decode(batch_input_ids, skip_special_tokens=True) 109 | # duplicate the prompts to match the number of return sequences 110 | batch_prompts = [prompt for prompt in batch_prompts for _ in range(num_return_sequences)] 111 | batch_generations = [output[len(prompt) :] for prompt, output in zip(batch_prompts, batch_outputs)] 112 | except Exception as e: 113 | print("Error when generating completions for batch:") 114 | print(batch_prompts) 115 | print("Error message:") 116 | print(e) 117 | print("Use empty string as the completion.") 118 | batch_generations = [""] * len(batch_prompts) * num_return_sequences 119 | 120 | generations += batch_generations 121 | 122 | # for prompt, generation in zip(batch_prompts, batch_generations): 123 | # print("========") 124 | # print(prompt) 125 | # print("--------") 126 | # print(generation) 127 | 128 | if not disable_tqdm: 129 | progress.update(len(batch_prompts) // num_return_sequences) 130 | 131 | assert ( 132 | len(generations) == len(prompts) * num_return_sequences 133 | ), "number of generations should be equal to number of prompts * num_return_sequences" 134 | 135 | generations = sorted(zip(sorted_prompt_indices, generations), key=lambda x: x[0]) 136 | generations = [generation for _, generation in generations] 137 | return generations 138 | 139 | 140 | @torch.no_grad() 141 | def get_next_word_predictions( 142 | model, 143 | tokenizer, 144 | prompts, 145 | candidate_token_ids=None, 146 | batch_size=1, 147 | return_token_predictions=False, 148 | disable_tqdm=False, 149 | ): 150 | predictions, probs = [], [] 151 | if not disable_tqdm: 152 | progress = tqdm.tqdm(total=len(prompts), desc="Getting Predictions") 153 | 154 | for i in range(0, len(prompts), batch_size): 155 | batch_prompts = prompts[i : i + batch_size] 156 | tokenized_prompts = tokenizer(batch_prompts, padding="longest", return_tensors="pt", add_special_tokens=False) 157 | batch_input_ids = tokenized_prompts.input_ids 158 | attention_mask = tokenized_prompts.attention_mask 159 | 160 | if model.device.type == "cuda": 161 | batch_input_ids = batch_input_ids.cuda() 162 | attention_mask = attention_mask.cuda() 163 | 164 | batch_logits = model(input_ids=batch_input_ids, attention_mask=attention_mask).logits[:, -1, :] 165 | if candidate_token_ids is not None: 166 | batch_logits = batch_logits[:, candidate_token_ids] 167 | batch_probs = torch.softmax(batch_logits, dim=-1) 168 | batch_prediction_indices = torch.argmax(batch_probs, dim=-1) 169 | if return_token_predictions: 170 | if candidate_token_ids is not None: 171 | candidate_tokens = tokenizer.convert_ids_to_tokens(candidate_token_ids) 172 | batch_predictions = [candidate_tokens[idx] for idx in batch_prediction_indices] 173 | else: 174 | batch_predictions = tokenizer.convert_ids_to_tokens(batch_prediction_indices) 175 | predictions += batch_predictions 176 | else: 177 | predictions += batch_prediction_indices.tolist() 178 | probs += batch_probs.tolist() 179 | 180 | if not disable_tqdm: 181 | progress.update(len(batch_prompts)) 182 | 183 | assert len(predictions) == len(prompts), "number of predictions should be equal to number of prompts" 184 | return predictions, probs 185 | 186 | 187 | def load_hf_lm_and_tokenizer( 188 | model_name_or_path, 189 | tokenizer_name_or_path=None, 190 | device_map="auto", 191 | load_in_8bit=False, 192 | load_in_half=False, 193 | gptq_model=False, 194 | use_fast_tokenizer=False, 195 | padding_side="left", 196 | ): 197 | from transformers import AutoModelForCausalLM, AutoTokenizer 198 | 199 | if not tokenizer_name_or_path: 200 | tokenizer_name_or_path = model_name_or_path 201 | tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path, use_fast=use_fast_tokenizer, legacy=True) 202 | # set padding side to left for batch generation 203 | tokenizer.padding_side = padding_side 204 | # set pad token to eos token if pad token is not set (as is the case for llama models) 205 | if tokenizer.pad_token is None: 206 | tokenizer.pad_token = tokenizer.unk_token 207 | tokenizer.pad_token_id = tokenizer.unk_token_id 208 | if gptq_model: 209 | from auto_gptq import AutoGPTQForCausalLM 210 | 211 | model_wrapper = AutoGPTQForCausalLM.from_quantized(model_name_or_path, device="cuda:0", use_triton=True) 212 | model = model_wrapper.model 213 | elif load_in_8bit: 214 | model = AutoModelForCausalLM.from_pretrained(model_name_or_path, device_map=device_map, load_in_8bit=True) 215 | else: 216 | if device_map: 217 | model = AutoModelForCausalLM.from_pretrained(model_name_or_path, device_map=device_map) 218 | else: 219 | model = AutoModelForCausalLM.from_pretrained(model_name_or_path) 220 | if torch.cuda.is_available(): 221 | model = model.cuda() 222 | if load_in_half: 223 | model = model.half() 224 | model.eval() 225 | return model, tokenizer 226 | 227 | 228 | def query_openai_chat_model( 229 | engine, 230 | instances, 231 | output_path=None, 232 | batch_size=10, 233 | retry_limit=5, 234 | reuse_existing_outputs=True, 235 | organized_message_format=False, 236 | assigned_temperatures=None, 237 | **completion_kwargs, 238 | ): 239 | """Query OpenAI chat model and save the results to output_path. 240 | 241 | `instances` is a list of dictionaries, each dictionary contains a key "prompt" and a key "id". 242 | """ 243 | existing_data = {} 244 | if reuse_existing_outputs and output_path is not None and os.path.exists(output_path): 245 | with open(output_path) as f: 246 | for line in f: 247 | instance = json.loads(line) 248 | existing_data[instance["id"]] = instance 249 | 250 | # by default, we use temperature 0.0 to get the most likely completion. 251 | generate_temperature = completion_kwargs.pop("temperature", 0.0) 252 | 253 | results = [] 254 | if output_path is not None: 255 | fout = open(output_path, "w") # noqa: SIM115 256 | 257 | retry_count = 0 258 | progress_bar = tqdm.tqdm(total=len(instances)) 259 | for i in range(0, len(instances), batch_size): 260 | batch = instances[i : i + batch_size] 261 | if all(x["id"] in existing_data for x in batch): 262 | results.extend([existing_data[x["id"]] for x in batch]) 263 | if output_path is not None: 264 | for instance in batch: 265 | fout.write(json.dumps(existing_data[instance["id"]]) + "\n") 266 | fout.flush() 267 | progress_bar.update(batch_size) 268 | continue 269 | messages_list = [] 270 | for instance in batch: 271 | if organized_message_format: 272 | messages = instance["prompt"] 273 | else: 274 | messages = [{"role": "user", "content": instance["prompt"]}] 275 | messages_list.append(messages) 276 | 277 | # Set temperature for each batch if assigned 278 | assigned_temperature = ( 279 | assigned_temperatures[i : i + batch_size][0] if assigned_temperatures is not None else None 280 | ) 281 | temperature = assigned_temperature if assigned_temperature is not None else generate_temperature 282 | 283 | while retry_count < retry_limit: 284 | try: 285 | outputs = asyncio.run( 286 | dispatch_openai_chat_requesets( 287 | messages_list=messages_list, 288 | model=engine, 289 | temperature=temperature, 290 | **completion_kwargs, 291 | ) 292 | ) 293 | retry_count = 0 294 | break 295 | except Exception as e: 296 | retry_count += 1 297 | print("Error while requesting OpenAI API.") 298 | print(e) 299 | print(f"Sleep for {30*retry_count} seconds.") 300 | time.sleep(30 * retry_count) 301 | print(f"Retry for the {retry_count} time.") 302 | if retry_count == retry_limit: 303 | raise RuntimeError(f"Failed to get response from OpenAI API after {retry_limit} retries.") 304 | assert len(outputs) == len(batch) 305 | for instance, output in zip(batch, outputs): 306 | instance["output"] = output["choices"][0]["message"]["content"] 307 | instance["response_metadata"] = output 308 | results.append(instance) 309 | if output_path is not None: 310 | fout.write(json.dumps(instance) + "\n") 311 | fout.flush() 312 | progress_bar.update(batch_size) 313 | return results 314 | 315 | 316 | def query_openai_model( 317 | engine, instances, output_path=None, batch_size=10, retry_limit=5, reuse_existing_outputs=True, **completion_kwargs 318 | ): 319 | """Query OpenAI chat model and save the results to output_path. 320 | 321 | `instances` is a list of dictionaries, each dictionary contains a key "prompt" and a key "id". 322 | """ 323 | existing_data = {} 324 | if reuse_existing_outputs and output_path is not None and os.path.exists(output_path): 325 | with open(output_path) as f: 326 | for line in f: 327 | instance = json.loads(line) 328 | existing_data[instance["id"]] = instance 329 | 330 | # by default, we use temperature 0.0 to get the most likely completion. 331 | if "temperature" not in completion_kwargs: 332 | completion_kwargs["temperature"] = 0.0 333 | 334 | results = [] 335 | if output_path is not None: 336 | fout = open(output_path, "w") # noqa: SIM115 337 | 338 | retry_count = 0 339 | progress_bar = tqdm.tqdm(total=len(instances)) 340 | for i in range(0, len(instances), batch_size): 341 | batch = instances[i : i + batch_size] 342 | if all(x["id"] in existing_data for x in batch): 343 | results.extend([existing_data[x["id"]] for x in batch]) 344 | if output_path is not None: 345 | for instance in batch: 346 | fout.write(json.dumps(existing_data[instance["id"]]) + "\n") 347 | fout.flush() 348 | progress_bar.update(batch_size) 349 | continue 350 | messages_list = [] 351 | for instance in batch: 352 | messages = instance["prompt"] 353 | messages_list.append(messages) 354 | while retry_count < retry_limit: 355 | try: 356 | outputs = asyncio.run( 357 | dispatch_openai_prompt_requesets( 358 | prompt_list=messages_list, 359 | model=engine, 360 | **completion_kwargs, 361 | ) 362 | ) 363 | retry_count = 0 364 | break 365 | except Exception as e: 366 | retry_count += 1 367 | print("Error while requesting OpenAI API.") 368 | print(e) 369 | print(f"Sleep for {30*retry_count} seconds.") 370 | time.sleep(30 * retry_count) 371 | print(f"Retry for the {retry_count} time.") 372 | if retry_count == retry_limit: 373 | raise RuntimeError(f"Failed to get response from OpenAI API after {retry_limit} retries.") 374 | assert len(outputs) == len(batch) 375 | for instance, output in zip(batch, outputs): 376 | instance["output"] = output["choices"][0]["text"] 377 | instance["response_metadata"] = output 378 | results.append(instance) 379 | if output_path is not None: 380 | fout.write(json.dumps(instance) + "\n") 381 | fout.flush() 382 | progress_bar.update(batch_size) 383 | return results 384 | -------------------------------------------------------------------------------- /xchat/train/train.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script is adapted from the https://github.com/allenai/open-instruct/blob/main/open_instruct/finetune.py. 3 | """ 4 | 5 | import argparse 6 | import logging 7 | import math 8 | import os 9 | import random 10 | from functools import partial 11 | 12 | import datasets 13 | import torch 14 | import transformers 15 | from accelerate import Accelerator 16 | from accelerate.logging import get_logger 17 | from accelerate.utils import set_seed 18 | from datasets import load_dataset 19 | from peft import LoraConfig, TaskType, get_peft_model 20 | from torch.utils.data import DataLoader 21 | from tqdm.auto import tqdm 22 | from transformers import ( 23 | AutoConfig, 24 | AutoModelForCausalLM, 25 | AutoTokenizer, 26 | DataCollatorForSeq2Seq, 27 | GPT2Tokenizer, 28 | GPTNeoXTokenizerFast, 29 | LlamaTokenizer, 30 | LlamaTokenizerFast, 31 | OPTForCausalLM, 32 | SchedulerType, 33 | get_scheduler, 34 | ) 35 | 36 | logger = get_logger(__name__) 37 | 38 | 39 | def parse_args(): 40 | parser = argparse.ArgumentParser(description="Finetune a transformers model on a causal language modeling task") 41 | parser.add_argument( 42 | "--dataset_name", 43 | type=str, 44 | default=None, 45 | help="The name of the dataset to use (via the datasets library).", 46 | ) 47 | parser.add_argument( 48 | "--dataset_config_name", 49 | type=str, 50 | default=None, 51 | help="The configuration name of the dataset to use (via the datasets library).", 52 | ) 53 | parser.add_argument( 54 | "--train_file", type=str, default=None, help="A csv or a json file containing the training data." 55 | ) 56 | parser.add_argument( 57 | "--chat_format", 58 | type=str, 59 | default="chatml", 60 | choices=["chatml", "tulu"], 61 | ) 62 | parser.add_argument( 63 | "--data_packing", 64 | action="store_true", 65 | help="If passed, will pack multiple examples into one sequence.", 66 | ) 67 | parser.add_argument( 68 | "--model_name_or_path", 69 | type=str, 70 | help="Path to pretrained model or model identifier from huggingface.co/models.", 71 | required=False, 72 | ) 73 | parser.add_argument( 74 | "--config_name", 75 | type=str, 76 | default=None, 77 | help="Pretrained config name or path if not the same as model_name", 78 | ) 79 | parser.add_argument( 80 | "--use_lora", 81 | action="store_true", 82 | help="If passed, will use LORA (low-rank parameter-efficient training) to train the model.", 83 | ) 84 | parser.add_argument( 85 | "--lora_rank", 86 | type=int, 87 | default=64, 88 | help="The rank of lora.", 89 | ) 90 | parser.add_argument( 91 | "--lora_alpha", 92 | type=float, 93 | default=16, 94 | help="The alpha parameter of lora.", 95 | ) 96 | parser.add_argument( 97 | "--lora_dropout", 98 | type=float, 99 | default=0.1, 100 | help="The dropout rate of lora modules.", 101 | ) 102 | parser.add_argument( 103 | "--save_merged_lora_model", 104 | action="store_true", 105 | help="If passed, will merge the lora modules and save the entire model.", 106 | ) 107 | parser.add_argument( 108 | "--use_flash_attn", 109 | action="store_true", 110 | help="If passed, will use flash attention to train the model.", 111 | ) 112 | parser.add_argument( 113 | "--tokenizer_name", 114 | type=str, 115 | default=None, 116 | help="Pretrained tokenizer name or path if not the same as model_name", 117 | ) 118 | parser.add_argument( 119 | "--use_slow_tokenizer", 120 | action="store_true", 121 | help="If passed, will use a slow tokenizer (not backed by the 🤗 Tokenizers library).", 122 | ) 123 | parser.add_argument( 124 | "--max_seq_length", 125 | type=int, 126 | default=512, 127 | help="The maximum total sequence length (prompt+completion) of each training example.", 128 | ) 129 | parser.add_argument( 130 | "--per_device_train_batch_size", 131 | type=int, 132 | default=8, 133 | help="Batch size (per device) for the training dataloader.", 134 | ) 135 | parser.add_argument( 136 | "--learning_rate", 137 | type=float, 138 | default=5e-5, 139 | help="Initial learning rate (after the potential warmup period) to use.", 140 | ) 141 | parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.") 142 | parser.add_argument("--num_train_epochs", type=int, default=3, help="Total number of training epochs to perform.") 143 | parser.add_argument( 144 | "--max_train_steps", 145 | type=int, 146 | default=None, 147 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", 148 | ) 149 | parser.add_argument( 150 | "--gradient_accumulation_steps", 151 | type=int, 152 | default=1, 153 | help="Number of updates steps to accumulate before performing a backward/update pass.", 154 | ) 155 | parser.add_argument( 156 | "--gradient_checkpointing", 157 | action="store_true", 158 | help="If passed, use gradient checkpointing to save memory at the expense of slower backward pass.", 159 | ) 160 | parser.add_argument( 161 | "--lr_scheduler_type", 162 | type=SchedulerType, 163 | default="linear", 164 | help="The scheduler type to use.", 165 | choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"], 166 | ) 167 | parser.add_argument("--warmup_ratio", type=float, default=0, help="Ratio of total training steps used for warmup.") 168 | parser.add_argument("--output_dir", type=str, default=None, help="Where to store the final model.") 169 | parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") 170 | parser.add_argument( 171 | "--preprocessing_num_workers", 172 | type=int, 173 | default=None, 174 | help="The number of processes to use for the preprocessing.", 175 | ) 176 | parser.add_argument( 177 | "--cache_dir", 178 | type=str, 179 | default=None, 180 | help="Where to store the preprocessed datasets downloaded from the datasets library.", 181 | ) 182 | parser.add_argument( 183 | "--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets" 184 | ) 185 | parser.add_argument( 186 | "--checkpointing_steps", 187 | type=str, 188 | default=None, 189 | help="Whether the various states should be saved at the end of every n steps, or 'epoch' for each epoch.", 190 | ) 191 | parser.add_argument( 192 | "--logging_steps", 193 | type=int, 194 | default=None, 195 | help="Log the training loss and learning rate every logging_steps steps.", 196 | ) 197 | parser.add_argument( 198 | "--resume_from_checkpoint", 199 | type=str, 200 | default=None, 201 | help="If the training should continue from a checkpoint folder.", 202 | ) 203 | parser.add_argument( 204 | "--with_tracking", 205 | action="store_true", 206 | help="Whether to enable experiment trackers for logging.", 207 | ) 208 | parser.add_argument( 209 | "--report_to", 210 | type=str, 211 | default="all", 212 | help=( 213 | 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,' 214 | ' `"wandb"`, `"comet_ml"` and `"clearml"`. Use `"all"` (default) to report to all integrations.' 215 | "Only applicable when `--with_tracking` is passed." 216 | ), 217 | ) 218 | parser.add_argument( 219 | "--low_cpu_mem_usage", 220 | action="store_true", 221 | help=( 222 | "It is an option to create the model as an empty shell, then only materialize its parameters when the pretrained weights are loaded." 223 | "If passed, LLM loading time and RAM consumption will be benefited." 224 | ), 225 | ) 226 | args = parser.parse_args() 227 | 228 | # Sanity checks 229 | if args.dataset_name is None and args.train_file is None: 230 | raise ValueError("Need either a dataset name or a training file.") 231 | else: 232 | if args.train_file is not None: 233 | extension = args.train_file.split(".")[-1] 234 | assert extension in ["json", "jsonl"], "`train_file` should be a json/jsonl file." 235 | return args 236 | 237 | 238 | def build_dataset(args, tokenizer, accelerator): 239 | def encode_with_prompt_completion_format(example, max_seq_length): 240 | """ 241 | Here we assume each example has 'prompt' and 'completion' fields. 242 | We concatenate prompt and completion and tokenize them together because otherwise prompt will be padded/trancated 243 | and it doesn't make sense to follow directly with the completion. 244 | """ 245 | # if prompt doesn't end with space and completion doesn't start with space, add space 246 | if not example["prompt"].endswith((" ", "\n", "\t")) and not example["completion"].startswith( 247 | (" ", "\n", "\t") 248 | ): 249 | example_text = example["prompt"] + " " + example["completion"] 250 | else: 251 | example_text = example["prompt"] + example["completion"] 252 | example_text = example_text + tokenizer.eos_token 253 | tokenized_example = tokenizer(example_text, return_tensors="pt", max_length=max_seq_length, truncation=True) 254 | input_ids = tokenized_example.input_ids 255 | labels = input_ids.clone() 256 | tokenized_prompt = tokenizer(example["prompt"], return_tensors="pt", max_length=max_seq_length, truncation=True) 257 | # mask the prompt part for avoiding loss 258 | labels[:, : tokenized_prompt.input_ids.shape[1]] = -100 259 | attention_mask = torch.ones_like(input_ids) 260 | return { 261 | "input_ids": input_ids.flatten(), 262 | "labels": labels.flatten(), 263 | "attention_mask": attention_mask.flatten(), 264 | } 265 | 266 | def encode_with_messages_format(example, max_seq_length, chat_format): 267 | """ 268 | Here we assume each example has a 'messages' field Each message is a dict with 'role' and 'content' fields. 269 | We concatenate all messages with the roles as delimiters and tokenize them together. 270 | """ 271 | messages = example["messages"] 272 | if len(messages) == 0: 273 | raise ValueError("messages field is empty.") 274 | 275 | def _concat_messages_tulu(messages): 276 | message_text = "" 277 | for message in messages: 278 | if message["role"] == "system": 279 | message_text += "<|system|>\n" + message["content"].strip() + "\n" 280 | elif message["role"] == "user": 281 | message_text += "<|user|>\n" + message["content"].strip() + "\n" 282 | elif message["role"] == "assistant": 283 | message_text += "<|assistant|>\n" + message["content"].strip() + tokenizer.eos_token + "\n" 284 | else: 285 | raise ValueError("Invalid role: {}".format(message["role"])) 286 | return message_text 287 | 288 | def _concat_messages_chatml(messages): 289 | message_text = "" 290 | for message in messages: 291 | if message["role"] == "system": 292 | message_text += "<|im_start|>system\n" + message["content"].strip() + "<|im_end|>\n" 293 | elif message["role"] == "user": 294 | message_text += "<|im_start|>user\n" + message["content"].strip() + "<|im_end|>\n" 295 | elif message["role"] == "assistant": 296 | message_text += "<|im_start|>assistant\n" + message["content"].strip() + "<|im_end|>\n" 297 | else: 298 | raise ValueError("Invalid role: {}".format(message["role"])) 299 | return message_text 300 | 301 | if chat_format == "chatml": 302 | _concat_messages = _concat_messages_chatml 303 | assistant_start_token = "<|im_start|>assistant\n" 304 | elif chat_format == "tulu": 305 | _concat_messages = _concat_messages_tulu 306 | assistant_start_token = "<|assistant|>\n" 307 | else: 308 | raise ValueError(f"Invalid chat format: {chat_format}") 309 | example_text = _concat_messages(messages) 310 | tokenized_example = tokenizer(example_text, return_tensors="pt", max_length=max_seq_length, truncation=True) 311 | input_ids = tokenized_example.input_ids 312 | labels = input_ids.clone() 313 | 314 | # mask the non-assistant part for avoiding loss 315 | for message_idx, message in enumerate(messages): 316 | if message["role"] != "assistant": 317 | if message_idx == 0: 318 | message_start_idx = 0 319 | else: 320 | message_start_idx = tokenizer( 321 | _concat_messages(messages[:message_idx]), 322 | return_tensors="pt", 323 | max_length=max_seq_length, 324 | truncation=True, 325 | ).input_ids.shape[1] 326 | if message_idx < len(messages) - 1 and messages[message_idx + 1]["role"] == "assistant": 327 | # here we also ignore the role of the assistant 328 | messages_so_far = _concat_messages(messages[: message_idx + 1]) + assistant_start_token 329 | else: 330 | messages_so_far = _concat_messages(messages[: message_idx + 1]) 331 | message_end_idx = tokenizer( 332 | messages_so_far, return_tensors="pt", max_length=max_seq_length, truncation=True 333 | ).input_ids.shape[1] 334 | labels[:, message_start_idx:message_end_idx] = -100 335 | 336 | if message_end_idx >= max_seq_length: 337 | break 338 | 339 | attention_mask = torch.ones_like(input_ids) 340 | return { 341 | "input_ids": input_ids.flatten(), 342 | "labels": labels.flatten(), 343 | "attention_mask": attention_mask.flatten(), 344 | } 345 | 346 | if args.dataset_name is not None: 347 | # Downloading and loading a dataset from the hub. 348 | raw_datasets = load_dataset(args.dataset_name, args.dataset_config_name, cache_dir=args.cache_dir) 349 | else: 350 | data_files = {} 351 | dataset_args = {} 352 | if args.train_file is not None: 353 | data_files["train"] = args.train_file 354 | raw_datasets = load_dataset( 355 | "json", 356 | data_files=data_files, 357 | **dataset_args, 358 | ) 359 | # raw_datasets["train"] = raw_datasets["train"].select(range(100)) 360 | 361 | if "input_ids" in raw_datasets["train"].column_names: 362 | lm_datasets = raw_datasets.remove_columns( 363 | [ 364 | name 365 | for name in raw_datasets["train"].column_names 366 | if name not in ["input_ids", "labels", "attention_mask"] 367 | ] 368 | ) 369 | lm_datasets.set_format(type="pt") 370 | else: 371 | # Preprocessing the datasets. 372 | if "prompt" in raw_datasets["train"].column_names and "completion" in raw_datasets["train"].column_names: 373 | encode_function = partial( 374 | encode_with_prompt_completion_format, 375 | max_seq_length=args.max_seq_length, 376 | ) 377 | elif "messages" in raw_datasets["train"].column_names: 378 | encode_function = partial( 379 | encode_with_messages_format, 380 | max_seq_length=args.max_seq_length, 381 | chat_format=args.chat_format, 382 | ) 383 | else: 384 | raise ValueError("You need to have either 'prompt'&'completion' or 'messages' in your column names.") 385 | with accelerator.main_process_first(): 386 | lm_datasets = raw_datasets.map( 387 | encode_function, 388 | batched=False, 389 | num_proc=args.preprocessing_num_workers, 390 | load_from_cache_file=not args.overwrite_cache, 391 | remove_columns=[ 392 | name 393 | for name in raw_datasets["train"].column_names 394 | if name not in ["input_ids", "labels", "attention_mask"] 395 | ], 396 | desc="Tokenizing and reformatting instruction data", 397 | ) 398 | lm_datasets.set_format(type="pt") 399 | lm_datasets = lm_datasets.filter(lambda example: (example["labels"] != -100).any()) 400 | 401 | train_dataset = lm_datasets["train"] 402 | 403 | if args.data_packing: 404 | 405 | def data_packing(examples, max_seq_length): 406 | num_examples = len(examples["input_ids"]) 407 | packed_examples = {"input_ids": [], "labels": [], "attention_mask": []} 408 | eoc_id = tokenizer.encode(tokenizer.eoc_token, return_tensors="pt", add_special_tokens=False).squeeze(0) 409 | eoc_label = torch.tensor([-100]) 410 | eoc_mask = torch.tensor([1]) 411 | 412 | for i in range(num_examples): 413 | if ( 414 | len(packed_examples["input_ids"]) > 0 415 | and len(packed_examples["input_ids"][-1]) + len(examples["input_ids"][i]) <= max_seq_length 416 | ): 417 | packed_examples["input_ids"][-1] = torch.cat( 418 | [packed_examples["input_ids"][-1], examples["input_ids"][i], eoc_id] 419 | ) 420 | packed_examples["labels"][-1] = torch.cat( 421 | [packed_examples["labels"][-1], examples["labels"][i], eoc_label] 422 | ) 423 | packed_examples["attention_mask"][-1] = torch.cat( 424 | [packed_examples["attention_mask"][-1], examples["attention_mask"][i], eoc_mask] 425 | ) 426 | elif len(examples["input_ids"][i]) > max_seq_length: 427 | for j in range(0, len(examples["input_ids"][i]), max_seq_length): 428 | packed_examples["input_ids"].append(examples["input_ids"][i][j : j + max_seq_length]) 429 | packed_examples["labels"].append(examples["labels"][i][j : j + max_seq_length]) 430 | packed_examples["attention_mask"].append(examples["attention_mask"][i][j : j + max_seq_length]) 431 | else: 432 | packed_examples["input_ids"].append(torch.cat([examples["input_ids"][i], eoc_id])) 433 | packed_examples["labels"].append(torch.cat([examples["labels"][i], eoc_label])) 434 | packed_examples["attention_mask"].append(torch.cat([examples["attention_mask"][i], eoc_mask])) 435 | 436 | assert ( 437 | len(packed_examples["input_ids"][-1]) 438 | == len(packed_examples["labels"][-1]) 439 | == len(packed_examples["labels"][-1]) 440 | == len(packed_examples["attention_mask"][-1]) 441 | ) 442 | 443 | return packed_examples 444 | 445 | pack_function = partial(data_packing, max_seq_length=args.max_seq_length) 446 | with accelerator.main_process_first(): 447 | logger.info(f"Training dataset size before packing: {len(train_dataset)}") 448 | train_dataset = train_dataset.map( 449 | pack_function, 450 | batched=True, 451 | num_proc=args.preprocessing_num_workers, 452 | remove_columns=train_dataset.column_names, 453 | desc="Packing data", 454 | ) 455 | logger.info(f"Training dataset size after packing: {len(train_dataset)}") 456 | 457 | return train_dataset 458 | 459 | 460 | def build_tokenizer(args): 461 | if args.tokenizer_name: 462 | tokenizer = AutoTokenizer.from_pretrained( 463 | args.tokenizer_name, use_fast=not args.use_slow_tokenizer, legacy=True 464 | ) 465 | elif args.model_name_or_path: 466 | tokenizer = AutoTokenizer.from_pretrained( 467 | args.model_name_or_path, use_fast=not args.use_slow_tokenizer, legacy=True 468 | ) 469 | else: 470 | raise ValueError( 471 | "You are instantiating a new tokenizer from scratch. This is not supported by this script." 472 | "You can do it from another script, save it, and load it from here, using --tokenizer_name." 473 | ) 474 | print(f"Tokenizer type: {type(tokenizer)}") 475 | # no default pad token for llama! 476 | # here we add all special tokens again, because the default ones are not in the special_tokens_map 477 | if isinstance(tokenizer, (LlamaTokenizer, LlamaTokenizerFast)): 478 | num_added_tokens = tokenizer.add_special_tokens( 479 | { 480 | "bos_token": "", 481 | "eos_token": "", 482 | "unk_token": "", 483 | "pad_token": "", 484 | } 485 | ) 486 | assert num_added_tokens in [ 487 | 0, 488 | 1, 489 | ], "LlamaTokenizer should only add one special token - the pad_token, or no tokens if pad token present." 490 | if args.chat_format == "tulu": 491 | if args.data_packing: 492 | tokenizer.add_special_tokens( 493 | { 494 | "additional_special_tokens": ["<|end_of_conv|>"], 495 | } 496 | ) 497 | tokenizer.eoc_token = "<|end_of_conv|>" 498 | elif args.chat_format == "chatml": 499 | tokenizer.add_special_tokens( 500 | { 501 | "additional_special_tokens": [ 502 | "<|im_start|>", 503 | "<|im_end|>", 504 | "<|im_sep|>", 505 | "<|diff_marker|>", 506 | ], 507 | } 508 | ) 509 | tokenizer.eoc_token = tokenizer.eos_token 510 | else: 511 | raise ValueError(f"Unknown chat format {args.chat_format}") 512 | elif isinstance(tokenizer, GPTNeoXTokenizerFast): 513 | num_added_tokens = tokenizer.add_special_tokens( 514 | { 515 | "pad_token": "", 516 | } 517 | ) 518 | assert num_added_tokens == 1, "GPTNeoXTokenizer should only add one special token - the pad_token." 519 | elif isinstance(tokenizer, GPT2Tokenizer): 520 | num_added_tokens = tokenizer.add_special_tokens({"unk_token": ""}) 521 | 522 | return tokenizer 523 | 524 | 525 | def save_with_accelerate(accelerator, model, tokenizer, output_dir, args): 526 | unwrapped_model = accelerator.unwrap_model(model) 527 | # When doing multi-gpu training, we need to use accelerator.get_state_dict(model) to get the state_dict. 528 | # Otherwise, sometimes the model will be saved with only part of the parameters. 529 | # Also, accelerator needs to use the wrapped model to get the state_dict. 530 | state_dict = accelerator.get_state_dict(model) 531 | if args.use_lora: 532 | # When using lora, the unwrapped model is a PeftModel, which doesn't support the is_main_process 533 | # and has its own save_pretrained function for only saving lora modules. 534 | # We have to manually specify the is_main_process outside the save_pretrained function. 535 | if accelerator.is_main_process: 536 | unwrapped_model.save_pretrained(output_dir, state_dict=state_dict) 537 | else: 538 | unwrapped_model.save_pretrained( 539 | output_dir, 540 | is_main_process=accelerator.is_main_process, 541 | save_function=accelerator.save, 542 | state_dict=state_dict, 543 | ) 544 | 545 | 546 | def main(): 547 | args = parse_args() 548 | 549 | # A hacky way to make llama work with flash attention 550 | if args.use_flash_attn: 551 | from llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn 552 | 553 | replace_llama_attn_with_flash_attn() 554 | 555 | # Initialize the accelerator. We will let the accelerator handle device placement for us in this example. 556 | # If we're using tracking, we also need to initialize it here and it will by default pick up all supported trackers 557 | # in the environment 558 | accelerator_log_kwargs = {} 559 | 560 | if args.with_tracking: 561 | accelerator_log_kwargs["log_with"] = args.report_to 562 | accelerator_log_kwargs["project_dir"] = args.output_dir 563 | 564 | accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps, **accelerator_log_kwargs) 565 | 566 | # Make one log on every process with the configuration for debugging. 567 | logging.basicConfig( 568 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 569 | datefmt="%m/%d/%Y %H:%M:%S", 570 | level=logging.INFO, 571 | ) 572 | logger.info(accelerator.state, main_process_only=False) 573 | if accelerator.is_local_main_process: 574 | datasets.utils.logging.set_verbosity_warning() 575 | transformers.utils.logging.set_verbosity_info() 576 | else: 577 | datasets.utils.logging.set_verbosity_error() 578 | transformers.utils.logging.set_verbosity_error() 579 | 580 | # If passed along, set the training seed now. 581 | if args.seed is not None: 582 | set_seed(args.seed) 583 | 584 | if accelerator.is_main_process and args.output_dir is not None: 585 | os.makedirs(args.output_dir, exist_ok=True) 586 | 587 | accelerator.wait_for_everyone() 588 | 589 | # Load pretrained model and tokenizer 590 | if args.config_name: 591 | config = AutoConfig.from_pretrained(args.config_name) 592 | elif args.model_name_or_path: 593 | config = AutoConfig.from_pretrained(args.model_name_or_path) 594 | else: 595 | raise ValueError( 596 | "You are instantiating a new config instance from scratch. This is not supported by this script." 597 | ) 598 | 599 | tokenizer = build_tokenizer(args) 600 | 601 | if args.model_name_or_path: 602 | model = AutoModelForCausalLM.from_pretrained( 603 | args.model_name_or_path, 604 | cache_dir=args.cache_dir, 605 | from_tf=bool(".ckpt" in args.model_name_or_path), 606 | config=config, 607 | low_cpu_mem_usage=args.low_cpu_mem_usage, 608 | ) 609 | else: 610 | logger.info("Training new model from scratch") 611 | model = AutoModelForCausalLM.from_config(config) 612 | 613 | # We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch 614 | # on a small vocab and want a smaller embedding size, remove this test. 615 | embedding_size = model.get_input_embeddings().weight.shape[0] 616 | if len(tokenizer) > embedding_size: 617 | model.resize_token_embeddings(len(tokenizer)) 618 | 619 | if args.gradient_checkpointing: 620 | model.gradient_checkpointing_enable() 621 | 622 | if args.use_lora: 623 | logger.info("Initializing LORA model...") 624 | peft_config = LoraConfig( 625 | task_type=TaskType.CAUSAL_LM, 626 | inference_mode=False, 627 | r=args.lora_rank, 628 | lora_alpha=args.lora_alpha, 629 | lora_dropout=args.lora_dropout, 630 | ) 631 | model = get_peft_model(model, peft_config) 632 | model.print_trainable_parameters() 633 | 634 | train_dataset = build_dataset(args, tokenizer, accelerator) 635 | 636 | # Log a few random samples from the training set: 637 | for index in random.sample(range(len(train_dataset)), 3): 638 | logger.info(f"Sample {index} of the training set: {train_dataset[index]}.") 639 | 640 | # DataLoaders creation: 641 | train_dataloader = DataLoader( 642 | train_dataset, 643 | shuffle=True, 644 | collate_fn=DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model, padding="longest"), 645 | batch_size=args.per_device_train_batch_size, 646 | ) 647 | 648 | # Optimizer 649 | # Split weights in two groups, one with weight decay and the other not. 650 | no_decay = ["bias", "layer_norm.weight"] 651 | optimizer_grouped_parameters = [ 652 | { 653 | "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 654 | "weight_decay": args.weight_decay, 655 | }, 656 | { 657 | "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 658 | "weight_decay": 0.0, 659 | }, 660 | ] 661 | optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate) 662 | 663 | # Scheduler and math around the number of training steps. 664 | overrode_max_train_steps = False 665 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) 666 | if args.max_train_steps is None: 667 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 668 | overrode_max_train_steps = True 669 | 670 | # Create the learning rate scheduler. 671 | # Note: the current accelerator.step() calls the .step() of the real scheduler for the `num_processes` times. This is because they assume 672 | # the user initialize the scheduler with the entire training set. In the case of data parallel training, each process only 673 | # sees a subset (1/num_processes) of the training set. So each time the process needs to update the lr multiple times so that the total 674 | # number of updates in the end matches the num_training_steps here. 675 | # Here we need to set the num_training_steps to either using the entire training set (when epochs is specified) or we need to multiply the 676 | # num_training_steps by num_processes so that the total number of updates matches the num_training_steps. 677 | num_training_steps_for_scheduler = ( 678 | args.max_train_steps if overrode_max_train_steps else args.max_train_steps * accelerator.num_processes 679 | ) 680 | lr_scheduler = get_scheduler( 681 | name=args.lr_scheduler_type, 682 | optimizer=optimizer, 683 | num_training_steps=num_training_steps_for_scheduler, 684 | num_warmup_steps=int(num_training_steps_for_scheduler * args.warmup_ratio), 685 | ) 686 | 687 | # Prepare everything with `accelerator`. 688 | model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( 689 | model, optimizer, train_dataloader, lr_scheduler 690 | ) 691 | 692 | # We need to recalculate our total training steps as the size of the training dataloader may have changed. 693 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) 694 | if overrode_max_train_steps: 695 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 696 | # Afterwards we recalculate our number of training epochs 697 | args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) 698 | 699 | # Figure out how many steps we should save the Accelerator states 700 | checkpointing_steps = args.checkpointing_steps 701 | if checkpointing_steps is not None and checkpointing_steps.isdigit(): 702 | checkpointing_steps = int(checkpointing_steps) 703 | 704 | # We need to initialize the trackers we use, and also store our configuration. 705 | # The trackers initializes automatically on the main process. 706 | if args.with_tracking: 707 | experiment_config = vars(args) 708 | # TensorBoard cannot log Enums, need the raw value 709 | experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value 710 | accelerator.init_trackers("sft-chat", experiment_config) 711 | 712 | # Train! 713 | total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps 714 | 715 | logger.info("***** Running training *****") 716 | logger.info(f" Num examples = {len(train_dataset)}") 717 | logger.info(f" Num Epochs = {args.num_train_epochs}") 718 | logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}") 719 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") 720 | logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") 721 | logger.info(f" Total optimization steps = {args.max_train_steps}") 722 | # Only show the progress bar once on each machine. 723 | progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) 724 | completed_steps = 0 725 | starting_epoch = 0 726 | 727 | # Potentially load in the weights and states from a previous save 728 | if args.resume_from_checkpoint: 729 | if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "": 730 | checkpoint_path = args.resume_from_checkpoint 731 | path = os.path.basename(args.resume_from_checkpoint) 732 | else: 733 | # Get the most recent checkpoint 734 | dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()] 735 | dirs.sort(key=os.path.getctime) 736 | path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last 737 | checkpoint_path = path 738 | path = os.path.basename(checkpoint_path) 739 | 740 | accelerator.print(f"Resumed from checkpoint: {checkpoint_path}") 741 | accelerator.load_state(checkpoint_path) 742 | # Extract `epoch_{i}` or `step_{i}` 743 | training_difference = os.path.splitext(path)[0] 744 | 745 | if "epoch" in training_difference: 746 | starting_epoch = int(training_difference.replace("epoch_", "")) + 1 747 | resume_step = None 748 | completed_steps = starting_epoch * num_update_steps_per_epoch 749 | else: 750 | # need to multiply `gradient_accumulation_steps` to reflect real steps 751 | resume_step = int(training_difference.replace("step_", "")) * args.gradient_accumulation_steps 752 | starting_epoch = resume_step // len(train_dataloader) 753 | completed_steps = resume_step // args.gradient_accumulation_steps 754 | resume_step -= starting_epoch * len(train_dataloader) 755 | 756 | # update the progress_bar if load from checkpoint 757 | progress_bar.update(completed_steps) 758 | 759 | for epoch in range(starting_epoch, args.num_train_epochs): 760 | model.train() 761 | total_loss = 0 762 | if args.resume_from_checkpoint and epoch == starting_epoch and resume_step is not None: 763 | # We skip the first `n` batches in the dataloader when resuming from a checkpoint 764 | active_dataloader = accelerator.skip_first_batches(train_dataloader, resume_step) 765 | else: 766 | active_dataloader = train_dataloader 767 | for _step, batch in enumerate(active_dataloader): 768 | with accelerator.accumulate(model): 769 | outputs = model(**batch, use_cache=False) 770 | loss = outputs.loss 771 | # We keep track of the loss at each logged step 772 | total_loss += loss.detach().float() 773 | accelerator.backward(loss) 774 | optimizer.step() 775 | optimizer.zero_grad() 776 | lr_scheduler.step() 777 | 778 | # Checks if the accelerator has performed an optimization step behind the scenes 779 | if accelerator.sync_gradients: 780 | progress_bar.update(1) 781 | completed_steps += 1 782 | if args.logging_steps and completed_steps % args.logging_steps == 0: 783 | avg_loss = ( 784 | accelerator.gather(total_loss).mean().item() 785 | / args.gradient_accumulation_steps 786 | / args.logging_steps 787 | ) 788 | logger.info(f" Step: {completed_steps}, LR: {lr_scheduler.get_last_lr()[0]}, Loss: {avg_loss}") 789 | if args.with_tracking: 790 | accelerator.log( 791 | { 792 | "learning_rate": lr_scheduler.get_last_lr()[0], 793 | "train_loss": avg_loss, 794 | }, 795 | step=completed_steps, 796 | ) 797 | total_loss = 0 798 | 799 | if isinstance(checkpointing_steps, int) and completed_steps % checkpointing_steps == 0: 800 | output_dir = f"step_{completed_steps}" 801 | if args.output_dir is not None: 802 | output_dir = os.path.join(args.output_dir, output_dir) 803 | accelerator.save_state(output_dir) 804 | if completed_steps >= args.max_train_steps: 805 | break 806 | 807 | if args.checkpointing_steps == "epoch": 808 | output_dir = f"epoch_{epoch}" 809 | if args.output_dir is not None: 810 | output_dir = os.path.join(args.output_dir, output_dir) 811 | accelerator.save_state(output_dir) 812 | 813 | if args.with_tracking: 814 | accelerator.end_training() 815 | 816 | if args.output_dir is not None: 817 | accelerator.wait_for_everyone() 818 | if accelerator.is_main_process: 819 | tokenizer.save_pretrained(args.output_dir) 820 | save_with_accelerate(accelerator, model, tokenizer, args.output_dir, args) 821 | 822 | 823 | if __name__ == "__main__": 824 | main() 825 | --------------------------------------------------------------------------------