├── xchat
├── eval
│ ├── __init__.py
│ ├── mbpp
│ │ ├── data.py
│ │ ├── examplars.py
│ │ ├── prompts.py
│ │ ├── evaluation.py
│ │ ├── execution.py
│ │ └── run_eval.py
│ ├── humaneval
│ │ ├── prompts.py
│ │ └── run_eval.py
│ ├── mmlu
│ │ ├── categories.py
│ │ └── run_eval.py
│ ├── gsm
│ │ ├── examplars.py
│ │ └── run_eval.py
│ ├── dispatch_openai_requests.py
│ ├── bbh
│ │ └── run_eval.py
│ └── utils.py
└── train
│ ├── llama_flash_attn_monkey_patch.py
│ └── train.py
├── assets
├── xlang.png
├── interface.png
├── training.png
├── salesforce.webp
├── transparent.png
├── agent-scenarios.png
└── overall-perform.png
├── scripts
├── deploy
│ ├── vllm_lemur.sh
│ └── tgi_lemur.sh
└── eval
│ ├── mbpp.sh
│ ├── humaneval.sh
│ ├── bbh.sh
│ ├── mmlu.sh
│ └── gsm.sh
├── .pre-commit-config.yaml
├── pyproject.toml
├── .gitignore
├── LICENSE
└── README.md
/xchat/eval/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/assets/xlang.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenLemur/Lemur/HEAD/assets/xlang.png
--------------------------------------------------------------------------------
/assets/interface.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenLemur/Lemur/HEAD/assets/interface.png
--------------------------------------------------------------------------------
/assets/training.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenLemur/Lemur/HEAD/assets/training.png
--------------------------------------------------------------------------------
/assets/salesforce.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenLemur/Lemur/HEAD/assets/salesforce.webp
--------------------------------------------------------------------------------
/assets/transparent.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenLemur/Lemur/HEAD/assets/transparent.png
--------------------------------------------------------------------------------
/assets/agent-scenarios.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenLemur/Lemur/HEAD/assets/agent-scenarios.png
--------------------------------------------------------------------------------
/assets/overall-perform.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenLemur/Lemur/HEAD/assets/overall-perform.png
--------------------------------------------------------------------------------
/scripts/deploy/vllm_lemur.sh:
--------------------------------------------------------------------------------
1 | N_GPUS=4
2 | gpus='"device=0,1,2,3"'
3 | MODEL_PATH=OpenLemur/lemur-70b-v1
4 | MODEL_NAME=lemur-70b-v1
5 | # HF_HOME=""
6 |
7 | docker run --gpus $gpus --rm \
8 | -v $HF_HOME:/root/.cache/huggingface \
9 | --shm-size=10.24gb \
10 | --name vllm-$MODEL_NAME \
11 | ranpox/fastchat:lemur \
12 | python -m vllm.entrypoints.openai.api_server \
13 | --model $MODEL_PATH \
14 | --tensor-parallel-size $N_GPUS \
15 | --served-model-name $MODEL_NAME \
16 | --max-num-batched-tokens 4096 \
17 | --load-format pt \
18 | --port 8000
19 |
--------------------------------------------------------------------------------
/scripts/eval/mbpp.sh:
--------------------------------------------------------------------------------
1 | export CUDA_VISIBLE_DEVICES=0,1,2,3
2 | DATA_DIR=data/eval/mbpp
3 | OUTPUT_DIR=results/mbpp/Llama-2-70b-chat-hf
4 | MODEL_PATH=meta-llama/Llama-2-70b-chat-hf
5 | python -m xchat.eval.mbpp.run_eval \
6 | --data_dir $DATA_DIR \
7 | --max_new_token 650 \
8 | --max_num_examples 500 \
9 | --save_dir $OUTPUT_DIR \
10 | --model $MODEL_PATH \
11 | --tokenizer $MODEL_PATH \
12 | --eval_batch_size 8 \
13 | --chat_format llama2chat \
14 | --load_in_8bit \
15 | --greedy_decoding \
16 | --few_shot \
17 | --eval_pass_at_ks 1 \
18 | --unbiased_sampling_size_n 1
19 |
--------------------------------------------------------------------------------
/scripts/eval/humaneval.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=0
2 | DATA_DIR=data/eval/humaneval
3 |
4 | if [ ! -d "$DATA_DIR" ]; then
5 | echo "Downloading HumanEval data..."
6 | mkdir -p $DATA_DIR
7 | wget -P $DATA_DIR https://github.com/openai/human-eval/raw/master/data/HumanEval.jsonl.gz
8 | fi
9 |
10 | PYTHONPATH=$PWD
11 | OUTPUT_DIR=results/humaneval/Llama-2-70b-chat-hf
12 | MODEL=meta-llama/Llama-2-70b-chat-hf
13 |
14 | python -m xchat.eval.humaneval.run_eval \
15 | --data_file $DATA_DIR/HumanEval.jsonl.gz \
16 | --max_new_token 512 \
17 | --eval_pass_at_ks 1 \
18 | --unbiased_sampling_size_n 1 \
19 | --temperature 0.2 \
20 | --save_dir $OUTPUT_DIR \
21 | --model $MODEL \
22 | --tokenizer $MODEL \
23 | --eval_batch_size 24 \
24 | --load_in_8bit \
25 | --chat_format llama2chat
26 |
--------------------------------------------------------------------------------
/scripts/deploy/tgi_lemur.sh:
--------------------------------------------------------------------------------
1 | model=OpenLemur/lemur-70b-v1
2 | name="lemur-70b-v1"
3 | # HF_HOME=""
4 | volume=$HF_HOME/hub
5 | gpus='"device=0,1,2,3"'
6 | num_shard=4
7 | dtype="bfloat16"
8 | quantize="bitsandbytes"
9 | max_input_length=4095
10 | max_total_tokens=4096
11 | max_batch_prefill_tokens=16380
12 | max_batch_total_tokens=16384
13 |
14 | docker run --gpus $gpus \
15 | --shm-size 1g --rm \
16 | -p 8080:80 -v $volume:/data \
17 | --name tgi-$name \
18 | ghcr.io/huggingface/text-generation-inference:1.0.3 \
19 | --model-id $model \
20 | --sharded false \
21 | --num-shard $num_shard \
22 | --dtype $dtype \
23 | --max-input-length $max_input_length \
24 | --max-total-tokens $max_total_tokens \
25 | --max-batch-prefill-tokens $max_batch_prefill_tokens \
26 | --max-batch-total-tokens $max_batch_total_tokens \
27 | --max-stop-sequences 10
28 |
--------------------------------------------------------------------------------
/scripts/eval/bbh.sh:
--------------------------------------------------------------------------------
1 | export CUDA_VISIBLE_DEVICES=0,1
2 |
3 | DATA_DIR=data/eval/bbh
4 |
5 | if [ ! -d "data" ]; then
6 | mkdir data
7 | fi
8 |
9 | if [ ! -d "data/eval" ]; then
10 | mkdir data/eval
11 | fi
12 |
13 | if [ ! -d $DATA_DIR ]; then
14 | echo "Downloading BBH data..."
15 | mkdir -p data/downloads
16 | wget -O data/downloads/bbh_data.zip https://github.com/suzgunmirac/BIG-Bench-Hard/archive/refs/heads/main.zip
17 | unzip data/downloads/bbh_data.zip -d data/downloads/bbh
18 | mv data/downloads/bbh/BIG-Bench-Hard-main data/eval/bbh && rm -rf data/downloads/
19 | fi
20 |
21 | MODEL_DIR=codellama/CodeLlama-7b-Instruct-hf
22 | OUTPUT_DIR=results/bbh/llama-2-7b-hf
23 |
24 | python -m xchat.eval.bbh.run_eval \
25 | --data_dir data/eval/bbh/ \
26 | --save_dir $OUTPUT_DIR \
27 | --model $MODEL_DIR \
28 | --tokenizer $MODEL_DIR \
29 | --eval_batch_size 20 \
30 | --load_in_8bit \
31 | --no_cot \
32 | --chat_format codellama-instruct
33 |
--------------------------------------------------------------------------------
/scripts/eval/mmlu.sh:
--------------------------------------------------------------------------------
1 | export CUDA_VISIBLE_DEVICES=4,5,6,7
2 |
3 |
4 | DATA_DIR=data/eval/mmlu
5 |
6 | if [ ! -d $DATA_DIR ]; then
7 | echo "Downloading MMLU data..."
8 | wget -O data/mmlu_data.tar https://people.eecs.berkeley.edu/~hendrycks/data.tar
9 | mkdir -p data/eval/mmlu_data
10 | tar -xvf data/mmlu_data.tar -C data/eval/mmlu_data
11 | mv data/eval/mmlu_data/data $DATA_DIR && rm -r data/eval/mmlu_data data/mmlu_data.tar
12 | fi
13 |
14 | MODEL_DIR=meta-llama/Llama-2-7b-hf
15 | OUTPUT_DIR=results/mmlu/llama-2-7b-hf
16 | python -m xchat.eval.mmlu.run_eval \
17 | --ntrain 5 \
18 | --data_dir $DATA_DIR \
19 | --save_dir $OUTPUT_DIR \
20 | --model_name_or_path $MODEL_DIR \
21 | --tokenizer_name_or_path $MODEL_DIR \
22 | --eval_batch_size 4
23 |
24 | # MODEL_DIR=bigcode/starcoder
25 | # OUTPUT_DIR=results/mmlu/starcoder
26 | # python -m xchat.eval.mmlu.run_eval \
27 | # --ntrain 5 \
28 | # --data_dir $DATA_DIR \
29 | # --save_dir $OUTPUT_DIR \
30 | # --model_name_or_path $MODEL_DIR \
31 | # --tokenizer_name_or_path $MODEL_DIR \
32 | # --eval_batch_size 16
33 |
--------------------------------------------------------------------------------
/xchat/eval/mbpp/data.py:
--------------------------------------------------------------------------------
1 | import gzip
2 | import json
3 | import os
4 | from typing import Dict, Iterable
5 |
6 |
7 | def read_problems(evalset_file: str) -> Dict[str, Dict]:
8 | return {task["task_id"]: task for task in stream_jsonl(evalset_file)}
9 |
10 |
11 | def stream_jsonl(filename: str) -> Iterable[Dict]:
12 | """
13 | Parses each jsonl line and yields it as a dictionary.
14 | """
15 | if filename.endswith(".gz"):
16 | with open(filename, "rb") as gzfp, gzip.open(gzfp, "rt") as fp:
17 | for line in fp:
18 | if any(not x.isspace() for x in line):
19 | yield json.loads(line)
20 | else:
21 | with open(filename) as fp:
22 | for line in fp:
23 | if any(not x.isspace() for x in line):
24 | yield json.loads(line)
25 |
26 |
27 | def write_jsonl(filename: str, data: Iterable[Dict], append: bool = False):
28 | """
29 | Writes an iterable of dictionaries to jsonl.
30 | """
31 | mode = "ab" if append else "wb"
32 | filename = os.path.expanduser(filename)
33 | if filename.endswith(".gz"):
34 | with open(filename, mode) as fp, gzip.GzipFile(fileobj=fp, mode="wb") as gzfp:
35 | for x in data:
36 | gzfp.write((json.dumps(x) + "\n").encode("utf-8"))
37 | else:
38 | with open(filename, mode) as fp:
39 | for x in data:
40 | fp.write((json.dumps(x) + "\n").encode("utf-8"))
41 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | ci:
2 | autofix_prs: false
3 | exclude: ^(poetry.lock|.idea/)
4 | repos:
5 | - repo: https://github.com/charliermarsh/ruff-pre-commit
6 | rev: "v0.0.270"
7 | hooks:
8 | - id: ruff
9 | args: [ --fix, --exit-non-zero-on-fix ]
10 |
11 | - repo: https://github.com/psf/black
12 | rev: 23.3.0
13 | hooks:
14 | - id: black
15 | args:
16 | - "--line-length=120"
17 | - "--target-version=py39"
18 | - "--target-version=py310"
19 | - "--target-version=py311"
20 | types: [ python ]
21 |
22 | - repo: https://github.com/pre-commit/pre-commit-hooks
23 | rev: v4.4.0
24 | hooks:
25 | - id: check-added-large-files
26 | - id: check-ast
27 | - id: check-builtin-literals
28 | - id: check-case-conflict
29 | - id: check-docstring-first
30 | - id: check-shebang-scripts-are-executable
31 | - id: check-merge-conflict
32 | - id: check-json
33 | - id: check-toml
34 | - id: check-xml
35 | - id: check-yaml
36 | - id: debug-statements
37 | - id: destroyed-symlinks
38 | - id: detect-private-key
39 | - id: end-of-file-fixer
40 | exclude: ^LICENSE|\.(html|csv|txt|svg|py)$
41 | - id: pretty-format-json
42 | args: ["--autofix", "--no-ensure-ascii", "--no-sort-keys"]
43 | - id: requirements-txt-fixer
44 | - id: trailing-whitespace
45 | args: [--markdown-linebreak-ext=md]
46 | exclude: \.(html|svg)$
47 |
--------------------------------------------------------------------------------
/xchat/eval/mbpp/examplars.py:
--------------------------------------------------------------------------------
1 | # from https://github.com/google-research/google-research/blob/master/mbpp/README.md?plain=1#L21
2 |
3 | EXAMPLARS = [
4 | {
5 | "task_id": 2,
6 | "text": "Write a function to find the similar elements from the given two tuple lists.",
7 | "code": "def similar_elements(test_tup1, test_tup2):\r\n res = tuple(set(test_tup1) & set(test_tup2))\r\n return (res) ",
8 | "test_list": [
9 | "assert similar_elements((3, 4, 5, 6),(5, 7, 4, 10)) == (4, 5)",
10 | "assert similar_elements((1, 2, 3, 4),(5, 4, 3, 7)) == (3, 4)",
11 | "assert similar_elements((11, 12, 14, 13),(17, 15, 14, 13)) == (13, 14)",
12 | ],
13 | "test_setup_code": "",
14 | "challenge_test_list": [],
15 | },
16 | {
17 | "task_id": 3,
18 | "text": "Write a python function to identify non-prime numbers.",
19 | "code": "import math\r\ndef is_not_prime(n):\r\n result = False\r\n for i in range(2,int(math.sqrt(n)) + 1):\r\n if n % i == 0:\r\n result = True\r\n return result",
20 | "test_list": [
21 | "assert is_not_prime(2) == False",
22 | "assert is_not_prime(10) == True",
23 | "assert is_not_prime(35) == True",
24 | ],
25 | "test_setup_code": "",
26 | "challenge_test_list": [],
27 | },
28 | {
29 | "task_id": 4,
30 | "text": "Write a function to find the largest integers from a given list of numbers using heap queue algorithm.",
31 | "code": "import heapq as hq\r\ndef heap_queue_largest(nums,n):\r\n largest_nums = hq.nlargest(n, nums)\r\n return largest_nums",
32 | "test_list": [
33 | "assert heap_queue_largest( [25, 35, 22, 85, 14, 65, 75, 22, 58],3)==[85, 75, 65] ",
34 | "assert heap_queue_largest( [25, 35, 22, 85, 14, 65, 75, 22, 58],2)==[85, 75] ",
35 | "assert heap_queue_largest( [25, 35, 22, 85, 14, 65, 75, 22, 58],5)==[85, 75, 65, 58, 35]",
36 | ],
37 | "test_setup_code": "",
38 | "challenge_test_list": [],
39 | },
40 | ]
41 |
--------------------------------------------------------------------------------
/scripts/eval/gsm.sh:
--------------------------------------------------------------------------------
1 | export CUDA_VISIBLE_DEVICES=0,1,2,3 # ,4,5,6,7
2 |
3 | DATA_DIR=data/eval/gsm
4 |
5 | if [ ! -d $DATA_DIR ]; then
6 | echo "Downloading GSM data..."
7 | mkdir -p $DATA_DIR
8 | wget -P $DATA_DIR https://github.com/openai/grade-school-math/raw/master/grade_school_math/data/test.jsonl
9 | fi
10 |
11 | # MODEL_DIR=meta-llama/Llama-2-7b-hf
12 | # OUTPUT_DIR=results/gsm/llama-2-7b-hf
13 | # python -m xchat.eval.gsm.run_eval \
14 | # --data_dir $DATA_DIR \
15 | # --max_num_examples 32 \
16 | # --save_dir $OUTPUT_DIR \
17 | # --model $MODEL_DIR \
18 | # --tokenizer $MODEL_DIR \
19 | # --eval_batch_size 16 \
20 | # --n_shot 8
21 |
22 | MODEL_DIR=codellama/CodeLlama-7b-hf
23 | OUTPUT_DIR=results/gsm/CodeLlama-7b-hf
24 | python -m xchat.eval.gsm.run_eval \
25 | --data_dir $DATA_DIR \
26 | --max_num_examples 32 \
27 | --save_dir $OUTPUT_DIR \
28 | --model $MODEL_DIR \
29 | --tokenizer $MODEL_DIR \
30 | --eval_batch_size 16 \
31 | --n_shot 8
32 |
33 | # MODEL_DIR=OpenLemur/lemur-70b-v1
34 | # OUTPUT_DIR=results/gsm/lemur-70b-v1
35 | # python -m xchat.eval.gsm.run_eval \
36 | # --data_dir $DATA_DIR \
37 | # --max_num_examples 32 \
38 | # --save_dir $OUTPUT_DIR \
39 | # --model $MODEL_DIR \
40 | # --tokenizer $MODEL_DIR \
41 | # --eval_batch_size 16 \
42 | # --n_shot 8 \
43 | # --load_in_8bit
44 |
45 | # MODEL_DIR=lmsys/vicuna-13b-v1.5
46 | # OUTPUT_DIR=results/gsm/vicuna-13b-v1.5
47 | # python -m xchat.eval.gsm.run_eval \
48 | # --data_dir $DATA_DIR \
49 | # --max_num_examples 32 \
50 | # --save_dir $OUTPUT_DIR \
51 | # --model $MODEL_DIR \
52 | # --tokenizer $MODEL_DIR \
53 | # --eval_batch_size 16 \
54 | # --n_shot 8 \
55 | # --chat_format vicuna \
56 | # --load_in_8bit
57 |
58 | # MODEL_DIR=codellama/CodeLlama-34b-Instruct-hf
59 | # OUTPUT_DIR=results/gsm/codellama-34b-instruct-hf
60 | # python -m xchat.eval.gsm.run_eval \
61 | # --data_dir $DATA_DIR \
62 | # --save_dir $OUTPUT_DIR \
63 | # --model $MODEL_DIR \
64 | # --tokenizer $MODEL_DIR \
65 | # --eval_batch_size 48 \
66 | # --n_shot 8 \
67 | # --chat_format codellama-instruct
68 |
--------------------------------------------------------------------------------
/xchat/eval/humaneval/prompts.py:
--------------------------------------------------------------------------------
1 | class HumanEvalPromptBase:
2 | def __init__(self):
3 | pass
4 |
5 | def getPrompt(self, prompt):
6 | return prompt
7 |
8 |
9 | class ChatMLPrompt(HumanEvalPromptBase):
10 | def __init__(self):
11 | super().__init__()
12 |
13 | def getPrompt(self, prompt):
14 | return (
15 | "<|im_start|>user\n"
16 | + "Complete the following python function.\n\n\n"
17 | + prompt
18 | + "<|im_end|>\n<|im_start|>assistant\n"
19 | + "Here is the completed function:\n\n\n"
20 | + prompt
21 | )
22 |
23 |
24 | class LLaMA2ChatPrompt(HumanEvalPromptBase):
25 | def __init__(self, prompt):
26 | super().__init__(prompt)
27 |
28 | def getPrompt(self, prompt):
29 | return (
30 | "[INST] <>\n"
31 | + "Below is an instruction that describes a task. Write a response that appropriately completes the request."
32 | + "\n<>\n\n"
33 | + "Complete the following python function.\n\n\n"
34 | + prompt
35 | + "[/INST]"
36 | + "Here is the completed function:\n\n\n"
37 | + prompt
38 | )
39 |
40 |
41 | class VicunaPrompt(HumanEvalPromptBase):
42 | def __init__(self, prompt):
43 | super().__init__(prompt)
44 |
45 | def getPrompt(self, prompt):
46 | system_message = """Below is an instruction that describes a task. Write a response that appropriately completes the request."""
47 | return (
48 | "USER: "
49 | + system_message
50 | + "Complete the following python function.\n\n\n"
51 | + prompt
52 | + "\nASSISTANT: "
53 | + "Here is the completed function:\n\n\n"
54 | + prompt
55 | )
56 |
57 |
58 | class TuluPrompt(HumanEvalPromptBase):
59 | def __init__(self, prompt):
60 | super().__init__(prompt)
61 |
62 | def getPrompt(self, prompt):
63 | return (
64 | "<|user|>\n"
65 | + "Complete the following python function.\n\n\n"
66 | + prompt
67 | + "\n<|assistant|>\n"
68 | + "Here is the completed function:\n\n\n"
69 | + prompt
70 | )
71 |
72 |
73 | def getprompts(chat_format):
74 | if chat_format in ["base"]:
75 | return HumanEvalPromptBase()
76 | if chat_format in ["chatml"]:
77 | return ChatMLPrompt()
78 | elif chat_format in ["tulu"]:
79 | return TuluPrompt()
80 | elif chat_format in ["llama2chat"]:
81 | return LLaMA2ChatPrompt()
82 | elif chat_format in ["vicuna"]:
83 | return VicunaPrompt()
84 | raise NotImplementedError
85 |
--------------------------------------------------------------------------------
/xchat/eval/mmlu/categories.py:
--------------------------------------------------------------------------------
1 | subcategories = {
2 | "abstract_algebra": ["math"],
3 | "anatomy": ["health"],
4 | "astronomy": ["physics"],
5 | "business_ethics": ["business"],
6 | "clinical_knowledge": ["health"],
7 | "college_biology": ["biology"],
8 | "college_chemistry": ["chemistry"],
9 | "college_computer_science": ["computer science"],
10 | "college_mathematics": ["math"],
11 | "college_medicine": ["health"],
12 | "college_physics": ["physics"],
13 | "computer_security": ["computer science"],
14 | "conceptual_physics": ["physics"],
15 | "econometrics": ["economics"],
16 | "electrical_engineering": ["engineering"],
17 | "elementary_mathematics": ["math"],
18 | "formal_logic": ["philosophy"],
19 | "global_facts": ["other"],
20 | "high_school_biology": ["biology"],
21 | "high_school_chemistry": ["chemistry"],
22 | "high_school_computer_science": ["computer science"],
23 | "high_school_european_history": ["history"],
24 | "high_school_geography": ["geography"],
25 | "high_school_government_and_politics": ["politics"],
26 | "high_school_macroeconomics": ["economics"],
27 | "high_school_mathematics": ["math"],
28 | "high_school_microeconomics": ["economics"],
29 | "high_school_physics": ["physics"],
30 | "high_school_psychology": ["psychology"],
31 | "high_school_statistics": ["math"],
32 | "high_school_us_history": ["history"],
33 | "high_school_world_history": ["history"],
34 | "human_aging": ["health"],
35 | "human_sexuality": ["culture"],
36 | "international_law": ["law"],
37 | "jurisprudence": ["law"],
38 | "logical_fallacies": ["philosophy"],
39 | "machine_learning": ["computer science"],
40 | "management": ["business"],
41 | "marketing": ["business"],
42 | "medical_genetics": ["health"],
43 | "miscellaneous": ["other"],
44 | "moral_disputes": ["philosophy"],
45 | "moral_scenarios": ["philosophy"],
46 | "nutrition": ["health"],
47 | "philosophy": ["philosophy"],
48 | "prehistory": ["history"],
49 | "professional_accounting": ["other"],
50 | "professional_law": ["law"],
51 | "professional_medicine": ["health"],
52 | "professional_psychology": ["psychology"],
53 | "public_relations": ["politics"],
54 | "security_studies": ["politics"],
55 | "sociology": ["culture"],
56 | "us_foreign_policy": ["politics"],
57 | "virology": ["health"],
58 | "world_religions": ["philosophy"],
59 | }
60 |
61 | categories = {
62 | "STEM": ["physics", "chemistry", "biology", "computer science", "math", "engineering"],
63 | "humanities": ["history", "philosophy", "law"],
64 | "social sciences": ["politics", "culture", "economics", "geography", "psychology"],
65 | "other (business, health, misc.)": ["other", "business", "health"],
66 | }
67 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools>=61.0"]
3 | build-backend = "setuptools.build_meta"
4 |
5 | [project]
6 | name = "xchat"
7 | version = "0.0.1"
8 | description = "An open platform for training, serving, and evaluating large language model based chatbots."
9 | requires-python = ">=3.9"
10 | classifiers = [
11 | "Programming Language :: Python :: 3",
12 | "License :: OSI Approved :: Apache Software License",
13 | ]
14 |
15 | dependencies = [
16 | "accelerate>=0.21.0",
17 | "bitsandbytes>=0.41.1",
18 | "datasets",
19 | "deepspeed>=0.10.0",
20 | "einops",
21 | "evaluate>=0.4.0",
22 | "peft>=0.4.0",
23 | "scipy",
24 | "sentencepiece",
25 | "tokenizers>=0.13.3",
26 | "transformers==4.34.0",
27 | "wandb",
28 | "openai"
29 | ]
30 |
31 | [project.optional-dependencies]
32 | dev = ["black==23.3.0", "pylint==2.8.2", "pre-commit"]
33 | all = ["xchat[dev]"]
34 |
35 | [tool.setuptools.packages.find]
36 | exclude = ["assets", "data*", "tests*"]
37 |
38 | [tool.wheel]
39 | exclude = ["assets", "data*", "tests*"]
40 |
41 | [tool.ruff]
42 | target-version = 'py38'
43 | select = [
44 | "B", # flake8-bugbear
45 | "C4", # flake8-comprehensions
46 | "D", # pydocstyle
47 | "E", # Error
48 | "F", # pyflakes
49 | "I", # isort
50 | "ISC", # flake8-implicit-str-concat
51 | "N", # pep8-naming
52 | "PGH", # pygrep-hooks
53 | # "PTH", # flake8-use-pathlib
54 | "Q", # flake8-quotes
55 | "S", # bandit
56 | "SIM", # flake8-simplify
57 | "TRY", # tryceratops
58 | "UP", # pyupgrade
59 | "W", # Warning
60 | "YTT", # flake8-2020
61 | ]
62 |
63 | exclude = [
64 | "migrations",
65 | "__pycache__",
66 | "manage.py",
67 | "settings.py",
68 | "env",
69 | ".env",
70 | "venv",
71 | ".venv",
72 | ]
73 |
74 | ignore = [
75 | "B905", # zip strict=True; remove once python <3.10 support is dropped.
76 | "D100",
77 | "D101",
78 | "D102",
79 | "D103",
80 | "D104",
81 | "D105",
82 | "D106",
83 | "D107",
84 | "D200",
85 | "D401",
86 | "E402",
87 | "E501",
88 | "F401",
89 | "TRY003", # Avoid specifying messages outside exception class; overly strict, especially for ValueError
90 | "S101", # Use of assert detected; overly strict, especially for tests
91 | ]
92 | line-length = 120 # Must agree with Black
93 |
94 | [tool.ruff.flake8-bugbear]
95 | extend-immutable-calls = [
96 | "chr",
97 | "typer.Argument",
98 | "typer.Option",
99 | ]
100 |
101 | [tool.ruff.pydocstyle]
102 | convention = "numpy"
103 |
104 | [tool.ruff.pep8-naming]
105 | staticmethod-decorators = [
106 | "pydantic.validator",
107 | "pydantic.root_validator",
108 | ]
109 |
--------------------------------------------------------------------------------
/xchat/eval/gsm/examplars.py:
--------------------------------------------------------------------------------
1 | # These examplars are from the Table 20 of CoT paper (https://arxiv.org/pdf/2201.11903.pdf).
2 |
3 | EXAMPLARS = [
4 | {
5 | "question": "There are 15 trees in the grove. Grove workers will plant trees in the grove today. After they are done, there will be 21 trees. How many trees did the grove workers plant today?",
6 | "cot_answer": "There are 15 trees originally. Then there were 21 trees after some more were planted. So there must have been 21 - 15 = 6. So the answer is 6.",
7 | "short_answer": "6",
8 | },
9 | {
10 | "question": "If there are 3 cars in the parking lot and 2 more cars arrive, how many cars are in the parking lot?",
11 | "cot_answer": "There are originally 3 cars. 2 more cars arrive. 3 + 2 = 5. So the answer is 5.",
12 | "short_answer": "5",
13 | },
14 | {
15 | "question": "Leah had 32 chocolates and her sister had 42. If they ate 35, how many pieces do they have left in total?",
16 | "cot_answer": "Originally, Leah had 32 chocolates. Her sister had 42. So in total they had 32 + 42 = 74. After eating 35, they had 74 - 35 = 39. So the answer is 39.",
17 | "short_answer": "39",
18 | },
19 | {
20 | "question": "Jason had 20 lollipops. He gave Denny some lollipops. Now Jason has 12 lollipops. How many lollipops did Jason give to Denny?",
21 | "cot_answer": "Jason started with 20 lollipops. Then he had 12 after giving some to Denny. So he gave Denny 20 - 12 = 8. So the answer is 8.",
22 | "short_answer": "8",
23 | },
24 | {
25 | "question": "Shawn has five toys. For Christmas, he got two toys each from his mom and dad. How many toys does he have now?",
26 | "cot_answer": "Shawn started with 5 toys. If he got 2 toys each from his mom and dad, then that is 4 more toys. 5 + 4 = 9. So the answer is 9.",
27 | "short_answer": "9",
28 | },
29 | {
30 | "question": "There were nine computers in the server room. Five more computers were installed each day, from monday to thursday. How many computers are now in the server room?",
31 | "cot_answer": "There were originally 9 computers. For each of 4 days, 5 more computers were added. So 5 * 4 = 20 computers were added. 9 + 20 is 29. So the answer is 29.",
32 | "short_answer": "29",
33 | },
34 | {
35 | "question": "Michael had 58 golf balls. On tuesday, he lost 23 golf balls. On wednesday, he lost 2 more. How many golf balls did he have at the end of wednesday?",
36 | "cot_answer": "Michael started with 58 golf balls. After losing 23 on tuesday, he had 58 - 23 = 35. After losing 2 more, he had 35 - 2 = 33 golf balls. So the answer is 33.",
37 | "short_answer": "33",
38 | },
39 | {
40 | "question": "Olivia has $23. She bought five bagels for $3 each. How much money does she have left?",
41 | "cot_answer": "Olivia had 23 dollars. 5 bagels for 3 dollars each will be 5 x 3 = 15 dollars. So she has 23 - 15 dollars left. 23 - 15 is 8. So the answer is 8.",
42 | "short_answer": "8",
43 | },
44 | ]
45 |
--------------------------------------------------------------------------------
/xchat/eval/dispatch_openai_requests.py:
--------------------------------------------------------------------------------
1 | # Copied and modified from https://gist.github.com/neubig/80de662fb3e225c18172ec218be4917a. Thanks to Graham Neubig for sharing the original code.
2 |
3 | import asyncio
4 | from typing import Any, Dict, List
5 |
6 | import openai
7 |
8 |
9 | async def dispatch_openai_chat_requesets(
10 | messages_list: List[List[Dict[str, Any]]],
11 | model: str,
12 | **completion_kwargs: Any,
13 | ) -> List[str]:
14 | """Dispatches requests to OpenAI chat completion API asynchronously.
15 |
16 | Args:
17 | messages_list: List of messages to be sent to OpenAI chat completion API.
18 | model: OpenAI model to use.
19 | completion_kwargs: Keyword arguments to be passed to OpenAI ChatCompletion API. See https://platform.openai.com/docs/api-reference/chat for details.
20 |
21 | Returns
22 | -------
23 | List of responses from OpenAI API.
24 | """
25 | async_responses = [
26 | openai.ChatCompletion.acreate(
27 | model=model,
28 | messages=x,
29 | **completion_kwargs,
30 | )
31 | for x in messages_list
32 | ]
33 | return await asyncio.gather(*async_responses)
34 |
35 |
36 | async def dispatch_openai_prompt_requesets(
37 | prompt_list: List[str],
38 | model: str,
39 | **completion_kwargs: Any,
40 | ) -> List[str]:
41 | """Dispatches requests to OpenAI text completion API asynchronously.
42 |
43 | Args:
44 | prompt_list: List of prompts to be sent to OpenAI text completion API.
45 | model: OpenAI model to use.
46 | completion_kwargs: Keyword arguments to be passed to OpenAI text completion API. See https://platform.openai.com/docs/api-reference/completions for details.
47 |
48 | Returns
49 | -------
50 | List of responses from OpenAI API.
51 | """
52 | async_responses = [
53 | openai.Completion.acreate(
54 | model=model,
55 | prompt=x,
56 | **completion_kwargs,
57 | )
58 | for x in prompt_list
59 | ]
60 | return await asyncio.gather(*async_responses)
61 |
62 |
63 | if __name__ == "__main__":
64 | chat_completion_responses = asyncio.run(
65 | dispatch_openai_chat_requesets(
66 | messages_list=[
67 | [{"role": "user", "content": "Write a poem about asynchronous execution."}],
68 | [{"role": "user", "content": "Write a poem about asynchronous pirates."}],
69 | ],
70 | model="gpt-3.5-turbo",
71 | temperature=0.3,
72 | max_tokens=200,
73 | top_p=1.0,
74 | )
75 | )
76 |
77 | for i, x in enumerate(chat_completion_responses):
78 | print(f"Chat completion response {i}:\n{x['choices'][0]['message']['content']}\n\n")
79 |
80 | prompt_completion_responses = asyncio.run(
81 | dispatch_openai_prompt_requesets(
82 | prompt_list=[
83 | "Write a poem about asynchronous execution.\n",
84 | "Write a poem about asynchronous pirates.\n",
85 | ],
86 | model="text-davinci-003",
87 | temperature=0.3,
88 | max_tokens=200,
89 | top_p=1.0,
90 | )
91 | )
92 |
93 | for i, x in enumerate(prompt_completion_responses):
94 | print(f"Prompt completion response {i}:\n{x['choices'][0]['text']}\n\n")
95 |
--------------------------------------------------------------------------------
/xchat/eval/mbpp/prompts.py:
--------------------------------------------------------------------------------
1 | from xchat.eval.mbpp.examplars import EXAMPLARS
2 |
3 |
4 | class MbppPromptBase:
5 | def __init__(self, instuction, end_seq, is_few_shot) -> None:
6 | self.instruction = instuction
7 | self.end_seq = end_seq
8 | self.is_few_shot = is_few_shot
9 |
10 | def get_prompt(self, example):
11 | prompt = ""
12 | if self.is_few_shot:
13 | for d in EXAMPLARS:
14 | prompt += (
15 | self.instruction.format(instruction=d["text"], test_cases="\n".join(d["test_list"]), code=d["code"])
16 | + self.end_seq
17 | )
18 | prompt += self.instruction.format(
19 | instruction=example["prompt"].strip(), test_cases=example["reference"], code=""
20 | )
21 | return prompt.strip()
22 |
23 | def get_parser_seq(self):
24 | return self.end_seq
25 |
26 |
27 | class MbppPromptChat(MbppPromptBase):
28 | def __init__(self, instuction, end_seq, is_few_shot) -> None:
29 | super().__init__(instuction, end_seq, is_few_shot)
30 |
31 | def get_prompt(self, example):
32 | prompt = "[INST]"
33 | if self.is_few_shot:
34 | for d in EXAMPLARS:
35 | prompt += (
36 | self.instruction.format(instruction=d["text"], test_cases="\n".join(d["test_list"]), code=d["code"])
37 | + self.end_seq
38 | )
39 | prompt += self.instruction.format(
40 | instruction=example["prompt"].strip(), test_cases=example["reference"], code=""
41 | )
42 | prompt = prompt.strip()
43 | assert prompt.endswith("[PYTHON]")
44 | prompt = prompt[:-8]
45 | prompt += "[/INST]\n[PYTHON]\n"
46 | return prompt.strip()
47 |
48 |
49 | BASE_INSTRUCTION = """
50 | You are an expert Python programmer, and here is your task: {instruction} Your code should pass these tests:\n\n{test_cases}
51 | [BEGIN]
52 | {code}"""
53 |
54 |
55 | class BasePrompt(MbppPromptBase):
56 | def __init__(self, is_few_shot) -> None:
57 | super().__init__(instuction=BASE_INSTRUCTION, end_seq="\n[DONE]", is_few_shot=is_few_shot)
58 |
59 |
60 | INSTRUCTION = """
61 | You are an expert Python programmer, and here is your task: {instruction}
62 | Your code should pass these tests:\n\n{test_cases}\nYour code should start with a [PYTHON] tag and end with a [/PYTHON] tag.
63 | [PYTHON]
64 | {code}"""
65 |
66 |
67 | class CodeLLaMAPrompt(MbppPromptBase):
68 | def __init__(self, is_few_shot) -> None:
69 | super().__init__(instuction=INSTRUCTION, end_seq="\n[/PYTHON]", is_few_shot=is_few_shot)
70 |
71 |
72 | class LLaMA2ChatPrompt(MbppPromptChat):
73 | def __init__(self, is_few_shot) -> None:
74 | super().__init__(instuction=INSTRUCTION, end_seq="\n[/PYTHON]", is_few_shot=is_few_shot)
75 |
76 |
77 | LEMUR_INSTRUCTION = """
78 | <|im_start|>user
79 | You are an expert Python programmer, and here is your task: {instruction}
80 | Your code should pass these tests:\n\n{test_cases}\nYour code should start with a [PYTHON] tag and end with a [/PYTHON] tag.
81 | <|im_end|>
82 | <|im_start|>assistant
83 | [PYTHON]
84 | {code}"""
85 |
86 |
87 | class LemurPrompt(MbppPromptBase):
88 | def __init__(self, is_few_shot) -> None:
89 | super().__init__(instuction=LEMUR_INSTRUCTION, end_seq="\n[/PYTHON]", is_few_shot=is_few_shot)
90 |
91 |
92 | def getprompts(chat_format, is_few_shot):
93 | if chat_format in ["base"]:
94 | return BasePrompt(is_few_shot)
95 | elif chat_format in ["lemur"]:
96 | return LemurPrompt(is_few_shot)
97 | elif chat_format in ["codellama", "wizardcoder"]:
98 | return CodeLLaMAPrompt(is_few_shot)
99 | elif chat_format in ["llama2chat"]:
100 | return LLaMA2ChatPrompt(is_few_shot)
101 | raise NotImplementedError
102 |
--------------------------------------------------------------------------------
/xchat/eval/mbpp/evaluation.py:
--------------------------------------------------------------------------------
1 | import itertools
2 | from collections import Counter, defaultdict
3 | from concurrent.futures import ThreadPoolExecutor, as_completed
4 | from typing import Dict, Iterable, List, Union
5 |
6 | import numpy as np
7 | import tqdm
8 | from evaluate import load
9 |
10 | from xchat.eval.mbpp.data import read_problems, stream_jsonl, write_jsonl
11 | from xchat.eval.mbpp.execution import check_correctness
12 |
13 |
14 | def estimate_pass_at_k(
15 | num_samples: Union[int, List[int], np.ndarray], num_correct: Union[List[int], np.ndarray], k: int
16 | ) -> np.ndarray:
17 | """
18 | Estimates pass@k of each problem and returns them in an array.
19 | """
20 |
21 | def estimator(n: int, c: int, k: int) -> float:
22 | """
23 | Calculates 1 - comb(n - c, k) / comb(n, k).
24 | """
25 | if n - c < k:
26 | return 1.0
27 | return 1.0 - np.prod(1.0 - k / np.arange(n - c + 1, n + 1))
28 |
29 | if isinstance(num_samples, int):
30 | num_samples_it = itertools.repeat(num_samples, len(num_correct))
31 | else:
32 | assert len(num_samples) == len(num_correct)
33 | num_samples_it = iter(num_samples)
34 |
35 | return np.array([estimator(int(n), int(c), k) for n, c in zip(num_samples_it, num_correct)])
36 |
37 |
38 | def evaluate_functional_correctness(
39 | task: str,
40 | sample_file: str,
41 | k: List[int] = [1, 10, 100],
42 | n_workers: int = 4,
43 | timeout: float = 3.0,
44 | problems=None,
45 | problem_file=None,
46 | ):
47 | """
48 | Evaluates the functional correctness of generated samples, and writes
49 | results to f"{sample_file}_results.jsonl.gz".
50 | """
51 | if not problems:
52 | problems = read_problems(problem_file)
53 |
54 | # Check the generated samples against test suites.
55 | with ThreadPoolExecutor(max_workers=n_workers) as executor:
56 | futures = []
57 | completion_id = Counter()
58 | n_samples = 0
59 | results = defaultdict(list)
60 |
61 | print("Reading samples...")
62 | for sample in tqdm.tqdm(stream_jsonl(sample_file)):
63 | task_id = sample["task_id"]
64 | completion = sample["completion"]
65 | args = (task, problems[task_id], completion, timeout, completion_id[task_id])
66 | future = executor.submit(check_correctness, *args)
67 | futures.append(future)
68 | completion_id[task_id] += 1
69 | n_samples += 1
70 |
71 | assert len(completion_id) == len(problems), "Some problems are not attempted."
72 |
73 | print("Running test suites...")
74 | for future in tqdm.tqdm(as_completed(futures), total=len(futures)):
75 | result = future.result()
76 | results[result["task_id"]].append((result["completion_id"], result))
77 |
78 | # Calculate pass@k.
79 | total, correct = [], []
80 | for result in results.values():
81 | result.sort()
82 | passed = [r[1]["passed"] for r in result]
83 | total.append(len(passed))
84 | correct.append(sum(passed))
85 | total = np.array(total)
86 | correct = np.array(correct)
87 |
88 | ks = k
89 | pass_at_k = {f"pass@{k}": estimate_pass_at_k(total, correct, k).mean() for k in ks if (total >= k).all()}
90 |
91 | # Finally, save the results in one file:
92 | def combine_results():
93 | for sample in stream_jsonl(sample_file):
94 | task_id = sample["task_id"]
95 | result = results[task_id].pop(0)
96 | sample["result"] = result[1]["result"]
97 | sample["passed"] = result[1]["passed"]
98 | yield sample
99 |
100 | out_file = sample_file + "_results.jsonl"
101 | print(f"Writing results to {out_file}...")
102 | write_jsonl(out_file, tqdm.tqdm(combine_results(), total=n_samples))
103 |
104 | return pass_at_k
105 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Created by https://www.toptal.com/developers/gitignore/api/linux,macos,python,visualstudiocode
2 | # Edit at https://www.toptal.com/developers/gitignore?templates=linux,macos,python,visualstudiocode
3 |
4 | ### Linux ###
5 | *~
6 |
7 | # temporary files which can be created if a process still has a handle open of a deleted file
8 | .fuse_hidden*
9 |
10 | # KDE directory preferences
11 | .directory
12 |
13 | # Linux trash folder which might appear on any partition or disk
14 | .Trash-*
15 |
16 | # .nfs files are created when an open file is removed but is still being accessed
17 | .nfs*
18 |
19 | ### macOS ###
20 | # General
21 | .DS_Store
22 | .AppleDouble
23 | .LSOverride
24 |
25 | # Icon must end with two \r
26 | Icon
27 |
28 |
29 | # Thumbnails
30 | ._*
31 |
32 | # Files that might appear in the root of a volume
33 | .DocumentRevisions-V100
34 | .fseventsd
35 | .Spotlight-V100
36 | .TemporaryItems
37 | .Trashes
38 | .VolumeIcon.icns
39 | .com.apple.timemachine.donotpresent
40 |
41 | # Directories potentially created on remote AFP share
42 | .AppleDB
43 | .AppleDesktop
44 | Network Trash Folder
45 | Temporary Items
46 | .apdisk
47 |
48 | ### macOS Patch ###
49 | # iCloud generated files
50 | *.icloud
51 |
52 | ### Python ###
53 | # Byte-compiled / optimized / DLL files
54 | __pycache__/
55 | *.py[cod]
56 | *$py.class
57 |
58 | # C extensions
59 | *.so
60 |
61 | # Distribution / packaging
62 | .Python
63 | build/
64 | develop-eggs/
65 | dist/
66 | downloads/
67 | eggs/
68 | .eggs/
69 | lib/
70 | lib64/
71 | parts/
72 | sdist/
73 | var/
74 | wheels/
75 | share/python-wheels/
76 | *.egg-info/
77 | .installed.cfg
78 | *.egg
79 | MANIFEST
80 |
81 | # PyInstaller
82 | # Usually these files are written by a python script from a template
83 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
84 | *.manifest
85 | *.spec
86 |
87 | # Installer logs
88 | pip-log.txt
89 | pip-delete-this-directory.txt
90 |
91 | # Unit test / coverage reports
92 | htmlcov/
93 | .tox/
94 | .nox/
95 | .coverage
96 | .coverage.*
97 | .cache
98 | nosetests.xml
99 | coverage.xml
100 | *.cover
101 | *.py,cover
102 | .hypothesis/
103 | .pytest_cache/
104 | cover/
105 |
106 | # Translations
107 | *.mo
108 | *.pot
109 |
110 | # Django stuff:
111 | *.log
112 | local_settings.py
113 | db.sqlite3
114 | db.sqlite3-journal
115 |
116 | # Flask stuff:
117 | instance/
118 | .webassets-cache
119 |
120 | # Scrapy stuff:
121 | .scrapy
122 |
123 | # Sphinx documentation
124 | docs/_build/
125 |
126 | # PyBuilder
127 | .pybuilder/
128 | target/
129 |
130 | # Jupyter Notebook
131 | .ipynb_checkpoints
132 |
133 | # IPython
134 | profile_default/
135 | ipython_config.py
136 |
137 | # pyenv
138 | # For a library or package, you might want to ignore these files since the code is
139 | # intended to run in multiple environments; otherwise, check them in:
140 | # .python-version
141 |
142 | # pipenv
143 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
144 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
145 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
146 | # install all needed dependencies.
147 | #Pipfile.lock
148 |
149 | # poetry
150 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
151 | # This is especially recommended for binary packages to ensure reproducibility, and is more
152 | # commonly ignored for libraries.
153 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
154 | #poetry.lock
155 |
156 | # pdm
157 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
158 | #pdm.lock
159 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
160 | # in version control.
161 | # https://pdm.fming.dev/#use-with-ide
162 | .pdm.toml
163 |
164 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
165 | __pypackages__/
166 |
167 | # Celery stuff
168 | celerybeat-schedule
169 | celerybeat.pid
170 |
171 | # SageMath parsed files
172 | *.sage.py
173 |
174 | # Environments
175 | .env
176 | .venv
177 | env/
178 | venv/
179 | ENV/
180 | env.bak/
181 | venv.bak/
182 |
183 | # Spyder project settings
184 | .spyderproject
185 | .spyproject
186 |
187 | # Rope project settings
188 | .ropeproject
189 |
190 | # mkdocs documentation
191 | /site
192 |
193 | # mypy
194 | .mypy_cache/
195 | .dmypy.json
196 | dmypy.json
197 |
198 | # Pyre type checker
199 | .pyre/
200 |
201 | # pytype static type analyzer
202 | .pytype/
203 |
204 | # Cython debug symbols
205 | cython_debug/
206 |
207 | # PyCharm
208 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
209 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
210 | # and can be added to the global gitignore or merged into this file. For a more nuclear
211 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
212 | .idea/
213 |
214 | ### Python Patch ###
215 | # Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration
216 | poetry.toml
217 |
218 | # ruff
219 | .ruff_cache/
220 |
221 | # LSP config files
222 | pyrightconfig.json
223 |
224 | ### VisualStudioCode ###
225 | .vscode/*
226 | !.vscode/settings.json
227 | !.vscode/tasks.json
228 | !.vscode/launch.json
229 | !.vscode/extensions.json
230 | !.vscode/*.code-snippets
231 |
232 | # Local History for Visual Studio Code
233 | .history/
234 |
235 | # Built Visual Studio Code Extensions
236 | *.vsix
237 |
238 | ### VisualStudioCode Patch ###
239 | # Ignore all local history of files
240 | .history
241 | .ionide
242 |
243 | # End of https://www.toptal.com/developers/gitignore/api/linux,macos,python,visualstudiocode
244 |
245 | data
246 | results
247 | wandb
248 | outputs
249 |
--------------------------------------------------------------------------------
/xchat/train/llama_flash_attn_monkey_patch.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import warnings
3 | from typing import List, Optional, Tuple
4 |
5 | import flash_attn
6 | import torch
7 | import torch.nn.functional as F
8 | from einops import rearrange
9 | from flash_attn.bert_padding import pad_input, unpad_input
10 |
11 | logging.warning(f"Find flash_attn version {flash_attn.__version__}")
12 | if flash_attn.__version__ <= "1.0.9":
13 | from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func as flash_attn_varlen_qkvpacked_func
14 | else:
15 | from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
16 |
17 | import transformers
18 | from torch import nn
19 | from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
20 |
21 |
22 | def forward(
23 | self,
24 | hidden_states: torch.Tensor,
25 | attention_mask: Optional[torch.Tensor] = None,
26 | position_ids: Optional[torch.Tensor] = None,
27 | past_key_value: Optional[Tuple[torch.Tensor]] = None,
28 | output_attentions: bool = False,
29 | use_cache: bool = False,
30 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
31 | """Input shape: Batch x Time x Channel.
32 |
33 | attention_mask: [bsz, q_len]
34 | """
35 | if output_attentions:
36 | warnings.warn("Output attentions is not supported for patched `LlamaAttention`, returning `None` instead.")
37 |
38 | bsz, q_len, _ = hidden_states.size()
39 |
40 | if self.config.pretraining_tp > 1:
41 | key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
42 | query_slices = self.q_proj.weight.split((self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0)
43 | key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
44 | value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
45 |
46 | query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
47 | query_states = torch.cat(query_states, dim=-1)
48 |
49 | key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
50 | key_states = torch.cat(key_states, dim=-1)
51 |
52 | value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
53 | value_states = torch.cat(value_states, dim=-1)
54 |
55 | else:
56 | query_states = self.q_proj(hidden_states)
57 | key_states = self.k_proj(hidden_states)
58 | value_states = self.v_proj(hidden_states)
59 |
60 | query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
61 | key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
62 | value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
63 |
64 | # [bsz, q_len, nh, hd]
65 | # [bsz, nh, q_len, hd]
66 |
67 | kv_seq_len = key_states.shape[-2]
68 | if past_key_value is not None:
69 | kv_seq_len += past_key_value[0].shape[-2]
70 |
71 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
72 | query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
73 | # [bsz, nh, t, hd]
74 | # Past Key value support
75 | if past_key_value is not None:
76 | # reuse k, v, self_attention
77 | key_states = torch.cat([past_key_value[0], key_states], dim=2)
78 | value_states = torch.cat([past_key_value[1], value_states], dim=2)
79 |
80 | past_key_value = (key_states, value_states) if use_cache else None
81 |
82 | # repeat k/v heads if n_kv_heads < n_heads
83 | key_states = repeat_kv(key_states, self.num_key_value_groups)
84 | value_states = repeat_kv(value_states, self.num_key_value_groups)
85 |
86 | # Flash attention codes from
87 | # https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attention.py
88 |
89 | # transform the data into the format required by flash attention
90 | qkv = torch.stack([query_states, key_states, value_states], dim=2) # [bsz, nh, 3, q_len, hd]
91 | qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
92 | # We have disabled _prepare_decoder_attention_mask in LlamaModel
93 | # the attention_mask should be the same as the key_padding_mask
94 | key_padding_mask = attention_mask
95 |
96 | if key_padding_mask is None:
97 | qkv = rearrange(qkv, "b s ... -> (b s) ...")
98 | max_s = q_len
99 | cu_q_lens = torch.arange(0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device)
100 | output = flash_attn_varlen_qkvpacked_func(qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True)
101 | output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
102 | else:
103 | nheads = qkv.shape[-2]
104 | x = rearrange(qkv, "b s three h d -> b s (three h d)")
105 | x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask)
106 | x_unpad = rearrange(x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads)
107 | output_unpad = flash_attn_varlen_qkvpacked_func(x_unpad, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True)
108 | output = rearrange(
109 | pad_input(rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, bsz, q_len),
110 | "b s (h d) -> b s h d",
111 | h=nheads,
112 | )
113 | output = rearrange(output, "b s h d -> b s (h d)")
114 | if self.config.pretraining_tp > 1:
115 | output = output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
116 | o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
117 | output = sum([F.linear(output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
118 | else:
119 | output = self.o_proj(output)
120 | return output, None, past_key_value
121 |
122 |
123 | # Disable the transformation of the attention mask in LlamaModel as the flash attention
124 | # requires the attention mask to be the same as the key_padding_mask
125 | def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
126 | # [bsz, seq_len]
127 | return attention_mask
128 |
129 |
130 | def replace_llama_attn_with_flash_attn():
131 | cuda_major, cuda_minor = torch.cuda.get_device_capability()
132 | if cuda_major < 8:
133 | logging.warning(
134 | "Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward."
135 | "ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593"
136 | )
137 | transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = (
138 | _prepare_decoder_attention_mask
139 | )
140 | transformers.models.llama.modeling_llama.LlamaAttention.forward = forward
141 |
--------------------------------------------------------------------------------
/xchat/eval/mbpp/execution.py:
--------------------------------------------------------------------------------
1 | import ast
2 | import contextlib
3 | import faulthandler
4 | import io
5 | import multiprocessing
6 | import os
7 | import platform
8 | import signal
9 | import tempfile
10 | from typing import Callable, Dict, Optional
11 |
12 |
13 | def check_correctness(
14 | task: str, problem: Dict, completion: str, timeout: float, completion_id: Optional[int] = None
15 | ) -> Dict:
16 | """
17 | Evaluates the functional correctness of a completion by running the test
18 | suite provided in the problem.
19 |
20 | :param completion_id: an optional completion ID so we can match
21 | the results later even if execution finishes asynchronously.
22 | """
23 |
24 | def unsafe_execute():
25 | with create_tempdir():
26 | # These system calls are needed when cleaning up tempdir.
27 | import os
28 | import shutil
29 |
30 | rmtree = shutil.rmtree
31 | rmdir = os.rmdir
32 | chdir = os.chdir
33 |
34 | # Disable functionalities that can make destructive changes to the test.
35 | reliability_guard()
36 |
37 | # Construct the check program and run it.
38 | if task == "humaneval":
39 | check_program = (
40 | problem["prompt"] + completion + "\n" + problem["test"] + "\n" + f"check({problem['entry_point']})"
41 | )
42 | elif task == "mbpp":
43 | check_program = completion + "\n" + problem["reference"]
44 | else:
45 | raise ValueError(f"Unknown task: {task}")
46 |
47 | try:
48 | exec_globals = {}
49 | with swallow_io(), time_limit(timeout):
50 | # WARNING
51 | # This program exists to execute untrusted model-generated code. Although
52 | # it is highly unlikely that model-generated code will do something overtly
53 | # malicious in response to this test suite, model-generated code may act
54 | # destructively due to a lack of model capability or alignment.
55 | # Users are strongly encouraged to sandbox this evaluation suite so that it
56 | # does not perform destructive actions on their host or network. For more
57 | # information on how OpenAI sandboxes its code, see the accompanying paper.
58 | # Once you have read this disclaimer and taken appropriate precautions,
59 | # uncomment the following line and proceed at your own risk:
60 | exec(check_program, exec_globals)
61 | result.append("passed")
62 | except TimeoutException:
63 | result.append("timed out")
64 | except BaseException as e:
65 | result.append(f"failed: {e}")
66 |
67 | # Needed for cleaning up.
68 | shutil.rmtree = rmtree
69 | os.rmdir = rmdir
70 | os.chdir = chdir
71 |
72 | manager = multiprocessing.Manager()
73 | result = manager.list()
74 |
75 | p = multiprocessing.Process(target=unsafe_execute)
76 | p.start()
77 | p.join(timeout=timeout + 1)
78 | if p.is_alive():
79 | p.kill()
80 |
81 | if not result:
82 | result.append("timed out")
83 |
84 | return {
85 | "task_id": problem["task_id"],
86 | "passed": result[0] == "passed",
87 | "result": result[0],
88 | "completion_id": completion_id,
89 | }
90 |
91 |
92 | @contextlib.contextmanager
93 | def time_limit(seconds: float):
94 | def signal_handler(signum, frame):
95 | raise TimeoutException("Timed out!")
96 |
97 | signal.setitimer(signal.ITIMER_REAL, seconds)
98 | signal.signal(signal.SIGALRM, signal_handler)
99 | try:
100 | yield
101 | finally:
102 | signal.setitimer(signal.ITIMER_REAL, 0)
103 |
104 |
105 | @contextlib.contextmanager
106 | def swallow_io():
107 | stream = WriteOnlyStringIO()
108 | with contextlib.redirect_stdout(stream), contextlib.redirect_stderr(stream), redirect_stdin(stream):
109 | yield
110 |
111 |
112 | @contextlib.contextmanager
113 | def create_tempdir():
114 | with tempfile.TemporaryDirectory() as dirname, chdir(dirname):
115 | yield dirname
116 |
117 |
118 | class TimeoutException(Exception):
119 | pass
120 |
121 |
122 | class WriteOnlyStringIO(io.StringIO):
123 | """StringIO that throws an exception when it's read from."""
124 |
125 | def read(self, *args, **kwargs):
126 | raise OSError
127 |
128 | def readline(self, *args, **kwargs):
129 | raise OSError
130 |
131 | def readlines(self, *args, **kwargs):
132 | raise OSError
133 |
134 | def readable(self, *args, **kwargs):
135 | """Returns True if the IO object can be read."""
136 | return False
137 |
138 |
139 | class redirect_stdin(contextlib._RedirectStream): # type: ignore
140 | _stream = "stdin"
141 |
142 |
143 | @contextlib.contextmanager
144 | def chdir(root):
145 | if root == ".":
146 | yield
147 | return
148 | cwd = os.getcwd()
149 | os.chdir(root)
150 | try:
151 | yield
152 | except BaseException as exc:
153 | raise exc
154 | finally:
155 | os.chdir(cwd)
156 |
157 |
158 | def reliability_guard(maximum_memory_bytes: Optional[int] = None):
159 | """
160 | This disables various destructive functions and prevents the generated code
161 | from interfering with the test (e.g. fork bomb, killing other processes,
162 | removing filesystem files, etc.).
163 |
164 | WARNING
165 | This function is NOT a security sandbox. Untrusted code, including, model-
166 | generated code, should not be blindly executed outside of one. See the
167 | Codex paper for more information about OpenAI's code sandbox, and proceed
168 | with caution.
169 | """
170 | if maximum_memory_bytes is not None:
171 | import resource
172 |
173 | resource.setrlimit(resource.RLIMIT_AS, (maximum_memory_bytes, maximum_memory_bytes))
174 | resource.setrlimit(resource.RLIMIT_DATA, (maximum_memory_bytes, maximum_memory_bytes))
175 | if platform.uname().system != "Darwin":
176 | resource.setrlimit(resource.RLIMIT_STACK, (maximum_memory_bytes, maximum_memory_bytes))
177 |
178 | faulthandler.disable()
179 |
180 | import builtins
181 |
182 | builtins.exit = None
183 | builtins.quit = None
184 |
185 | import os
186 |
187 | os.environ["OMP_NUM_THREADS"] = "1"
188 |
189 | os.kill = None
190 | os.system = None
191 | os.putenv = None
192 | os.remove = None
193 | os.removedirs = None
194 | os.rmdir = None
195 | os.fchdir = None
196 | os.setuid = None
197 | os.fork = None
198 | os.forkpty = None
199 | os.killpg = None
200 | os.rename = None
201 | os.renames = None
202 | os.truncate = None
203 | os.replace = None
204 | os.unlink = None
205 | os.fchmod = None
206 | os.fchown = None
207 | os.chmod = None
208 | os.chown = None
209 | os.chroot = None
210 | os.fchdir = None
211 | os.lchflags = None
212 | os.lchmod = None
213 | os.lchown = None
214 | os.getcwd = None
215 | os.chdir = None
216 |
217 | import shutil
218 |
219 | shutil.rmtree = None
220 | shutil.move = None
221 | shutil.chown = None
222 |
223 | import subprocess
224 |
225 | subprocess.Popen = None # type: ignore
226 |
227 | __builtins__["help"] = None
228 |
229 | import sys
230 |
231 | sys.modules["ipdb"] = None
232 | sys.modules["joblib"] = None
233 | sys.modules["resource"] = None
234 | sys.modules["psutil"] = None
235 | sys.modules["tkinter"] = None
236 |
--------------------------------------------------------------------------------
/xchat/eval/humaneval/run_eval.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import os
4 | import random
5 |
6 | import torch
7 |
8 | from xchat.eval.humaneval.prompts import getprompts
9 | from xchat.eval.mbpp.data import read_problems, write_jsonl
10 | from xchat.eval.mbpp.evaluation import evaluate_functional_correctness
11 | from xchat.eval.utils import generate_completions, load_hf_lm_and_tokenizer, query_openai_chat_model
12 |
13 |
14 | def main(args):
15 | random.seed(42)
16 | if not os.path.exists(args.save_dir):
17 | os.makedirs(args.save_dir, exist_ok=True)
18 | data_file = args.data_file
19 | test_data = list(read_problems(data_file).values())
20 |
21 | if args.max_num_examples is not None and len(test_data) > args.max_num_examples:
22 | test_data = random.sample(test_data, args.max_num_examples)
23 | print("Number of examples:", len(test_data))
24 | promptclass = getprompts(args.chat_format)
25 | prompts = [promptclass.getPrompt(example["prompt"]) for example in test_data]
26 |
27 | if args.model_name_or_path:
28 | print("Loading model and tokenizer...")
29 | model, tokenizer = load_hf_lm_and_tokenizer(
30 | model_name_or_path=args.model_name_or_path,
31 | tokenizer_name_or_path=args.tokenizer_name_or_path,
32 | load_in_8bit=args.load_in_8bit,
33 | load_in_half=True,
34 | # device map is determined by the number of gpus available.
35 | device_map="balanced_low_0" if torch.cuda.device_count() > 1 else "auto",
36 | gptq_model=args.gptq,
37 | )
38 |
39 | # these stop sequences are those mentioned in the codex paper.
40 | stop_sequences = ["\nclass", "\ndef", "\n#", "\nif", "\nprint"]
41 | # Because many tokenizers will treat the word after space differently from the original word alone,
42 | # to be consistent, we add a space before tokenization and remove it after tokenization.
43 | stop_sequences = [tokenizer.encode(" " + x, add_special_tokens=False)[1:] for x in stop_sequences]
44 | print("stop sequences:")
45 | print(stop_sequences)
46 | if args.chat_format == "vicuna":
47 | # Vicuna uses a different stop sequence otherwise meaning tokens may be cut off.
48 | stop_sequences = [[tokenizer.eos_token_id]]
49 | outputs_per_sampling_iter = []
50 | for sampling_iter in range(args.unbiased_sampling_size_n):
51 | print(f"Sampling iter: {sampling_iter} / {args.unbiased_sampling_size_n}")
52 | samping_outputs = generate_completions(
53 | model=model,
54 | tokenizer=tokenizer,
55 | prompts=prompts,
56 | max_new_tokens=args.max_new_tokens,
57 | batch_size=args.eval_batch_size,
58 | stop_id_sequences=stop_sequences,
59 | num_return_sequences=1, # we don't use the hf num_return_sequences, because otherwise the real batch size will be multiplied by it and often cause oom.
60 | do_sample=True, # if only pass@1 is evaluated, we do greedy decoding.
61 | top_p=args.top_p,
62 | temperature=args.temperature,
63 | eos_token_id=tokenizer.eos_token_id,
64 | )
65 | outputs_per_sampling_iter.append(samping_outputs)
66 | # regroup the outputs to match the number of test data.
67 | outputs = []
68 | for i in range(len(prompts)):
69 | for j in range(args.unbiased_sampling_size_n):
70 | outputs.append(outputs_per_sampling_iter[j][i])
71 | else:
72 | instances = [
73 | {
74 | "id": examle["task_id"],
75 | "prompt": "Complete the following python function. Please only output the code for the completed function.\n\n\n"
76 | + prompt,
77 | }
78 | for examle, prompt in zip(test_data, prompts)
79 | ]
80 | results = query_openai_chat_model(
81 | engine=args.openai_engine,
82 | instances=instances,
83 | output_path=os.path.join(args.save_dir, "openai_query_results.jsonl"),
84 | batch_size=args.eval_batch_size,
85 | top_p=0.95,
86 | temperature=args.temperature,
87 | n=args.unbiased_sampling_size_n,
88 | )
89 | outputs = []
90 | for result in results:
91 | for choice in result["response_metadata"]["choices"]:
92 | outputs.append(choice["message"]["content"])
93 |
94 | # duplicates test data to match the number of outputs.
95 | duplicate_test_data = [example for example in test_data for _ in range(args.unbiased_sampling_size_n)]
96 | predictions = []
97 | for _, example, output in zip(prompts, duplicate_test_data, outputs):
98 | output_lines = output.split("\n")
99 | idx = len(output_lines) - 1
100 | while idx >= 0 and not output_lines[idx].strip().startswith("return"):
101 | idx -= 1
102 | output_lines = output_lines[: idx + 1]
103 | processed_output = "\n".join(output_lines) + "\n"
104 | predictions.append(
105 | {
106 | "task_id": example["task_id"],
107 | "prompt": example["prompt"],
108 | "completion": processed_output,
109 | "raw_response": output,
110 | }
111 | )
112 |
113 | prediction_save_path = os.path.join(args.save_dir, "humaneval_predictions.jsonl")
114 |
115 | write_jsonl(prediction_save_path, predictions)
116 | pass_at_k_results = evaluate_functional_correctness(
117 | task="humaneval",
118 | sample_file=prediction_save_path,
119 | k=args.eval_pass_at_ks,
120 | problems={example["task_id"]: example for example in test_data},
121 | )
122 |
123 | print(pass_at_k_results)
124 |
125 | with open(os.path.join(args.save_dir, "metrics.json"), "w") as fout:
126 | json.dump(pass_at_k_results, fout)
127 |
128 |
129 | if __name__ == "__main__":
130 | parser = argparse.ArgumentParser()
131 | parser.add_argument("--data_file", type=str, help="Path to the data file.")
132 | parser.add_argument("--max_num_examples", type=int, default=None, help="Maximum number of examples to evaluate.")
133 | parser.add_argument(
134 | "--model_name_or_path",
135 | type=str,
136 | default=None,
137 | help="If specified, we will load the model to generate the predictions.",
138 | )
139 | parser.add_argument(
140 | "--tokenizer_name_or_path", type=str, default=None, help="If specified, we will load the tokenizer from here."
141 | )
142 | parser.add_argument(
143 | "--openai_engine",
144 | type=str,
145 | default=None,
146 | help="If specified, we will use the OpenAI API to generate the predictions.",
147 | )
148 | parser.add_argument("--data_dir", type=str, help="Directory to save the data.")
149 | parser.add_argument("--cache_dir", type=str, help="Directory to save the cache.")
150 | parser.add_argument("--save_dir", type=str, required=True, help="Directory to save the results.")
151 | parser.add_argument("--max_new_tokens", type=int, required=True, help="Maximum number of tokens to generate.")
152 | parser.add_argument("--eval_batch_size", type=int, default=1, help="Batch size for evaluation.")
153 | parser.add_argument("--top_p", type=float, default=0.95)
154 | parser.add_argument(
155 | "--temperature",
156 | type=float,
157 | default=0.1,
158 | help="Temperature for sampling. This is should be low for evaluating smaller pass@k, and high for larger pass@k.",
159 | )
160 | parser.add_argument(
161 | "--eval_pass_at_ks", nargs="+", type=int, default=[1], help="Multiple k's that we will report pass@k."
162 | )
163 | parser.add_argument(
164 | "--unbiased_sampling_size_n",
165 | type=int,
166 | default=20,
167 | help="Codex HumanEval requires `n` sampled generations per prompt, to estimate the unbiased pass@k. ",
168 | )
169 |
170 | parser.add_argument(
171 | "--load_in_8bit",
172 | action="store_true",
173 | help="Load model in 8bit mode, which will reduce memory and speed up inference.",
174 | )
175 | parser.add_argument("--gptq", action="store_true", help="If given, we're evaluating a 4-bit quantized GPTQ model.")
176 | parser.add_argument(
177 | "--chat_format",
178 | type=str,
179 | default=None,
180 | choices=["tulu", "chatml", "llama-2-chat", "vicuna", "base"],
181 | help="if given, we will use the chat format to generate the predictions.",
182 | )
183 | args = parser.parse_args()
184 | # model_name_or_path and openai_engine cannot be both None or both not None.
185 | assert (args.model_name_or_path is None) != (
186 | args.openai_engine is None
187 | ), "Either model_name_or_path or openai_engine should be specified."
188 | assert args.unbiased_sampling_size_n >= max(
189 | args.eval_pass_at_ks
190 | ), "n should be larger than the largest k in eval_pass_at_ks."
191 | main(args)
192 |
--------------------------------------------------------------------------------
/xchat/eval/mbpp/run_eval.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import os
4 | import random
5 |
6 | import torch
7 | from datasets import load_dataset
8 |
9 | from xchat.eval.mbpp.data import read_problems, write_jsonl
10 | from xchat.eval.mbpp.evaluation import evaluate_functional_correctness
11 | from xchat.eval.mbpp.prompts import getprompts
12 | from xchat.eval.utils import generate_completions, load_hf_lm_and_tokenizer, query_openai_chat_model
13 |
14 |
15 | def prepare_inference_arg(args, test_data):
16 | res = []
17 | for example in test_data:
18 | d = {}
19 | d["prompt"] = getprompts(args.chat_format, args.few_shot).get_prompt(example)
20 | d["reference"] = example["reference"]
21 | d["task_id"] = example["task_id"]
22 | res.append(d)
23 | return res
24 |
25 |
26 | def prepare_data(args):
27 | if not os.path.exists(os.path.join(args.save_dir, "mbpp_inference_args.jsonl")):
28 | os.makedirs(args.save_dir, exist_ok=True)
29 | raw_dataset = load_dataset(path="mbpp", name=None, cache_dir=args.cache_dir)["test"]
30 | if len(raw_dataset) != 500:
31 | raise ValueError(f"Expected 500 test examples, got {len(raw_dataset)}")
32 | with open(os.path.join(args.save_dir, "mbpp_inference_args.jsonl"), "w") as f:
33 | for i, example in enumerate(raw_dataset):
34 | d__ = {
35 | "task_id": i,
36 | "prompt": example["text"],
37 | "reference": "\n".join(example["test_list"]),
38 | }
39 | f.write(json.dumps(d__) + "\n")
40 | return os.path.join(args.save_dir, "mbpp_inference_args.jsonl")
41 |
42 |
43 | def main(args):
44 | random.seed(42)
45 | if not os.path.exists(args.data_dir):
46 | os.makedirs(args.data_dir, exist_ok=True)
47 | if not os.path.exists(args.save_dir):
48 | os.makedirs(args.save_dir, exist_ok=True)
49 |
50 | data_file = prepare_data(args)
51 |
52 | test_data = list(read_problems(data_file).values())
53 | if args.max_num_examples is not None and len(test_data) > args.max_num_examples:
54 | test_data = random.sample(test_data, args.max_num_examples)
55 |
56 | test_data = prepare_inference_arg(args, test_data)
57 | print("Number of examples:", len(test_data))
58 |
59 | if args.model_name_or_path:
60 | print("Loading model and tokenizer...")
61 | model, tokenizer = load_hf_lm_and_tokenizer(
62 | model_name_or_path=args.model_name_or_path,
63 | tokenizer_name_or_path=args.tokenizer_name_or_path,
64 | load_in_8bit=args.load_in_8bit,
65 | load_in_half=True,
66 | # device map is determined by the number of gpus available.
67 | device_map="balanced_low_0" if torch.cuda.device_count() > 1 else "auto",
68 | gptq_model=args.gptq,
69 | )
70 | tokenizer.padding_side = "left"
71 | tokenizer.pad_token_id = 0
72 | test_data = sorted(
73 | test_data, key=lambda x: len(tokenizer(x["prompt"], return_tensors="pt").input_ids[0]), reverse=True
74 | )
75 |
76 | stop_sequences = ["\nassert", "\nprint"]
77 | stop_sequences = [tokenizer.encode(" " + x, add_special_tokens=False)[1:] for x in stop_sequences]
78 | prompts = [example["prompt"] for example in test_data]
79 |
80 | outputs_per_sampling_iter = []
81 | for sampling_iter in range(args.unbiased_sampling_size_n):
82 | assert args.greedy_decoding is True
83 | print(f"Sampling iter: {sampling_iter} / {args.unbiased_sampling_size_n}")
84 | samping_outputs = generate_completions(
85 | model=model,
86 | tokenizer=tokenizer,
87 | prompts=prompts,
88 | max_new_tokens=args.max_new_tokens,
89 | batch_size=args.eval_batch_size,
90 | stop_id_sequences=stop_sequences,
91 | num_return_sequences=1, # we don't use the hf num_return_sequences, because otherwise the real batch size will be multiplied by it and often cause oom.
92 | do_sample=not args.greedy_decoding, # if only pass@1 is evaluated, we do greedy decoding.
93 | top_p=args.top_p,
94 | temperature=args.temperature,
95 | eos_token_id=tokenizer.eos_token_id,
96 | )
97 | assert len(samping_outputs) == len(prompts)
98 | outputs_per_sampling_iter.append(samping_outputs)
99 | outputs = []
100 | for i in range(len(prompts)):
101 | for j in range(args.unbiased_sampling_size_n):
102 | outputs.append(outputs_per_sampling_iter[j][i])
103 | else:
104 | instances = [
105 | {
106 | "id": examle["task_id"],
107 | "prompt": "Complete the following python function. Please only output the code for the completed function.\n\n\n"
108 | + prompt,
109 | }
110 | for examle, prompt in zip(test_data, prompts)
111 | ]
112 | results = query_openai_chat_model(
113 | engine=args.openai_engine,
114 | instances=instances,
115 | output_path=os.path.join(args.save_dir, "openai_query_results.jsonl"),
116 | batch_size=args.eval_batch_size,
117 | top_p=0.95,
118 | temperature=args.temperature,
119 | n=args.unbiased_sampling_size_n,
120 | )
121 | outputs = []
122 | for result in results:
123 | for choice in result["response_metadata"]["choices"]:
124 | outputs.append(choice["message"]["content"])
125 |
126 | duplicate_test_data = [example for example in test_data for _ in range(args.unbiased_sampling_size_n)]
127 | assert len(duplicate_test_data) == len(outputs)
128 |
129 | def process(args, s):
130 | parse_seq = getprompts(args.chat_format, args.few_shot).get_parser_seq()
131 | return s.split(parse_seq)[0].strip()
132 |
133 | predictions = [
134 | {
135 | "task_id": example["task_id"],
136 | "prompt": example["prompt"],
137 | "completion": process(args, output.strip()),
138 | "reference": example["reference"],
139 | }
140 | for example, output in zip(duplicate_test_data, outputs)
141 | ]
142 | prediction_save_path = os.path.join(args.save_dir, "mbpp_predictions.jsonl")
143 | write_jsonl(prediction_save_path, predictions, append=True)
144 |
145 | pass_at_k_results = evaluate_functional_correctness(
146 | sample_file=prediction_save_path,
147 | k=args.eval_pass_at_ks,
148 | problems={example["task_id"]: example for example in test_data},
149 | )
150 |
151 | print(pass_at_k_results)
152 |
153 | with open(os.path.join(args.save_dir, "metrics.json"), "w") as fout:
154 | json.dump(pass_at_k_results, fout)
155 |
156 |
157 | if __name__ == "__main__":
158 | parser = argparse.ArgumentParser()
159 | parser.add_argument("--data_file", type=str, help="Path to the data file.")
160 | parser.add_argument("--max_num_examples", type=int, default=None, help="Maximum number of examples to evaluate.")
161 | parser.add_argument(
162 | "--model_name_or_path",
163 | type=str,
164 | default=None,
165 | help="If specified, we will load the model to generate the predictions.",
166 | )
167 | parser.add_argument(
168 | "--tokenizer_name_or_path", type=str, default=None, help="If specified, we will load the tokenizer from here."
169 | )
170 | parser.add_argument(
171 | "--openai_engine",
172 | type=str,
173 | default=None,
174 | help="If specified, we will use the OpenAI API to generate the predictions.",
175 | )
176 | parser.add_argument("--data_dir", type=str, help="Directory to save the data.")
177 | parser.add_argument("--cache_dir", type=str, help="Directory to save the cache.")
178 | parser.add_argument("--save_dir", type=str, required=True, help="Directory to save the results.")
179 | parser.add_argument("--max_new_tokens", type=int, required=True, help="Maximum number of tokens to generate.")
180 | parser.add_argument("--eval_batch_size", type=int, default=1, help="Batch size for evaluation.")
181 | parser.add_argument("--unbiased_sampling_size_n", type=int, default=1, help="Number of unbiased samples.")
182 | parser.add_argument("--few_shot", action="store_true", help="If given, we're evaluating 3-shot performance.")
183 | parser.add_argument("--greedy_decoding", action="store_true")
184 | parser.add_argument("--top_p", type=float)
185 | parser.add_argument(
186 | "--eval_pass_at_ks", nargs="+", type=int, default=[1], help="Multiple k's that we will report pass@k."
187 | )
188 | parser.add_argument(
189 | "--temperature",
190 | type=float,
191 | default=0.1,
192 | help="Temperature for sampling. This is should be low for evaluating smaller pass@k, and high for larger pass@k.",
193 | )
194 | parser.add_argument(
195 | "--load_in_8bit",
196 | action="store_true",
197 | help="Load mdodel in 8bit mode, which will reduce memory and speed up inference.",
198 | )
199 | parser.add_argument("--gptq", action="store_true", help="If given, we're evaluating a 4-bit quantized GPTQ model.")
200 | parser.add_argument("--chat_format", type=str)
201 | args = parser.parse_args()
202 | # model_name_or_path and openai_engine cannot be both None or both not None.
203 | assert (args.model_name_or_path is None) != (
204 | args.openai_engine is None
205 | ), "Either model_name_or_path or openai_engine should be specified."
206 | main(args)
207 |
--------------------------------------------------------------------------------
/xchat/eval/gsm/run_eval.py:
--------------------------------------------------------------------------------
1 | # Copied and modified from https://github.com/allenai/open-instruct
2 |
3 | import argparse
4 | import json
5 | import os
6 | import random
7 | import re
8 |
9 | import evaluate
10 |
11 | from xchat.eval.gsm.examplars import EXAMPLARS as GSM_EXAMPLARS
12 | from xchat.eval.utils import generate_completions, load_hf_lm_and_tokenizer, query_openai_chat_model
13 |
14 | exact_match = evaluate.load("exact_match")
15 |
16 |
17 | def main(args):
18 | random.seed(42)
19 |
20 | print("Loading data...")
21 | test_data = []
22 | with open(os.path.join(args.data_dir, "test.jsonl")) as fin:
23 | for line in fin:
24 | example = json.loads(line)
25 | test_data.append({"question": example["question"], "answer": example["answer"].split("####")[1].strip()})
26 |
27 | # some numbers are in the `x,xxx` format, and we want to remove the comma
28 | for example in test_data:
29 | example["answer"] = re.sub(r"(\d),(\d)", r"\1\2", example["answer"])
30 | assert float(example["answer"]), f"answer is not a valid number: {example['answer']}"
31 |
32 | if args.max_num_examples and len(test_data) > args.max_num_examples:
33 | test_data = random.sample(test_data, args.max_num_examples)
34 |
35 | if not os.path.exists(args.save_dir):
36 | os.makedirs(args.save_dir, exist_ok=True)
37 |
38 | global GSM_EXAMPLARS
39 | if args.n_shot:
40 | if len(GSM_EXAMPLARS) > args.n_shot:
41 | GSM_EXAMPLARS = random.sample(GSM_EXAMPLARS, args.n_shot)
42 | demonstrations = []
43 | for example in GSM_EXAMPLARS:
44 | if args.no_cot:
45 | demonstrations.append("Question: " + example["question"] + "\n" + "Answer: " + example["short_answer"])
46 | else:
47 | demonstrations.append("Question: " + example["question"] + "\n" + "Answer: " + example["cot_answer"])
48 | prompt_prefix = "Answer the following questions.\n\n" + "\n\n".join(demonstrations) + "\n\n"
49 | else:
50 | prompt_prefix = "Answer the following question.\n\n"
51 |
52 | prompts = []
53 | for example in test_data:
54 | if args.chat_format == "tulu":
55 | prompt = (
56 | "<|user|>\n"
57 | + prompt_prefix
58 | + "Question: "
59 | + example["question"].strip()
60 | + "\n<|assistant|>\n"
61 | + "Answer:"
62 | )
63 | elif args.chat_format == "lemur":
64 | prompt = (
65 | "<|im_start|>user\n"
66 | + prompt_prefix
67 | + "Question: "
68 | + example["question"].strip()
69 | + "<|im_end|>\n<|im_start|>assistant\n"
70 | + "Answer:"
71 | )
72 | elif args.chat_format == "wizardcoder":
73 | prompt = (
74 | "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n"
75 | "### Instruction:\n"
76 | f"{prompt_prefix}"
77 | "Question: "
78 | f"{example['question'].strip()}"
79 | "\n### Response:\n"
80 | "Answer:"
81 | )
82 | elif args.chat_format == "vicuna":
83 | prompt = (
84 | "A chat between a curious user and an artificial intelligence assistant. "
85 | "The assistant gives helpful, detailed, and polite answers to the user's questions. "
86 | "USER: " + prompt_prefix + "Question: " + example["question"].strip() + "\nASSISTANT: " + "Answer:"
87 | )
88 | elif args.chat_format == "codellama-instruct":
89 | prompt = "[INST] " + prompt_prefix + "Question: " + example["question"].strip() + "[/INST]" + "Answer:"
90 | elif args.chat_format == "llama-2-chat":
91 | system_message = """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your \
92 | answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure\
93 | that your responses are socially unbiased and positive in nature.
94 |
95 | If a question does not make any sense, or is not factually coherent, explain why instead of answering something not \
96 | correct. If you don't know the answer to a question, please don't share false information."""
97 | prompt = (
98 | "[INST] <>\n"
99 | f"{system_message}"
100 | "\n<>\n\n"
101 | f"{prompt_prefix}"
102 | "Question: "
103 | f"{example['question'].strip()}"
104 | "[/INST]"
105 | "Answer:"
106 | )
107 | else:
108 | prompt = prompt_prefix + "Question: " + example["question"].strip() + "\nAnswer:"
109 | prompts.append(prompt)
110 |
111 | if args.model_name_or_path:
112 | print("Loading model and tokenizer...")
113 | model, tokenizer = load_hf_lm_and_tokenizer(
114 | model_name_or_path=args.model_name_or_path,
115 | tokenizer_name_or_path=args.tokenizer_name_or_path,
116 | load_in_8bit=args.load_in_8bit,
117 | load_in_half=True,
118 | gptq_model=args.gptq,
119 | )
120 |
121 | if args.chat_format == "tulu":
122 | stop_id_sequences = [[tokenizer.encode("<|assistant|>", add_special_tokens=False)[-1]]]
123 | elif args.chat_format == "lemur":
124 | stop_id_sequences = [
125 | [tokenizer.encode("<|im_end|>", add_special_tokens=False)[-1]],
126 | [tokenizer.eos_token_id],
127 | ]
128 | elif args.chat_format == "vicuna" or args.chat_format == "llama-2-chat":
129 | stop_id_sequences = [[tokenizer.encode("", add_special_tokens=False)[-1]]]
130 | elif args.chat_format == "codellama-instruct":
131 | stop_id_sequences = [
132 | [tokenizer.encode("[INST]", add_special_tokens=False)[-1]],
133 | [tokenizer.encode("", add_special_tokens=False)[-1]],
134 | ]
135 | else:
136 | # get the last token because the tokenizer may add space tokens at the start.
137 | stop_id_sequences = [[tokenizer.encode("\n", add_special_tokens=False)[-1]]]
138 | outputs = generate_completions(
139 | model=model,
140 | tokenizer=tokenizer,
141 | prompts=prompts,
142 | max_new_tokens=512,
143 | batch_size=args.eval_batch_size,
144 | stop_id_sequences=stop_id_sequences,
145 | )
146 | else:
147 | instances = [{"id": prompt, "prompt": prompt} for _, prompt in enumerate(prompts)]
148 | results = query_openai_chat_model(
149 | engine=args.openai_engine,
150 | instances=instances,
151 | batch_size=args.eval_batch_size if args.eval_batch_size else 10,
152 | output_path=os.path.join(args.save_dir, "openai_results.jsonl"),
153 | )
154 | outputs = [result["output"] for result in results]
155 |
156 | predictions = []
157 | for output in outputs:
158 | # replace numbers like `x,xxx` with `xxxx`
159 | output = re.sub(r"(\d),(\d)", r"\1\2", output)
160 | numbers = re.findall(r"[-+]?\d*\.\d+|\d+", output)
161 | if numbers:
162 | predictions.append(numbers[-1])
163 | else:
164 | predictions.append(output)
165 |
166 | raw_predictions = [
167 | {"question": example["question"], "answer": example["answer"], "model_output": output, "prediction": pred}
168 | for example, output, pred in zip(test_data, outputs, predictions)
169 | ]
170 |
171 | with open(os.path.join(args.save_dir, "raw_predictions.jsonl"), "w") as fout:
172 | for raw_prediction in raw_predictions:
173 | fout.write(json.dumps(raw_prediction) + "\n")
174 |
175 | print("Calculating accuracy...")
176 | targets = [example["answer"] for example in test_data]
177 |
178 | em_score = exact_match.compute(
179 | predictions=predictions, references=targets, ignore_case=True, ignore_punctuation=True
180 | )["exact_match"]
181 | print(f"Exact match : {em_score}")
182 |
183 | with open(os.path.join(args.save_dir, "metrics.json"), "w") as fout:
184 | json.dump({"exact_match": em_score}, fout, indent=4)
185 |
186 |
187 | if __name__ == "__main__":
188 | parser = argparse.ArgumentParser()
189 | parser.add_argument("--data_dir", type=str, default="data/mgsm")
190 | parser.add_argument("--max_num_examples", type=int, default=None, help="maximum number of examples to evaluate.")
191 | parser.add_argument("--save_dir", type=str, default="results/mgsm")
192 | parser.add_argument(
193 | "--model_name_or_path",
194 | type=str,
195 | default=None,
196 | help="if specified, we will load the model to generate the predictions.",
197 | )
198 | parser.add_argument(
199 | "--tokenizer_name_or_path", type=str, default=None, help="if specified, we will load the tokenizer from here."
200 | )
201 | parser.add_argument(
202 | "--openai_engine",
203 | type=str,
204 | default=None,
205 | help="if specified, we will use the OpenAI API to generate the predictions.",
206 | )
207 | parser.add_argument("--n_shot", type=int, default=8, help="max number of examples to use for demonstration.")
208 | parser.add_argument(
209 | "--no_cot", action="store_true", help="If given, we're evaluating a model without chain-of-thought."
210 | )
211 | parser.add_argument("--eval_batch_size", type=int, default=1, help="batch size for evaluation.")
212 | parser.add_argument(
213 | "--load_in_8bit",
214 | action="store_true",
215 | help="load model in 8bit mode, which will reduce memory and speed up inference.",
216 | )
217 | parser.add_argument("--gptq", action="store_true", help="If given, we're evaluating a 4-bit quantized GPTQ model.")
218 | parser.add_argument(
219 | "--chat_format",
220 | type=str,
221 | default=None,
222 | choices=["tulu", "lemur", "wizardcoder", "vicuna", "llama-2-chat", "codellama-instruct"],
223 | help="if given, we will use the chat format to generate the predictions.",
224 | )
225 | args = parser.parse_args()
226 |
227 | # model_name_or_path and openai_engine cannot be both None or both not None.
228 | assert (args.model_name_or_path is None) != (
229 | args.openai_engine is None
230 | ), "Either model_name_or_path or openai_engine should be specified."
231 | main(args)
232 |
--------------------------------------------------------------------------------
/xchat/eval/mmlu/run_eval.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import os
4 | import time
5 |
6 | import numpy as np
7 | import pandas as pd
8 | import torch
9 | from tqdm import tqdm
10 |
11 | from xchat.eval.mmlu.categories import categories, subcategories
12 | from xchat.eval.utils import get_next_word_predictions, load_hf_lm_and_tokenizer, query_openai_chat_model
13 |
14 | choices = ["A", "B", "C", "D"]
15 |
16 |
17 | def format_subject(subject):
18 | subject_list = subject.split("_")
19 | s = ""
20 | for entry in subject_list:
21 | s += " " + entry
22 | return s
23 |
24 |
25 | def format_example(df, idx, include_answer=True):
26 | prompt = df.iloc[idx, 0]
27 | k = df.shape[1] - 2
28 | for j in range(k):
29 | prompt += f"\n{choices[j]}. {df.iloc[idx, j + 1]}"
30 | prompt += "\nAnswer:"
31 | if include_answer:
32 | prompt += f" {df.iloc[idx, k + 1]}\n\n"
33 | return prompt
34 |
35 |
36 | def gen_prompt(train_df, subject, k=-1):
37 | prompt = f"The following are multiple choice questions (with answers) about {format_subject(subject)}.\n\n"
38 | if k == -1:
39 | k = train_df.shape[0]
40 | for i in range(k):
41 | prompt += format_example(train_df, i)
42 | return prompt
43 |
44 |
45 | @torch.no_grad()
46 | def eval_hf_model(args, subject, model, tokenizer, dev_df, test_df, batch_size=1):
47 | prompts = []
48 | for i in range(0, test_df.shape[0]):
49 | k = args.ntrain
50 | prompt_end = format_example(test_df, i, include_answer=False)
51 | train_prompt = gen_prompt(dev_df, subject, k)
52 | prompt = train_prompt + prompt_end
53 |
54 | tokenized_prompt = tokenizer(prompt, return_tensors="pt", add_special_tokens=False).input_ids
55 | # make sure every prompt is less than 2048 tokens
56 | while tokenized_prompt.shape[-1] > 2048:
57 | k -= 1
58 | train_prompt = gen_prompt(dev_df, subject, k)
59 | prompt = train_prompt + prompt_end
60 | tokenized_prompt = tokenizer(prompt, return_tensors="pt", add_special_tokens=False).input_ids
61 |
62 | if args.use_chat_format:
63 | prompt = "<|user|>\n" + prompt.strip() + "\n<|assistant|>\nThe answer is:"
64 |
65 | prompts.append(prompt)
66 |
67 | # get the answer for all examples
68 | # note: here we cannot directly use convert_tokens_to_ids because the some tokenizers will automatically add space prefix.
69 | answer_choice_ids = [tokenizer.encode(answer_choice, add_special_tokens=False)[0] for answer_choice in choices]
70 | pred_indices, all_probs = get_next_word_predictions(
71 | model,
72 | tokenizer,
73 | prompts,
74 | candidate_token_ids=answer_choice_ids,
75 | return_token_predictions=False,
76 | batch_size=batch_size,
77 | )
78 |
79 | # get the metrics
80 | cors = []
81 | groud_truths = test_df.iloc[:, -1].values
82 | for i in range(len(pred_indices)):
83 | prediction = choices[pred_indices[i]]
84 | ground_truth = groud_truths[i]
85 | cors.append(prediction == ground_truth)
86 |
87 | acc = np.mean(cors)
88 | cors = np.array(cors)
89 |
90 | all_probs = np.array(all_probs)
91 | print(f"Average accuracy {acc:.3f} - {subject}")
92 | return cors, acc, all_probs
93 |
94 |
95 | def eval_openai_chat_engine(args, subject, engine, dev_df, test_df, batch_size=1):
96 | import tiktoken
97 |
98 | gpt_tokenizer = tiktoken.get_encoding("cl100k_base")
99 | answer_choice_ids = [
100 | gpt_tokenizer.encode(" " + x)[0] for x in choices
101 | ] # be careful, the tokenizer will tokenize " A" and "A" differently.
102 |
103 | prompts = []
104 | for i in range(0, test_df.shape[0]):
105 | k = args.ntrain
106 | prompt_end = format_example(test_df, i, include_answer=False)
107 | train_prompt = gen_prompt(dev_df, subject, k)
108 | prompt = train_prompt + prompt_end
109 | prompts.append(prompt)
110 |
111 | instances = [{"id": prompt, "prompt": prompt} for _, prompt in enumerate(prompts)]
112 | results = query_openai_chat_model(
113 | engine=args.openai_engine,
114 | instances=instances,
115 | batch_size=args.eval_batch_size if args.eval_batch_size else 10,
116 | output_path=os.path.join(args.save_dir, f"{subject}_openai_results.jsonl"),
117 | logit_bias={token_id: 100 for token_id in answer_choice_ids},
118 | max_tokens=1,
119 | )
120 |
121 | # get the metrics
122 | cors = []
123 | groud_truths = test_df.iloc[:, -1].values
124 | for i in range(len(test_df)):
125 | prediction = results[i]["output"].strip()
126 | ground_truth = groud_truths[i]
127 | cors.append(prediction == ground_truth)
128 |
129 | acc = np.mean(cors)
130 | cors = np.array(cors)
131 |
132 | all_probs = np.array(
133 | [[0.25, 0.25, 0.25, 0.25] for _ in range(len(test_df))]
134 | ) # dummy probs, just don't want to dig into the openai probs
135 |
136 | print(f"Average accuracy {acc:.3f} - {subject}")
137 | return cors, acc, all_probs
138 |
139 |
140 | def main(args):
141 | if args.model_name_or_path:
142 | print("Loading model and tokenizer...")
143 | model, tokenizer = load_hf_lm_and_tokenizer(
144 | model_name_or_path=args.model_name_or_path,
145 | tokenizer_name_or_path=args.tokenizer_name_or_path,
146 | load_in_8bit=args.load_in_8bit,
147 | load_in_half=True,
148 | gptq_model=args.gptq,
149 | )
150 |
151 | subjects = sorted(
152 | [f.split("_test.csv")[0] for f in os.listdir(os.path.join(args.data_dir, "test")) if "_test.csv" in f]
153 | )
154 |
155 | if args.subjects:
156 | assert all(
157 | subj in subjects for subj in args.subjects
158 | ), f"Some of the subjects you specified are not valid: {args.subjects}"
159 | subjects = args.subjects
160 |
161 | if not os.path.exists(args.save_dir):
162 | os.makedirs(args.save_dir)
163 | if not os.path.exists(os.path.join(args.save_dir)):
164 | os.makedirs(os.path.join(args.save_dir))
165 |
166 | all_cors = []
167 | subcat_cors = {subcat: [] for subcat_lists in subcategories.values() for subcat in subcat_lists}
168 | cat_cors = {cat: [] for cat in categories}
169 |
170 | for subject in tqdm(subjects, desc="Evaluating subjects: "):
171 | dev_df = pd.read_csv(os.path.join(args.data_dir, "dev", subject + "_dev.csv"), header=None)[: args.ntrain]
172 | test_df = pd.read_csv(os.path.join(args.data_dir, "test", subject + "_test.csv"), header=None)
173 | if args.n_instances and args.n_instances < test_df.shape[0]:
174 | test_df = test_df.sample(args.n_instances, random_state=42)
175 |
176 | if args.model_name_or_path:
177 | cors, acc, probs = eval_hf_model(args, subject, model, tokenizer, dev_df, test_df, args.eval_batch_size)
178 | else:
179 | cors, acc, probs = eval_openai_chat_engine(
180 | args, subject, args.openai_engine, dev_df, test_df, args.eval_batch_size
181 | )
182 |
183 | subcats = subcategories[subject]
184 | for subcat in subcats:
185 | subcat_cors[subcat].append(cors)
186 | for key in categories:
187 | if subcat in categories[key]:
188 | cat_cors[key].append(cors)
189 | all_cors.append(cors)
190 |
191 | test_df["correct"] = cors
192 | for j in range(probs.shape[1]):
193 | choice = choices[j]
194 | test_df[f"choice{choice}_probs"] = probs[:, j]
195 | test_df.to_csv(
196 | os.path.join(args.save_dir, f"{subject}.csv"),
197 | index=None,
198 | )
199 |
200 | for subcat in subcat_cors:
201 | subcat_acc = np.mean(np.concatenate(subcat_cors[subcat]))
202 | print(f"Average accuracy {subcat_acc:.3f} - {subcat}")
203 |
204 | for cat in cat_cors:
205 | cat_acc = np.mean(np.concatenate(cat_cors[cat]))
206 | print(f"Average accuracy {cat_acc:.3f} - {cat}")
207 | weighted_acc = np.mean(np.concatenate(all_cors))
208 | print(f"Average accuracy: {weighted_acc:.3f}")
209 |
210 | # save results
211 | with open(os.path.join(args.save_dir, "metrics.json"), "w") as f:
212 | json.dump(
213 | {
214 | "average_acc": weighted_acc,
215 | "subcat_acc": {subcat: np.mean(np.concatenate(subcat_cors[subcat])) for subcat in subcat_cors},
216 | "cat_acc": {cat: np.mean(np.concatenate(cat_cors[cat])) for cat in cat_cors},
217 | },
218 | f,
219 | )
220 |
221 |
222 | if __name__ == "__main__":
223 | parser = argparse.ArgumentParser()
224 | parser.add_argument("--ntrain", type=int, default=5)
225 | parser.add_argument("--data_dir", type=str, default="data/mmlu")
226 | parser.add_argument("--save_dir", type=str, default="results/mmlu/llama-7B/")
227 | parser.add_argument(
228 | "--model_name_or_path",
229 | type=str,
230 | default=None,
231 | help="if specified, we will load the model to generate the predictions.",
232 | )
233 | parser.add_argument(
234 | "--tokenizer_name_or_path", type=str, default=None, help="if specified, we will load the tokenizer from here."
235 | )
236 | parser.add_argument(
237 | "--openai_engine",
238 | type=str,
239 | default=None,
240 | help="if specified, we will use the OpenAI API to generate the predictions.",
241 | )
242 | parser.add_argument(
243 | "--subjects",
244 | nargs="*",
245 | help="which subjects to evaluate. If not specified, all the 57 subjects will be evaluated.",
246 | )
247 | parser.add_argument(
248 | "--n_instances",
249 | type=int,
250 | help="if specified, a maximum of n_instances per subject will be used for the evaluation.",
251 | )
252 | parser.add_argument("--eval_batch_size", type=int, default=1, help="batch size for evaluation.")
253 | parser.add_argument(
254 | "--load_in_8bit",
255 | action="store_true",
256 | help="load model in 8bit mode, which will reduce memory and speed up inference.",
257 | )
258 | parser.add_argument("--gptq", action="store_true", help="If given, we're evaluating a 4-bit quantized GPTQ model.")
259 | parser.add_argument(
260 | "--use_chat_format",
261 | action="store_true",
262 | help="If given, the prompt will be encoded as a chat format with the roles in prompt.",
263 | )
264 | args = parser.parse_args()
265 |
266 | # model_name_or_path and openai_engine cannot be both None or both not None.
267 | assert (args.model_name_or_path is None) != (
268 | args.openai_engine is None
269 | ), "Either model_name_or_path or openai_engine should be specified."
270 | main(args)
271 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # [Lemur: Open Foundation Models for Language Agents](https://arxiv.org/abs/2310.06830)
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 | Lemur is an openly accessible language model optimized for both natural language and coding capabilities to serve as the backbone of versatile language agents.
32 | As language models continue to evolve from conversational chatbots to functional agents that can act in the real world, they need both strong language understanding and the ability to execute actions. Lemur balances natural language and coding skills to enable agents to follow instructions, reason for tasks, and take grounded actions.
33 |
34 |
35 |

36 |
37 |
38 | Please refer to our paper and code for more details:
39 | - [[Paper](https://arxiv.org/abs/2310.06830)] Lemur: Harmonizing Natural Language and Code for Language Agents
40 | - [[Blog](https://www.xlang.ai/blog/openlemur)] Introducing Lemur: Open Foundation Models for Language Agents
41 |
42 |
43 | ## 🔥 News
44 | * **[18 October, 2023]:** 🎉 We open-sourced [code for OpenAgents](https://github.com/xlang-ai/OpenAgents): An Open Platform for Language Agents in the Wild.
45 | * **[11 October, 2023]:** 🎉 We released the research paper and codebase. We will continue updating this repository.
46 | * **[23 August, 2023]:** 🎉 We released the weights of [`OpenLemur/lemur-70b-v1`](https://huggingface.co/OpenLemur/lemur-70b-v1), and [`OpenLemur/lemur-70b-chat-v1`](https://huggingface.co/OpenLemur/lemur-70b-chat-v1)! Check it out in [HuggingFace Hub](https://huggingface.co/OpenLemur).
47 |
48 | ## Models
49 | We released our models on the HuggingFace Hub:
50 | * [OpenLemur/lemur-70b-v1](https://huggingface.co/OpenLemur/lemur-70b-v1)
51 | * [OpenLemur/lemur-70b-chat-v1](https://huggingface.co/OpenLemur/lemur-70b-chat-v1)
52 |
53 | ## Table of Contents
54 | - [Lemur: Open Foundation Models for Language Agents](#lemur-open-foundation-models-for-language-agents)
55 | - [🔥 News](#-news)
56 | - [Models](#models)
57 | - [Table of Contents](#table-of-contents)
58 | - [Why Lemur?](#why-lemur)
59 | - [Quickstart](#quickstart)
60 | - [Setup](#setup)
61 | - [Lemur-70B](#lemur-70b)
62 | - [Lemur-70B-Chat](#lemur-70b-chat)
63 | - [Training](#training)
64 | - [Evaluation](#evaluation)
65 | - [Foundational Abilities](#foundational-abilities)
66 | - [Interactive Agent Skills](#interactive-agent-skills)
67 | - [Deploy](#deploy)
68 | - [MINT](#mint)
69 | - [WebArena](#webarena)
70 | - [InterCode](#intercode)
71 | - [Citation](#citation)
72 | - [Acknowledgements](#acknowledgements)
73 |
74 |
75 |
76 | ## Why Lemur?
77 | Most existing open-source models specialize in either natural language or code. Lemur combines both strengths by:
78 |
79 | - Pretraining Llama-2-70B on a 90B token corpus with the 10:1 ratio of code to text and obtaining Lemur-70B-v1
80 | - Instruction tuning Lemur-70B-v1 on 300K examples covering both text and code and obtaining Lemur-70B-Chat-v1
81 |
82 | This two-stage training produces state-of-the-art performance averaged across diverse language and coding benchmarks, surpassing other available open-source models and narrowing the gap between open-source and commercial models on agent abilities.
83 |
84 | ## Quickstart
85 |
86 | ### Setup
87 | First, we have to install all the libraries listed in `requirements.txt`
88 |
89 | ```bash
90 | conda create -n xchat python=3.10
91 | conda activate xchat
92 | conda install pytorch==2.0.1 pytorch-cuda=11.8 -c pytorch -c nvidia
93 | conda install -c "nvidia/label/cuda-11.8.0" cuda-nvcc
94 | ```
95 | Then, install the xchat package:
96 | ```bash
97 | git clone git@github.com:OpenLemur/Lemur.git
98 | cd Lemur
99 | pip install -e .
100 | ```
101 |
102 | ### Lemur-70B
103 | For the base model lemur-70b-v1, you can use it in this way:
104 |
105 |
106 | Click me
107 |
108 |
109 | ```python
110 | from transformers import AutoTokenizer, AutoModelForCausalLM
111 |
112 | tokenizer = AutoTokenizer.from_pretrained("OpenLemur/lemur-70b-v1")
113 | model = AutoModelForCausalLM.from_pretrained("OpenLemur/lemur-70b-v1", device_map="auto", load_in_8bit=True)
114 |
115 | # Text Generation Example
116 | prompt = "The world is "
117 | input = tokenizer(prompt, return_tensors="pt")
118 | output = model.generate(**input, max_length=50, num_return_sequences=1)
119 | generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
120 | print(generated_text)
121 |
122 | # Code Generation Example
123 | prompt = """
124 | def factorial(n):
125 | if n == 0:
126 | return 1
127 | """
128 | input = tokenizer(prompt, return_tensors="pt")
129 | output = model.generate(**input, max_length=200, num_return_sequences=1)
130 | generated_code = tokenizer.decode(output[0], skip_special_tokens=True)
131 | print(generated_code)
132 | ```
133 |
134 |
135 |
136 |
137 |
138 |
139 | ### Lemur-70B-Chat
140 | We instruction-finetune lemur-70b-v1 model with ChatML format to obtain lemur-70b-chat-v1. You can use lemur-70b-chat-v1 in this way:
141 |
142 |
143 | Click me
144 |
145 |
146 | ```python
147 | from transformers import AutoTokenizer, AutoModelForCausalLM
148 |
149 | tokenizer = AutoTokenizer.from_pretrained("OpenLemur/lemur-70b-chat-v1")
150 | model = AutoModelForCausalLM.from_pretrained("OpenLemur/lemur-70b-chat-v1", device_map="auto", load_in_8bit=True)
151 |
152 | # Text Generation Example
153 | prompt = """<|im_start|>system
154 | You are a helpful, respectful, and honest assistant.
155 | <|im_end|>
156 | <|im_start|>user
157 | What's a lemur's favorite fruit?<|im_end|>
158 | <|im_start|>assistant
159 | """
160 | input = tokenizer(prompt, return_tensors="pt")
161 | output = model.generate(**input, max_length=50, num_return_sequences=1)
162 | generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
163 | print(generated_text)
164 |
165 | # Code Generation Example
166 | prompt = """<|im_start|>system
167 | Below is an instruction that describes a task. Write a response that appropriately completes the request.
168 | <|im_end|>
169 | <|im_start|>user
170 | Write a Python function to merge two sorted lists into one sorted list without using any built-in sort functions.<|im_end|>
171 | <|im_start|>assistant
172 | """
173 | input = tokenizer(prompt, return_tensors="pt")
174 | output = model.generate(**input, max_length=200, num_return_sequences=1)
175 | generated_code = tokenizer.decode(output[0], skip_special_tokens=True)
176 | print(generated_code)
177 | ```
178 |
179 |
180 |
181 |
182 |
183 | ## Training
184 |
185 |
186 |

187 |
188 |
189 |
190 | ## Evaluation
191 | We evaluated Lemur across:
192 | - 8 language and code datasets like MMLU, BBH, GSM8K, HumanEval, and Spider to validate balanced capabilities
193 | - 13 interactive agent datasets to test skills like tool usage, adapting to feedback from environments or humans, and exploring partially observable digital or physical environments.
194 |
195 |
196 |

197 |
198 |
199 |
200 |

201 |
202 |
203 | ### Foundational Abilities
204 | We build the evaluation suite based on [open-instruct](https://github.com/allenai/open-instruct). We will keep updating more tasks and models.
205 |
206 | Currently, we support the following tasks:
207 | - [✅] [MMLU](./scripts/eval/mmlu.sh)
208 | - [✅] [BBH](./scripts/eval/bbh.sh)
209 | - [✅] [GSM8K](./scripts/eval/gsm8k.sh)
210 | - [✅] [HumanEval](./scripts/eval/human_eval.sh)
211 | - [✅] [MBPP](./scripts/eval/mbpp.sh)
212 | - [🚧] [Spider]()
213 | - [🚧] [MultiPL-E]()
214 | - [🚧] [DS-1000]()
215 | - [🚧] ...
216 |
217 | ### Interactive Agent Skills
218 | We use the evaluation frameworks provided by [MINT](https://github.com/xingyaoww/mint-bench), [InterCode](https://github.com/princeton-nlp/intercode), and [WebArena](https://github.com/web-arena-x/webarena) to evaluate interactive agent skills.
219 |
220 | #### Deploy
221 | We use vLLM to serve the Lemur model. However, the official FastChat codebase does not yet support Lemur-Chat. Therefore, we provide a docker to serve vLLM for Lemur. Please refer to [vllm_lemur.sh](./scripts/deploy/vllm_lemur.sh) for more detailed information.
222 |
223 | ```bash
224 | bash scripts/deploy/vllm_lemur.sh
225 | ```
226 |
227 | #### MINT
228 | We [fork MINT](https://github.com/OpenLemur/mint-bench) codebase to share the configs we used. Please refer to [this config folder](https://github.com/OpenLemur/mint-bench/tree/main/configs) for more details. Please run vllm with [`vllm_lemur.sh`](./scripts/deploy/vllm_lemur.sh) script.
229 |
230 | #### WebArena
231 | We [fork WebArena](https://github.com/OpenLemur/webarena) codebase to enable vLLM evaluation. To run the evaluation on WebArena, please refer to our [forked WebArena codebase](https://github.com/OpenLemur/webarena).
232 |
233 | #### InterCode
234 | We [fork InterCode](https://github.com/OpenLemur/intercode) codebase and do modifications to enable Lemur evaluation. Please refer to [this script folder](https://github.com/OpenLemur/intercode/tree/master/scripts) for more details. Please run `text-generation-inference` with [`tgi_lemur.sh`](scripts/deploy/tgi_lemur.sh) script.
235 |
236 |
237 | ## Citation
238 | If you find our work helpful, please cite us:
239 | ```
240 | @misc{xu2023lemur,
241 | title={Lemur: Harmonizing Natural Language and Code for Language Agents},
242 | author={Yiheng Xu and Hongjin Su and Chen Xing and Boyu Mi and Qian Liu and Weijia Shi and Binyuan Hui and Fan Zhou and Yitao Liu and Tianbao Xie and Zhoujun Cheng and Siheng Zhao and Lingpeng Kong and Bailin Wang and Caiming Xiong and Tao Yu},
243 | year={2023},
244 | eprint={2310.06830},
245 | archivePrefix={arXiv},
246 | primaryClass={cs.CL}
247 | }
248 | ```
249 |
250 | ## Acknowledgements
251 |
252 | The Lemur project is an open collaborative research effort between [XLang Lab](https://www.xlang.ai/) and [Salesforce Research](https://www.salesforceairesearch.com/). We thank the following institutions for their gift support:
253 |
254 |
276 |
--------------------------------------------------------------------------------
/xchat/eval/bbh/run_eval.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import glob
3 | import json
4 | import os
5 | import random
6 | import re
7 |
8 | import evaluate
9 | import torch
10 | import tqdm
11 |
12 | from xchat.eval.utils import generate_completions, load_hf_lm_and_tokenizer, query_openai_chat_model
13 |
14 | exact_match = evaluate.load("exact_match")
15 |
16 |
17 | @torch.no_grad()
18 | def eval_hf_model(args, model, tokenizer, examples, task_prompt, save_path=None):
19 | targets = [example["target"] for example in examples]
20 | if save_path:
21 | fout = open(save_path, "w")
22 |
23 | prompts = []
24 | for example in examples:
25 | if args.chat_format == "tulu":
26 | prompt = "<|user|>\n" + task_prompt.strip() + "\n\nQ: " + example["input"] + "\n<|assistant|>\nA:"
27 | elif args.chat_format == "lemur":
28 | prompt = (
29 | "<|im_start|>user\n"
30 | + task_prompt.strip()
31 | + "\n\nQ: "
32 | + example["input"]
33 | + "<|im_end|>\n<|im_start|>assistant\nA:"
34 | )
35 | elif args.chat_format == "codellama-instruct":
36 | prompt = "[INST]" + task_prompt.strip() + "\n\nQ:" + example["input"] + "[/INST]" + "\nA:"
37 | else:
38 | prompt = task_prompt.strip() + "\n\nQ: " + example["input"] + "\nA: "
39 | prompts.append(prompt)
40 |
41 | new_line_sequence = tokenizer.encode("\n", add_special_tokens=False)[-1]
42 | q_sequence = tokenizer.encode("Q: ", add_special_tokens=False)
43 | if args.chat_format == "tulu":
44 | stop_id_sequences = [
45 | [tokenizer.encode("<|assistant|>", add_special_tokens=False)[-1]],
46 | [new_line_sequence],
47 | [q_sequence],
48 | ]
49 | elif args.chat_format == "lemur":
50 | stop_id_sequences = [
51 | [tokenizer.encode("<|im_end|>", add_special_tokens=False)[-1]],
52 | [tokenizer.eos_token_id],
53 | [new_line_sequence],
54 | [q_sequence],
55 | ]
56 | elif args.chat_format == "codellama-instruct":
57 | stop_id_sequences = [
58 | [tokenizer.encode("[INST]", add_special_tokens=False)[-1]],
59 | [tokenizer.encode("", add_special_tokens=False)[-1]],
60 | [new_line_sequence],
61 | [q_sequence],
62 | ]
63 | else:
64 | stop_id_sequences = [
65 | [new_line_sequence],
66 | ]
67 |
68 | outputs = generate_completions(
69 | model=model,
70 | tokenizer=tokenizer,
71 | prompts=prompts,
72 | max_new_tokens=512,
73 | batch_size=args.eval_batch_size if args.eval_batch_size else 1,
74 | stop_id_sequences=stop_id_sequences,
75 | )
76 |
77 | predictions = []
78 | for example, output in zip(examples, outputs):
79 | example["raw_output"] = output
80 |
81 | # only keep the first part of the output - this is mainly for vanilla language models.
82 | output = output.strip().split("\n\n")[0].strip()
83 |
84 | # extract the first answer after `So the answer is` and before the next period.
85 | # if there is no such answer, we will just use the raw output.
86 | results = re.search(r"(.*?)\.", output)
87 |
88 | prediction = results.group(1).strip() if results else output.strip()
89 |
90 | if "(" in prediction and ")" in prediction and "[" not in prediction and "]" not in prediction:
91 | pattern = r"\(\w\)"
92 | match = re.search(pattern, prediction)
93 | prediction = f"{match.group(0)}" if match else prediction
94 |
95 | example["prediction"] = prediction
96 | predictions.append(prediction)
97 | if save_path:
98 | fout.write(json.dumps(example) + "\n")
99 |
100 | assert len(predictions) == len(targets), "number of predictions and targets are not the same."
101 | return exact_match.compute(predictions=predictions, references=targets, ignore_case=True, ignore_punctuation=True)[
102 | "exact_match"
103 | ]
104 |
105 |
106 | def eval_openai_chat_engine(args, examples, task_prompt, save_path=None):
107 | targets = [example["target"] for example in examples]
108 | instances = []
109 | for i, example in enumerate(examples):
110 | prompt = task_prompt.strip() + "\n\nQ: " + example["input"] + "\nA:"
111 | instances.append(
112 | {
113 | "id": example["id"] if "id" in example else i,
114 | "prompt": prompt,
115 | }
116 | )
117 |
118 | if save_path:
119 | openai_result_save_path = os.path.join(
120 | os.path.dirname(save_path), os.path.basename(save_path).split(".")[0] + "_openai_results.jsonl"
121 | )
122 |
123 | results = query_openai_chat_model(
124 | engine=args.openai_engine,
125 | instances=instances,
126 | batch_size=args.eval_batch_size if args.eval_batch_size else 10,
127 | output_path=openai_result_save_path if save_path else None,
128 | )
129 |
130 | outputs = [result["output"] for result in results]
131 | assert len(outputs) == len(targets), "number of predictions and targets are not the same."
132 |
133 | if save_path:
134 | fout = open(save_path, "w")
135 |
136 | predictions = []
137 | for example, output in zip(examples, outputs):
138 | example["raw_output"] = output
139 | # extract the first answer after `So the answer is` and before the next period.
140 | # if there is no such answer, we will just use the raw output.
141 | results = re.search(r"So the answer is (.*?)\.", output)
142 | prediction = results.group(1).strip() if results else output.strip()
143 | example["prediction"] = prediction
144 | predictions.append(prediction)
145 | if save_path:
146 | fout.write(json.dumps(example) + "\n")
147 |
148 | assert len(predictions) == len(targets), "number of predictions and targets are not the same."
149 | return exact_match.compute(predictions=predictions, references=targets, ignore_case=True, ignore_punctuation=True)[
150 | "exact_match"
151 | ]
152 |
153 |
154 | def main(args):
155 | random.seed(args.seed)
156 |
157 | all_tasks = {}
158 | task_files = glob.glob(os.path.join(args.data_dir, "bbh", "*.json"))
159 | task_files = sorted(task_files)
160 | if args.task_start_idx is not None and args.task_end_idx is not None:
161 | task_files = task_files[args.task_start_idx : args.task_end_idx]
162 | print(task_files)
163 | for task_file in tqdm.tqdm(task_files, desc="Loading tasks"):
164 | with open(task_file) as f:
165 | task_name = os.path.basename(task_file).split(".")[0]
166 | all_tasks[task_name] = json.load(f)["examples"]
167 | if args.max_num_examples_per_task:
168 | all_tasks[task_name] = random.sample(all_tasks[task_name], args.max_num_examples_per_task)
169 |
170 | all_prompts = {}
171 | cot_prompt_files = glob.glob(os.path.join(args.data_dir, "cot-prompts", "*.txt"))
172 | cot_prompt_files = sorted(cot_prompt_files)
173 | if args.task_start_idx is not None and args.task_end_idx is not None:
174 | cot_prompt_files = cot_prompt_files[args.task_start_idx : args.task_end_idx]
175 | print(cot_prompt_files)
176 | for cot_prompt_file in tqdm.tqdm(cot_prompt_files, desc="Loading prompts"):
177 | with open(cot_prompt_file) as f:
178 | task_name = os.path.basename(cot_prompt_file).split(".")[0]
179 | task_prompt = "".join(f.readlines()[2:])
180 | if args.no_cot:
181 | prompt_fields = task_prompt.split("\n\n")
182 | new_prompt_fields = []
183 | for prompt_field in prompt_fields:
184 | if prompt_field.startswith("Q:"):
185 | assert (
186 | "So the answer is" in prompt_field
187 | ), f"`So the answer is` not found in prompt field of {task_name}.txt."
188 | assert "\nA:" in prompt_field, "`\nA:` not found in prompt field."
189 | answer = prompt_field.split("So the answer is")[-1].strip()
190 | question = prompt_field.split("\nA:")[0].strip()
191 | new_prompt_fields.append(question + "\nA: " + answer)
192 | else:
193 | new_prompt_fields.append(prompt_field)
194 | task_prompt = "\n\n".join(new_prompt_fields)
195 | all_prompts[task_name] = task_prompt
196 |
197 | assert set(all_tasks.keys()) == set(
198 | all_prompts.keys()
199 | ), "task names in task data and task prompts are not the same."
200 |
201 | os.makedirs(args.save_dir, exist_ok=True)
202 | os.makedirs(os.path.join(args.save_dir, "predictions"), exist_ok=True)
203 |
204 | if args.model_name_or_path:
205 | print("Loading model and tokenizer...")
206 | model, tokenizer = load_hf_lm_and_tokenizer(
207 | model_name_or_path=args.model_name_or_path,
208 | tokenizer_name_or_path=args.tokenizer_name_or_path,
209 | load_in_8bit=args.load_in_8bit,
210 | load_in_half=True,
211 | gptq_model=args.gptq,
212 | )
213 |
214 | performance = {}
215 | for task_name in tqdm.tqdm(all_tasks.keys(), desc="Evaluating"):
216 | if args.special_task is not None:
217 | special_tasks = args.special_task.split(",")
218 | if task_name not in special_tasks:
219 | continue
220 | task_examples = all_tasks[task_name]
221 | prompt = all_prompts[task_name]
222 | if args.model_name_or_path:
223 | task_perf = eval_hf_model(
224 | args,
225 | model,
226 | tokenizer,
227 | task_examples,
228 | prompt,
229 | save_path=os.path.join(args.save_dir, "predictions", f"{task_name}.jsonl"),
230 | )
231 | else:
232 | task_perf = eval_openai_chat_engine(
233 | args, task_examples, prompt, save_path=os.path.join(args.save_dir, "predictions", f"{task_name}.jsonl")
234 | )
235 | performance[task_name] = task_perf
236 | print(f"Task {task_name} - EM: {task_perf}")
237 |
238 | with open(os.path.join(args.save_dir, "metrics.json"), "w") as fout:
239 | performance["average_exact_match"] = sum(performance.values()) / len(performance)
240 | print(f"Average EM: {performance['average_exact_match']}")
241 | json.dump(performance, fout, indent=4)
242 |
243 |
244 | if __name__ == "__main__":
245 | parser = argparse.ArgumentParser()
246 | parser.add_argument("--data_dir", type=str, default="data/bbh")
247 | parser.add_argument("--save_dir", type=str, default="results/bbh")
248 | parser.add_argument(
249 | "--model_name_or_path",
250 | type=str,
251 | default=None,
252 | help="if specified, we will load the model to generate the predictions.",
253 | )
254 | parser.add_argument(
255 | "--tokenizer_name_or_path", type=str, default=None, help="if specified, we will load the tokenizer from here."
256 | )
257 | parser.add_argument(
258 | "--openai_engine",
259 | type=str,
260 | default=None,
261 | help="if specified, we will use the OpenAI API to generate the predictions.",
262 | )
263 | parser.add_argument(
264 | "--no_cot", action="store_true", help="if specified, chain of thoughts will be removed from the prompts."
265 | )
266 | parser.add_argument(
267 | "--max_num_examples_per_task", type=int, default=None, help="maximum number of examples to evaluate per task."
268 | )
269 | parser.add_argument("--eval_batch_size", type=int, default=1, help="batch size for evaluation.")
270 | parser.add_argument(
271 | "--load_in_8bit",
272 | action="store_true",
273 | help="load model in 8bit mode, which will reduce memory and speed up inference.",
274 | )
275 | parser.add_argument("--gptq", action="store_true", help="If given, we're evaluating a 4-bit quantized GPTQ model.")
276 | parser.add_argument(
277 | "--chat_format",
278 | type=str,
279 | default=None,
280 | choices=["tulu", "lemur", "codellama-instruct"],
281 | help="If given, the prompt will be encoded as a chat format with the roles in prompt.",
282 | )
283 | parser.add_argument(
284 | "--seed",
285 | type=int,
286 | default=42,
287 | )
288 | parser.add_argument(
289 | "--task_start_idx",
290 | type=int,
291 | default=None,
292 | )
293 | parser.add_argument(
294 | "--task_end_idx",
295 | type=int,
296 | default=None,
297 | )
298 | parser.add_argument("--special_task", type=str, default=None)
299 |
300 | args = parser.parse_args()
301 |
302 | # model_name_or_path and openai_engine cannot be both None or both not None.
303 | assert (args.model_name_or_path is None) != (
304 | args.openai_engine is None
305 | ), "Either model_name_or_path or openai_engine should be specified."
306 | main(args)
307 |
--------------------------------------------------------------------------------
/xchat/eval/utils.py:
--------------------------------------------------------------------------------
1 | # This code is based on https://github.com/allenai/open-instruct.
2 |
3 | import asyncio
4 | import json
5 | import os
6 | import time
7 |
8 | import torch
9 | import tqdm
10 | from transformers import StoppingCriteria
11 |
12 | from xchat.eval.dispatch_openai_requests import dispatch_openai_chat_requesets, dispatch_openai_prompt_requesets
13 |
14 |
15 | class KeyWordsCriteria(StoppingCriteria):
16 | def __init__(self, stop_id_sequences):
17 | assert isinstance(stop_id_sequences[0], list), "stop_id_sequences should be a list of list of ids"
18 | self.stop_sequences = stop_id_sequences
19 |
20 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
21 | sequences_should_be_stopped = []
22 | for i in range(input_ids.shape[0]):
23 | sequence_should_be_stopped = False
24 | for stop_sequence in self.stop_sequences:
25 | if input_ids[i][-len(stop_sequence) :].tolist() == stop_sequence:
26 | sequence_should_be_stopped = True
27 | break
28 | sequences_should_be_stopped.append(sequence_should_be_stopped)
29 | return all(sequences_should_be_stopped)
30 |
31 |
32 | @torch.no_grad()
33 | def generate_completions(
34 | model,
35 | tokenizer,
36 | prompts,
37 | batch_size=1,
38 | stop_id_sequences=None,
39 | disable_tqdm=False,
40 | assigned_temperatures=None,
41 | **generation_kwargs,
42 | ):
43 | # sort prompts by length to reduce the number of padding tokens
44 | prompt_lengths = [len(prompt) for prompt in prompts]
45 | sorted_prompt_indices = sorted(range(len(prompt_lengths)), key=lambda k: prompt_lengths[k])
46 | prompts = [prompts[i] for i in sorted_prompt_indices]
47 |
48 | if assigned_temperatures is not None:
49 | assigned_temperatures = [assigned_temperatures[i] for i in sorted_prompt_indices]
50 |
51 | generations = []
52 | if not disable_tqdm:
53 | progress = tqdm.tqdm(total=len(prompts), desc="Generating Completions")
54 |
55 | num_return_sequences = generation_kwargs.get("num_return_sequences", 1)
56 |
57 | # if do_sample=True then set default temperature to 0.0
58 | # by default, we use temperature 0.0 to get the most likely completion.
59 | if generation_kwargs.get("do_sample", False):
60 | generate_temperature = generation_kwargs.get("temperature", 0.0)
61 |
62 | for i in range(0, len(prompts), batch_size):
63 | batch_prompts = prompts[i : i + batch_size]
64 | tokenized_prompts = tokenizer(batch_prompts, padding="longest", return_tensors="pt", add_special_tokens=False)
65 | batch_input_ids = tokenized_prompts.input_ids
66 | attention_mask = tokenized_prompts.attention_mask
67 |
68 | if model.device.type == "cuda":
69 | batch_input_ids = batch_input_ids.cuda()
70 | attention_mask = attention_mask.cuda()
71 |
72 | # Set temperature for each batch if assigned
73 | if generation_kwargs.get("do_sample", False):
74 | assigned_temperature = (
75 | assigned_temperatures[i : i + batch_size][0] if assigned_temperatures is not None else None
76 | )
77 | # temperature = assigned_temperature if assigned_temperature is not None else generate_temperature
78 | generation_kwargs["temperature"] = (
79 | assigned_temperature if assigned_temperature is not None else generate_temperature
80 | )
81 |
82 | try:
83 | batch_outputs = model.generate(
84 | input_ids=batch_input_ids,
85 | attention_mask=attention_mask,
86 | stopping_criteria=[KeyWordsCriteria(stop_id_sequences)] if stop_id_sequences else None,
87 | **generation_kwargs,
88 | )
89 |
90 | # the stopping criteria is applied at batch level, so if other examples are not stopped, the entire batch will continue to generate.
91 | # so some outputs still have the stop sequence, which we need to remove.
92 | if stop_id_sequences:
93 | for output_idx in range(batch_outputs.shape[0]):
94 | for token_idx in range(batch_input_ids.shape[1], batch_outputs.shape[1]):
95 | if any(
96 | batch_outputs[output_idx, token_idx : token_idx + len(stop_sequence)].tolist()
97 | == stop_sequence
98 | for stop_sequence in stop_id_sequences
99 | ):
100 | batch_outputs[output_idx, token_idx:] = tokenizer.pad_token_id
101 | break
102 |
103 | # remove the prompt from the output
104 | # we need to re-encode the prompt because we need to make sure the special tokens are treated the same way as in the outputs.
105 | # we changed our previous way of truncating the output token ids dicrectly because some tokenizer (e.g., llama) won't add space token before the first token.
106 | # space is important for some tasks (e.g., code completion).
107 | batch_outputs = tokenizer.batch_decode(batch_outputs, skip_special_tokens=True)
108 | batch_prompts = tokenizer.batch_decode(batch_input_ids, skip_special_tokens=True)
109 | # duplicate the prompts to match the number of return sequences
110 | batch_prompts = [prompt for prompt in batch_prompts for _ in range(num_return_sequences)]
111 | batch_generations = [output[len(prompt) :] for prompt, output in zip(batch_prompts, batch_outputs)]
112 | except Exception as e:
113 | print("Error when generating completions for batch:")
114 | print(batch_prompts)
115 | print("Error message:")
116 | print(e)
117 | print("Use empty string as the completion.")
118 | batch_generations = [""] * len(batch_prompts) * num_return_sequences
119 |
120 | generations += batch_generations
121 |
122 | # for prompt, generation in zip(batch_prompts, batch_generations):
123 | # print("========")
124 | # print(prompt)
125 | # print("--------")
126 | # print(generation)
127 |
128 | if not disable_tqdm:
129 | progress.update(len(batch_prompts) // num_return_sequences)
130 |
131 | assert (
132 | len(generations) == len(prompts) * num_return_sequences
133 | ), "number of generations should be equal to number of prompts * num_return_sequences"
134 |
135 | generations = sorted(zip(sorted_prompt_indices, generations), key=lambda x: x[0])
136 | generations = [generation for _, generation in generations]
137 | return generations
138 |
139 |
140 | @torch.no_grad()
141 | def get_next_word_predictions(
142 | model,
143 | tokenizer,
144 | prompts,
145 | candidate_token_ids=None,
146 | batch_size=1,
147 | return_token_predictions=False,
148 | disable_tqdm=False,
149 | ):
150 | predictions, probs = [], []
151 | if not disable_tqdm:
152 | progress = tqdm.tqdm(total=len(prompts), desc="Getting Predictions")
153 |
154 | for i in range(0, len(prompts), batch_size):
155 | batch_prompts = prompts[i : i + batch_size]
156 | tokenized_prompts = tokenizer(batch_prompts, padding="longest", return_tensors="pt", add_special_tokens=False)
157 | batch_input_ids = tokenized_prompts.input_ids
158 | attention_mask = tokenized_prompts.attention_mask
159 |
160 | if model.device.type == "cuda":
161 | batch_input_ids = batch_input_ids.cuda()
162 | attention_mask = attention_mask.cuda()
163 |
164 | batch_logits = model(input_ids=batch_input_ids, attention_mask=attention_mask).logits[:, -1, :]
165 | if candidate_token_ids is not None:
166 | batch_logits = batch_logits[:, candidate_token_ids]
167 | batch_probs = torch.softmax(batch_logits, dim=-1)
168 | batch_prediction_indices = torch.argmax(batch_probs, dim=-1)
169 | if return_token_predictions:
170 | if candidate_token_ids is not None:
171 | candidate_tokens = tokenizer.convert_ids_to_tokens(candidate_token_ids)
172 | batch_predictions = [candidate_tokens[idx] for idx in batch_prediction_indices]
173 | else:
174 | batch_predictions = tokenizer.convert_ids_to_tokens(batch_prediction_indices)
175 | predictions += batch_predictions
176 | else:
177 | predictions += batch_prediction_indices.tolist()
178 | probs += batch_probs.tolist()
179 |
180 | if not disable_tqdm:
181 | progress.update(len(batch_prompts))
182 |
183 | assert len(predictions) == len(prompts), "number of predictions should be equal to number of prompts"
184 | return predictions, probs
185 |
186 |
187 | def load_hf_lm_and_tokenizer(
188 | model_name_or_path,
189 | tokenizer_name_or_path=None,
190 | device_map="auto",
191 | load_in_8bit=False,
192 | load_in_half=False,
193 | gptq_model=False,
194 | use_fast_tokenizer=False,
195 | padding_side="left",
196 | ):
197 | from transformers import AutoModelForCausalLM, AutoTokenizer
198 |
199 | if not tokenizer_name_or_path:
200 | tokenizer_name_or_path = model_name_or_path
201 | tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path, use_fast=use_fast_tokenizer, legacy=True)
202 | # set padding side to left for batch generation
203 | tokenizer.padding_side = padding_side
204 | # set pad token to eos token if pad token is not set (as is the case for llama models)
205 | if tokenizer.pad_token is None:
206 | tokenizer.pad_token = tokenizer.unk_token
207 | tokenizer.pad_token_id = tokenizer.unk_token_id
208 | if gptq_model:
209 | from auto_gptq import AutoGPTQForCausalLM
210 |
211 | model_wrapper = AutoGPTQForCausalLM.from_quantized(model_name_or_path, device="cuda:0", use_triton=True)
212 | model = model_wrapper.model
213 | elif load_in_8bit:
214 | model = AutoModelForCausalLM.from_pretrained(model_name_or_path, device_map=device_map, load_in_8bit=True)
215 | else:
216 | if device_map:
217 | model = AutoModelForCausalLM.from_pretrained(model_name_or_path, device_map=device_map)
218 | else:
219 | model = AutoModelForCausalLM.from_pretrained(model_name_or_path)
220 | if torch.cuda.is_available():
221 | model = model.cuda()
222 | if load_in_half:
223 | model = model.half()
224 | model.eval()
225 | return model, tokenizer
226 |
227 |
228 | def query_openai_chat_model(
229 | engine,
230 | instances,
231 | output_path=None,
232 | batch_size=10,
233 | retry_limit=5,
234 | reuse_existing_outputs=True,
235 | organized_message_format=False,
236 | assigned_temperatures=None,
237 | **completion_kwargs,
238 | ):
239 | """Query OpenAI chat model and save the results to output_path.
240 |
241 | `instances` is a list of dictionaries, each dictionary contains a key "prompt" and a key "id".
242 | """
243 | existing_data = {}
244 | if reuse_existing_outputs and output_path is not None and os.path.exists(output_path):
245 | with open(output_path) as f:
246 | for line in f:
247 | instance = json.loads(line)
248 | existing_data[instance["id"]] = instance
249 |
250 | # by default, we use temperature 0.0 to get the most likely completion.
251 | generate_temperature = completion_kwargs.pop("temperature", 0.0)
252 |
253 | results = []
254 | if output_path is not None:
255 | fout = open(output_path, "w") # noqa: SIM115
256 |
257 | retry_count = 0
258 | progress_bar = tqdm.tqdm(total=len(instances))
259 | for i in range(0, len(instances), batch_size):
260 | batch = instances[i : i + batch_size]
261 | if all(x["id"] in existing_data for x in batch):
262 | results.extend([existing_data[x["id"]] for x in batch])
263 | if output_path is not None:
264 | for instance in batch:
265 | fout.write(json.dumps(existing_data[instance["id"]]) + "\n")
266 | fout.flush()
267 | progress_bar.update(batch_size)
268 | continue
269 | messages_list = []
270 | for instance in batch:
271 | if organized_message_format:
272 | messages = instance["prompt"]
273 | else:
274 | messages = [{"role": "user", "content": instance["prompt"]}]
275 | messages_list.append(messages)
276 |
277 | # Set temperature for each batch if assigned
278 | assigned_temperature = (
279 | assigned_temperatures[i : i + batch_size][0] if assigned_temperatures is not None else None
280 | )
281 | temperature = assigned_temperature if assigned_temperature is not None else generate_temperature
282 |
283 | while retry_count < retry_limit:
284 | try:
285 | outputs = asyncio.run(
286 | dispatch_openai_chat_requesets(
287 | messages_list=messages_list,
288 | model=engine,
289 | temperature=temperature,
290 | **completion_kwargs,
291 | )
292 | )
293 | retry_count = 0
294 | break
295 | except Exception as e:
296 | retry_count += 1
297 | print("Error while requesting OpenAI API.")
298 | print(e)
299 | print(f"Sleep for {30*retry_count} seconds.")
300 | time.sleep(30 * retry_count)
301 | print(f"Retry for the {retry_count} time.")
302 | if retry_count == retry_limit:
303 | raise RuntimeError(f"Failed to get response from OpenAI API after {retry_limit} retries.")
304 | assert len(outputs) == len(batch)
305 | for instance, output in zip(batch, outputs):
306 | instance["output"] = output["choices"][0]["message"]["content"]
307 | instance["response_metadata"] = output
308 | results.append(instance)
309 | if output_path is not None:
310 | fout.write(json.dumps(instance) + "\n")
311 | fout.flush()
312 | progress_bar.update(batch_size)
313 | return results
314 |
315 |
316 | def query_openai_model(
317 | engine, instances, output_path=None, batch_size=10, retry_limit=5, reuse_existing_outputs=True, **completion_kwargs
318 | ):
319 | """Query OpenAI chat model and save the results to output_path.
320 |
321 | `instances` is a list of dictionaries, each dictionary contains a key "prompt" and a key "id".
322 | """
323 | existing_data = {}
324 | if reuse_existing_outputs and output_path is not None and os.path.exists(output_path):
325 | with open(output_path) as f:
326 | for line in f:
327 | instance = json.loads(line)
328 | existing_data[instance["id"]] = instance
329 |
330 | # by default, we use temperature 0.0 to get the most likely completion.
331 | if "temperature" not in completion_kwargs:
332 | completion_kwargs["temperature"] = 0.0
333 |
334 | results = []
335 | if output_path is not None:
336 | fout = open(output_path, "w") # noqa: SIM115
337 |
338 | retry_count = 0
339 | progress_bar = tqdm.tqdm(total=len(instances))
340 | for i in range(0, len(instances), batch_size):
341 | batch = instances[i : i + batch_size]
342 | if all(x["id"] in existing_data for x in batch):
343 | results.extend([existing_data[x["id"]] for x in batch])
344 | if output_path is not None:
345 | for instance in batch:
346 | fout.write(json.dumps(existing_data[instance["id"]]) + "\n")
347 | fout.flush()
348 | progress_bar.update(batch_size)
349 | continue
350 | messages_list = []
351 | for instance in batch:
352 | messages = instance["prompt"]
353 | messages_list.append(messages)
354 | while retry_count < retry_limit:
355 | try:
356 | outputs = asyncio.run(
357 | dispatch_openai_prompt_requesets(
358 | prompt_list=messages_list,
359 | model=engine,
360 | **completion_kwargs,
361 | )
362 | )
363 | retry_count = 0
364 | break
365 | except Exception as e:
366 | retry_count += 1
367 | print("Error while requesting OpenAI API.")
368 | print(e)
369 | print(f"Sleep for {30*retry_count} seconds.")
370 | time.sleep(30 * retry_count)
371 | print(f"Retry for the {retry_count} time.")
372 | if retry_count == retry_limit:
373 | raise RuntimeError(f"Failed to get response from OpenAI API after {retry_limit} retries.")
374 | assert len(outputs) == len(batch)
375 | for instance, output in zip(batch, outputs):
376 | instance["output"] = output["choices"][0]["text"]
377 | instance["response_metadata"] = output
378 | results.append(instance)
379 | if output_path is not None:
380 | fout.write(json.dumps(instance) + "\n")
381 | fout.flush()
382 | progress_bar.update(batch_size)
383 | return results
384 |
--------------------------------------------------------------------------------
/xchat/train/train.py:
--------------------------------------------------------------------------------
1 | """
2 | This script is adapted from the https://github.com/allenai/open-instruct/blob/main/open_instruct/finetune.py.
3 | """
4 |
5 | import argparse
6 | import logging
7 | import math
8 | import os
9 | import random
10 | from functools import partial
11 |
12 | import datasets
13 | import torch
14 | import transformers
15 | from accelerate import Accelerator
16 | from accelerate.logging import get_logger
17 | from accelerate.utils import set_seed
18 | from datasets import load_dataset
19 | from peft import LoraConfig, TaskType, get_peft_model
20 | from torch.utils.data import DataLoader
21 | from tqdm.auto import tqdm
22 | from transformers import (
23 | AutoConfig,
24 | AutoModelForCausalLM,
25 | AutoTokenizer,
26 | DataCollatorForSeq2Seq,
27 | GPT2Tokenizer,
28 | GPTNeoXTokenizerFast,
29 | LlamaTokenizer,
30 | LlamaTokenizerFast,
31 | OPTForCausalLM,
32 | SchedulerType,
33 | get_scheduler,
34 | )
35 |
36 | logger = get_logger(__name__)
37 |
38 |
39 | def parse_args():
40 | parser = argparse.ArgumentParser(description="Finetune a transformers model on a causal language modeling task")
41 | parser.add_argument(
42 | "--dataset_name",
43 | type=str,
44 | default=None,
45 | help="The name of the dataset to use (via the datasets library).",
46 | )
47 | parser.add_argument(
48 | "--dataset_config_name",
49 | type=str,
50 | default=None,
51 | help="The configuration name of the dataset to use (via the datasets library).",
52 | )
53 | parser.add_argument(
54 | "--train_file", type=str, default=None, help="A csv or a json file containing the training data."
55 | )
56 | parser.add_argument(
57 | "--chat_format",
58 | type=str,
59 | default="chatml",
60 | choices=["chatml", "tulu"],
61 | )
62 | parser.add_argument(
63 | "--data_packing",
64 | action="store_true",
65 | help="If passed, will pack multiple examples into one sequence.",
66 | )
67 | parser.add_argument(
68 | "--model_name_or_path",
69 | type=str,
70 | help="Path to pretrained model or model identifier from huggingface.co/models.",
71 | required=False,
72 | )
73 | parser.add_argument(
74 | "--config_name",
75 | type=str,
76 | default=None,
77 | help="Pretrained config name or path if not the same as model_name",
78 | )
79 | parser.add_argument(
80 | "--use_lora",
81 | action="store_true",
82 | help="If passed, will use LORA (low-rank parameter-efficient training) to train the model.",
83 | )
84 | parser.add_argument(
85 | "--lora_rank",
86 | type=int,
87 | default=64,
88 | help="The rank of lora.",
89 | )
90 | parser.add_argument(
91 | "--lora_alpha",
92 | type=float,
93 | default=16,
94 | help="The alpha parameter of lora.",
95 | )
96 | parser.add_argument(
97 | "--lora_dropout",
98 | type=float,
99 | default=0.1,
100 | help="The dropout rate of lora modules.",
101 | )
102 | parser.add_argument(
103 | "--save_merged_lora_model",
104 | action="store_true",
105 | help="If passed, will merge the lora modules and save the entire model.",
106 | )
107 | parser.add_argument(
108 | "--use_flash_attn",
109 | action="store_true",
110 | help="If passed, will use flash attention to train the model.",
111 | )
112 | parser.add_argument(
113 | "--tokenizer_name",
114 | type=str,
115 | default=None,
116 | help="Pretrained tokenizer name or path if not the same as model_name",
117 | )
118 | parser.add_argument(
119 | "--use_slow_tokenizer",
120 | action="store_true",
121 | help="If passed, will use a slow tokenizer (not backed by the 🤗 Tokenizers library).",
122 | )
123 | parser.add_argument(
124 | "--max_seq_length",
125 | type=int,
126 | default=512,
127 | help="The maximum total sequence length (prompt+completion) of each training example.",
128 | )
129 | parser.add_argument(
130 | "--per_device_train_batch_size",
131 | type=int,
132 | default=8,
133 | help="Batch size (per device) for the training dataloader.",
134 | )
135 | parser.add_argument(
136 | "--learning_rate",
137 | type=float,
138 | default=5e-5,
139 | help="Initial learning rate (after the potential warmup period) to use.",
140 | )
141 | parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.")
142 | parser.add_argument("--num_train_epochs", type=int, default=3, help="Total number of training epochs to perform.")
143 | parser.add_argument(
144 | "--max_train_steps",
145 | type=int,
146 | default=None,
147 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
148 | )
149 | parser.add_argument(
150 | "--gradient_accumulation_steps",
151 | type=int,
152 | default=1,
153 | help="Number of updates steps to accumulate before performing a backward/update pass.",
154 | )
155 | parser.add_argument(
156 | "--gradient_checkpointing",
157 | action="store_true",
158 | help="If passed, use gradient checkpointing to save memory at the expense of slower backward pass.",
159 | )
160 | parser.add_argument(
161 | "--lr_scheduler_type",
162 | type=SchedulerType,
163 | default="linear",
164 | help="The scheduler type to use.",
165 | choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"],
166 | )
167 | parser.add_argument("--warmup_ratio", type=float, default=0, help="Ratio of total training steps used for warmup.")
168 | parser.add_argument("--output_dir", type=str, default=None, help="Where to store the final model.")
169 | parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
170 | parser.add_argument(
171 | "--preprocessing_num_workers",
172 | type=int,
173 | default=None,
174 | help="The number of processes to use for the preprocessing.",
175 | )
176 | parser.add_argument(
177 | "--cache_dir",
178 | type=str,
179 | default=None,
180 | help="Where to store the preprocessed datasets downloaded from the datasets library.",
181 | )
182 | parser.add_argument(
183 | "--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets"
184 | )
185 | parser.add_argument(
186 | "--checkpointing_steps",
187 | type=str,
188 | default=None,
189 | help="Whether the various states should be saved at the end of every n steps, or 'epoch' for each epoch.",
190 | )
191 | parser.add_argument(
192 | "--logging_steps",
193 | type=int,
194 | default=None,
195 | help="Log the training loss and learning rate every logging_steps steps.",
196 | )
197 | parser.add_argument(
198 | "--resume_from_checkpoint",
199 | type=str,
200 | default=None,
201 | help="If the training should continue from a checkpoint folder.",
202 | )
203 | parser.add_argument(
204 | "--with_tracking",
205 | action="store_true",
206 | help="Whether to enable experiment trackers for logging.",
207 | )
208 | parser.add_argument(
209 | "--report_to",
210 | type=str,
211 | default="all",
212 | help=(
213 | 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,'
214 | ' `"wandb"`, `"comet_ml"` and `"clearml"`. Use `"all"` (default) to report to all integrations.'
215 | "Only applicable when `--with_tracking` is passed."
216 | ),
217 | )
218 | parser.add_argument(
219 | "--low_cpu_mem_usage",
220 | action="store_true",
221 | help=(
222 | "It is an option to create the model as an empty shell, then only materialize its parameters when the pretrained weights are loaded."
223 | "If passed, LLM loading time and RAM consumption will be benefited."
224 | ),
225 | )
226 | args = parser.parse_args()
227 |
228 | # Sanity checks
229 | if args.dataset_name is None and args.train_file is None:
230 | raise ValueError("Need either a dataset name or a training file.")
231 | else:
232 | if args.train_file is not None:
233 | extension = args.train_file.split(".")[-1]
234 | assert extension in ["json", "jsonl"], "`train_file` should be a json/jsonl file."
235 | return args
236 |
237 |
238 | def build_dataset(args, tokenizer, accelerator):
239 | def encode_with_prompt_completion_format(example, max_seq_length):
240 | """
241 | Here we assume each example has 'prompt' and 'completion' fields.
242 | We concatenate prompt and completion and tokenize them together because otherwise prompt will be padded/trancated
243 | and it doesn't make sense to follow directly with the completion.
244 | """
245 | # if prompt doesn't end with space and completion doesn't start with space, add space
246 | if not example["prompt"].endswith((" ", "\n", "\t")) and not example["completion"].startswith(
247 | (" ", "\n", "\t")
248 | ):
249 | example_text = example["prompt"] + " " + example["completion"]
250 | else:
251 | example_text = example["prompt"] + example["completion"]
252 | example_text = example_text + tokenizer.eos_token
253 | tokenized_example = tokenizer(example_text, return_tensors="pt", max_length=max_seq_length, truncation=True)
254 | input_ids = tokenized_example.input_ids
255 | labels = input_ids.clone()
256 | tokenized_prompt = tokenizer(example["prompt"], return_tensors="pt", max_length=max_seq_length, truncation=True)
257 | # mask the prompt part for avoiding loss
258 | labels[:, : tokenized_prompt.input_ids.shape[1]] = -100
259 | attention_mask = torch.ones_like(input_ids)
260 | return {
261 | "input_ids": input_ids.flatten(),
262 | "labels": labels.flatten(),
263 | "attention_mask": attention_mask.flatten(),
264 | }
265 |
266 | def encode_with_messages_format(example, max_seq_length, chat_format):
267 | """
268 | Here we assume each example has a 'messages' field Each message is a dict with 'role' and 'content' fields.
269 | We concatenate all messages with the roles as delimiters and tokenize them together.
270 | """
271 | messages = example["messages"]
272 | if len(messages) == 0:
273 | raise ValueError("messages field is empty.")
274 |
275 | def _concat_messages_tulu(messages):
276 | message_text = ""
277 | for message in messages:
278 | if message["role"] == "system":
279 | message_text += "<|system|>\n" + message["content"].strip() + "\n"
280 | elif message["role"] == "user":
281 | message_text += "<|user|>\n" + message["content"].strip() + "\n"
282 | elif message["role"] == "assistant":
283 | message_text += "<|assistant|>\n" + message["content"].strip() + tokenizer.eos_token + "\n"
284 | else:
285 | raise ValueError("Invalid role: {}".format(message["role"]))
286 | return message_text
287 |
288 | def _concat_messages_chatml(messages):
289 | message_text = ""
290 | for message in messages:
291 | if message["role"] == "system":
292 | message_text += "<|im_start|>system\n" + message["content"].strip() + "<|im_end|>\n"
293 | elif message["role"] == "user":
294 | message_text += "<|im_start|>user\n" + message["content"].strip() + "<|im_end|>\n"
295 | elif message["role"] == "assistant":
296 | message_text += "<|im_start|>assistant\n" + message["content"].strip() + "<|im_end|>\n"
297 | else:
298 | raise ValueError("Invalid role: {}".format(message["role"]))
299 | return message_text
300 |
301 | if chat_format == "chatml":
302 | _concat_messages = _concat_messages_chatml
303 | assistant_start_token = "<|im_start|>assistant\n"
304 | elif chat_format == "tulu":
305 | _concat_messages = _concat_messages_tulu
306 | assistant_start_token = "<|assistant|>\n"
307 | else:
308 | raise ValueError(f"Invalid chat format: {chat_format}")
309 | example_text = _concat_messages(messages)
310 | tokenized_example = tokenizer(example_text, return_tensors="pt", max_length=max_seq_length, truncation=True)
311 | input_ids = tokenized_example.input_ids
312 | labels = input_ids.clone()
313 |
314 | # mask the non-assistant part for avoiding loss
315 | for message_idx, message in enumerate(messages):
316 | if message["role"] != "assistant":
317 | if message_idx == 0:
318 | message_start_idx = 0
319 | else:
320 | message_start_idx = tokenizer(
321 | _concat_messages(messages[:message_idx]),
322 | return_tensors="pt",
323 | max_length=max_seq_length,
324 | truncation=True,
325 | ).input_ids.shape[1]
326 | if message_idx < len(messages) - 1 and messages[message_idx + 1]["role"] == "assistant":
327 | # here we also ignore the role of the assistant
328 | messages_so_far = _concat_messages(messages[: message_idx + 1]) + assistant_start_token
329 | else:
330 | messages_so_far = _concat_messages(messages[: message_idx + 1])
331 | message_end_idx = tokenizer(
332 | messages_so_far, return_tensors="pt", max_length=max_seq_length, truncation=True
333 | ).input_ids.shape[1]
334 | labels[:, message_start_idx:message_end_idx] = -100
335 |
336 | if message_end_idx >= max_seq_length:
337 | break
338 |
339 | attention_mask = torch.ones_like(input_ids)
340 | return {
341 | "input_ids": input_ids.flatten(),
342 | "labels": labels.flatten(),
343 | "attention_mask": attention_mask.flatten(),
344 | }
345 |
346 | if args.dataset_name is not None:
347 | # Downloading and loading a dataset from the hub.
348 | raw_datasets = load_dataset(args.dataset_name, args.dataset_config_name, cache_dir=args.cache_dir)
349 | else:
350 | data_files = {}
351 | dataset_args = {}
352 | if args.train_file is not None:
353 | data_files["train"] = args.train_file
354 | raw_datasets = load_dataset(
355 | "json",
356 | data_files=data_files,
357 | **dataset_args,
358 | )
359 | # raw_datasets["train"] = raw_datasets["train"].select(range(100))
360 |
361 | if "input_ids" in raw_datasets["train"].column_names:
362 | lm_datasets = raw_datasets.remove_columns(
363 | [
364 | name
365 | for name in raw_datasets["train"].column_names
366 | if name not in ["input_ids", "labels", "attention_mask"]
367 | ]
368 | )
369 | lm_datasets.set_format(type="pt")
370 | else:
371 | # Preprocessing the datasets.
372 | if "prompt" in raw_datasets["train"].column_names and "completion" in raw_datasets["train"].column_names:
373 | encode_function = partial(
374 | encode_with_prompt_completion_format,
375 | max_seq_length=args.max_seq_length,
376 | )
377 | elif "messages" in raw_datasets["train"].column_names:
378 | encode_function = partial(
379 | encode_with_messages_format,
380 | max_seq_length=args.max_seq_length,
381 | chat_format=args.chat_format,
382 | )
383 | else:
384 | raise ValueError("You need to have either 'prompt'&'completion' or 'messages' in your column names.")
385 | with accelerator.main_process_first():
386 | lm_datasets = raw_datasets.map(
387 | encode_function,
388 | batched=False,
389 | num_proc=args.preprocessing_num_workers,
390 | load_from_cache_file=not args.overwrite_cache,
391 | remove_columns=[
392 | name
393 | for name in raw_datasets["train"].column_names
394 | if name not in ["input_ids", "labels", "attention_mask"]
395 | ],
396 | desc="Tokenizing and reformatting instruction data",
397 | )
398 | lm_datasets.set_format(type="pt")
399 | lm_datasets = lm_datasets.filter(lambda example: (example["labels"] != -100).any())
400 |
401 | train_dataset = lm_datasets["train"]
402 |
403 | if args.data_packing:
404 |
405 | def data_packing(examples, max_seq_length):
406 | num_examples = len(examples["input_ids"])
407 | packed_examples = {"input_ids": [], "labels": [], "attention_mask": []}
408 | eoc_id = tokenizer.encode(tokenizer.eoc_token, return_tensors="pt", add_special_tokens=False).squeeze(0)
409 | eoc_label = torch.tensor([-100])
410 | eoc_mask = torch.tensor([1])
411 |
412 | for i in range(num_examples):
413 | if (
414 | len(packed_examples["input_ids"]) > 0
415 | and len(packed_examples["input_ids"][-1]) + len(examples["input_ids"][i]) <= max_seq_length
416 | ):
417 | packed_examples["input_ids"][-1] = torch.cat(
418 | [packed_examples["input_ids"][-1], examples["input_ids"][i], eoc_id]
419 | )
420 | packed_examples["labels"][-1] = torch.cat(
421 | [packed_examples["labels"][-1], examples["labels"][i], eoc_label]
422 | )
423 | packed_examples["attention_mask"][-1] = torch.cat(
424 | [packed_examples["attention_mask"][-1], examples["attention_mask"][i], eoc_mask]
425 | )
426 | elif len(examples["input_ids"][i]) > max_seq_length:
427 | for j in range(0, len(examples["input_ids"][i]), max_seq_length):
428 | packed_examples["input_ids"].append(examples["input_ids"][i][j : j + max_seq_length])
429 | packed_examples["labels"].append(examples["labels"][i][j : j + max_seq_length])
430 | packed_examples["attention_mask"].append(examples["attention_mask"][i][j : j + max_seq_length])
431 | else:
432 | packed_examples["input_ids"].append(torch.cat([examples["input_ids"][i], eoc_id]))
433 | packed_examples["labels"].append(torch.cat([examples["labels"][i], eoc_label]))
434 | packed_examples["attention_mask"].append(torch.cat([examples["attention_mask"][i], eoc_mask]))
435 |
436 | assert (
437 | len(packed_examples["input_ids"][-1])
438 | == len(packed_examples["labels"][-1])
439 | == len(packed_examples["labels"][-1])
440 | == len(packed_examples["attention_mask"][-1])
441 | )
442 |
443 | return packed_examples
444 |
445 | pack_function = partial(data_packing, max_seq_length=args.max_seq_length)
446 | with accelerator.main_process_first():
447 | logger.info(f"Training dataset size before packing: {len(train_dataset)}")
448 | train_dataset = train_dataset.map(
449 | pack_function,
450 | batched=True,
451 | num_proc=args.preprocessing_num_workers,
452 | remove_columns=train_dataset.column_names,
453 | desc="Packing data",
454 | )
455 | logger.info(f"Training dataset size after packing: {len(train_dataset)}")
456 |
457 | return train_dataset
458 |
459 |
460 | def build_tokenizer(args):
461 | if args.tokenizer_name:
462 | tokenizer = AutoTokenizer.from_pretrained(
463 | args.tokenizer_name, use_fast=not args.use_slow_tokenizer, legacy=True
464 | )
465 | elif args.model_name_or_path:
466 | tokenizer = AutoTokenizer.from_pretrained(
467 | args.model_name_or_path, use_fast=not args.use_slow_tokenizer, legacy=True
468 | )
469 | else:
470 | raise ValueError(
471 | "You are instantiating a new tokenizer from scratch. This is not supported by this script."
472 | "You can do it from another script, save it, and load it from here, using --tokenizer_name."
473 | )
474 | print(f"Tokenizer type: {type(tokenizer)}")
475 | # no default pad token for llama!
476 | # here we add all special tokens again, because the default ones are not in the special_tokens_map
477 | if isinstance(tokenizer, (LlamaTokenizer, LlamaTokenizerFast)):
478 | num_added_tokens = tokenizer.add_special_tokens(
479 | {
480 | "bos_token": "",
481 | "eos_token": "",
482 | "unk_token": "",
483 | "pad_token": "",
484 | }
485 | )
486 | assert num_added_tokens in [
487 | 0,
488 | 1,
489 | ], "LlamaTokenizer should only add one special token - the pad_token, or no tokens if pad token present."
490 | if args.chat_format == "tulu":
491 | if args.data_packing:
492 | tokenizer.add_special_tokens(
493 | {
494 | "additional_special_tokens": ["<|end_of_conv|>"],
495 | }
496 | )
497 | tokenizer.eoc_token = "<|end_of_conv|>"
498 | elif args.chat_format == "chatml":
499 | tokenizer.add_special_tokens(
500 | {
501 | "additional_special_tokens": [
502 | "<|im_start|>",
503 | "<|im_end|>",
504 | "<|im_sep|>",
505 | "<|diff_marker|>",
506 | ],
507 | }
508 | )
509 | tokenizer.eoc_token = tokenizer.eos_token
510 | else:
511 | raise ValueError(f"Unknown chat format {args.chat_format}")
512 | elif isinstance(tokenizer, GPTNeoXTokenizerFast):
513 | num_added_tokens = tokenizer.add_special_tokens(
514 | {
515 | "pad_token": "",
516 | }
517 | )
518 | assert num_added_tokens == 1, "GPTNeoXTokenizer should only add one special token - the pad_token."
519 | elif isinstance(tokenizer, GPT2Tokenizer):
520 | num_added_tokens = tokenizer.add_special_tokens({"unk_token": ""})
521 |
522 | return tokenizer
523 |
524 |
525 | def save_with_accelerate(accelerator, model, tokenizer, output_dir, args):
526 | unwrapped_model = accelerator.unwrap_model(model)
527 | # When doing multi-gpu training, we need to use accelerator.get_state_dict(model) to get the state_dict.
528 | # Otherwise, sometimes the model will be saved with only part of the parameters.
529 | # Also, accelerator needs to use the wrapped model to get the state_dict.
530 | state_dict = accelerator.get_state_dict(model)
531 | if args.use_lora:
532 | # When using lora, the unwrapped model is a PeftModel, which doesn't support the is_main_process
533 | # and has its own save_pretrained function for only saving lora modules.
534 | # We have to manually specify the is_main_process outside the save_pretrained function.
535 | if accelerator.is_main_process:
536 | unwrapped_model.save_pretrained(output_dir, state_dict=state_dict)
537 | else:
538 | unwrapped_model.save_pretrained(
539 | output_dir,
540 | is_main_process=accelerator.is_main_process,
541 | save_function=accelerator.save,
542 | state_dict=state_dict,
543 | )
544 |
545 |
546 | def main():
547 | args = parse_args()
548 |
549 | # A hacky way to make llama work with flash attention
550 | if args.use_flash_attn:
551 | from llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn
552 |
553 | replace_llama_attn_with_flash_attn()
554 |
555 | # Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
556 | # If we're using tracking, we also need to initialize it here and it will by default pick up all supported trackers
557 | # in the environment
558 | accelerator_log_kwargs = {}
559 |
560 | if args.with_tracking:
561 | accelerator_log_kwargs["log_with"] = args.report_to
562 | accelerator_log_kwargs["project_dir"] = args.output_dir
563 |
564 | accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps, **accelerator_log_kwargs)
565 |
566 | # Make one log on every process with the configuration for debugging.
567 | logging.basicConfig(
568 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
569 | datefmt="%m/%d/%Y %H:%M:%S",
570 | level=logging.INFO,
571 | )
572 | logger.info(accelerator.state, main_process_only=False)
573 | if accelerator.is_local_main_process:
574 | datasets.utils.logging.set_verbosity_warning()
575 | transformers.utils.logging.set_verbosity_info()
576 | else:
577 | datasets.utils.logging.set_verbosity_error()
578 | transformers.utils.logging.set_verbosity_error()
579 |
580 | # If passed along, set the training seed now.
581 | if args.seed is not None:
582 | set_seed(args.seed)
583 |
584 | if accelerator.is_main_process and args.output_dir is not None:
585 | os.makedirs(args.output_dir, exist_ok=True)
586 |
587 | accelerator.wait_for_everyone()
588 |
589 | # Load pretrained model and tokenizer
590 | if args.config_name:
591 | config = AutoConfig.from_pretrained(args.config_name)
592 | elif args.model_name_or_path:
593 | config = AutoConfig.from_pretrained(args.model_name_or_path)
594 | else:
595 | raise ValueError(
596 | "You are instantiating a new config instance from scratch. This is not supported by this script."
597 | )
598 |
599 | tokenizer = build_tokenizer(args)
600 |
601 | if args.model_name_or_path:
602 | model = AutoModelForCausalLM.from_pretrained(
603 | args.model_name_or_path,
604 | cache_dir=args.cache_dir,
605 | from_tf=bool(".ckpt" in args.model_name_or_path),
606 | config=config,
607 | low_cpu_mem_usage=args.low_cpu_mem_usage,
608 | )
609 | else:
610 | logger.info("Training new model from scratch")
611 | model = AutoModelForCausalLM.from_config(config)
612 |
613 | # We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch
614 | # on a small vocab and want a smaller embedding size, remove this test.
615 | embedding_size = model.get_input_embeddings().weight.shape[0]
616 | if len(tokenizer) > embedding_size:
617 | model.resize_token_embeddings(len(tokenizer))
618 |
619 | if args.gradient_checkpointing:
620 | model.gradient_checkpointing_enable()
621 |
622 | if args.use_lora:
623 | logger.info("Initializing LORA model...")
624 | peft_config = LoraConfig(
625 | task_type=TaskType.CAUSAL_LM,
626 | inference_mode=False,
627 | r=args.lora_rank,
628 | lora_alpha=args.lora_alpha,
629 | lora_dropout=args.lora_dropout,
630 | )
631 | model = get_peft_model(model, peft_config)
632 | model.print_trainable_parameters()
633 |
634 | train_dataset = build_dataset(args, tokenizer, accelerator)
635 |
636 | # Log a few random samples from the training set:
637 | for index in random.sample(range(len(train_dataset)), 3):
638 | logger.info(f"Sample {index} of the training set: {train_dataset[index]}.")
639 |
640 | # DataLoaders creation:
641 | train_dataloader = DataLoader(
642 | train_dataset,
643 | shuffle=True,
644 | collate_fn=DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model, padding="longest"),
645 | batch_size=args.per_device_train_batch_size,
646 | )
647 |
648 | # Optimizer
649 | # Split weights in two groups, one with weight decay and the other not.
650 | no_decay = ["bias", "layer_norm.weight"]
651 | optimizer_grouped_parameters = [
652 | {
653 | "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
654 | "weight_decay": args.weight_decay,
655 | },
656 | {
657 | "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
658 | "weight_decay": 0.0,
659 | },
660 | ]
661 | optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
662 |
663 | # Scheduler and math around the number of training steps.
664 | overrode_max_train_steps = False
665 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
666 | if args.max_train_steps is None:
667 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
668 | overrode_max_train_steps = True
669 |
670 | # Create the learning rate scheduler.
671 | # Note: the current accelerator.step() calls the .step() of the real scheduler for the `num_processes` times. This is because they assume
672 | # the user initialize the scheduler with the entire training set. In the case of data parallel training, each process only
673 | # sees a subset (1/num_processes) of the training set. So each time the process needs to update the lr multiple times so that the total
674 | # number of updates in the end matches the num_training_steps here.
675 | # Here we need to set the num_training_steps to either using the entire training set (when epochs is specified) or we need to multiply the
676 | # num_training_steps by num_processes so that the total number of updates matches the num_training_steps.
677 | num_training_steps_for_scheduler = (
678 | args.max_train_steps if overrode_max_train_steps else args.max_train_steps * accelerator.num_processes
679 | )
680 | lr_scheduler = get_scheduler(
681 | name=args.lr_scheduler_type,
682 | optimizer=optimizer,
683 | num_training_steps=num_training_steps_for_scheduler,
684 | num_warmup_steps=int(num_training_steps_for_scheduler * args.warmup_ratio),
685 | )
686 |
687 | # Prepare everything with `accelerator`.
688 | model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
689 | model, optimizer, train_dataloader, lr_scheduler
690 | )
691 |
692 | # We need to recalculate our total training steps as the size of the training dataloader may have changed.
693 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
694 | if overrode_max_train_steps:
695 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
696 | # Afterwards we recalculate our number of training epochs
697 | args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
698 |
699 | # Figure out how many steps we should save the Accelerator states
700 | checkpointing_steps = args.checkpointing_steps
701 | if checkpointing_steps is not None and checkpointing_steps.isdigit():
702 | checkpointing_steps = int(checkpointing_steps)
703 |
704 | # We need to initialize the trackers we use, and also store our configuration.
705 | # The trackers initializes automatically on the main process.
706 | if args.with_tracking:
707 | experiment_config = vars(args)
708 | # TensorBoard cannot log Enums, need the raw value
709 | experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value
710 | accelerator.init_trackers("sft-chat", experiment_config)
711 |
712 | # Train!
713 | total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
714 |
715 | logger.info("***** Running training *****")
716 | logger.info(f" Num examples = {len(train_dataset)}")
717 | logger.info(f" Num Epochs = {args.num_train_epochs}")
718 | logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}")
719 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
720 | logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
721 | logger.info(f" Total optimization steps = {args.max_train_steps}")
722 | # Only show the progress bar once on each machine.
723 | progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
724 | completed_steps = 0
725 | starting_epoch = 0
726 |
727 | # Potentially load in the weights and states from a previous save
728 | if args.resume_from_checkpoint:
729 | if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "":
730 | checkpoint_path = args.resume_from_checkpoint
731 | path = os.path.basename(args.resume_from_checkpoint)
732 | else:
733 | # Get the most recent checkpoint
734 | dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()]
735 | dirs.sort(key=os.path.getctime)
736 | path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last
737 | checkpoint_path = path
738 | path = os.path.basename(checkpoint_path)
739 |
740 | accelerator.print(f"Resumed from checkpoint: {checkpoint_path}")
741 | accelerator.load_state(checkpoint_path)
742 | # Extract `epoch_{i}` or `step_{i}`
743 | training_difference = os.path.splitext(path)[0]
744 |
745 | if "epoch" in training_difference:
746 | starting_epoch = int(training_difference.replace("epoch_", "")) + 1
747 | resume_step = None
748 | completed_steps = starting_epoch * num_update_steps_per_epoch
749 | else:
750 | # need to multiply `gradient_accumulation_steps` to reflect real steps
751 | resume_step = int(training_difference.replace("step_", "")) * args.gradient_accumulation_steps
752 | starting_epoch = resume_step // len(train_dataloader)
753 | completed_steps = resume_step // args.gradient_accumulation_steps
754 | resume_step -= starting_epoch * len(train_dataloader)
755 |
756 | # update the progress_bar if load from checkpoint
757 | progress_bar.update(completed_steps)
758 |
759 | for epoch in range(starting_epoch, args.num_train_epochs):
760 | model.train()
761 | total_loss = 0
762 | if args.resume_from_checkpoint and epoch == starting_epoch and resume_step is not None:
763 | # We skip the first `n` batches in the dataloader when resuming from a checkpoint
764 | active_dataloader = accelerator.skip_first_batches(train_dataloader, resume_step)
765 | else:
766 | active_dataloader = train_dataloader
767 | for _step, batch in enumerate(active_dataloader):
768 | with accelerator.accumulate(model):
769 | outputs = model(**batch, use_cache=False)
770 | loss = outputs.loss
771 | # We keep track of the loss at each logged step
772 | total_loss += loss.detach().float()
773 | accelerator.backward(loss)
774 | optimizer.step()
775 | optimizer.zero_grad()
776 | lr_scheduler.step()
777 |
778 | # Checks if the accelerator has performed an optimization step behind the scenes
779 | if accelerator.sync_gradients:
780 | progress_bar.update(1)
781 | completed_steps += 1
782 | if args.logging_steps and completed_steps % args.logging_steps == 0:
783 | avg_loss = (
784 | accelerator.gather(total_loss).mean().item()
785 | / args.gradient_accumulation_steps
786 | / args.logging_steps
787 | )
788 | logger.info(f" Step: {completed_steps}, LR: {lr_scheduler.get_last_lr()[0]}, Loss: {avg_loss}")
789 | if args.with_tracking:
790 | accelerator.log(
791 | {
792 | "learning_rate": lr_scheduler.get_last_lr()[0],
793 | "train_loss": avg_loss,
794 | },
795 | step=completed_steps,
796 | )
797 | total_loss = 0
798 |
799 | if isinstance(checkpointing_steps, int) and completed_steps % checkpointing_steps == 0:
800 | output_dir = f"step_{completed_steps}"
801 | if args.output_dir is not None:
802 | output_dir = os.path.join(args.output_dir, output_dir)
803 | accelerator.save_state(output_dir)
804 | if completed_steps >= args.max_train_steps:
805 | break
806 |
807 | if args.checkpointing_steps == "epoch":
808 | output_dir = f"epoch_{epoch}"
809 | if args.output_dir is not None:
810 | output_dir = os.path.join(args.output_dir, output_dir)
811 | accelerator.save_state(output_dir)
812 |
813 | if args.with_tracking:
814 | accelerator.end_training()
815 |
816 | if args.output_dir is not None:
817 | accelerator.wait_for_everyone()
818 | if accelerator.is_main_process:
819 | tokenizer.save_pretrained(args.output_dir)
820 | save_with_accelerate(accelerator, model, tokenizer, args.output_dir, args)
821 |
822 |
823 | if __name__ == "__main__":
824 | main()
825 |
--------------------------------------------------------------------------------