├── assets └── __init__.py ├── chat ├── data │ ├── __init__.py │ ├── pretty_json.py │ ├── merge.py │ ├── extract_single_round.py │ ├── inspect_data.py │ ├── extract_gpt4_only.py │ ├── convert_alpaca.py │ ├── split_train_test.py │ ├── filter_wrong_format.py │ ├── sample.py │ ├── prepare_all.py │ ├── optional_replace.py │ ├── get_stats.py │ ├── optional_clean.py │ ├── split_long_conversation.py │ ├── hardcoded_questions.py │ └── clean_sharegpt.py ├── server │ ├── __init__.py │ ├── monitor │ │ ├── conv_release_scripts │ │ │ ├── upload_hf_dataset.py │ │ │ ├── count_unique_users.py │ │ │ ├── merge_field.py │ │ │ ├── sample.py │ │ │ └── filter_bad_conv.py │ │ ├── replace_model_name.py │ │ ├── leaderboard_csv_to_html.py │ │ ├── tag_openai_moderation.py │ │ ├── summarize_cluster.py │ │ ├── inspect_conv.py │ │ ├── clean_chat_data.py │ │ ├── basic_stats.py │ │ ├── clean_battle_data.py │ │ ├── topic_clustering.py │ │ └── elo_analysis.py │ ├── register_worker.py │ ├── shutdown_serve.py │ ├── test.py │ ├── gateway │ │ ├── README.md │ │ └── nginx.conf │ ├── huggingface_api.py │ ├── test_message.py │ ├── api_provider.py │ ├── test_throughput.py │ ├── vllm_worker.py │ ├── multi_model_worker.py │ └── launch_all_serve.py ├── modules │ ├── __init__.py │ ├── gptq.py │ └── awq.py ├── __init__.py ├── model │ ├── __init__.py │ ├── convert_fp16.py │ ├── upload_hub.py │ ├── apply_lora.py │ ├── make_delta.py │ ├── rwkv_model.py │ ├── llama_condense_monkey_patch.py │ ├── model_chatglm.py │ ├── model_codet5p.py │ ├── monkey_patch_non_inplace.py │ ├── model_falcon.py │ └── apply_delta.py ├── constants.py └── protocol │ ├── api_protocol.py │ └── openai_api_protocol.py ├── pilot └── __init__.py ├── tests ├── killall_python.sh ├── test_cli_inputs.txt ├── launch_openai_api_test_server.py ├── test_openai_langchain.py ├── test_cli.py └── test_openai_api.py ├── requirements.txt ├── .gitignore └── README.md /assets/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /chat/data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /chat/server/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /pilot/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /chat/modules/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /chat/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.2.26" 2 | -------------------------------------------------------------------------------- /tests/killall_python.sh: -------------------------------------------------------------------------------- 1 | kill -9 $(ps aux | grep 'python3' | grep -v 'grep' | awk '{print $2}') 2 | -------------------------------------------------------------------------------- /chat/model/__init__.py: -------------------------------------------------------------------------------- 1 | from chat.model.model_adapter import ( 2 | load_model, 3 | get_conversation_template, 4 | add_model_args, 5 | ) 6 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | fastapi==0.105.0 2 | numpy==1.26.2 3 | requests==2.31.0 4 | uvicorn==0.25.0 5 | transformers==4.33.2 6 | torch==2.1.2 7 | accelerate==0.25.0 8 | gradio==3.44.3 9 | -------------------------------------------------------------------------------- /tests/test_cli_inputs.txt: -------------------------------------------------------------------------------- 1 | Who are you? __END_OF_A_MESSAGE_47582648__ 2 | Three tips for staying healthy. __END_OF_A_MESSAGE_47582648__ 3 | One more tip. __END_OF_A_MESSAGE_47582648__ 4 | !!exit __END_OF_A_MESSAGE_47582648__ 5 | -------------------------------------------------------------------------------- /chat/server/monitor/conv_release_scripts/upload_hf_dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | Upload to huggingface. 3 | """ 4 | import json 5 | from datasets import Dataset, DatasetDict, load_dataset 6 | 7 | objs = json.load(open("clean_battle_conv_20230630_tagged_v3_pii_33k_added.json")) 8 | data = Dataset.from_list(objs) 9 | data.push_to_hub("lmsys/chatbot_arena_conversations", private=True) 10 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Python 2 | __pycache__ 3 | *.pyc 4 | *.egg-info 5 | dist 6 | .venv 7 | 8 | # Log 9 | *.log 10 | *.log.* 11 | *.json 12 | !playground/deepspeed_config_s2.json 13 | !playground/deepspeed_config_s3.json 14 | 15 | # Editor 16 | .idea 17 | *.swp 18 | 19 | # models 20 | models 21 | 22 | # Other 23 | .DS_Store 24 | wandb 25 | output 26 | checkpoints_flant5_3b 27 | 28 | # Data 29 | *.pkl 30 | *.csv 31 | tests/state_of_the_union.txt 32 | 33 | # Build 34 | build 35 | -------------------------------------------------------------------------------- /chat/data/pretty_json.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 pretty_json.py --in in.json --out out.json 4 | """ 5 | 6 | import argparse 7 | import json 8 | 9 | 10 | if __name__ == "__main__": 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument("--in-file", type=str, required=True) 13 | parser.add_argument("--out-file", type=str, required=True) 14 | args = parser.parse_args() 15 | 16 | with open(args.in_file, "r") as fin: 17 | data = json.load(fin) 18 | 19 | with open(args.out_file, "w") as fout: 20 | json.dump(data, fout, indent=2, ensure_ascii=False) 21 | -------------------------------------------------------------------------------- /chat/server/monitor/replace_model_name.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 replace_model_name.py --in clean_conv_20230809_10k.json 4 | """ 5 | 6 | import argparse 7 | import json 8 | 9 | from chat.server.monitor.clean_battle_data import replace_model_name 10 | 11 | if __name__ == "__main__": 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument("--in-file", type=str, required=True) 14 | args = parser.parse_args() 15 | 16 | convs = json.load(open(args.in_file)) 17 | for x in convs: 18 | x["model"] = replace_model_name(x["model"]) 19 | 20 | with open(args.in_file, "w") as fout: 21 | json.dump(convs, fout, indent=2, ensure_ascii=False) 22 | -------------------------------------------------------------------------------- /chat/server/monitor/conv_release_scripts/count_unique_users.py: -------------------------------------------------------------------------------- 1 | """Count the unique users in a battle log file.""" 2 | 3 | import argparse 4 | import json 5 | 6 | 7 | if __name__ == "__main__": 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument("--input", type=str) 10 | args = parser.parse_args() 11 | 12 | lines = json.load(open(args.input)) 13 | ct_anony_votes = 0 14 | all_users = set() 15 | all_models = set() 16 | for l in lines: 17 | if not l["anony"]: 18 | continue 19 | all_users.add(l["judge"]) 20 | all_models.add(l["model_a"]) 21 | all_models.add(l["model_b"]) 22 | ct_anony_votes += 1 23 | 24 | print(f"#anony_vote: {ct_anony_votes}, #user: {len(all_users)}") 25 | print(f"#model: {len(all_models)}") 26 | -------------------------------------------------------------------------------- /chat/data/merge.py: -------------------------------------------------------------------------------- 1 | """ 2 | Merge two conversation files into one 3 | 4 | Usage: python3 -m chat.data.merge --in file1.json file2.json --out merged.json 5 | """ 6 | 7 | import argparse 8 | import json 9 | from typing import Dict, Sequence, Optional 10 | 11 | 12 | if __name__ == "__main__": 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument("--in-file", type=str, required=True, nargs="+") 15 | parser.add_argument("--out-file", type=str, default="merged.json") 16 | args = parser.parse_args() 17 | 18 | new_content = [] 19 | for in_file in args.in_file: 20 | content = json.load(open(in_file, "r")) 21 | new_content.extend(content) 22 | 23 | print(f"#out: {len(new_content)}") 24 | json.dump(new_content, open(args.out_file, "w"), indent=2, ensure_ascii=False) 25 | -------------------------------------------------------------------------------- /tests/launch_openai_api_test_server.py: -------------------------------------------------------------------------------- 1 | """ 2 | Launch an OpenAI API server with multiple model workers. 3 | """ 4 | import os 5 | 6 | 7 | def launch_process(cmd): 8 | os.popen(cmd) 9 | 10 | 11 | if __name__ == "__main__": 12 | launch_process("python3 -m chat.server.controller") 13 | launch_process("python3 -m chat.server.openai_api_server") 14 | 15 | models = [ 16 | "lmsys/vicuna-7b-v1.3", 17 | "lmsys/chat-t5-3b-v1.0", 18 | "THUDM/chatglm-6b", 19 | "mosaicml/mpt-7b-chat", 20 | ] 21 | 22 | for i, model_path in enumerate(models): 23 | launch_process( 24 | f"CUDA_VISIBLE_DEVICES={i} python3 -m chat.server.model_worker " 25 | f"--model-path {model_path} --port {30000+i} --worker http://localhost:{30000+i}" 26 | ) 27 | 28 | while True: 29 | pass 30 | -------------------------------------------------------------------------------- /chat/server/register_worker.py: -------------------------------------------------------------------------------- 1 | """ 2 | Manually register workers. 3 | 4 | Usage: 5 | python3 -m chat.server.register_worker --controller http://localhost:21001 --worker-name http://localhost:21002 6 | """ 7 | 8 | import argparse 9 | 10 | import requests 11 | 12 | if __name__ == "__main__": 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument("--controller-address", type=str) 15 | parser.add_argument("--worker-name", type=str) 16 | parser.add_argument("--check-heart-beat", action="store_true") 17 | args = parser.parse_args() 18 | 19 | url = args.controller_address + "/register_worker" 20 | data = { 21 | "worker_name": args.worker_name, 22 | "check_heart_beat": args.check_heart_beat, 23 | "worker_status": None, 24 | } 25 | r = requests.post(url, json=data) 26 | assert r.status_code == 200 27 | -------------------------------------------------------------------------------- /chat/server/shutdown_serve.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python shutdown_serve.py --down all 4 | options: "all","controller","model_worker","openai_api_server", `all` means to stop all related servers 5 | """ 6 | 7 | import argparse 8 | import os 9 | import subprocess 10 | 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument( 13 | "--down", choices=["all", "controller", "model_worker", "openai_api_server"] 14 | ) 15 | args = parser.parse_args() 16 | base_shell = "ps -eo user,pid,cmd|grep chat.server{}|grep -v grep|awk '{{print $2}}'|xargs kill -9" 17 | if args.down == "all": 18 | shell_script = base_shell.format("") 19 | else: 20 | serve = f".{args.down}" 21 | shell_script = base_shell.format(serve) 22 | print(f"execute shell cmd: {shell_script}") 23 | subprocess.run(shell_script, shell=True, check=True) 24 | print(f"{args.down} has been shutdown!") 25 | -------------------------------------------------------------------------------- /chat/server/monitor/conv_release_scripts/merge_field.py: -------------------------------------------------------------------------------- 1 | """Count the unique users in a battle log file.""" 2 | 3 | import argparse 4 | import json 5 | 6 | 7 | if __name__ == "__main__": 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument("--input", type=str) 10 | parser.add_argument("--tag-file", type=str) 11 | args = parser.parse_args() 12 | 13 | # build index 14 | objs = json.load(open(args.tag_file)) 15 | new_field_dict = {} 16 | for obj in objs: 17 | new_field_dict[obj["question_id"]] = obj["toxic_chat"] 18 | 19 | objs = json.load(open(args.input)) 20 | for obj in objs: 21 | obj["toxic_chat_tag"] = new_field_dict[obj["question_id"]] 22 | 23 | output = args.input.replace(".json", "_added.json") 24 | with open(output, "w") as fout: 25 | json.dump(objs, fout, indent=2, ensure_ascii=False) 26 | -------------------------------------------------------------------------------- /chat/server/test.py: -------------------------------------------------------------------------------- 1 | import gradio as gr 2 | import os 3 | 4 | # save your HF API token from https:/hf.co/settings/tokens as an env variable to avoid rate limiting 5 | auth_token = os.getenv("auth_token") 6 | 7 | # load a model from https://hf.co/models as an interface, then use it as an api 8 | # you can remove the api_key parameter if you don't care about rate limiting. 9 | api = gr.load("huggingface/gpt2-xl", hf_token=auth_token) 10 | 11 | 12 | def complete_with_gpt(text): 13 | return text[:-50] + api(text[-50:]) 14 | 15 | 16 | with gr.Blocks() as demo: 17 | textbox = gr.Textbox(placeholder="Type here...", lines=4) 18 | btn = gr.Button("Autocomplete") 19 | 20 | # define what will run when the button is clicked, here the textbox is used as both an input and an output 21 | btn.click(fn=complete_with_gpt, inputs=textbox, outputs=textbox, queue=False) 22 | 23 | demo.launch() -------------------------------------------------------------------------------- /chat/server/monitor/conv_release_scripts/sample.py: -------------------------------------------------------------------------------- 1 | """ 2 | Count the unique users in a battle log file. 3 | 4 | Usage: 5 | python3 -input in.json --number 1000 6 | """ 7 | 8 | import argparse 9 | import json 10 | import random 11 | 12 | K = 1000 13 | 14 | if __name__ == "__main__": 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument("--input", type=str) 17 | parser.add_argument("--number", type=int, nargs="+") 18 | args = parser.parse_args() 19 | 20 | convs = json.load(open(args.input)) 21 | random.seed(0) 22 | random.shuffle(convs) 23 | 24 | for number in args.number: 25 | new_convs = convs[:number] 26 | 27 | output = args.input.replace(".json", f"_{number//K}k.json") 28 | with open(output, "w") as fout: 29 | json.dump(new_convs, fout, indent=2, ensure_ascii=False) 30 | 31 | print(f"#in: {len(convs)}, #out: {len(new_convs)}") 32 | print(f"Write to file: {output}") 33 | -------------------------------------------------------------------------------- /chat/model/convert_fp16.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 -m chat.model.convert_fp16 --in in-folder --out out-folder 4 | """ 5 | import argparse 6 | 7 | from transformers import AutoTokenizer, AutoModelForCausalLM 8 | import torch 9 | 10 | 11 | def convert_fp16(in_checkpoint, out_checkpoint): 12 | tokenizer = AutoTokenizer.from_pretrained(in_checkpoint, use_fast=False) 13 | model = AutoModelForCausalLM.from_pretrained( 14 | in_checkpoint, torch_dtype=torch.float16, low_cpu_mem_usage=True 15 | ) 16 | model.save_pretrained(out_checkpoint) 17 | tokenizer.save_pretrained(out_checkpoint) 18 | 19 | 20 | if __name__ == "__main__": 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument("--in-checkpoint", type=str, help="Path to the model") 23 | parser.add_argument("--out-checkpoint", type=str, help="Path to the output model") 24 | args = parser.parse_args() 25 | 26 | convert_fp16(args.in_checkpoint, args.out_checkpoint) 27 | -------------------------------------------------------------------------------- /chat/data/extract_single_round.py: -------------------------------------------------------------------------------- 1 | """ 2 | Extract the first round of the conversations. 3 | 4 | Usage: python3 -m chat.data.extract_single_round --in sharegpt.json 5 | """ 6 | import argparse 7 | import json 8 | 9 | 10 | if __name__ == "__main__": 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument("--in-file", type=str, required=True) 13 | parser.add_argument("--out-file", type=str) 14 | parser.add_argument("--begin", type=int) 15 | parser.add_argument("--end", type=int) 16 | args = parser.parse_args() 17 | 18 | content = json.load(open(args.in_file, "r")) 19 | content = content[args.begin : args.end] 20 | for c in content: 21 | c["conversations"] = c["conversations"][:2] 22 | 23 | if args.out_file: 24 | out_file = args.out_file 25 | else: 26 | out_file = args.in_file.replace(".json", "_single.json") 27 | 28 | print(f"#in: {len(content)}, #out: {len(content)}") 29 | json.dump(content, open(out_file, "w"), indent=2, ensure_ascii=False) 30 | -------------------------------------------------------------------------------- /chat/data/inspect_data.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 -m chat.data.inspect_data --in sharegpt_20230322_clean_lang_split.json 4 | """ 5 | import argparse 6 | import json 7 | import random 8 | 9 | 10 | if __name__ == "__main__": 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument("--in-file", type=str, required=True) 13 | parser.add_argument("--begin", type=int) 14 | parser.add_argument("--random-n", type=int) 15 | args = parser.parse_args() 16 | 17 | content = json.load(open(args.in_file, "r")) 18 | 19 | if args.random_n: 20 | indices = [random.randint(0, len(content) - 1) for _ in range(args.random_n)] 21 | elif args.begin: 22 | indices = range(args.begin, len(content)) 23 | else: 24 | indices = range(0, len(content)) 25 | 26 | for idx in indices: 27 | sample = content[idx] 28 | print("=" * 40) 29 | print(f"no: {idx}, id: {sample['id']}") 30 | for conv in sample["conversations"]: 31 | print(conv["from"] + ": ") 32 | print(conv["value"]) 33 | input() 34 | -------------------------------------------------------------------------------- /chat/data/extract_gpt4_only.py: -------------------------------------------------------------------------------- 1 | """ 2 | Extract the conversations generated by GPT-4 only. 3 | 4 | Usage: python3 -m chat.data.extract_gpt4_only --in sharegpt.json 5 | """ 6 | import argparse 7 | import json 8 | 9 | 10 | if __name__ == "__main__": 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument("--in-file", type=str, required=True) 13 | parser.add_argument("--out-file", type=str) 14 | parser.add_argument("--begin", type=int) 15 | parser.add_argument("--end", type=int) 16 | args = parser.parse_args() 17 | 18 | content = json.load(open(args.in_file, "r")) 19 | content = content[args.begin : args.end] 20 | new_content = [] 21 | for c in content: 22 | model = c.get("model", None) 23 | if model == "gpt4" or model is None: 24 | new_content.append(c) 25 | 26 | if args.out_file: 27 | out_file = args.out_file 28 | else: 29 | out_file = args.in_file.replace(".json", "_gpt4.json") 30 | 31 | print(f"#in: {len(content)}, #out: {len(new_content)}") 32 | json.dump(new_content, open(out_file, "w"), indent=2, ensure_ascii=False) 33 | -------------------------------------------------------------------------------- /chat/data/convert_alpaca.py: -------------------------------------------------------------------------------- 1 | """ 2 | Convert alpaca dataset into sharegpt format. 3 | 4 | Usage: python3 -m chat.data.convert_alpaca --in alpaca_data.json 5 | """ 6 | 7 | import argparse 8 | import json 9 | 10 | from transformers import AutoTokenizer, AutoModelForCausalLM 11 | import numpy as np 12 | 13 | 14 | if __name__ == "__main__": 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument("--in-file", type=str) 17 | parser.add_argument("--out-file", type=str) 18 | args = parser.parse_args() 19 | 20 | content = json.load(open(args.in_file, "r")) 21 | new_content = [] 22 | for i, c in enumerate(content): 23 | if len(c["input"].strip()) > 1: 24 | q, a = c["instruction"] + "\nInput:\n" + c["input"], c["output"] 25 | else: 26 | q, a = c["instruction"], c["output"] 27 | new_content.append( 28 | { 29 | "id": f"alpaca_{i}", 30 | "conversations": [ 31 | {"from": "human", "value": q}, 32 | {"from": "gpt", "value": a}, 33 | ], 34 | } 35 | ) 36 | 37 | print(f"#out: {len(new_content)}") 38 | json.dump(new_content, open(args.out_file, "w"), indent=2, ensure_ascii=False) 39 | -------------------------------------------------------------------------------- /chat/data/split_train_test.py: -------------------------------------------------------------------------------- 1 | """ 2 | Split the dataset into training and test set. 3 | 4 | Usage: python3 -m chat.data.split_train_test --in sharegpt.json 5 | """ 6 | import argparse 7 | import json 8 | 9 | import numpy as np 10 | 11 | 12 | if __name__ == "__main__": 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument("--in-file", type=str, required=True) 15 | parser.add_argument("--begin", type=int, default=0) 16 | parser.add_argument("--end", type=int, default=100) 17 | parser.add_argument("--ratio", type=float, default=0.9) 18 | args = parser.parse_args() 19 | 20 | content = json.load(open(args.in_file, "r")) 21 | np.random.seed(0) 22 | 23 | perm = np.random.permutation(len(content)) 24 | content = [content[i] for i in perm] 25 | split = int(args.ratio * len(content)) 26 | 27 | train_set = content[:split] 28 | test_set = content[split:] 29 | 30 | print(f"#train: {len(train_set)}, #test: {len(test_set)}") 31 | train_name = args.in_file.replace(".json", "_train.json") 32 | test_name = args.in_file.replace(".json", "_test.json") 33 | json.dump(train_set, open(train_name, "w"), indent=2, ensure_ascii=False) 34 | json.dump(test_set, open(test_name, "w"), indent=2, ensure_ascii=False) 35 | -------------------------------------------------------------------------------- /chat/data/filter_wrong_format.py: -------------------------------------------------------------------------------- 1 | """ 2 | Filter conversations with wrong formats. 3 | 4 | Usage: 5 | python3 -m chat.data.filter_wrong_format --in input.json --out output.json 6 | 7 | """ 8 | import argparse 9 | import json 10 | import re 11 | 12 | from tqdm import tqdm 13 | 14 | wrong_indices_pattern = re.compile("\n1\. [^2]*\n1\. ") 15 | 16 | 17 | def should_skip(conv): 18 | # Filter wrong list indices like https://sharegpt.com/c/1pREAGO 19 | for sentence in conv["conversations"]: 20 | val = sentence["value"] 21 | sub = re.search(wrong_indices_pattern, val) 22 | if sub is not None: 23 | return True 24 | 25 | return False 26 | 27 | 28 | if __name__ == "__main__": 29 | parser = argparse.ArgumentParser() 30 | parser.add_argument("--in-file", type=str, required=True) 31 | parser.add_argument("--out-file", type=str, required=True) 32 | args = parser.parse_args() 33 | 34 | content = json.load(open(args.in_file, "r")) 35 | 36 | new_content = [] 37 | for conv in tqdm(content): 38 | if should_skip(conv): 39 | print(f"{conv['id']} contains a wrong format.") 40 | else: 41 | new_content.append(conv) 42 | 43 | print(f"#in: {len(content)}, #out: {len(new_content)}") 44 | json.dump(new_content, open(args.out_file, "w"), indent=2, ensure_ascii=False) 45 | -------------------------------------------------------------------------------- /chat/data/sample.py: -------------------------------------------------------------------------------- 1 | """ 2 | Sample some conversations from a file. 3 | 4 | Usage: python3 -m chat.data.sample --in sharegpt.json --out sampled.json 5 | """ 6 | import argparse 7 | import json 8 | 9 | import numpy as np 10 | 11 | 12 | if __name__ == "__main__": 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument("--in-file", type=str, required=True) 15 | parser.add_argument("--out-file", type=str, default="sampled.json") 16 | parser.add_argument("--begin", type=int, default=0) 17 | parser.add_argument("--end", type=int, default=100) 18 | parser.add_argument("--max-length", type=int, default=1024) 19 | parser.add_argument("--keep-order", action="store_true") 20 | args = parser.parse_args() 21 | 22 | content = json.load(open(args.in_file, "r")) 23 | if not args.keep_order: 24 | np.random.seed(42) 25 | np.random.shuffle(content) 26 | 27 | new_content = [] 28 | for i in range(args.begin, min(args.end, len(content))): 29 | sample = content[i] 30 | concat = "" 31 | for s in sample["conversations"]: 32 | concat += s["value"] 33 | 34 | if len(concat) > args.max_length: 35 | continue 36 | 37 | new_content.append(sample) 38 | 39 | print(f"#in: {len(content)}, #out: {len(new_content)}") 40 | json.dump(new_content, open(args.out_file, "w"), indent=2, ensure_ascii=False) 41 | -------------------------------------------------------------------------------- /tests/test_openai_langchain.py: -------------------------------------------------------------------------------- 1 | # Usage: 2 | # python3 -m chat.server.model_worker --model-path lmsys/vicuna-7b-v1.3 --model-names gpt-3.5-turbo,text-davinci-003,text-embedding-ada-002 3 | # export OPENAI_API_BASE=http://localhost:8000/v1 4 | # export OPENAI_API_KEY=EMPTY 5 | # wget https://raw.githubusercontent.com/hwchase17/langchain/v0.0.200/docs/modules/state_of_the_union.txt 6 | 7 | import os 8 | 9 | from langchain.chat_models import ChatOpenAI 10 | from langchain.document_loaders import TextLoader 11 | from langchain.embeddings import OpenAIEmbeddings 12 | from langchain.indexes import VectorstoreIndexCreator 13 | 14 | 15 | def test_chain(): 16 | embedding = OpenAIEmbeddings(model="text-embedding-ada-002") 17 | loader = TextLoader("state_of_the_union.txt") 18 | index = VectorstoreIndexCreator(embedding=embedding).from_loaders([loader]) 19 | 20 | llm = ChatOpenAI(model="gpt-3.5-turbo") 21 | 22 | questions = [ 23 | "Who is the speaker", 24 | "What did the president say about Ketanji Brown Jackson", 25 | "What are the threats to America", 26 | "Who are mentioned in the speech", 27 | "Who is the vice president", 28 | "How many projects were announced", 29 | ] 30 | 31 | for query in questions: 32 | print("Query:", query) 33 | print("Answer:", index.query(query, llm=llm)) 34 | 35 | 36 | if __name__ == "__main__": 37 | os.environ["OPENAI_API_BASE"] = "http://localhost:8000/v1" 38 | os.environ["OPENAI_API_KEY"] = "empty" 39 | test_chain() 40 | -------------------------------------------------------------------------------- /chat/server/monitor/leaderboard_csv_to_html.py: -------------------------------------------------------------------------------- 1 | """ 2 | Convert a leaderboard csv file to html table used in the blog. 3 | 4 | Usage: 5 | python3 leaderboard_csv_to_html.py --in leaderboard_table_20230619.csv 6 | """ 7 | import argparse 8 | 9 | import numpy as np 10 | 11 | from chat.server.monitor.monitor import load_leaderboard_table_csv 12 | 13 | 14 | def model_hyperlink(model_name, link): 15 | return f' {model_name} ' 16 | 17 | 18 | if __name__ == "__main__": 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument("--input", type=str, required=True) 21 | args = parser.parse_args() 22 | 23 | data = load_leaderboard_table_csv(args.input, add_hyperlink=False) 24 | headers = [ 25 | "Model", 26 | "MT-bench (score)", 27 | "Arena Elo rating", 28 | "MMLU", 29 | "License", 30 | ] 31 | values = [] 32 | for item in data: 33 | row = [] 34 | for key in headers: 35 | value = item[key] 36 | row.append(value) 37 | row[0] = model_hyperlink(item["Model"], item["Link"]) 38 | values.append(row) 39 | values.sort(key=lambda x: -x[1] if not np.isnan(x[1]) else 1e9) 40 | 41 | for value in values: 42 | row = "" 43 | for x in value: 44 | try: 45 | if np.isnan(x): 46 | x = "-" 47 | except TypeError: 48 | pass 49 | row += f" {x} " 50 | row += "" 51 | print(row) 52 | -------------------------------------------------------------------------------- /chat/model/upload_hub.py: -------------------------------------------------------------------------------- 1 | """ 2 | Upload weights to huggingface. 3 | 4 | Usage: 5 | python3 -m chat.model.upload_hub --model-path ~/model_weights/vicuna-13b --hub-repo-id lmsys/vicuna-13b-v1.3 6 | """ 7 | import argparse 8 | import tempfile 9 | 10 | import torch 11 | from transformers import AutoTokenizer, AutoModelForCausalLM 12 | 13 | 14 | def upload_hub(model_path, hub_repo_id, component, private): 15 | if component == "all": 16 | components = ["model", "tokenizer"] 17 | else: 18 | components = [component] 19 | 20 | kwargs = {"push_to_hub": True, "repo_id": hub_repo_id, "private": args.private} 21 | 22 | if "model" in components: 23 | model = AutoModelForCausalLM.from_pretrained( 24 | model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True 25 | ) 26 | with tempfile.TemporaryDirectory() as tmp_path: 27 | model.save_pretrained(tmp_path, **kwargs) 28 | 29 | if "tokenizer" in components: 30 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) 31 | with tempfile.TemporaryDirectory() as tmp_path: 32 | tokenizer.save_pretrained(tmp_path, **kwargs) 33 | 34 | 35 | if __name__ == "__main__": 36 | parser = argparse.ArgumentParser() 37 | parser.add_argument("--model-path", type=str, required=True) 38 | parser.add_argument("--hub-repo-id", type=str, required=True) 39 | parser.add_argument( 40 | "--component", type=str, choices=["all", "model", "tokenizer"], default="all" 41 | ) 42 | parser.add_argument("--private", action="store_true") 43 | args = parser.parse_args() 44 | 45 | upload_hub(args.model_path, args.hub_repo_id, args.component, args.private) 46 | -------------------------------------------------------------------------------- /chat/model/apply_lora.py: -------------------------------------------------------------------------------- 1 | """ 2 | Apply the LoRA weights on top of a base model. 3 | 4 | Usage: 5 | python3 -m chat.model.apply_lora --base ~/model_weights/llama-7b --target ~/model_weights/baize-7b --lora project-baize/baize-lora-7B 6 | 7 | Dependency: 8 | pip3 install git+https://github.com/huggingface/peft.git@2822398fbe896f25d4dac5e468624dc5fd65a51b 9 | """ 10 | import argparse 11 | 12 | import torch 13 | from peft import PeftModel 14 | from transformers import AutoTokenizer, AutoModelForCausalLM 15 | 16 | 17 | def apply_lora(base_model_path, target_model_path, lora_path): 18 | print(f"Loading the base model from {base_model_path}") 19 | base = AutoModelForCausalLM.from_pretrained( 20 | base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True 21 | ) 22 | base_tokenizer = AutoTokenizer.from_pretrained(base_model_path, use_fast=False) 23 | 24 | print(f"Loading the LoRA adapter from {lora_path}") 25 | 26 | lora_model = PeftModel.from_pretrained( 27 | base, 28 | lora_path, 29 | # torch_dtype=torch.float16 30 | ) 31 | 32 | print("Applying the LoRA") 33 | model = lora_model.merge_and_unload() 34 | 35 | print(f"Saving the target model to {target_model_path}") 36 | model.save_pretrained(target_model_path) 37 | base_tokenizer.save_pretrained(target_model_path) 38 | 39 | 40 | if __name__ == "__main__": 41 | parser = argparse.ArgumentParser() 42 | parser.add_argument("--base-model-path", type=str, required=True) 43 | parser.add_argument("--target-model-path", type=str, required=True) 44 | parser.add_argument("--lora-path", type=str, required=True) 45 | 46 | args = parser.parse_args() 47 | 48 | apply_lora(args.base_model_path, args.target_model_path, args.lora_path) 49 | -------------------------------------------------------------------------------- /chat/server/monitor/tag_openai_moderation.py: -------------------------------------------------------------------------------- 1 | """ 2 | Add OpenAI moderation API results to all conversations. 3 | """ 4 | import argparse 5 | from concurrent.futures import ThreadPoolExecutor 6 | import json 7 | import os 8 | import time 9 | 10 | import openai 11 | import requests 12 | from tqdm import tqdm 13 | 14 | 15 | API_MAX_RETRY = 16 16 | API_RETRY_SLEEP = 10 17 | API_ERROR_OUTPUT = "$ERROR$" 18 | 19 | 20 | def tag_moderation(text): 21 | result = API_ERROR_OUTPUT 22 | for _ in range(API_MAX_RETRY): 23 | try: 24 | result = openai.Moderation.create(input=text)["results"][0] 25 | break 26 | except openai.error.OpenAIError as e: 27 | print(type(e), e) 28 | time.sleep(API_RETRY_SLEEP) 29 | 30 | return result 31 | 32 | 33 | def tag_openai_moderation(x): 34 | conv = x["conversation_a"] 35 | user_prompts = "\n".join([x["content"] for x in conv if x["role"] == "user"]) 36 | result = tag_moderation(user_prompts) 37 | x["openai_moderation"] = result 38 | 39 | 40 | if __name__ == "__main__": 41 | parser = argparse.ArgumentParser() 42 | parser.add_argument("--input", type=str, required=True) 43 | parser.add_argument( 44 | "--parallel", type=int, default=1, help="The number of concurrent API calls." 45 | ) 46 | parser.add_argument("--first-n", type=int) 47 | args = parser.parse_args() 48 | 49 | battles = json.load(open(args.input)) 50 | 51 | if args.first_n: 52 | battles = battles[: args.first_n] 53 | 54 | with ThreadPoolExecutor(args.parallel) as executor: 55 | for line in tqdm( 56 | executor.map(tag_openai_moderation, battles), total=len(battles) 57 | ): 58 | pass 59 | 60 | output = args.input.replace(".json", "_tagged.json") 61 | with open(output, "w") as fout: 62 | json.dump(battles, fout, indent=2, ensure_ascii=False) 63 | print(f"Write cleaned data to {output}") 64 | -------------------------------------------------------------------------------- /chat/server/gateway/README.md: -------------------------------------------------------------------------------- 1 | # fastchat Nginx Gateway 2 | 3 | ## Purpose of the Gateway 4 | 5 | The Nginx gateway serves the following purposes: 6 | 7 | 1. Protects Gradio servers by acting as a firewall. 8 | 2. Facilitates dynamic mounting and unmounting of Gradio servers. 9 | 3. Provides load balancing for Gradio servers. 10 | 4. Offers additional security features, such as total connection limit. 11 | 5. Reduces attack surface by requiring only a single public port to be exposed for serving. 12 | 13 | ## Deployment and Updating of the Gateway 14 | 15 | ### Installing Nginx 16 | 17 | On Debian-based distributions (e.g., Ubuntu): 18 | 19 | ```bash 20 | sudo apt update 21 | sudo apt install nginx 22 | ``` 23 | On Red Hat-based distributions (e.g., CentOS, Fedora): 24 | 25 | ```bash 26 | sudo yum install epel-release 27 | sudo yum install nginx 28 | ``` 29 | 30 | ### Deployment 31 | 32 | Copy `nginx.conf` to `/etc/nginx/nginx.conf` (need sudo permission). 33 | 34 | Replace the port number 7860 in `server localhost:7860` with the port where you deploy the Gradio web server. 35 | 36 | Modify `upstream websocket` to configure Gradio servers behind the gateway. 37 | 38 | Lastly, update Nginx. 39 | 40 | 41 | ### HTTPS Deployment with a Public Domain URL 42 | 43 | Make sure you obtain the HTTPS certificate and the private key used to generate the certificate. 44 | 45 | Fill the addresses to your certificate and private key in the `[PATH_TO_SSL_CERT]` and `[PATH_TO_PRIVATE_KEY]` fields. 46 | 47 | If you have your own domain url to serve the chatbot, replace the chat.lmsys.org url with your own domain url. 48 | 49 | ### Updating 50 | 51 | Every time when `/etc/nginx/nginx.conf` is modified, you need to update the Nginx service: 52 | 53 | ```bash 54 | sudo nginx -t # check `/etc/nginx/nginx.conf` 55 | sudo systemctl reload nginx # restart Nginx service to load the new config 56 | sudo systemctl status nginx # check the status of the Nginx service. It should be active (running). 57 | ``` 58 | -------------------------------------------------------------------------------- /chat/data/prepare_all.py: -------------------------------------------------------------------------------- 1 | """Prepare all datasets.""" 2 | 3 | import argparse 4 | import os 5 | 6 | from chat.utils import run_cmd 7 | 8 | 9 | if __name__ == "__main__": 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument("--prefix", type=str, default="~/datasets/sharegpt_20230521") 12 | parser.add_argument( 13 | "--model-name-or-path", type=str, default="meta-llama/Llama-2-7b-chat-hf" 14 | ) 15 | parser.add_argument("--seq-len", type=int, default=4096) 16 | args = parser.parse_args() 17 | 18 | in_prefix = args.prefix 19 | model_path = args.model_name_or_path 20 | seq_len = args.seq_len 21 | prefix = ( 22 | f"{in_prefix}_{seq_len}".replace("4096", "4k") 23 | .replace("8192", "8k") 24 | .replace("16384", "16k") 25 | ) 26 | 27 | cmd_list = [ 28 | f"python3 -m fastchat.data.clean_sharegpt --in {in_prefix}_html.json --out {prefix}_clean.json", 29 | f"python3 -m fastchat.data.optional_clean --in {prefix}_clean.json --out {prefix}_clean_lang.json --skip-lang ko", 30 | f"python3 -m fastchat.data.split_long_conversation --in {prefix}_clean_lang.json --out {prefix}_clean_lang_split.json --model-name {model_path} --max-length {seq_len}", 31 | f"python3 -m fastchat.data.filter_wrong_format --in {prefix}_clean_lang_split.json --out {prefix}_clean_lang_split.json", 32 | f"python3 -m fastchat.data.split_train_test --in {prefix}_clean_lang_split.json --ratio 0.99", 33 | f"python3 -m fastchat.data.hardcoded_questions", 34 | f"python3 -m fastchat.data.merge --in {prefix}_clean_lang_split_train.json hardcoded.json --out {prefix}_clean_lang_split_identity.json", 35 | f"python3 -m fastchat.data.extract_gpt4_only --in {prefix}_clean_lang_split_identity.json", 36 | f"python3 -m fastchat.data.extract_single_round --in {prefix}_clean_lang_split_identity.json", 37 | ] 38 | 39 | for cmd in cmd_list: 40 | ret = run_cmd(cmd) 41 | if ret != 0: 42 | exit(ret) 43 | -------------------------------------------------------------------------------- /chat/model/make_delta.py: -------------------------------------------------------------------------------- 1 | """ 2 | Make the delta weights by subtracting base weights. 3 | 4 | Usage: 5 | python3 -m chat.model.make_delta --base ~/model_weights/llama-13b --target ~/model_weights/vicuna-13b --delta ~/model_weights/vicuna-13b-delta --hub-repo-id lmsys/vicuna-13b-delta-v1.1 6 | """ 7 | import argparse 8 | 9 | import torch 10 | from tqdm import tqdm 11 | from transformers import AutoTokenizer, AutoModelForCausalLM 12 | 13 | 14 | def make_delta(base_model_path, target_model_path, delta_path): 15 | print(f"Loading the base model from {base_model_path}") 16 | base = AutoModelForCausalLM.from_pretrained( 17 | base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True 18 | ) 19 | 20 | print(f"Loading the target model from {target_model_path}") 21 | target = AutoModelForCausalLM.from_pretrained( 22 | target_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True 23 | ) 24 | target_tokenizer = AutoTokenizer.from_pretrained(target_model_path, use_fast=False) 25 | 26 | print("Calculating the delta") 27 | for name, param in tqdm(target.state_dict().items(), desc="Calculating delta"): 28 | assert name in base.state_dict() 29 | param.data -= base.state_dict()[name] 30 | 31 | print(f"Saving the delta to {delta_path}") 32 | if args.hub_repo_id: 33 | kwargs = {"push_to_hub": True, "repo_id": args.hub_repo_id} 34 | else: 35 | kwargs = {} 36 | target.save_pretrained(delta_path, **kwargs) 37 | target_tokenizer.save_pretrained(delta_path, **kwargs) 38 | 39 | 40 | if __name__ == "__main__": 41 | parser = argparse.ArgumentParser() 42 | parser.add_argument("--base-model-path", type=str, required=True) 43 | parser.add_argument("--target-model-path", type=str, required=True) 44 | parser.add_argument("--delta-path", type=str, required=True) 45 | parser.add_argument("--hub-repo-id", type=str) 46 | args = parser.parse_args() 47 | 48 | make_delta(args.base_model_path, args.target_model_path, args.delta_path) 49 | -------------------------------------------------------------------------------- /chat/constants.py: -------------------------------------------------------------------------------- 1 | """ 2 | Global constants. 3 | """ 4 | 5 | from enum import IntEnum 6 | import os 7 | 8 | REPO_PATH = os.path.dirname(os.path.dirname(__file__)) 9 | 10 | ##### For the gradio web server 11 | SERVER_ERROR_MSG = ( 12 | "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**" 13 | ) 14 | MODERATION_MSG = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE FIX YOUR INPUT AND TRY AGAIN." 15 | CONVERSATION_LIMIT_MSG = "YOU HAVE REACHED THE CONVERSATION LENGTH LIMIT. PLEASE CLEAR HISTORY AND START A NEW CONVERSATION." 16 | INACTIVE_MSG = "THIS SESSION HAS BEEN INACTIVE FOR TOO LONG. PLEASE REFRESH THIS PAGE." 17 | # Maximum input length 18 | INPUT_CHAR_LEN_LIMIT = int(os.getenv("FASTCHAT_INPUT_CHAR_LEN_LIMIT", 2560)) 19 | # Maximum conversation turns 20 | CONVERSATION_TURN_LIMIT = 50 21 | # Session expiration time 22 | SESSION_EXPIRATION_TIME = 3600 23 | # The output dir of log files 24 | LOGDIR = os.getenv("LOGDIR", ".") 25 | # CPU Instruction Set Architecture 26 | CPU_ISA = os.getenv("CPU_ISA") 27 | 28 | 29 | ##### For the controller and workers (could be overwritten through ENV variables.) 30 | CONTROLLER_HEART_BEAT_EXPIRATION = int( 31 | os.getenv("FASTCHAT_CONTROLLER_HEART_BEAT_EXPIRATION", 90) 32 | ) 33 | WORKER_HEART_BEAT_INTERVAL = int(os.getenv("FASTCHAT_WORKER_HEART_BEAT_INTERVAL", 45)) 34 | WORKER_API_TIMEOUT = int(os.getenv("FASTCHAT_WORKER_API_TIMEOUT", 100)) 35 | WORKER_API_EMBEDDING_BATCH_SIZE = int( 36 | os.getenv("FASTCHAT_WORKER_API_EMBEDDING_BATCH_SIZE", 4) 37 | ) 38 | 39 | 40 | class ErrorCode(IntEnum): 41 | """ 42 | https://platform.openai.com/docs/guides/error-codes/api-errors 43 | """ 44 | 45 | VALIDATION_TYPE_ERROR = 40001 46 | 47 | INVALID_AUTH_KEY = 40101 48 | INCORRECT_AUTH_KEY = 40102 49 | NO_PERMISSION = 40103 50 | 51 | INVALID_MODEL = 40301 52 | PARAM_OUT_OF_RANGE = 40302 53 | CONTEXT_OVERFLOW = 40303 54 | 55 | RATE_LIMIT = 42901 56 | QUOTA_EXCEEDED = 42902 57 | ENGINE_OVERLOADED = 42903 58 | 59 | INTERNAL_ERROR = 50001 60 | CUDA_OUT_OF_MEMORY = 50002 61 | GRADIO_REQUEST_ERROR = 50003 62 | GRADIO_STREAM_UNKNOWN_ERROR = 50004 63 | CONTROLLER_NO_WORKER = 50005 64 | CONTROLLER_WORKER_TIMEOUT = 50006 65 | -------------------------------------------------------------------------------- /chat/server/huggingface_api.py: -------------------------------------------------------------------------------- 1 | """ 2 | Use FastChat with Hugging Face generation APIs. 3 | 4 | Usage: 5 | python3 -m chat.server.huggingface_api --model lmsys/vicuna-7b-v1.3 6 | python3 -m chat.server.huggingface_api --model lmsys/chat-t5-3b-v1.0 7 | """ 8 | import argparse 9 | 10 | import torch 11 | 12 | from chat.model import load_model, get_conversation_template, add_model_args 13 | 14 | 15 | @torch.inference_mode() 16 | def main(args): 17 | model, tokenizer = load_model( 18 | args.model_path, 19 | device=args.device, 20 | num_gpus=args.num_gpus, 21 | max_gpu_memory=args.max_gpu_memory, 22 | load_8bit=args.load_8bit, 23 | cpu_offloading=args.cpu_offloading, 24 | revision=args.revision, 25 | debug=args.debug, 26 | ) 27 | 28 | msg = args.message 29 | 30 | conv = get_conversation_template(args.model_path) 31 | conv.append_message(conv.roles[0], msg) 32 | conv.append_message(conv.roles[1], None) 33 | prompt = conv.get_prompt() 34 | 35 | inputs = tokenizer([prompt], return_tensors="pt").to(args.device) 36 | output_ids = model.generate( 37 | **inputs, 38 | do_sample=True if args.temperature > 1e-5 else False, 39 | temperature=args.temperature, 40 | repetition_penalty=args.repetition_penalty, 41 | max_new_tokens=args.max_new_tokens, 42 | ) 43 | 44 | if model.config.is_encoder_decoder: 45 | output_ids = output_ids[0] 46 | else: 47 | output_ids = output_ids[0][len(inputs["input_ids"][0]) :] 48 | outputs = tokenizer.decode( 49 | output_ids, skip_special_tokens=True, spaces_between_special_tokens=False 50 | ) 51 | 52 | print(f"{conv.roles[0]}: {msg}") 53 | print(f"{conv.roles[1]}: {outputs}") 54 | 55 | 56 | if __name__ == "__main__": 57 | parser = argparse.ArgumentParser() 58 | add_model_args(parser) 59 | parser.add_argument("--temperature", type=float, default=0.7) 60 | parser.add_argument("--repetition_penalty", type=float, default=1.0) 61 | parser.add_argument("--max-new-tokens", type=int, default=512) 62 | parser.add_argument("--debug", action="store_true") 63 | parser.add_argument("--message", type=str, default="Hello! Who are you?") 64 | args = parser.parse_args() 65 | 66 | # Reset default repetition penalty for T5 models. 67 | if "t5" in args.model_path and args.repetition_penalty == 1.0: 68 | args.repetition_penalty = 1.2 69 | 70 | main(args) 71 | -------------------------------------------------------------------------------- /chat/server/monitor/summarize_cluster.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3 | Usage: 4 | python3 summarize_cluster.py --in results_c20_kmeans_cluster.pkl --model gpt-4 5 | """ 6 | import argparse 7 | import pickle 8 | 9 | from chat.llm_judge.common import ( 10 | chat_compeletion_openai, 11 | chat_compeletion_anthropic, 12 | ) 13 | from chat.conversation import get_conv_template 14 | 15 | 16 | def truncate_string(s, l): 17 | half = int(l // 2) 18 | return s[:half] + s[-half:] if len(s) > l else s 19 | 20 | 21 | if __name__ == "__main__": 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument("--input-file", type=str, required=True) 24 | parser.add_argument("--model", type=str, default="gpt-3.5-turbo") 25 | parser.add_argument("--num-prompts", type=int, default=100) 26 | args = parser.parse_args() 27 | 28 | model = args.model 29 | 30 | cluster_infos = pickle.load(open(args.input_file, "rb")) 31 | num_total_prompts = sum([x[0] for x in cluster_infos]) 32 | 33 | topics = [] 34 | percentages = [] 35 | for i, info in enumerate(cluster_infos): 36 | num_samples, prompts = info 37 | percentage = num_samples / num_total_prompts 38 | print( 39 | f"cluster {i}, #prompts {num_samples}, percentage: {percentage * 100:.2f}%" 40 | ) 41 | instruct = "Given a list of user messages, use less than 8 words to summarize a central topic for all messages in English. Your output should only include a single line. Try to be specific." 42 | prompt = "\n".join( 43 | [truncate_string(x, l=200) for x in prompts[: args.num_prompts]] 44 | ) 45 | prompt = "BEGIN OF THE MESSAGE LIST\n" + prompt + "\nEND OF THE MESSAGE LIST." 46 | 47 | if "gpt" in model: 48 | template_name = "chatgpt" 49 | completion_func = chat_compeletion_openai 50 | elif "claude" in model: 51 | template_name = "claude" 52 | completion_func = chat_compeletion_anthropic 53 | 54 | conv = get_conv_template(template_name) 55 | conv.set_system_message(instruct) 56 | conv.append_message(conv.roles[0], prompt) 57 | conv.append_message(conv.roles[1], None) 58 | 59 | topic = completion_func(model, conv, temperature=0, max_tokens=256) 60 | print(topic) 61 | 62 | topics.append(topic) 63 | percentages.append(round(percentage, 6)) 64 | 65 | print() 66 | print(f"topics: {topics}") 67 | print(f"percentages: {percentages}") 68 | -------------------------------------------------------------------------------- /tests/test_cli.py: -------------------------------------------------------------------------------- 1 | """Test command line interface for model inference.""" 2 | import argparse 3 | import os 4 | 5 | from chat.utils import run_cmd 6 | 7 | 8 | def test_single_gpu(): 9 | models = [ 10 | "lmsys/vicuna-7b-v1.3", 11 | "lmsys/longchat-7b-16k", 12 | "lmsys/chat-t5-3b-v1.0", 13 | "THUDM/chatglm-6b", 14 | "THUDM/chatglm2-6b", 15 | "mosaicml/mpt-7b-chat", 16 | "project-baize/baize-v2-7b", 17 | "h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b", 18 | "tiiuae/falcon-7b-instruct", 19 | "~/model_weights/alpaca-7b", 20 | "~/model_weights/RWKV-4-Raven-7B-v11x-Eng99%-Other1%-20230429-ctx8192.pth", 21 | ] 22 | 23 | for model_path in models: 24 | if "model_weights" in model_path and not os.path.exists( 25 | os.path.expanduser(model_path) 26 | ): 27 | continue 28 | cmd = ( 29 | f"python3 -m chat.serve.cli --model-path {model_path} " 30 | f"--style programmatic < test_cli_inputs.txt" 31 | ) 32 | ret = run_cmd(cmd) 33 | if ret != 0: 34 | return 35 | 36 | print("") 37 | 38 | 39 | def test_multi_gpu(): 40 | models = [ 41 | "lmsys/vicuna-13b-v1.3", 42 | ] 43 | 44 | for model_path in models: 45 | cmd = ( 46 | f"python3 -m chat.serve.cli --model-path {model_path} " 47 | f"--style programmatic --num-gpus 2 --max-gpu-memory 14Gib < test_cli_inputs.txt" 48 | ) 49 | ret = run_cmd(cmd) 50 | if ret != 0: 51 | return 52 | print("") 53 | 54 | 55 | def test_8bit(): 56 | models = [ 57 | "lmsys/vicuna-13b-v1.3", 58 | ] 59 | 60 | for model_path in models: 61 | cmd = ( 62 | f"python3 -m chat.serve.cli --model-path {model_path} " 63 | f"--style programmatic --load-8bit < test_cli_inputs.txt" 64 | ) 65 | ret = run_cmd(cmd) 66 | if ret != 0: 67 | return 68 | print("") 69 | 70 | 71 | def test_hf_api(): 72 | models = [ 73 | "lmsys/vicuna-7b-v1.3", 74 | "lmsys/chat-t5-3b-v1.0", 75 | ] 76 | 77 | for model_path in models: 78 | cmd = f"python3 -m chat.serve.huggingface_api --model-path {model_path}" 79 | ret = run_cmd(cmd) 80 | if ret != 0: 81 | return 82 | print("") 83 | 84 | 85 | if __name__ == "__main__": 86 | test_single_gpu() 87 | test_multi_gpu() 88 | test_8bit() 89 | test_hf_api() 90 | -------------------------------------------------------------------------------- /chat/modules/gptq.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | import os 3 | from os.path import isdir, isfile 4 | from pathlib import Path 5 | import sys 6 | 7 | from transformers import AutoTokenizer 8 | 9 | 10 | @dataclass 11 | class GptqConfig: 12 | ckpt: str = field( 13 | default=None, 14 | metadata={ 15 | "help": "Load quantized model. The path to the local GPTQ checkpoint." 16 | }, 17 | ) 18 | wbits: int = field(default=16, metadata={"help": "#bits to use for quantization"}) 19 | groupsize: int = field( 20 | default=-1, 21 | metadata={"help": "Groupsize to use for quantization; default uses full row."}, 22 | ) 23 | act_order: bool = field( 24 | default=True, 25 | metadata={"help": "Whether to apply the activation order GPTQ heuristic"}, 26 | ) 27 | 28 | 29 | def load_gptq_quantized(model_name, gptq_config: GptqConfig): 30 | print("Loading GPTQ quantized model...") 31 | 32 | try: 33 | script_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 34 | module_path = os.path.join(script_path, "../repositories/GPTQ-for-LLaMa") 35 | 36 | sys.path.insert(0, module_path) 37 | from llama import load_quant 38 | except ImportError as e: 39 | print(f"Error: Failed to load GPTQ-for-LLaMa. {e}") 40 | print("See https://github.com/lm-sys/FastChat/blob/main/docs/gptq.md") 41 | sys.exit(-1) 42 | 43 | tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False) 44 | # only `fastest-inference-4bit` branch cares about `act_order` 45 | if gptq_config.act_order: 46 | model = load_quant( 47 | model_name, 48 | find_gptq_ckpt(gptq_config), 49 | gptq_config.wbits, 50 | gptq_config.groupsize, 51 | act_order=gptq_config.act_order, 52 | ) 53 | else: 54 | # other branches 55 | model = load_quant( 56 | model_name, 57 | find_gptq_ckpt(gptq_config), 58 | gptq_config.wbits, 59 | gptq_config.groupsize, 60 | ) 61 | 62 | return model, tokenizer 63 | 64 | 65 | def find_gptq_ckpt(gptq_config: GptqConfig): 66 | if Path(gptq_config.ckpt).is_file(): 67 | return gptq_config.ckpt 68 | 69 | for ext in ["*.pt", "*.safetensors"]: 70 | matched_result = sorted(Path(gptq_config.ckpt).glob(ext)) 71 | if len(matched_result) > 0: 72 | return str(matched_result[-1]) 73 | 74 | print("Error: gptq checkpoint not found") 75 | sys.exit(1) 76 | -------------------------------------------------------------------------------- /chat/data/optional_replace.py: -------------------------------------------------------------------------------- 1 | """ 2 | Do optional replace of bos/eos/pad/unk. 3 | 4 | Usage: 5 | python3 -m chat.data.optional_replace --in input.json --out output.json --model-name-or-path 6 | 7 | Requirement: 8 | pip3 install transformers tqdm 9 | """ 10 | import argparse 11 | import json 12 | import traceback 13 | 14 | import transformers 15 | from tqdm import tqdm 16 | 17 | 18 | def replace_special_tokens( 19 | tokenizer: transformers.PreTrainedTokenizer, text: str 20 | ) -> str: 21 | if not text: 22 | return text 23 | 24 | def _insert_vline(token: str) -> str: 25 | if len(token) < 2: 26 | return " " 27 | elif len(token) == 2: 28 | return f"{token[0]}|{token[1]}" 29 | else: 30 | return f"{token[:1]}|{token[1:-1]}|{token[-1:]}" 31 | 32 | if tokenizer.bos_token: 33 | text = text.replace(tokenizer.bos_token, _insert_vline(tokenizer.bos_token)) 34 | if tokenizer.eos_token: 35 | text = text.replace(tokenizer.eos_token, _insert_vline(tokenizer.eos_token)) 36 | if tokenizer.pad_token: 37 | text = text.replace(tokenizer.pad_token, _insert_vline(tokenizer.pad_token)) 38 | if tokenizer.unk_token: 39 | text = text.replace(tokenizer.unk_token, _insert_vline(tokenizer.unk_token)) 40 | return text 41 | 42 | 43 | def replace(conv, tokenizer): 44 | # Replace bos/eos/pad/unk tokens 45 | if tokenizer: 46 | try: 47 | for sentence in conv["conversations"]: 48 | sentence["value"] = replace_special_tokens(tokenizer, sentence["value"]) 49 | except Exception as e: 50 | traceback.print_exc() 51 | 52 | 53 | if __name__ == "__main__": 54 | parser = argparse.ArgumentParser() 55 | parser.add_argument("--in-file", type=str, required=True) 56 | parser.add_argument("--out-file", type=str) 57 | parser.add_argument( 58 | "--model-name-or-path", 59 | type=str, 60 | help="The directory or address where the model token is stored.", 61 | ) 62 | args = parser.parse_args() 63 | 64 | in_file = args.in_file 65 | out_file = args.out_file 66 | tokenizer = None 67 | if args.model_name_or_path: 68 | tokenizer = transformers.AutoTokenizer.from_pretrained( 69 | args.model_name_or_path, 70 | trust_remote_code=True, 71 | use_fast=False, 72 | ) 73 | 74 | if out_file is None: 75 | out_file = f"{in_file}_replace.json" 76 | 77 | content = json.load(open(in_file, "r")) 78 | 79 | for conv in tqdm(content): 80 | replace(conv, tokenizer) 81 | 82 | json.dump(content, open(out_file, "w"), indent=2, ensure_ascii=False) 83 | -------------------------------------------------------------------------------- /chat/data/get_stats.py: -------------------------------------------------------------------------------- 1 | """ 2 | Get stats of a dataset. 3 | 4 | Usage: python3 -m chat.data.get_stats --in sharegpt.json 5 | """ 6 | 7 | import argparse 8 | from concurrent.futures import ProcessPoolExecutor 9 | import json 10 | 11 | import numpy as np 12 | from tqdm import tqdm 13 | from transformers import AutoTokenizer, AutoModelForCausalLM 14 | 15 | K = 1e3 16 | M = 1e6 17 | 18 | 19 | def tokenize_one_sample(c): 20 | for i in range(len(c["conversations"])): 21 | v = c["conversations"][i]["value"] 22 | c["conversations"][i]["value"] = tokenizer.tokenize(v) 23 | return c 24 | 25 | 26 | def tokenize_dataset(content): 27 | processed = [] 28 | with ProcessPoolExecutor() as executor: 29 | for result in tqdm( 30 | executor.map(tokenize_one_sample, content), total=len(content) 31 | ): 32 | processed.append(result) 33 | 34 | return processed 35 | 36 | 37 | def compute_stats(content): 38 | sample_lens = [] 39 | sample_turns = [] 40 | prompt_lens = [] 41 | res_lens = [] 42 | 43 | for c in content: 44 | sample_len = 0 45 | sample_turns.append(len(c["conversations"]) // 2) 46 | for i in range(len(c["conversations"]) // 2): 47 | p = c["conversations"][i * 2]["value"] 48 | r = c["conversations"][i * 2 + 1]["value"] 49 | 50 | turn_len = len(p) + len(r) 51 | sample_len += turn_len 52 | prompt_lens.append(len(p)) 53 | res_lens.append(len(r)) 54 | sample_lens.append(sample_len) 55 | 56 | return sample_lens, sample_turns, prompt_lens, res_lens 57 | 58 | 59 | if __name__ == "__main__": 60 | parser = argparse.ArgumentParser() 61 | parser.add_argument("--in-file", type=str) 62 | parser.add_argument( 63 | "--model-name-or-path", type=str, default="meta-llama/Llama-2-7b-chat-hf" 64 | ) 65 | args = parser.parse_args() 66 | 67 | content = json.load(open(args.in_file, "r")) 68 | tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=False) 69 | content = tokenize_dataset(content) 70 | 71 | sample_lens, sample_turns, prompt_lens, res_lens = compute_stats(content) 72 | print(f"#sequence: {len(content)/K:.2f} K") 73 | print(f"#tokens: {np.sum(sample_lens)/M:.2f} M") 74 | print(f"avg. turns: {np.mean(sample_turns):.2f}") 75 | print(f"avg. prompt length: {np.mean(prompt_lens):.2f}") 76 | print(f"avg. response length: {np.mean(res_lens):.2f}") 77 | 78 | print("\n- Histogram -") 79 | bin_edges = [0, 1024, 2048, 4096, 8192, 16384, 32768] 80 | hist = np.histogram(sample_lens, bins=bin_edges)[0] 81 | for i in range(len(hist)): 82 | print(f"L{bin_edges[i]} - {bin_edges[i+1]}: {hist[i]}") 83 | -------------------------------------------------------------------------------- /chat/server/test_message.py: -------------------------------------------------------------------------------- 1 | """Send a test message.""" 2 | import argparse 3 | import json 4 | 5 | import requests 6 | 7 | from chat.model.model_adapter import get_conversation_template 8 | 9 | 10 | def main(): 11 | model_name = args.model_name 12 | 13 | if args.worker_address: 14 | worker_addr = args.worker_address 15 | else: 16 | controller_addr = args.controller_address 17 | ret = requests.post(controller_addr + "/refresh_all_workers") 18 | ret = requests.post(controller_addr + "/list_models") 19 | models = ret.json()["models"] 20 | models.sort() 21 | print(f"Models: {models}") 22 | 23 | ret = requests.post( 24 | controller_addr + "/get_worker_address", json={"model": model_name} 25 | ) 26 | worker_addr = ret.json()["address"] 27 | print(f"worker_addr: {worker_addr}") 28 | 29 | if worker_addr == "": 30 | print(f"No available workers for {model_name}") 31 | return 32 | 33 | conv = get_conversation_template(model_name) 34 | conv.append_message(conv.roles[0], args.message) 35 | conv.append_message(conv.roles[1], None) 36 | prompt = conv.get_prompt() 37 | 38 | headers = {"User-Agent": "FastChat Client"} 39 | gen_params = { 40 | "model": model_name, 41 | "prompt": prompt, 42 | "temperature": args.temperature, 43 | "max_new_tokens": args.max_new_tokens, 44 | "stop": conv.stop_str, 45 | "stop_token_ids": conv.stop_token_ids, 46 | "echo": False, 47 | } 48 | response = requests.post( 49 | worker_addr + "/worker_generate_stream", 50 | headers=headers, 51 | json=gen_params, 52 | stream=True, 53 | ) 54 | 55 | print(f"{conv.roles[0]}: {args.message}") 56 | print(f"{conv.roles[1]}: ", end="") 57 | prev = 0 58 | for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): 59 | if chunk: 60 | data = json.loads(chunk.decode()) 61 | output = data["text"].strip() 62 | print(output[prev:], end="", flush=True) 63 | prev = len(output) 64 | print("") 65 | 66 | 67 | if __name__ == "__main__": 68 | parser = argparse.ArgumentParser() 69 | parser.add_argument( 70 | "--controller-address", type=str, default="http://localhost:21001" 71 | ) 72 | parser.add_argument("--worker-address", type=str) 73 | parser.add_argument("--model-name", type=str, required=True) 74 | parser.add_argument("--temperature", type=float, default=0.0) 75 | parser.add_argument("--max-new-tokens", type=int, default=32) 76 | parser.add_argument( 77 | "--message", type=str, default="Tell me a story with more than 1000 words." 78 | ) 79 | args = parser.parse_args() 80 | 81 | main() 82 | -------------------------------------------------------------------------------- /chat/model/rwkv_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | from types import SimpleNamespace 3 | import warnings 4 | 5 | import torch 6 | 7 | os.environ["RWKV_JIT_ON"] = "1" 8 | os.environ["RWKV_CUDA_ON"] = "1" 9 | 10 | from rwkv.model import RWKV 11 | from rwkv.utils import PIPELINE, PIPELINE_ARGS 12 | 13 | 14 | class RwkvModel: 15 | def __init__(self, model_path): 16 | warnings.warn( 17 | "Experimental support. Please use ChatRWKV if you want to chat with RWKV" 18 | ) 19 | self.config = SimpleNamespace(is_encoder_decoder=False) 20 | self.model = RWKV(model=model_path, strategy="cuda fp16") 21 | # two GPUs 22 | # self.model = RWKV(model=model_path, strategy="cuda:0 fp16 *20 -> cuda:1 fp16") 23 | 24 | self.tokenizer = None 25 | self.model_path = model_path 26 | 27 | def to(self, target): 28 | assert target == "cuda" 29 | 30 | def __call__(self, input_ids, use_cache, past_key_values=None): 31 | assert use_cache == True 32 | input_ids = input_ids[0].detach().cpu().numpy() 33 | # print(input_ids) 34 | logits, state = self.model.forward(input_ids, past_key_values) 35 | # print(logits) 36 | logits = logits.unsqueeze(0).unsqueeze(0) 37 | out = SimpleNamespace(logits=logits, past_key_values=state) 38 | return out 39 | 40 | def generate( 41 | self, input_ids, do_sample, temperature, max_new_tokens, repetition_penalty=1.0 42 | ): 43 | # This function is used by chat.llm_judge. 44 | # Because RWKV does not support huggingface generation API, 45 | # we reuse chat.server.inference.generate_stream as a workaround. 46 | from transformers import AutoTokenizer 47 | 48 | from chat.server.inference import generate_stream 49 | from chat.conversation import get_conv_template 50 | 51 | if self.tokenizer is None: 52 | self.tokenizer = AutoTokenizer.from_pretrained( 53 | "EleutherAI/pythia-160m", use_fast=True 54 | ) 55 | prompt = self.tokenizer.decode(input_ids[0].tolist()) 56 | conv = get_conv_template("rwkv") 57 | 58 | gen_params = { 59 | "model": self.model_path, 60 | "prompt": prompt, 61 | "temperature": temperature, 62 | "repetition_penalty": repetition_penalty, 63 | "max_new_tokens": max_new_tokens, 64 | "stop": conv.stop_str, 65 | "stop_token_ids": conv.stop_token_ids, 66 | "echo": False, 67 | } 68 | res_iter = generate_stream(self, self.tokenizer, gen_params, "cuda") 69 | 70 | for res in res_iter: 71 | pass 72 | 73 | output = res["text"] 74 | output_ids = self.tokenizer.encode(output) 75 | 76 | return [input_ids[0].tolist() + output_ids] 77 | -------------------------------------------------------------------------------- /chat/modules/awq.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from pathlib import Path 3 | import sys 4 | 5 | import torch 6 | from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM, modeling_utils 7 | 8 | 9 | @dataclass 10 | class AWQConfig: 11 | ckpt: str = field( 12 | default=None, 13 | metadata={ 14 | "help": "Load quantized model. The path to the local AWQ checkpoint." 15 | }, 16 | ) 17 | wbits: int = field(default=16, metadata={"help": "#bits to use for quantization"}) 18 | groupsize: int = field( 19 | default=-1, 20 | metadata={"help": "Groupsize to use for quantization; default uses full row."}, 21 | ) 22 | 23 | 24 | def load_awq_quantized(model_name, awq_config: AWQConfig, device): 25 | print("Loading AWQ quantized model...") 26 | 27 | try: 28 | from tinychat.utils import load_quant 29 | from tinychat.modules import make_quant_norm, make_quant_attn, make_fused_mlp 30 | except ImportError as e: 31 | print(f"Error: Failed to import tinychat. {e}") 32 | print("Please double check if you have successfully installed AWQ") 33 | print("See https://github.com/lm-sys/FastChat/blob/main/docs/awq.md") 34 | sys.exit(-1) 35 | 36 | config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) 37 | tokenizer = AutoTokenizer.from_pretrained( 38 | model_name, use_fast=False, trust_remote_code=True 39 | ) 40 | 41 | def skip(*args, **kwargs): 42 | pass 43 | 44 | torch.nn.init.kaiming_uniform_ = skip 45 | torch.nn.init.kaiming_normal_ = skip 46 | torch.nn.init.uniform_ = skip 47 | torch.nn.init.normal_ = skip 48 | modeling_utils._init_weights = False 49 | 50 | torch.set_default_dtype(torch.half) 51 | model = AutoModelForCausalLM.from_config(config, trust_remote_code=True) 52 | 53 | if any(name in find_awq_ckpt(awq_config) for name in ["llama", "vicuna"]): 54 | model = load_quant.load_awq_llama_fast( 55 | model, 56 | find_awq_ckpt(awq_config), 57 | awq_config.wbits, 58 | awq_config.groupsize, 59 | device, 60 | ) 61 | make_quant_attn(model, device) 62 | make_quant_norm(model) 63 | make_fused_mlp(model) 64 | else: 65 | model = load_quant.load_awq_model( 66 | model, 67 | find_awq_ckpt(awq_config), 68 | awq_config.wbits, 69 | awq_config.groupsize, 70 | device, 71 | ) 72 | return model, tokenizer 73 | 74 | 75 | def find_awq_ckpt(awq_config: AWQConfig): 76 | if Path(awq_config.ckpt).is_file(): 77 | return awq_config.ckpt 78 | 79 | for ext in ["*.pt", "*.safetensors"]: 80 | matched_result = sorted(Path(awq_config.ckpt).glob(ext)) 81 | if len(matched_result) > 0: 82 | return str(matched_result[-1]) 83 | 84 | print("Error: AWQ checkpoint not found") 85 | sys.exit(1) 86 | -------------------------------------------------------------------------------- /chat/data/optional_clean.py: -------------------------------------------------------------------------------- 1 | """ 2 | Do optional cleaning (e.g., remove some languages). 3 | 4 | Usage: 5 | python3 -m chat.data.optional_clean --in input.json --out output.json --keep-lang en 6 | python3 -m chat.data.optional_clean --in input.json --out output.json --skip-lang en 7 | 8 | Requirement: 9 | pip3 install polyglot pyicu pycld2 10 | """ 11 | import argparse 12 | import json 13 | import re 14 | 15 | import polyglot 16 | from polyglot.detect import Detector 17 | import pycld2 18 | from tqdm import tqdm 19 | 20 | 21 | def skip(conv, args): 22 | # Remove certain languages 23 | if args.keep_lang != "all" or args.skip_lang is not None: 24 | text = "\n".join([x["value"] for x in conv["conversations"]]) 25 | try: 26 | lang_code = Detector(text).language.code 27 | except (pycld2.error, polyglot.detect.base.UnknownLanguage): 28 | lang_code = "unknown" 29 | 30 | if args.keep_lang != "all" and lang_code != args.keep_lang: 31 | return True 32 | 33 | if lang_code == args.skip_lang: 34 | return True 35 | 36 | # Remove repetitive numbers 37 | if args.reduce_rep: 38 | for sentence in conv["conversations"]: 39 | val = sentence["value"] 40 | sub = re.search(r"(\d)\1{8}", val) 41 | if sub is not None: 42 | return True 43 | 44 | return False 45 | 46 | 47 | if __name__ == "__main__": 48 | parser = argparse.ArgumentParser() 49 | parser.add_argument("--in-file", type=str, required=True) 50 | parser.add_argument("--out-file", type=str) 51 | parser.add_argument( 52 | "--keep-lang", 53 | type=str, 54 | default="all", 55 | choices=["all", "en"], 56 | help="Only keep certain langauges.", 57 | ) 58 | parser.add_argument("--skip-lang", type=str, help="Skip a specific language.") 59 | # NOTE: Be careful about reduce_rep which may remove some good data. 60 | # For example, addresses could have long consecutive 0's 61 | parser.add_argument("--reduce-rep", action="store_true") 62 | args = parser.parse_args() 63 | 64 | in_file = args.in_file 65 | out_file = args.out_file 66 | keep_lang = args.keep_lang 67 | skip_lang = args.skip_lang 68 | reduce_rep = args.reduce_rep 69 | assert keep_lang == "all" or skip_lang is None 70 | 71 | if out_file is None: 72 | out_file = "sharegpt_clean" 73 | if keep_lang != "all": 74 | out_file += "_" + keep_lang 75 | if skip_lang is not None: 76 | out_file += "_skip_" + skip_lang 77 | if reduce_rep: 78 | out_file += "_reduce_rep" 79 | out_file += ".json" 80 | 81 | content = json.load(open(in_file, "r")) 82 | num_conv = len(content) 83 | 84 | new_content = [] 85 | for conv in tqdm(content): 86 | if not skip(conv, args): 87 | new_content.append(conv) 88 | 89 | print(f"#in: {len(content)}, #out: {len(new_content)}") 90 | json.dump(new_content, open(out_file, "w"), indent=2, ensure_ascii=False) 91 | -------------------------------------------------------------------------------- /chat/server/monitor/inspect_conv.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import code 3 | import datetime 4 | import json 5 | import os 6 | from pytz import timezone 7 | import time 8 | 9 | import pandas as pd 10 | from tqdm import tqdm 11 | 12 | 13 | def get_log_files(max_num_files=None): 14 | dates = [] 15 | for month in [4, 5]: 16 | for day in range(1, 32): 17 | dates.append(f"2023-{month:02d}-{day:02d}") 18 | 19 | num_servers = 14 20 | filenames = [] 21 | for d in dates: 22 | for i in range(num_servers): 23 | name = os.path.expanduser(f"~/fastchat_logs/server{i}/{d}-conv.json") 24 | if os.path.exists(name): 25 | filenames.append(name) 26 | max_num_files = max_num_files or len(filenames) 27 | filenames = filenames[-max_num_files:] 28 | return filenames 29 | 30 | 31 | def pretty_print_conversation(messages): 32 | for role, msg in messages: 33 | print(f"[[{role}]]: {msg}") 34 | 35 | 36 | def inspect_convs(log_files): 37 | data = [] 38 | for filename in tqdm(log_files, desc="read files"): 39 | for retry in range(5): 40 | try: 41 | lines = open(filename).readlines() 42 | break 43 | except FileNotFoundError: 44 | time.sleep(2) 45 | 46 | for l in lines: 47 | row = json.loads(l) 48 | 49 | if "states" not in row: 50 | continue 51 | if row["type"] not in ["leftvote", "rightvote", "bothbad_vote"]: 52 | continue 53 | 54 | model_names = row["states"][0]["model_name"], row["states"][1]["model_name"] 55 | if row["type"] == "leftvote": 56 | winner, loser = model_names[0], model_names[1] 57 | winner_conv, loser_conv = row["states"][0], row["states"][1] 58 | elif row["type"] == "rightvote": 59 | loser, winner = model_names[0], model_names[1] 60 | loser_conv, winner_conv = row["states"][0], row["states"][1] 61 | 62 | if loser == "bard" and winner == "vicuna-13b": 63 | print("=" * 20) 64 | print(f"Winner: {winner}") 65 | pretty_print_conversation(winner_conv["messages"]) 66 | print(f"Loser: {loser}") 67 | pretty_print_conversation(loser_conv["messages"]) 68 | print("=" * 20) 69 | input() 70 | 71 | # if row["type"] == "bothbad_vote" and "gpt-4" in model_names: 72 | # print("=" * 20) 73 | # print(f"Model A: {model_names[0]}") 74 | # pretty_print_conversation(row["states"][0]["messages"]) 75 | # print(f"Model B: {model_names[1]}") 76 | # pretty_print_conversation(row["states"][1]["messages"]) 77 | # print("=" * 20) 78 | # input() 79 | 80 | 81 | if __name__ == "__main__": 82 | parser = argparse.ArgumentParser() 83 | parser.add_argument("--max-num-files", type=int) 84 | args = parser.parse_args() 85 | 86 | log_files = get_log_files(args.max_num_files) 87 | inspect_convs(log_files) 88 | -------------------------------------------------------------------------------- /chat/model/llama_condense_monkey_patch.py: -------------------------------------------------------------------------------- 1 | # Code adapted from https://huggingface.co/kaiokendev/superhot-13b-8k-no-rlhf-test/blob/main/llama_rope_scaled_monkey_patch.py 2 | 3 | from functools import partial 4 | 5 | import torch 6 | import transformers 7 | import transformers.models.llama.modeling_llama 8 | 9 | 10 | class CondenseRotaryEmbedding(torch.nn.Module): 11 | def __init__( 12 | self, dim, ratio, max_position_embeddings=2048, base=10000, device=None 13 | ): 14 | super().__init__() 15 | inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) 16 | self.register_buffer("inv_freq", inv_freq) 17 | 18 | # Build here to make `torch.jit.trace` work. 19 | self.ratio = ratio 20 | max_position_embeddings *= ratio 21 | self.max_seq_len_cached = max_position_embeddings 22 | # print(f"Monkey Patching condense ratio {ratio}") 23 | t = ( 24 | torch.arange( 25 | self.max_seq_len_cached, 26 | device=self.inv_freq.device, 27 | dtype=self.inv_freq.dtype, 28 | ) 29 | / ratio 30 | ) 31 | freqs = torch.einsum("i,j->ij", t, self.inv_freq) 32 | # Different from paper, but it uses a different permutation in order to obtain the same calculation 33 | emb = torch.cat((freqs, freqs), dim=-1) 34 | dtype = torch.get_default_dtype() 35 | self.register_buffer( 36 | "cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False 37 | ) 38 | self.register_buffer( 39 | "sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False 40 | ) 41 | 42 | def forward(self, x, seq_len=None): 43 | # x: [bs, num_attention_heads, seq_len, head_size] 44 | # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case. 45 | if seq_len > self.max_seq_len_cached: 46 | self.max_seq_len_cached = seq_len 47 | t = ( 48 | torch.arange( 49 | self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype 50 | ) 51 | / self.ratio 52 | ) 53 | freqs = torch.einsum("i,j->ij", t, self.inv_freq) 54 | # Different from paper, but it uses a different permutation in order to obtain the same calculation 55 | emb = torch.cat((freqs, freqs), dim=-1).to(x.device) 56 | self.register_buffer( 57 | "cos_cached", emb.cos()[None, None, :, :].to(x.dtype), persistent=False 58 | ) 59 | self.register_buffer( 60 | "sin_cached", emb.sin()[None, None, :, :].to(x.dtype), persistent=False 61 | ) 62 | return ( 63 | self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), 64 | self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), 65 | ) 66 | 67 | 68 | def replace_llama_with_condense(ratio): 69 | transformers.models.llama.modeling_llama.LlamaRotaryEmbedding = partial( 70 | CondenseRotaryEmbedding, ratio=ratio 71 | ) 72 | -------------------------------------------------------------------------------- /tests/test_openai_api.py: -------------------------------------------------------------------------------- 1 | """ 2 | Test the OpenAI compatible server 3 | 4 | Launch: 5 | python3 launch_openai_api_test_server.py 6 | """ 7 | 8 | import openai 9 | 10 | from chat.utils import run_cmd 11 | 12 | openai.api_key = "EMPTY" # Not support yet 13 | openai.api_base = "http://localhost:8000/v1" 14 | 15 | 16 | def test_list_models(): 17 | model_list = openai.Model.list() 18 | names = [x["id"] for x in model_list["data"]] 19 | return names 20 | 21 | 22 | def test_completion(model): 23 | prompt = "Once upon a time" 24 | completion = openai.Completion.create(model=model, prompt=prompt, max_tokens=64) 25 | print(prompt + completion.choices[0].text) 26 | 27 | 28 | def test_completion_stream(model): 29 | prompt = "Once upon a time" 30 | res = openai.Completion.create( 31 | model=model, prompt=prompt, max_tokens=64, stream=True 32 | ) 33 | print(prompt, end="") 34 | for chunk in res: 35 | content = chunk["choices"][0]["text"] 36 | print(content, end="", flush=True) 37 | print() 38 | 39 | 40 | def test_embedding(model): 41 | embedding = openai.Embedding.create(model=model, input="Hello world!") 42 | print(f"embedding len: {len(embedding['data'][0]['embedding'])}") 43 | print(f"embedding value[:5]: {embedding['data'][0]['embedding'][:5]}") 44 | 45 | 46 | def test_chat_completion(model): 47 | completion = openai.ChatCompletion.create( 48 | model=model, messages=[{"role": "user", "content": "Hello! What is your name?"}] 49 | ) 50 | print(completion.choices[0].message.content) 51 | 52 | 53 | def test_chat_completion_stream(model): 54 | messages = [{"role": "user", "content": "Hello! What is your name?"}] 55 | res = openai.ChatCompletion.create(model=model, messages=messages, stream=True) 56 | for chunk in res: 57 | content = chunk["choices"][0]["delta"].get("content", "") 58 | print(content, end="", flush=True) 59 | print() 60 | 61 | 62 | def test_openai_curl(model): 63 | run_cmd("curl http://localhost:8000/v1/models") 64 | 65 | run_cmd( 66 | """ 67 | curl http://localhost:8000/v1/chat/completions \ 68 | -H "Content-Type: application/json" \ 69 | -d '{ 70 | "model": "vicuna-7b-v1.3", 71 | "messages": [{"role": "user", "content": "Hello! What is your name?"}] 72 | }' 73 | """ 74 | ) 75 | 76 | run_cmd( 77 | """ 78 | curl http://localhost:8000/v1/completions \ 79 | -H "Content-Type: application/json" \ 80 | -d '{ 81 | "model": "vicuna-7b-v1.3", 82 | "prompt": "Once upon a time", 83 | "max_tokens": 41, 84 | "temperature": 0.5 85 | }' 86 | """ 87 | ) 88 | 89 | run_cmd( 90 | """ 91 | curl http://localhost:8000/v1/embeddings \ 92 | -H "Content-Type: application/json" \ 93 | -d '{ 94 | "model": "vicuna-7b-v1.3", 95 | "input": "Hello world!" 96 | }' 97 | """ 98 | ) 99 | 100 | 101 | if __name__ == "__main__": 102 | models = test_list_models() 103 | print(f"models: {models}") 104 | 105 | for model in models: 106 | print(f"===== Test {model} ======") 107 | test_completion(model) 108 | test_completion_stream(model) 109 | test_embedding(model) 110 | test_chat_completion(model) 111 | test_chat_completion_stream(model) 112 | 113 | print("===== Test curl =====") 114 | test_openai_curl("vicuna-7b-v1.3") 115 | -------------------------------------------------------------------------------- /chat/model/model_chatglm.py: -------------------------------------------------------------------------------- 1 | """ 2 | Inference code for ChatGLM. 3 | Adapted from https://huggingface.co/THUDM/chatglm-6b/blob/main/modeling_chatglm.py. 4 | """ 5 | import re 6 | 7 | import torch 8 | from transformers.generation.logits_process import LogitsProcessor 9 | 10 | 11 | class InvalidScoreLogitsProcessor(LogitsProcessor): 12 | def __call__( 13 | self, input_ids: torch.LongTensor, scores: torch.FloatTensor 14 | ) -> torch.FloatTensor: 15 | if torch.isnan(scores).any() or torch.isinf(scores).any(): 16 | scores.zero_() 17 | scores[..., 5] = 5e4 18 | return scores 19 | 20 | 21 | invalid_score_processor = InvalidScoreLogitsProcessor() 22 | 23 | 24 | def process_response(response): 25 | response = response.strip() 26 | response = response.replace("[[训练时间]]", "2023年") 27 | punkts = [ 28 | [",", ","], 29 | ["!", "!"], 30 | [":", ":"], 31 | [";", ";"], 32 | ["\?", "?"], 33 | ] 34 | for item in punkts: 35 | response = re.sub(r"([\u4e00-\u9fff])%s" % item[0], r"\1%s" % item[1], response) 36 | response = re.sub(r"%s([\u4e00-\u9fff])" % item[0], r"%s\1" % item[1], response) 37 | return response 38 | 39 | 40 | @torch.inference_mode() 41 | def generate_stream_chatglm( 42 | model, 43 | tokenizer, 44 | params, 45 | device, 46 | context_len=2048, 47 | stream_interval=2, 48 | judge_sent_end=False, 49 | ): 50 | prompt = params["prompt"] 51 | temperature = float(params.get("temperature", 1.0)) 52 | repetition_penalty = float(params.get("repetition_penalty", 1.0)) 53 | top_p = float(params.get("top_p", 1.0)) 54 | max_new_tokens = int(params.get("max_new_tokens", 256)) 55 | echo = params.get("echo", True) 56 | 57 | inputs = tokenizer([prompt], return_tensors="pt").to(model.device) 58 | input_echo_len = len(inputs["input_ids"][0]) 59 | 60 | gen_kwargs = { 61 | "max_length": max_new_tokens + input_echo_len, 62 | "do_sample": True if temperature > 1e-5 else False, 63 | "top_p": top_p, 64 | "repetition_penalty": repetition_penalty, 65 | "logits_processor": [invalid_score_processor], 66 | } 67 | if temperature > 1e-5: 68 | gen_kwargs["temperature"] = temperature 69 | 70 | total_len = 0 71 | for total_ids in model.stream_generate(**inputs, **gen_kwargs): 72 | total_ids = total_ids.tolist()[0] 73 | total_len = len(total_ids) 74 | if echo: 75 | output_ids = total_ids 76 | else: 77 | output_ids = total_ids[input_echo_len:] 78 | response = tokenizer.decode(output_ids) 79 | response = process_response(response) 80 | 81 | yield { 82 | "text": response, 83 | "usage": { 84 | "prompt_tokens": input_echo_len, 85 | "completion_tokens": total_len - input_echo_len, 86 | "total_tokens": total_len, 87 | }, 88 | "finish_reason": None, 89 | } 90 | 91 | # TODO: ChatGLM stop when it reach max length 92 | # Only last stream result contains finish_reason, we set finish_reason as stop 93 | ret = { 94 | "text": response, 95 | "usage": { 96 | "prompt_tokens": input_echo_len, 97 | "completion_tokens": total_len - input_echo_len, 98 | "total_tokens": total_len, 99 | }, 100 | "finish_reason": "stop", 101 | } 102 | yield ret 103 | -------------------------------------------------------------------------------- /chat/model/model_codet5p.py: -------------------------------------------------------------------------------- 1 | import gc 2 | from threading import Thread 3 | import torch 4 | import transformers 5 | from transformers import ( 6 | GenerationConfig, 7 | StoppingCriteria, 8 | StoppingCriteriaList, 9 | TextIteratorStreamer, 10 | ) 11 | 12 | 13 | @torch.inference_mode() 14 | def generate_stream_codet5p( 15 | model, 16 | tokenizer, 17 | params, 18 | device, 19 | context_len=2048, 20 | stream_interval=2, 21 | judge_sent_end=False, 22 | ): 23 | prompt = params["prompt"] 24 | temperature = float(params.get("temperature", 1.0)) 25 | repetition_penalty = float(params.get("repetition_penalty", 1.0)) 26 | top_p = float(params.get("top_p", 1.0)) 27 | top_k = int(params.get("top_k", 50)) # -1 means disable 28 | max_new_tokens = int(params.get("max_new_tokens", 1024)) 29 | stop_token_ids = params.get("stop_token_ids", None) or [] 30 | stop_token_ids.append(tokenizer.eos_token_id) 31 | 32 | decode_config = dict(skip_special_tokens=True, clean_up_tokenization_spaces=True) 33 | streamer = TextIteratorStreamer(tokenizer, **decode_config) 34 | encoding = tokenizer(prompt, return_tensors="pt").to(device) 35 | input_ids = encoding.input_ids 36 | encoding["decoder_input_ids"] = encoding["input_ids"].clone() 37 | input_echo_len = len(input_ids) 38 | 39 | generation_config = GenerationConfig( 40 | max_new_tokens=max_new_tokens, 41 | do_sample=temperature >= 1e-5, 42 | temperature=temperature, 43 | repetition_penalty=repetition_penalty, 44 | no_repeat_ngram_size=10, 45 | top_p=top_p, 46 | top_k=top_k, 47 | eos_token_id=stop_token_ids, 48 | ) 49 | 50 | class CodeBlockStopper(StoppingCriteria): 51 | def __call__( 52 | self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs 53 | ) -> bool: 54 | # Code-completion is open-end generation. 55 | # We check \n\n to stop at end of a code block. 56 | if list(input_ids[0][-2:]) == [628, 198]: 57 | return True 58 | return False 59 | 60 | gen_kwargs = dict( 61 | **encoding, 62 | streamer=streamer, 63 | generation_config=generation_config, 64 | stopping_criteria=StoppingCriteriaList([CodeBlockStopper()]), 65 | ) 66 | thread = Thread(target=model.generate, kwargs=gen_kwargs) 67 | thread.start() 68 | i = 0 69 | output = "" 70 | for new_text in streamer: 71 | i += 1 72 | output += new_text 73 | if i % stream_interval == 0 or i == max_new_tokens - 1: 74 | yield { 75 | "text": output, 76 | "usage": { 77 | "prompt_tokens": input_echo_len, 78 | "completion_tokens": i, 79 | "total_tokens": input_echo_len + i, 80 | }, 81 | "finish_reason": None, 82 | } 83 | if i >= max_new_tokens: 84 | break 85 | 86 | if i >= max_new_tokens: 87 | finish_reason = "length" 88 | else: 89 | finish_reason = "stop" 90 | 91 | yield { 92 | "text": output, 93 | "usage": { 94 | "prompt_tokens": input_echo_len, 95 | "completion_tokens": i, 96 | "total_tokens": input_echo_len + i, 97 | }, 98 | "finish_reason": finish_reason, 99 | } 100 | thread.join() 101 | 102 | # clean 103 | gc.collect() 104 | torch.cuda.empty_cache() 105 | if device == "xpu": 106 | torch.xpu.empty_cache() 107 | -------------------------------------------------------------------------------- /chat/server/api_provider.py: -------------------------------------------------------------------------------- 1 | """Call API providers.""" 2 | 3 | import os 4 | import random 5 | import time 6 | 7 | from chat.utils import build_logger 8 | from chat.constants import WORKER_API_TIMEOUT 9 | 10 | 11 | logger = build_logger("gradio_web_server", "gradio_web_server.log") 12 | 13 | 14 | def openai_api_stream_iter( 15 | model_name, 16 | messages, 17 | temperature, 18 | top_p, 19 | max_new_tokens, 20 | api_base=None, 21 | api_key=None, 22 | ): 23 | import openai 24 | 25 | openai.api_base = api_base or "https://api.openai.com/v1" 26 | openai.api_key = api_key or os.environ["OPENAI_API_KEY"] 27 | 28 | # Make requests 29 | gen_params = { 30 | "model": model_name, 31 | "prompt": messages, 32 | "temperature": temperature, 33 | "top_p": top_p, 34 | "max_new_tokens": max_new_tokens, 35 | } 36 | logger.info(f"==== request ====\n{gen_params}") 37 | 38 | res = openai.ChatCompletion.create( 39 | model=model_name, 40 | messages=messages, 41 | temperature=temperature, 42 | max_tokens=max_new_tokens, 43 | stream=True, 44 | ) 45 | text = "" 46 | for chunk in res: 47 | text += chunk["choices"][0]["delta"].get("content", "") 48 | data = { 49 | "text": text, 50 | "error_code": 0, 51 | } 52 | yield data 53 | 54 | 55 | def anthropic_api_stream_iter(model_name, prompt, temperature, top_p, max_new_tokens): 56 | import anthropic 57 | 58 | c = anthropic.Anthropic(api_key=os.environ["ANTHROPIC_API_KEY"]) 59 | 60 | # Make requests 61 | gen_params = { 62 | "model": model_name, 63 | "prompt": prompt, 64 | "temperature": temperature, 65 | "top_p": top_p, 66 | "max_new_tokens": max_new_tokens, 67 | } 68 | logger.info(f"==== request ====\n{gen_params}") 69 | 70 | res = c.completions.create( 71 | prompt=prompt, 72 | stop_sequences=[anthropic.HUMAN_PROMPT], 73 | max_tokens_to_sample=max_new_tokens, 74 | temperature=temperature, 75 | top_p=top_p, 76 | model=model_name, 77 | stream=True, 78 | ) 79 | text = "" 80 | for chunk in res: 81 | text += chunk.completion 82 | data = { 83 | "text": text, 84 | "error_code": 0, 85 | } 86 | yield data 87 | 88 | 89 | def init_palm_chat(model_name): 90 | import vertexai # pip3 install google-cloud-aiplatform 91 | from vertexai.preview.language_models import ChatModel 92 | 93 | project_id = os.environ["GCP_PROJECT_ID"] 94 | location = "us-central1" 95 | vertexai.init(project=project_id, location=location) 96 | 97 | chat_model = ChatModel.from_pretrained(model_name) 98 | chat = chat_model.start_chat(examples=[]) 99 | return chat 100 | 101 | 102 | def palm_api_stream_iter(chat, message, temperature, top_p, max_new_tokens): 103 | parameters = { 104 | "temperature": temperature, 105 | "top_p": top_p, 106 | "max_output_tokens": max_new_tokens, 107 | } 108 | gen_params = { 109 | "model": "palm-2", 110 | "prompt": message, 111 | } 112 | gen_params.update(parameters) 113 | logger.info(f"==== request ====\n{gen_params}") 114 | 115 | response = chat.send_message(message, **parameters) 116 | content = response.text 117 | 118 | pos = 0 119 | while pos < len(content): 120 | # This is a fancy way to simulate token generation latency combined 121 | # with a Poisson process. 122 | pos += random.randint(10, 20) 123 | time.sleep(random.expovariate(50)) 124 | data = { 125 | "text": content[:pos], 126 | "error_code": 0, 127 | } 128 | yield data 129 | -------------------------------------------------------------------------------- /chat/server/gateway/nginx.conf: -------------------------------------------------------------------------------- 1 | user www-data; 2 | worker_processes auto; 3 | pid /run/nginx.pid; 4 | include /etc/nginx/modules-enabled/*.conf; 5 | 6 | events { 7 | worker_connections 1024; # maximum number of connections that a worker process can handle concurrently 8 | # multi_accept on; # enabling multi_accept can help improve performance under high load, but may increase the number of simultaneous connections that a worker process can handle 9 | 10 | } 11 | 12 | http { 13 | ## 14 | # Basic Settings 15 | ## 16 | 17 | sendfile on; # enable sendfile for performance optimization 18 | tcp_nopush on; # enable TCP no-pushing 19 | tcp_nodelay on; # enable TCP no-delay 20 | keepalive_timeout 65; # sets the timeout for keep-alive connections 21 | types_hash_max_size 2048; # maximum size of the types hash table 22 | # server_tokens off; # disable server token (i.e., server signature) in response headers to improve security 23 | 24 | # server_names_hash_bucket_size 64; 25 | # server_name_in_redirect off; 26 | 27 | include /etc/nginx/mime.types; # include MIME types file 28 | default_type application/octet-stream; # default MIME type for unknown file types 29 | 30 | ## 31 | # SSL Settings 32 | ## 33 | 34 | ssl_protocols TLSv1.2; # specify SSL/TLS protocols to use 35 | ssl_prefer_server_ciphers on; # prefer server ciphers over client ciphers 36 | 37 | ## 38 | # Logging Settings 39 | ## 40 | 41 | access_log /var/log/nginx/access.log; # path to access log file 42 | error_log /var/log/nginx/error.log; # path to error log file 43 | 44 | ## 45 | # Gzip Settings 46 | ## 47 | gzip on; # enable Gzip compression 48 | 49 | ## 50 | # Virtual Host Configs 51 | ## 52 | 53 | include /etc/nginx/conf.d/*.conf; # include all configuration files in conf.d directory 54 | include /etc/nginx/sites-enabled/*; # include all enabled sites configuration files 55 | 56 | # WebSocket Proxy: https://www.nginx.com/blog/websocket-nginx/ 57 | map $http_upgrade $connection_upgrade { 58 | default upgrade; 59 | '' close; 60 | } 61 | 62 | upstream websocket { 63 | ip_hash; # load balancing by IP to guarantee session persistence 64 | server localhost:7860; # The port should be the gradio web server port 65 | # server localhost:7861; # extra gradio server if more than one 66 | } 67 | 68 | limit_conn_status 429; 69 | limit_conn_zone $binary_remote_addr zone=perip:10m; # limit number of connections per IP 70 | limit_conn_zone $server_name zone=perserver:10m; # limit number of connections per server 71 | 72 | server { 73 | listen 443 ssl; # the listening port of our server 74 | ssl_certificate [PATH_TO_SSL_CERT]; 75 | ssl_certificate_key [PATH_TO_PRIVATE_KEY]; 76 | server_name chat.lmsys.org; # replace the url with your own domain url 77 | limit_conn perserver 1024; # connections per server 78 | location / { 79 | proxy_pass http://websocket; # proxy all requests to the defined upstream server 80 | limit_conn perip 5; # connections per IP 81 | proxy_set_header Host $host; # set the Host header for the upstream server 82 | proxy_set_header X-Real-IP $remote_addr; # set the client IP address as the real IP for the upstream server 83 | proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; # set the client IP addresses in the X-Forwarded-For header 84 | proxy_http_version 1.1; # use HTTP version 1.1 for upstream communication 85 | proxy_set_header Upgrade $http_upgrade; 86 | proxy_set_header Connection "Upgrade"; # set the Connection header to Upgrade to enable WebSocket communication 87 | } 88 | } 89 | 90 | # the following block routes all HTTP traffic to HTTPS via nginx 91 | server { 92 | listen 80; 93 | server_name chat.lmsys.org; 94 | return 301 https://chat.lmsys.org$request_uri; 95 | } 96 | 97 | } 98 | -------------------------------------------------------------------------------- /chat/data/split_long_conversation.py: -------------------------------------------------------------------------------- 1 | """ 2 | Split long conversations based on certain max length. 3 | 4 | Usage: python3 -m chat.data.split_long_conversation \ 5 | --in sharegpt_clean.json \ 6 | --out sharegpt_split.json \ 7 | --model-name-or-path $ 8 | """ 9 | import argparse 10 | from concurrent.futures import ProcessPoolExecutor 11 | import json 12 | from typing import Dict, Sequence, Optional 13 | 14 | import transformers 15 | from tqdm import tqdm 16 | 17 | 18 | def make_sample(sample, start_idx, end_idx): 19 | assert (end_idx - start_idx) % 2 == 0 20 | return { 21 | "id": sample["id"] + "_" + str(start_idx), 22 | "model": sample.get("model", ""), 23 | "conversations": sample["conversations"][start_idx:end_idx], 24 | } 25 | 26 | 27 | tokenizer = max_length = None 28 | 29 | 30 | def split_one_sample(sample): 31 | tokenized_lens = [] 32 | conversations = sample["conversations"] 33 | conversations = conversations[: len(conversations) // 2 * 2] 34 | for c in conversations: 35 | length = len(tokenizer(c["value"]).input_ids) + 6 36 | tokenized_lens.append(length) 37 | 38 | start_idx = 0 39 | cur_len = 0 40 | 41 | if len(conversations) % 2 != 0 or len(conversations) < 2: 42 | return [] 43 | 44 | new_samples = [] 45 | for i in range(0, len(conversations), 2): 46 | tmp_len = tokenized_lens[i] + tokenized_lens[i + 1] 47 | if cur_len + tmp_len > max_length: 48 | new_samples.append(make_sample(sample, start_idx, i)) 49 | start_idx = i 50 | cur_len = 0 51 | elif i == len(conversations) - 2: 52 | new_samples.append(make_sample(sample, start_idx, i + 2)) 53 | 54 | cur_len += tmp_len 55 | 56 | return new_samples 57 | 58 | 59 | def worker(input_data): 60 | result = [] 61 | for sample in input_data: 62 | result.extend(split_one_sample(sample)) 63 | return result 64 | 65 | 66 | def split_all(content, begin, end, tokenizer_, max_length_): 67 | """ 68 | Keep the maximum round of conversations within the max token length constraint 69 | """ 70 | global tokenizer, max_length 71 | tokenizer = tokenizer_ 72 | max_length = max_length_ 73 | 74 | content = content[begin:end] 75 | new_content = [] 76 | 77 | # Split content into chunks 78 | chunks = [content[i : i + 1000] for i in range(0, len(content), 1000)] 79 | with ProcessPoolExecutor() as executor: 80 | for result in tqdm(executor.map(worker, chunks), total=len(chunks)): 81 | new_content.extend(result) 82 | 83 | return new_content 84 | 85 | 86 | def filter_invalid_roles(content): 87 | new_content = [] 88 | for i, c in enumerate(content): 89 | roles = ["human", "gpt"] 90 | if len(c["conversations"]) <= 0: 91 | continue 92 | 93 | valid = True 94 | for j, s in enumerate(c["conversations"]): 95 | if s["from"] != roles[j % 2]: 96 | valid = False 97 | break 98 | 99 | if valid: 100 | new_content.append(c) 101 | 102 | return new_content 103 | 104 | 105 | def main(args): 106 | content = json.load(open(args.in_file, "r")) 107 | tokenizer = transformers.AutoTokenizer.from_pretrained( 108 | args.model_name_or_path, 109 | model_max_length=args.max_length, 110 | padding_side="right", 111 | use_fast=False, 112 | ) 113 | new_content = split_all(content, args.begin, args.end, tokenizer, args.max_length) 114 | new_content = filter_invalid_roles(new_content) 115 | 116 | print(f"#in: {len(content)}, #out: {len(new_content)}") 117 | json.dump(new_content, open(args.out_file, "w"), indent=2, ensure_ascii=False) 118 | 119 | 120 | if __name__ == "__main__": 121 | parser = argparse.ArgumentParser() 122 | parser.add_argument("--in-file", type=str, required=True) 123 | parser.add_argument("--out-file", type=str, default="sharegpt_split.json") 124 | parser.add_argument("--begin", type=int) 125 | parser.add_argument("--end", type=int) 126 | parser.add_argument("--model-name-or-path", type=str, required=True) 127 | parser.add_argument("--max-length", type=int, default=2048) 128 | args = parser.parse_args() 129 | main(args) 130 | -------------------------------------------------------------------------------- /chat/server/test_throughput.py: -------------------------------------------------------------------------------- 1 | """Benchmarking script to test the throughput of serving workers.""" 2 | import argparse 3 | import json 4 | 5 | import requests 6 | import threading 7 | import time 8 | 9 | from chat.conversation import get_conv_template 10 | 11 | 12 | def main(): 13 | if args.worker_address: 14 | worker_addr = args.worker_address 15 | else: 16 | controller_addr = args.controller_address 17 | ret = requests.post(controller_addr + "/refresh_all_workers") 18 | ret = requests.post(controller_addr + "/list_models") 19 | models = ret.json()["models"] 20 | models.sort() 21 | print(f"Models: {models}") 22 | 23 | ret = requests.post( 24 | controller_addr + "/get_worker_address", json={"model": args.model_name} 25 | ) 26 | worker_addr = ret.json()["address"] 27 | print(f"worker_addr: {worker_addr}") 28 | 29 | if worker_addr == "": 30 | return 31 | 32 | conv = get_conv_template("vicuna_v1.1") 33 | conv.append_message(conv.roles[0], "Tell me a story with more than 1000 words") 34 | prompt_template = conv.get_prompt() 35 | prompts = [prompt_template for _ in range(args.n_thread)] 36 | 37 | headers = {"User-Agent": "chat Client"} 38 | ploads = [ 39 | { 40 | "model": args.model_name, 41 | "prompt": prompts[i], 42 | "max_new_tokens": args.max_new_tokens, 43 | "temperature": 0.0, 44 | # "stop": conv.sep, 45 | } 46 | for i in range(len(prompts)) 47 | ] 48 | 49 | def send_request(results, i): 50 | if args.test_dispatch: 51 | ret = requests.post( 52 | controller_addr + "/get_worker_address", json={"model": args.model_name} 53 | ) 54 | thread_worker_addr = ret.json()["address"] 55 | else: 56 | thread_worker_addr = worker_addr 57 | print(f"thread {i} goes to {thread_worker_addr}") 58 | response = requests.post( 59 | thread_worker_addr + "/worker_generate_stream", 60 | headers=headers, 61 | json=ploads[i], 62 | stream=False, 63 | ) 64 | k = list( 65 | response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0") 66 | ) 67 | # print(k) 68 | response_new_words = json.loads(k[-2].decode("utf-8"))["text"] 69 | error_code = json.loads(k[-2].decode("utf-8"))["error_code"] 70 | # print(f"=== Thread {i} ===, words: {1}, error code: {error_code}") 71 | results[i] = len(response_new_words.split(" ")) - len(prompts[i].split(" ")) 72 | 73 | # use N threads to prompt the backend 74 | tik = time.time() 75 | threads = [] 76 | results = [None] * args.n_thread 77 | for i in range(args.n_thread): 78 | t = threading.Thread(target=send_request, args=(results, i)) 79 | t.start() 80 | # time.sleep(0.5) 81 | threads.append(t) 82 | 83 | for t in threads: 84 | t.join() 85 | 86 | print(f"Time (POST): {time.time() - tik} s") 87 | # n_words = 0 88 | # for i, response in enumerate(results): 89 | # # print(prompt[i].replace(conv.sep, "\n"), end="") 90 | # # make sure the streaming finishes at EOS or stopping criteria 91 | # k = list(response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0")) 92 | # response_new_words = json.loads(k[-2].decode("utf-8"))["text"] 93 | # # print(response_new_words) 94 | # n_words += len(response_new_words.split(" ")) - len(prompts[i].split(" ")) 95 | n_words = sum(results) 96 | time_seconds = time.time() - tik 97 | print( 98 | f"Time (Completion): {time_seconds}, n threads: {args.n_thread}, " 99 | f"throughput: {n_words / time_seconds} words/s." 100 | ) 101 | 102 | 103 | if __name__ == "__main__": 104 | parser = argparse.ArgumentParser() 105 | parser.add_argument( 106 | "--controller-address", type=str, default="http://localhost:21001" 107 | ) 108 | parser.add_argument("--worker-address", type=str) 109 | parser.add_argument("--model-name", type=str, default="vicuna") 110 | parser.add_argument("--max-new-tokens", type=int, default=2048) 111 | parser.add_argument("--n-thread", type=int, default=8) 112 | parser.add_argument("--test-dispatch", action="store_true") 113 | args = parser.parse_args() 114 | 115 | main() 116 | -------------------------------------------------------------------------------- /chat/model/monkey_patch_non_inplace.py: -------------------------------------------------------------------------------- 1 | """ 2 | Monkey patch the llama implementation in the huggingface/transformers library. 3 | Avoid bugs in mps backend by not using in-place operations. 4 | """ 5 | import math 6 | from typing import List, Optional, Tuple 7 | 8 | import torch 9 | from torch import nn 10 | import transformers 11 | 12 | 13 | def rotate_half(x): 14 | """Rotates half the hidden dims of the input.""" 15 | x1 = x[..., : x.shape[-1] // 2].clone() 16 | x2 = x[..., x.shape[-1] // 2 :].clone() 17 | return torch.cat((-x2, x1), dim=-1) 18 | 19 | 20 | def apply_rotary_pos_emb(q, k, cos, sin, position_ids): 21 | gather_indices = position_ids[:, None, :, None] # [bs, 1, seq_len, 1] 22 | gather_indices = gather_indices.repeat(1, cos.shape[1], 1, cos.shape[3]) 23 | cos = torch.gather(cos.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices) 24 | sin = torch.gather(sin.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices) 25 | q_embed = (q * cos) + (rotate_half(q) * sin) 26 | k_embed = (k * cos) + (rotate_half(k) * sin) 27 | return q_embed, k_embed 28 | 29 | 30 | def forward( 31 | self, 32 | hidden_states: torch.Tensor, 33 | attention_mask: Optional[torch.Tensor] = None, 34 | position_ids: Optional[torch.LongTensor] = None, 35 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 36 | output_attentions: bool = False, 37 | use_cache: bool = False, 38 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 39 | bsz, q_len, _ = hidden_states.size() 40 | 41 | query_states = ( 42 | self.q_proj(hidden_states) 43 | .view(bsz, q_len, self.num_heads, self.head_dim) 44 | .transpose(1, 2) 45 | ) 46 | key_states = ( 47 | self.k_proj(hidden_states) 48 | .view(bsz, q_len, self.num_heads, self.head_dim) 49 | .transpose(1, 2) 50 | ) 51 | value_states = ( 52 | self.v_proj(hidden_states) 53 | .view(bsz, q_len, self.num_heads, self.head_dim) 54 | .transpose(1, 2) 55 | ) 56 | 57 | kv_seq_len = key_states.shape[-2] 58 | if past_key_value is not None: 59 | kv_seq_len += past_key_value[0].shape[-2] 60 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) 61 | query_states, key_states = apply_rotary_pos_emb( 62 | query_states, key_states, cos, sin, position_ids 63 | ) 64 | # [bsz, nh, t, hd] 65 | 66 | if past_key_value is not None: 67 | # reuse k, v, self_attention 68 | key_states = torch.cat([past_key_value[0], key_states], dim=2) 69 | value_states = torch.cat([past_key_value[1], value_states], dim=2) 70 | 71 | past_key_value = (key_states, value_states) if use_cache else None 72 | 73 | attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt( 74 | self.head_dim 75 | ) 76 | 77 | if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): 78 | raise ValueError( 79 | f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is" 80 | f" {attn_weights.size()}" 81 | ) 82 | 83 | if attention_mask is not None: 84 | if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): 85 | raise ValueError( 86 | f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" 87 | ) 88 | attn_weights = attn_weights + attention_mask 89 | attn_weights = torch.max( 90 | attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min) 91 | ) 92 | 93 | # upcast attention to fp32 94 | attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to( 95 | query_states.dtype 96 | ) 97 | attn_output = torch.matmul(attn_weights, value_states) 98 | 99 | if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): 100 | raise ValueError( 101 | f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" 102 | f" {attn_output.size()}" 103 | ) 104 | 105 | attn_output = attn_output.transpose(1, 2) 106 | attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) 107 | 108 | attn_output = self.o_proj(attn_output) 109 | 110 | if not output_attentions: 111 | attn_weights = None 112 | 113 | return attn_output, attn_weights, past_key_value 114 | 115 | 116 | def replace_llama_attn_with_non_inplace_operations(): 117 | """Avoid bugs in mps backend by not using in-place operations.""" 118 | transformers.models.llama.modeling_llama.LlamaAttention.forward = forward 119 | -------------------------------------------------------------------------------- /chat/model/model_falcon.py: -------------------------------------------------------------------------------- 1 | import gc 2 | from threading import Thread 3 | from typing import Iterable 4 | 5 | import torch 6 | import transformers 7 | from transformers import TextIteratorStreamer, GenerationConfig 8 | 9 | from chat.utils import is_partial_stop 10 | 11 | 12 | @torch.inference_mode() 13 | def generate_stream_falcon( 14 | model, 15 | tokenizer, 16 | params, 17 | device, 18 | context_len=2048, 19 | stream_interval=2, 20 | judge_sent_end=False, 21 | ): 22 | prompt = params["prompt"] 23 | len_prompt = len(prompt) 24 | temperature = float(params.get("temperature", 1.0)) 25 | repetition_penalty = float(params.get("repetition_penalty", 1.0)) 26 | top_p = float(params.get("top_p", 1.0)) 27 | top_k = int(params.get("top_k", 50)) # -1 means disable 28 | max_new_tokens = int(params.get("max_new_tokens", 256)) 29 | stop_str = params.get("stop", None) 30 | echo = bool(params.get("echo", True)) 31 | stop_token_ids = params.get("stop_token_ids", None) or [] 32 | stop_token_ids.append(tokenizer.eos_token_id) 33 | 34 | inputs = tokenizer(prompt, return_tensors="pt").to(model.device) 35 | input_ids = inputs["input_ids"] 36 | attention_mask = inputs["attention_mask"] 37 | 38 | max_src_len = context_len - max_new_tokens - 8 39 | 40 | input_ids = input_ids[-max_src_len:] # truncate from the left 41 | attention_mask = attention_mask[-max_src_len:] # truncate from the left 42 | input_echo_len = len(input_ids) 43 | 44 | decode_config = dict(skip_special_tokens=True, clean_up_tokenization_spaces=True) 45 | streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, **decode_config) 46 | 47 | generation_config = GenerationConfig( 48 | max_new_tokens=max_new_tokens, 49 | do_sample=temperature >= 1e-5, 50 | temperature=temperature, 51 | repetition_penalty=repetition_penalty, 52 | no_repeat_ngram_size=10, 53 | top_p=top_p, 54 | top_k=top_k, 55 | eos_token_id=stop_token_ids, 56 | ) 57 | 58 | generation_kwargs = dict( 59 | inputs=input_ids, 60 | attention_mask=attention_mask, 61 | streamer=streamer, 62 | generation_config=generation_config, 63 | ) 64 | 65 | thread = Thread(target=model.generate, kwargs=generation_kwargs) 66 | thread.start() 67 | 68 | if echo: 69 | # means keep the prompt 70 | output = prompt 71 | else: 72 | output = "" 73 | 74 | for i, new_text in enumerate(streamer): 75 | output += new_text 76 | if i % stream_interval == 0: 77 | if echo: 78 | rfind_start = len_prompt 79 | else: 80 | rfind_start = 0 81 | 82 | partially_stopped = False 83 | if stop_str: 84 | if isinstance(stop_str, str): 85 | pos = output.rfind(stop_str, rfind_start) 86 | if pos != -1: 87 | output = output[:pos] 88 | else: 89 | partially_stopped = is_partial_stop(output, stop_str) 90 | elif isinstance(stop_str, Iterable): 91 | for each_stop in stop_str: 92 | pos = output.rfind(each_stop, rfind_start) 93 | if pos != -1: 94 | output = output[:pos] 95 | break 96 | else: 97 | partially_stopped = is_partial_stop(output, each_stop) 98 | if partially_stopped: 99 | break 100 | else: 101 | raise ValueError("Invalid stop field type.") 102 | 103 | # prevent yielding partial stop sequence 104 | if not partially_stopped: 105 | yield { 106 | "text": output, 107 | "usage": { 108 | "prompt_tokens": input_echo_len, 109 | "completion_tokens": i, 110 | "total_tokens": input_echo_len + i, 111 | }, 112 | "finish_reason": None, 113 | } 114 | output = output.strip() 115 | 116 | # finish stream event, which contains finish reason 117 | if i == max_new_tokens - 1: 118 | finish_reason = "length" 119 | elif partially_stopped: 120 | finish_reason = None 121 | else: 122 | finish_reason = "stop" 123 | 124 | yield { 125 | "text": output, 126 | "usage": { 127 | "prompt_tokens": input_echo_len, 128 | "completion_tokens": i, 129 | "total_tokens": input_echo_len + i, 130 | }, 131 | "finish_reason": finish_reason, 132 | } 133 | 134 | # clean 135 | gc.collect() 136 | torch.cuda.empty_cache() 137 | if device == "xpu": 138 | torch.xpu.empty_cache() 139 | -------------------------------------------------------------------------------- /chat/server/monitor/conv_release_scripts/filter_bad_conv.py: -------------------------------------------------------------------------------- 1 | """ 2 | Filter conversations for release. 3 | 4 | Usage: python3 filter_bad_conv.py --in clean_battle_conv_20230630_tagged_v1_pii.json 5 | """ 6 | import argparse 7 | from collections import defaultdict 8 | from enum import Enum, auto 9 | import json 10 | import os 11 | import random 12 | 13 | from tqdm import tqdm 14 | 15 | BLOCKED_WORDS_FILENAME = "blocked_words.json" 16 | blocked_words = [] 17 | frequency = defaultdict(lambda: 0) 18 | 19 | 20 | class TypeCode(Enum): 21 | CORRECT = auto() 22 | ANONYMIZED = auto() 23 | REDACTED = auto() 24 | BAD_FORMAT = auto() 25 | BLOCKED_WORD = auto() 26 | BLOCKED_MODEL = auto() 27 | TOO_SHORT = auto() 28 | TOO_FREQUENT = auto() 29 | 30 | 31 | def detect_type(conv): 32 | for key in ["conversation_a", "conversation_b"]: 33 | messages = [row["content"] for row in conv[key]] 34 | for msg in messages: 35 | if not isinstance(msg, str): 36 | return TypeCode.BAD_FORMAT 37 | 38 | user_prompts = [ 39 | row["content"].lower().strip() for row in conv[key] if row["role"] == "user" 40 | ] 41 | if len(messages) <= 2 and all(len(x) < 16 for x in user_prompts): 42 | return TypeCode.TOO_SHORT 43 | 44 | if all(x in frequent_prompts for x in user_prompts): 45 | return TypeCode.TOO_FREQUENT 46 | 47 | for msg in messages: 48 | msg = msg.lower() 49 | if "" in msg: 50 | return TypeCode.ANONYMIZED 51 | if "" in msg: 52 | return TypeCode.REDACTED 53 | 54 | for w in blocked_words: 55 | if w in msg: 56 | return TypeCode.BLOCKED_WORD 57 | 58 | for key in ["model_a", "model_b"]: 59 | if conv[key] in ["vicuna-33b", "mpt-30b-chat"]: 60 | return TypeCode.BLOCKED_MODEL 61 | 62 | return TypeCode.CORRECT 63 | 64 | 65 | if __name__ == "__main__": 66 | parser = argparse.ArgumentParser() 67 | parser.add_argument("--in-file", type=str, required=True) 68 | parser.add_argument("--sample", type=int) 69 | args = parser.parse_args() 70 | 71 | # Read conversations 72 | convs = json.load(open(args.in_file)) 73 | print(f"#conv: {len(convs)}") 74 | 75 | # Read blocked words 76 | if os.path.exists(BLOCKED_WORDS_FILENAME): 77 | blocked_words = json.load(open(BLOCKED_WORDS_FILENAME)) 78 | 79 | # Count frequency 80 | for conv in convs: 81 | for key in ["conversation_a", "conversation_b"]: 82 | messages = [row["content"] for row in conv[key] if row["role"] == "user"] 83 | for msg in messages: 84 | if not isinstance(msg, str): 85 | continue 86 | msg = msg.lower().strip() 87 | frequency[msg] += 1 88 | 89 | keys = list(frequency.keys()) 90 | keys.sort(key=lambda x: -frequency[x]) 91 | frequent_prompts = keys[:10] 92 | frequent_prompts = set(frequent_prompts) 93 | frequent_prompts.add("") 94 | 95 | # Start filter 96 | ct_bad_format = 0 97 | ct_anonymized = 0 98 | ct_redacted = 0 99 | ct_error = 0 100 | ct_lang_filter = 0 101 | ct_flagged = 0 102 | ct_blocked_word = 0 103 | ct_blocked_model = 0 104 | ct_too_short = 0 105 | ct_too_frequent = 0 106 | 107 | new_convs = [] 108 | for conv in tqdm(convs): 109 | type_code = detect_type(conv) 110 | 111 | if type_code == TypeCode.BAD_FORMAT: 112 | ct_bad_format += 1 113 | continue 114 | 115 | if type_code == TypeCode.ANONYMIZED: 116 | ct_anonymized += 1 117 | continue 118 | elif type_code == TypeCode.REDACTED: 119 | ct_redacted += 1 120 | continue 121 | elif type_code == TypeCode.BLOCKED_WORD: 122 | ct_blocked_word += 1 123 | continue 124 | elif type_code == TypeCode.BLOCKED_MODEL: 125 | ct_blocked_model += 1 126 | continue 127 | elif type_code == TypeCode.TOO_SHORT: 128 | ct_too_short += 1 129 | continue 130 | elif type_code == TypeCode.TOO_FREQUENT: 131 | ct_too_frequent += 1 132 | continue 133 | 134 | if conv["openai_moderation"]["flagged"]: 135 | ct_flagged += 1 136 | continue 137 | 138 | if type_code in [TypeCode.CORRECT]: 139 | new_convs.append(conv) 140 | 141 | if args.sample: 142 | # random.seed(0) 143 | # random.shuffle(new_convs) 144 | new_convs = new_convs[: args.sample] 145 | 146 | print(f"ct_anonymized: {ct_anonymized}, ct_redacted: {ct_redacted}") 147 | print(f"ct_bad_format: {ct_bad_format}, ct_flagged: {ct_flagged}") 148 | print(f"ct_blocked_word: {ct_blocked_word}, ct_blocked_model: {ct_blocked_model}") 149 | print(f"ct_too_short: {ct_too_short}, ct_too_frequent: {ct_anonymized}") 150 | print(f"new_conv: {len(new_convs)}") 151 | 152 | out_file = args.in_file.replace(".json", ".out.json") 153 | print(f"Output to {out_file}") 154 | with open(out_file, "w") as fout: 155 | json.dump(new_convs, fout, indent=2, ensure_ascii=False) 156 | -------------------------------------------------------------------------------- /chat/protocol/api_protocol.py: -------------------------------------------------------------------------------- 1 | from typing import Literal, Optional, List, Dict, Any, Union 2 | 3 | import time 4 | 5 | import shortuuid 6 | from pydantic import BaseModel, Field 7 | 8 | 9 | class ErrorResponse(BaseModel): 10 | object: str = "error" 11 | message: str 12 | code: int 13 | 14 | 15 | class ModelPermission(BaseModel): 16 | id: str = Field(default_factory=lambda: f"modelperm-{shortuuid.random()}") 17 | object: str = "model_permission" 18 | created: int = Field(default_factory=lambda: int(time.time())) 19 | allow_create_engine: bool = False 20 | allow_sampling: bool = True 21 | allow_logprobs: bool = True 22 | allow_search_indices: bool = True 23 | allow_view: bool = True 24 | allow_fine_tuning: bool = False 25 | organization: str = "*" 26 | group: Optional[str] = None 27 | is_blocking: str = False 28 | 29 | 30 | class ModelCard(BaseModel): 31 | id: str 32 | object: str = "model" 33 | created: int = Field(default_factory=lambda: int(time.time())) 34 | owned_by: str = "chat" 35 | root: Optional[str] = None 36 | parent: Optional[str] = None 37 | permission: List[ModelPermission] = [] 38 | 39 | 40 | class ModelList(BaseModel): 41 | object: str = "list" 42 | data: List[ModelCard] = [] 43 | 44 | 45 | class UsageInfo(BaseModel): 46 | prompt_tokens: int = 0 47 | total_tokens: int = 0 48 | completion_tokens: Optional[int] = 0 49 | 50 | 51 | class APIChatCompletionRequest(BaseModel): 52 | model: str 53 | messages: Union[str, List[Dict[str, str]]] 54 | temperature: Optional[float] = 0.7 55 | top_p: Optional[float] = 1.0 56 | n: Optional[int] = 1 57 | max_tokens: Optional[int] = None 58 | stop: Optional[Union[str, List[str]]] = None 59 | stream: Optional[bool] = False 60 | user: Optional[str] = None 61 | repetition_penalty: Optional[float] = 1.0 62 | 63 | 64 | class ChatMessage(BaseModel): 65 | role: str 66 | content: str 67 | 68 | 69 | class ChatCompletionResponseChoice(BaseModel): 70 | index: int 71 | message: ChatMessage 72 | finish_reason: Optional[Literal["stop", "length"]] = None 73 | 74 | 75 | class ChatCompletionResponse(BaseModel): 76 | id: str = Field(default_factory=lambda: f"chatcmpl-{shortuuid.random()}") 77 | object: str = "chat.completion" 78 | created: int = Field(default_factory=lambda: int(time.time())) 79 | model: str 80 | choices: List[ChatCompletionResponseChoice] 81 | usage: UsageInfo 82 | 83 | 84 | class DeltaMessage(BaseModel): 85 | role: Optional[str] = None 86 | content: Optional[str] = None 87 | 88 | 89 | class ChatCompletionResponseStreamChoice(BaseModel): 90 | index: int 91 | delta: DeltaMessage 92 | finish_reason: Optional[Literal["stop", "length"]] = None 93 | 94 | 95 | class ChatCompletionStreamResponse(BaseModel): 96 | id: str = Field(default_factory=lambda: f"chatcmpl-{shortuuid.random()}") 97 | object: str = "chat.completion.chunk" 98 | created: int = Field(default_factory=lambda: int(time.time())) 99 | model: str 100 | choices: List[ChatCompletionResponseStreamChoice] 101 | 102 | 103 | class APITokenCheckRequestItem(BaseModel): 104 | model: str 105 | prompt: str 106 | max_tokens: int 107 | 108 | 109 | class APITokenCheckRequest(BaseModel): 110 | prompts: List[APITokenCheckRequestItem] 111 | 112 | 113 | class APITokenCheckResponseItem(BaseModel): 114 | fits: bool 115 | tokenCount: int 116 | contextLength: int 117 | 118 | 119 | class APITokenCheckResponse(BaseModel): 120 | prompts: List[APITokenCheckResponseItem] 121 | 122 | 123 | class CompletionRequest(BaseModel): 124 | model: str 125 | prompt: Union[str, List[Any]] 126 | suffix: Optional[str] = None 127 | temperature: Optional[float] = 0.7 128 | n: Optional[int] = 1 129 | max_tokens: Optional[int] = 16 130 | stop: Optional[Union[str, List[str]]] = None 131 | stream: Optional[bool] = False 132 | top_p: Optional[float] = 1.0 133 | logprobs: Optional[int] = None 134 | echo: Optional[bool] = False 135 | presence_penalty: Optional[float] = 0.0 136 | frequency_penalty: Optional[float] = 0.0 137 | user: Optional[str] = None 138 | 139 | 140 | class CompletionResponseChoice(BaseModel): 141 | index: int 142 | text: str 143 | logprobs: Optional[int] = None 144 | finish_reason: Optional[Literal["stop", "length"]] = None 145 | 146 | 147 | class CompletionResponse(BaseModel): 148 | id: str = Field(default_factory=lambda: f"cmpl-{shortuuid.random()}") 149 | object: str = "text_completion" 150 | created: int = Field(default_factory=lambda: int(time.time())) 151 | model: str 152 | choices: List[CompletionResponseChoice] 153 | usage: UsageInfo 154 | 155 | 156 | class CompletionResponseStreamChoice(BaseModel): 157 | index: int 158 | text: str 159 | logprobs: Optional[float] = None 160 | finish_reason: Optional[Literal["stop", "length"]] = None 161 | 162 | 163 | class CompletionStreamResponse(BaseModel): 164 | id: str = Field(default_factory=lambda: f"cmpl-{shortuuid.random()}") 165 | object: str = "text_completion" 166 | created: int = Field(default_factory=lambda: int(time.time())) 167 | model: str 168 | choices: List[CompletionResponseStreamChoice] 169 | -------------------------------------------------------------------------------- /chat/server/monitor/clean_chat_data.py: -------------------------------------------------------------------------------- 1 | """ 2 | Clean chatbot arena chat log. 3 | 4 | Usage: 5 | python3 clean_chat_data.py --mode conv_release 6 | """ 7 | import argparse 8 | import datetime 9 | import json 10 | import os 11 | from pytz import timezone 12 | import time 13 | 14 | from tqdm import tqdm 15 | 16 | from chat.server.monitor.basic_stats import NUM_SERVERS 17 | from chat.server.monitor.clean_battle_data import ( 18 | to_openai_format, 19 | replace_model_name, 20 | ) 21 | from chat.utils import detect_language 22 | 23 | 24 | NETWORK_ERROR_MSG = ( 25 | "NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.".lower() 26 | ) 27 | 28 | 29 | def get_log_files(max_num_files=None): 30 | dates = [] 31 | for month in [4, 5, 6, 7]: 32 | for day in range(1, 32): 33 | dates.append(f"2023-{month:02d}-{day:02d}") 34 | 35 | for month in [8]: 36 | for day in range(1, 32): 37 | dates.append(f"2023-{month:02d}-{day:02d}") 38 | 39 | filenames = [] 40 | for d in dates: 41 | for i in range(NUM_SERVERS): 42 | name = os.path.expanduser(f"~/fastchat_logs/server{i}/{d}-conv.json") 43 | if os.path.exists(name): 44 | filenames.append(name) 45 | max_num_files = max_num_files or len(filenames) 46 | # filenames = list(reversed(filenames)) 47 | filenames = filenames[-max_num_files:] 48 | return filenames 49 | 50 | 51 | def clean_chat_data(log_files): 52 | raw_data = [] 53 | for filename in tqdm(log_files, desc="read files"): 54 | for retry in range(5): 55 | try: 56 | lines = open(filename).readlines() 57 | break 58 | except FileNotFoundError: 59 | time.sleep(2) 60 | 61 | for l in lines: 62 | row = json.loads(l) 63 | if row["type"] == "chat": 64 | raw_data.append(row) 65 | 66 | all_models = set() 67 | all_ips = dict() 68 | chats = [] 69 | ct_invalid_conv_id = 0 70 | ct_invalid = 0 71 | ct_network_error = 0 72 | for row in raw_data: 73 | if "conv_id" not in row["state"]: 74 | ct_invalid_conv_id += 1 75 | continue 76 | 77 | conversation_id = row["state"]["conv_id"] 78 | if conversation_id is None: 79 | ct_invalid_conv_id += 1 80 | continue 81 | 82 | state = row["state"] 83 | conversation = to_openai_format(state["messages"][state["offset"] :]) 84 | model = row["model"] 85 | if not isinstance(model, str): 86 | ct_invalid += 1 87 | continue 88 | model = replace_model_name(model) 89 | 90 | try: 91 | lang_code = detect_language(state["messages"][state["offset"]][1]) 92 | except IndexError: 93 | ct_invalid += 1 94 | continue 95 | 96 | if not all(isinstance(x["content"], str) for x in conversation): 97 | ct_invalid += 1 98 | continue 99 | 100 | messages = "".join([x["content"] for x in conversation]).lower() 101 | if NETWORK_ERROR_MSG in messages: 102 | ct_network_error += 1 103 | continue 104 | 105 | ip = row["ip"] 106 | if ip not in all_ips: 107 | all_ips[ip] = len(all_ips) 108 | user_id = all_ips[ip] 109 | 110 | chats.append( 111 | dict( 112 | conversation_id=conversation_id, 113 | model=model, 114 | conversation=conversation, 115 | turn=len(conversation) // 2, 116 | language=lang_code, 117 | user_id=user_id, 118 | tstamp=row["tstamp"], 119 | ) 120 | ) 121 | 122 | all_models.update([model]) 123 | 124 | chats.sort(key=lambda x: x["tstamp"]) 125 | last_updated_tstamp = chats[-1]["tstamp"] 126 | last_updated_datetime = datetime.datetime.fromtimestamp( 127 | last_updated_tstamp, tz=timezone("US/Pacific") 128 | ).strftime("%Y-%m-%d %H:%M:%S %Z") 129 | 130 | # Deduplication 131 | dedup_chats = [] 132 | visited_conv_ids = set() 133 | for i in reversed(range(len(chats))): 134 | if chats[i]["conversation_id"] in visited_conv_ids: 135 | continue 136 | visited_conv_ids.add(chats[i]["conversation_id"]) 137 | dedup_chats.append(chats[i]) 138 | 139 | print( 140 | f"#raw: {len(raw_data)}, #chat: {len(chats)}, #dedup_chat: {len(dedup_chats)}" 141 | ) 142 | print( 143 | f"#invalid_conv_id: {ct_invalid_conv_id}, #network_error: {ct_network_error}, #invalid: {ct_invalid}" 144 | ) 145 | print(f"#models: {len(all_models)}, {all_models}") 146 | print(f"last-updated: {last_updated_datetime}") 147 | 148 | return list(reversed(dedup_chats)) 149 | 150 | 151 | if __name__ == "__main__": 152 | parser = argparse.ArgumentParser() 153 | parser.add_argument("--max-num-files", type=int) 154 | args = parser.parse_args() 155 | 156 | log_files = get_log_files(args.max_num_files) 157 | chats = clean_chat_data(log_files) 158 | last_updated_tstamp = chats[-1]["tstamp"] 159 | cutoff_date = datetime.datetime.fromtimestamp( 160 | last_updated_tstamp, tz=timezone("US/Pacific") 161 | ).strftime("%Y%m%d") 162 | 163 | output = f"clean_chat_conv_{cutoff_date}.json" 164 | with open(output, "w") as fout: 165 | json.dump(chats, fout, indent=2, ensure_ascii=False) 166 | print(f"Write cleaned data to {output}") 167 | -------------------------------------------------------------------------------- /chat/protocol/openai_api_protocol.py: -------------------------------------------------------------------------------- 1 | from typing import Literal, Optional, List, Dict, Any, Union 2 | 3 | import time 4 | 5 | import shortuuid 6 | from pydantic import BaseModel, Field 7 | 8 | 9 | class ErrorResponse(BaseModel): 10 | object: str = "error" 11 | message: str 12 | code: int 13 | 14 | 15 | class ModelPermission(BaseModel): 16 | id: str = Field(default_factory=lambda: f"modelperm-{shortuuid.random()}") 17 | object: str = "model_permission" 18 | created: int = Field(default_factory=lambda: int(time.time())) 19 | allow_create_engine: bool = False 20 | allow_sampling: bool = True 21 | allow_logprobs: bool = True 22 | allow_search_indices: bool = True 23 | allow_view: bool = True 24 | allow_fine_tuning: bool = False 25 | organization: str = "*" 26 | group: Optional[str] = None 27 | is_blocking: str = False 28 | 29 | 30 | class ModelCard(BaseModel): 31 | id: str 32 | object: str = "model" 33 | created: int = Field(default_factory=lambda: int(time.time())) 34 | owned_by: str = "chat" 35 | root: Optional[str] = None 36 | parent: Optional[str] = None 37 | permission: List[ModelPermission] = [] 38 | 39 | 40 | class ModelList(BaseModel): 41 | object: str = "list" 42 | data: List[ModelCard] = [] 43 | 44 | 45 | class UsageInfo(BaseModel): 46 | prompt_tokens: int = 0 47 | total_tokens: int = 0 48 | completion_tokens: Optional[int] = 0 49 | 50 | 51 | class ChatCompletionRequest(BaseModel): 52 | model: str 53 | messages: Union[str, List[Dict[str, str]]] 54 | temperature: Optional[float] = 0.7 55 | top_p: Optional[float] = 1.0 56 | n: Optional[int] = 1 57 | max_tokens: Optional[int] = None 58 | stop: Optional[Union[str, List[str]]] = None 59 | stream: Optional[bool] = False 60 | presence_penalty: Optional[float] = 0.0 61 | frequency_penalty: Optional[float] = 0.0 62 | user: Optional[str] = None 63 | 64 | 65 | class ChatMessage(BaseModel): 66 | role: str 67 | content: str 68 | 69 | 70 | class ChatCompletionResponseChoice(BaseModel): 71 | index: int 72 | message: ChatMessage 73 | finish_reason: Optional[Literal["stop", "length"]] = None 74 | 75 | 76 | class ChatCompletionResponse(BaseModel): 77 | id: str = Field(default_factory=lambda: f"chatcmpl-{shortuuid.random()}") 78 | object: str = "chat.completion" 79 | created: int = Field(default_factory=lambda: int(time.time())) 80 | model: str 81 | choices: List[ChatCompletionResponseChoice] 82 | usage: UsageInfo 83 | 84 | 85 | class DeltaMessage(BaseModel): 86 | role: Optional[str] = None 87 | content: Optional[str] = None 88 | 89 | 90 | class ChatCompletionResponseStreamChoice(BaseModel): 91 | index: int 92 | delta: DeltaMessage 93 | finish_reason: Optional[Literal["stop", "length"]] = None 94 | 95 | 96 | class ChatCompletionStreamResponse(BaseModel): 97 | id: str = Field(default_factory=lambda: f"chatcmpl-{shortuuid.random()}") 98 | object: str = "chat.completion.chunk" 99 | created: int = Field(default_factory=lambda: int(time.time())) 100 | model: str 101 | choices: List[ChatCompletionResponseStreamChoice] 102 | 103 | 104 | class TokenCheckRequestItem(BaseModel): 105 | model: str 106 | prompt: str 107 | max_tokens: int 108 | 109 | 110 | class TokenCheckRequest(BaseModel): 111 | prompts: List[TokenCheckRequestItem] 112 | 113 | 114 | class TokenCheckResponseItem(BaseModel): 115 | fits: bool 116 | tokenCount: int 117 | contextLength: int 118 | 119 | 120 | class TokenCheckResponse(BaseModel): 121 | prompts: List[TokenCheckResponseItem] 122 | 123 | 124 | class EmbeddingsRequest(BaseModel): 125 | model: Optional[str] = None 126 | engine: Optional[str] = None 127 | input: Union[str, List[Any]] 128 | user: Optional[str] = None 129 | encoding_format: Optional[str] = None 130 | 131 | 132 | class EmbeddingsResponse(BaseModel): 133 | object: str = "list" 134 | data: List[Dict[str, Any]] 135 | model: str 136 | usage: UsageInfo 137 | 138 | 139 | class CompletionRequest(BaseModel): 140 | model: str 141 | prompt: Union[str, List[Any]] 142 | suffix: Optional[str] = None 143 | temperature: Optional[float] = 0.7 144 | n: Optional[int] = 1 145 | max_tokens: Optional[int] = 16 146 | stop: Optional[Union[str, List[str]]] = None 147 | stream: Optional[bool] = False 148 | top_p: Optional[float] = 1.0 149 | logprobs: Optional[int] = None 150 | echo: Optional[bool] = False 151 | presence_penalty: Optional[float] = 0.0 152 | frequency_penalty: Optional[float] = 0.0 153 | user: Optional[str] = None 154 | 155 | 156 | class CompletionResponseChoice(BaseModel): 157 | index: int 158 | text: str 159 | logprobs: Optional[int] = None 160 | finish_reason: Optional[Literal["stop", "length"]] = None 161 | 162 | 163 | class CompletionResponse(BaseModel): 164 | id: str = Field(default_factory=lambda: f"cmpl-{shortuuid.random()}") 165 | object: str = "text_completion" 166 | created: int = Field(default_factory=lambda: int(time.time())) 167 | model: str 168 | choices: List[CompletionResponseChoice] 169 | usage: UsageInfo 170 | 171 | 172 | class CompletionResponseStreamChoice(BaseModel): 173 | index: int 174 | text: str 175 | logprobs: Optional[float] = None 176 | finish_reason: Optional[Literal["stop", "length"]] = None 177 | 178 | 179 | class CompletionStreamResponse(BaseModel): 180 | id: str = Field(default_factory=lambda: f"cmpl-{shortuuid.random()}") 181 | object: str = "text_completion" 182 | created: int = Field(default_factory=lambda: int(time.time())) 183 | model: str 184 | choices: List[CompletionResponseStreamChoice] 185 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Test-Agent: 您的智能测试助理 2 |

3 | 4 |

5 | 6 |

7 | 8 | stars 9 | 10 | 11 | forks 12 | 13 | 14 | License: MIT 15 | 16 | 17 | Open Issues 18 | 19 |

20 | 21 | ### 本地Mac M1体验效果 22 | ![图片](https://github.com/codefuse-ai/Test-Agent/assets/103973989/8dba860f-c1bb-49d5-b9dd-a58e541562a6) 23 | 24 | ### 魔搭体验效果 25 | 魔搭模型访问链接:[ModelScope TestGPT-7B](https://modelscope.cn/models/codefuse-ai/TestGPT-7B/summary) 26 | ![MS](https://github.com/codefuse-ai/Test-Agent/assets/103973989/0e50b258-44f9-4dc6-8e30-0a01cf62d02b) 27 | 28 | 29 | ## 什么是Test Agent?(Introduction) 30 | 31 | **Test Agent** 旨在构建测试领域的“智能体”,融合大模型和质量领域工程化技术,促进质量技术代系升级。我们期望和社区成员一起合作,打造创新的测试领域解决方案,构建24小时在线的测试助理服务,让测试如丝般顺滑。 32 | ## 本期特性(Features) 33 | 34 | * **模型** 本期我们开源了测试领域模型TestGPT-7B。模型以CodeLlama-7B为基座,进行了相关下游任务的微调: 35 | * **多语言测试用例生成(Java/Python/Javascript)** 一直以来都是学术界和工业界非常关注的领域,近年来不断有新产品或工具孵化出来,如EvoSuite、Randoop、SmartUnit等。然而传统的用例生成存在其难以解决的痛点问题,基于大模型的测试用例生成在测试用例可读性、测试场景完整度、多语言支持方面都优于传统用例生成工具。本次重点支持了多语言测试用例生成,在我们本次开源的版本中首先包含了Java、Python、Javascript的测试用例生成能力,下一版本中逐步开放Go、C++等语言。 36 | * **测试用例Assert补全** 对当前测试用例现状的分析与探查时,我们发现代码仓库中存在一定比例的存量测试用例中未包含Assert。没有Assert的测试用例虽然能够在回归过程中执行通过,却无法发现问题。因此我们拓展了测试用例Assert自动补全这一场景。通过该模型能力,结合一定的工程化配套,可以实现对全库测试用例的批量自动补全,智能提升项目质量水位。 37 | 38 | * **工程框架** 本地模型快速发布和体验工程化框架 39 | - ChatBot页面 40 | - 模型快速启动 41 | - 私有化部署,本地化的GPT大模型与您的数据和环境进行交互,无数据泄露风险,100%安全 42 | 43 | **后续我们会持续迭代模型和工程化能力:** 44 | - 不断加入更多令人激动的测试域应用场景,如领域知识问答、测试场景分析等 45 | - 支撑面向测试场景的copilot 工程框架开放,如测试领域知识智能embedding、测试通用工具API体系、智能测试Agent等,敬请期待! 46 | - 以7B为基础,逐步扩展至13B、34B模型。欢迎关注! 47 | 48 | ## 性能最强的7B测试领域大模型(Model) 49 | 目前在TestAgent中,我们默认使用了TestGPT-7B模型。与当前已有开源模型相比,**TestGPT-7B模型在用例执行通过率(pass@1)、用例场景覆盖(平均测试场景数)上都处于业界领先水平。** 50 | TestGPT-7B模型核心能力的评测结果如下: 51 | - 多语言测试用例生成 52 | 针对模型支持的三种语言:Java、Python、Javascript,Pass@1评测结果如下: 53 | 54 | | Model | Java pass@1 | Java Average number of test scenarios | Python pass@1 | Python Average number of test scenarios | Javascript pass@1 | Javascript Average number of test scenarios | 55 | | --- | --- | --- | --- | --- | --- | --- | 56 | | TestGPT-7B | 48.6% | 4.37 | 35.67% | 3.56 | 36% | 2.76 | 57 | | CodeLlama-13B-Instruct | 40.54% | 1.08 | 30.57% | 1.65 | 31.7% | 3.13 | 58 | | Qwen-14B-Chat | 10.81% | 2.78 | 15.9% | 1.32 | 9.15% | 4.22 | 59 | | Baichuan2-13B-Chat | 13.5% | 2.24 | 12.7% | 2.12 | 6.1% | 3.31 | 60 | 61 | 62 | - 测试用例Assert补全 63 | 目前模型支持Java用例的Assert补全,Pass@1评测结果如下: 64 | 65 | | Model | pass@1 | Percentage of strong validation | 66 | | --- | --- | --- | 67 | | Codefuse-TestGPT-7B | 71.1% | 100% | 68 | 69 | 70 | ## 工程架构(Engineering Architecture) 71 | ![JG](https://github.com/codefuse-ai/Test-Agent/assets/103973989/1b61beff-df59-4ab3-843c-266413c8dbc4) 72 | 73 | 大模型的号角已经吹响,测试领域大模型也在不断进化中,通过预训练过程中积累的丰富世界知识,在复杂交互环境中展现出了非凡的推理与决策能力。 74 | 75 | 尽管在测试领域中基础模型取得了显著的成果,但仍然存在一些局限性,特定领域的测试任务通常需要专业化的工具或领域知识来解决。例如,基础模型可以通过预训练知识完成单次测试代码生成和测试文本生成等任务,但处理复杂的集成用例生成、特定领域用例生成和测试流程pipeline交互等问题时,需要更专业的工具和领域知识。因此将专用工具与基础模型整合在一起,可以充分发挥它们各自的优势。专用工具可以解决模型时效性不足、增强专业知识、提高可解释性和鲁棒性的问题。而基础模型则具备类人的推理规划能力,可以理解复杂的数据和场景,并与现实世界进行交互。 76 | 77 | 在本期开放模型工程化部署和ChatBot基础上,我们将继续在测试开源领域深耕投入。协同社区志趣相投开发者们,一起打造测试领域最领先的Tools工程体系、智能测试助理和测试开源工程! 78 | 79 | ## 快速使用(QuickStart) 80 | ### 前置准备 81 | 82 | #### 模型下载 83 | 84 | 您可在[modelscope](https://modelscope.cn/models/codefuse-ai/TestGPT-7B)或[huggingface](https://huggingface.co/codefuse-ai/TestGPT-7B)上获取到模型的详细信息并下载模型文件。 85 | 需要注意的是: 86 | 1)如果您通过modelscope下载模型,下载方式可参考:[下载说明](https://www.modelscope.cn/docs/%E6%A8%A1%E5%9E%8B%E7%9A%84%E4%B8%8B%E8%BD%BD#%E4%BD%BF%E7%94%A8Git%E4%B8%8B%E8%BD%BD%E6%A8%A1%E5%9E%8B); 87 | 2)如果您通过huggingface下载模型,请确保您可以正常访问huggingface。 88 | 89 | #### 环境安装 90 | 91 | - python>=3.8 92 | - transformers==4.33.2 93 | 94 | ```plain 95 | git clone https://github.com/codefuse-ai/Test-Agent 96 | cd Test-Agent 97 | pip install -r requirements.txt 98 | ``` 99 | 100 | 在开始运行TestGPT-7B模型之前,请确保你的执行环境拥有大约14GB的显存。 101 | ### 启动服务 102 | 103 | 项目提供了网页端快速搭建UI的能力能够更直观的展示模型交互和效果,我们可以使用简单的几个命令把前端页面唤醒并实时调用模型能力。在项目目录下,依次启动以下服务: 104 | 105 | 1.**启动controller** 106 | ![controller](https://github.com/codefuse-ai/Test-Agent/assets/103973989/e68ce187-c9f1-4ce8-9d59-ff9d8348d0ac) 107 | python3 -m chat.server.controller 108 | 109 | 2.**启动模型worker** 110 | ![work](https://github.com/codefuse-ai/Test-Agent/assets/103973989/073e4e79-4005-4c98-87f7-0eaa0b2b1e22) 111 | python3 -m chat.server.model_worker --model-path models/TestGPT-7B --device mps 112 | 113 | (models/TestGPT-7B 为实际模型文件路径) 114 | 115 | 对于启动方式,可以按需选择以下几种配置选项: 116 | - --device mps 用于在Mac电脑上开启GPU加速的选项(Apple Silicon或AMD GPUs); 117 | - --device xpu 用于在Intel XPU上开启加速的选项(Intel Data Center and Arc A-Series GPUs); 118 | - 需安装[Intel Extension for PyTorch](https://intel.github.io/intel-extension-for-pytorch/xpu/latest/tutorials/installation.html) 119 | - 设置OneAPI环境变量:source /opt/intel/oneapi/setvars.sh 120 | - --device npu 用于在华为AI处理器上开启加速的选项; 121 | - 需安装[Ascend PyTorch Adapter](https://github.com/Ascend/pytorch) 122 | - 设置CANN环境变量:source /usr/local/Ascend/ascend-toolkit/set_env.sh 123 | - --device cpu 单独使用CPU运行的选项,不需要GPU; 124 | - --num-gpus 2 指定并发gpu运行的选项。 125 | 126 | 3. **启动web服务** 127 | python3 -m chat.server.gradio_testgpt 128 | ![web](https://github.com/codefuse-ai/Test-Agent/assets/103973989/340dae35-573b-4046-a3e8-e87a91453601) 129 | 待服务准备就绪后,我们可以打开本地启动的web服务地址 http://0.0.0.0:7860 ,就能看到完整的前端页面了。在页面下方包含了【单测生成】和【Assert补全】的两个例子,点击按钮后会自动生成一段样例文本到输入框中,点击Send按钮就会触发模型运行,之后耐心等待一段时间后(运行时间视本机性能而定)即可看到完整的回答了。 130 | ![demo](https://github.com/codefuse-ai/Test-Agent/assets/103973989/fd24274c-729b-4ce7-8763-a083b39300fb) 131 | 132 | ## 🤗 致谢 133 | 本项目基于[FastChat](https://github.com/lm-sys/FastChat) 构建,在此深深感谢他们的开源贡献! 134 | 135 | ## 联系我们 136 | ![testagent_wechat_3](https://github.com/codefuse-ai/Test-Agent/assets/106229399/dd803960-8952-4fbb-90b2-877ff792d2e3) 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | -------------------------------------------------------------------------------- /chat/model/apply_delta.py: -------------------------------------------------------------------------------- 1 | """ 2 | Apply the delta weights on top of a base model. 3 | 4 | Usage: 5 | python3 -m chat.model.apply_delta --base ~/model_weights/llama-7b --target ~/model_weights/vicuna-7b --delta lmsys/vicuna-7b-delta-v1.1 6 | """ 7 | import argparse 8 | import gc 9 | import glob 10 | import json 11 | import os 12 | import shutil 13 | import tempfile 14 | 15 | from huggingface_hub import snapshot_download 16 | import torch 17 | from torch import nn 18 | from tqdm import tqdm 19 | from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig 20 | 21 | 22 | GB = 1 << 30 23 | 24 | 25 | def split_files(model_path, tmp_path, split_size): 26 | if not os.path.exists(model_path): 27 | model_path = snapshot_download(repo_id=model_path) 28 | if not os.path.exists(tmp_path): 29 | os.makedirs(tmp_path) 30 | 31 | file_pattern = os.path.join(model_path, "pytorch_model-*.bin") 32 | files = glob.glob(file_pattern) 33 | 34 | part = 0 35 | try: 36 | for file_path in tqdm(files): 37 | state_dict = torch.load(file_path) 38 | new_state_dict = {} 39 | 40 | current_size = 0 41 | for name, param in state_dict.items(): 42 | param_size = param.numel() * param.element_size() 43 | 44 | if current_size + param_size > split_size: 45 | new_file_name = f"pytorch_model-{part}.bin" 46 | new_file_path = os.path.join(tmp_path, new_file_name) 47 | torch.save(new_state_dict, new_file_path) 48 | current_size = 0 49 | new_state_dict = None 50 | gc.collect() 51 | new_state_dict = {} 52 | part += 1 53 | 54 | new_state_dict[name] = param 55 | current_size += param_size 56 | 57 | new_file_name = f"pytorch_model-{part}.bin" 58 | new_file_path = os.path.join(tmp_path, new_file_name) 59 | torch.save(new_state_dict, new_file_path) 60 | new_state_dict = None 61 | gc.collect() 62 | new_state_dict = {} 63 | part += 1 64 | except Exception as e: 65 | print(f"An error occurred during split_files: {e}") 66 | shutil.rmtree(tmp_path) 67 | raise 68 | 69 | 70 | def apply_delta_low_cpu_mem(base_model_path, target_model_path, delta_path): 71 | delta_tokenizer = AutoTokenizer.from_pretrained(delta_path, use_fast=False) 72 | delta_config = AutoConfig.from_pretrained(delta_path) 73 | 74 | if os.path.exists(target_model_path): 75 | shutil.rmtree(target_model_path) 76 | os.makedirs(target_model_path) 77 | 78 | split_size = 4 * GB 79 | 80 | with tempfile.TemporaryDirectory() as tmp_base_path, tempfile.TemporaryDirectory() as tmp_delta_path: 81 | print(f"Split files for the base model to {tmp_base_path}") 82 | split_files(base_model_path, tmp_base_path, split_size) 83 | print(f"Split files for the delta weights to {tmp_delta_path}") 84 | split_files(delta_path, tmp_delta_path, split_size) 85 | 86 | base_pattern = os.path.join(tmp_base_path, "pytorch_model-*.bin") 87 | base_files = glob.glob(base_pattern) 88 | delta_pattern = os.path.join(tmp_delta_path, "pytorch_model-*.bin") 89 | delta_files = glob.glob(delta_pattern) 90 | delta_state_dict = torch.load(delta_files[0]) 91 | 92 | print("Applying the delta") 93 | weight_map = {} 94 | total_size = 0 95 | 96 | for i, base_file in tqdm(enumerate(base_files)): 97 | state_dict = torch.load(base_file) 98 | file_name = f"pytorch_model-{i}.bin" 99 | for name, param in state_dict.items(): 100 | if name not in delta_state_dict: 101 | for delta_file in delta_files: 102 | delta_state_dict = torch.load(delta_file) 103 | gc.collect() 104 | if name in delta_state_dict: 105 | break 106 | 107 | state_dict[name] += delta_state_dict[name] 108 | weight_map[name] = file_name 109 | total_size += param.numel() * param.element_size() 110 | gc.collect() 111 | torch.save(state_dict, os.path.join(target_model_path, file_name)) 112 | 113 | with open( 114 | os.path.join(target_model_path, "pytorch_model.bin.index.json"), "w" 115 | ) as f: 116 | json.dump( 117 | {"weight_map": weight_map, "metadata": {"total_size": total_size}}, f 118 | ) 119 | 120 | print(f"Saving the target model to {target_model_path}") 121 | delta_tokenizer.save_pretrained(target_model_path) 122 | delta_config.save_pretrained(target_model_path) 123 | 124 | 125 | def apply_delta(base_model_path, target_model_path, delta_path): 126 | print(f"Loading the delta weights from {delta_path}") 127 | delta_tokenizer = AutoTokenizer.from_pretrained(delta_path, use_fast=False) 128 | delta = AutoModelForCausalLM.from_pretrained( 129 | delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True 130 | ) 131 | 132 | print(f"Loading the base model from {base_model_path}") 133 | base = AutoModelForCausalLM.from_pretrained( 134 | base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True 135 | ) 136 | 137 | print("Applying the delta") 138 | for name, param in tqdm(base.state_dict().items(), desc="Applying delta"): 139 | assert name in delta.state_dict() 140 | param.data += delta.state_dict()[name] 141 | 142 | print(f"Saving the target model to {target_model_path}") 143 | base.save_pretrained(target_model_path) 144 | delta_tokenizer.save_pretrained(target_model_path) 145 | 146 | 147 | if __name__ == "__main__": 148 | parser = argparse.ArgumentParser() 149 | parser.add_argument("--base-model-path", type=str, required=True) 150 | parser.add_argument("--target-model-path", type=str, required=True) 151 | parser.add_argument("--delta-path", type=str, required=True) 152 | parser.add_argument( 153 | "--low-cpu-mem", 154 | action="store_true", 155 | help="Lower the cpu memory usage. This will split large files and use " 156 | "disk as swap to reduce the memory usage below 10GB.", 157 | ) 158 | args = parser.parse_args() 159 | 160 | if args.low_cpu_mem: 161 | apply_delta_low_cpu_mem( 162 | args.base_model_path, args.target_model_path, args.delta_path 163 | ) 164 | else: 165 | apply_delta(args.base_model_path, args.target_model_path, args.delta_path) 166 | -------------------------------------------------------------------------------- /chat/data/hardcoded_questions.py: -------------------------------------------------------------------------------- 1 | """ 2 | Hardcoded question and answers. 3 | """ 4 | import json 5 | 6 | 7 | def identity_questions(): 8 | """ " 9 | Adapted from https://github.com/young-geng/koala_data_pipeline/blob/main/process_hard_coded_data.py 10 | """ 11 | content = [] 12 | 13 | name = "Vicuna" 14 | org = "Large Model Systems Organization (LMSYS)" 15 | 16 | def generate_conversations(questions, answers): 17 | for q in questions: 18 | for a in answers: 19 | content.append( 20 | { 21 | "id": f"identity_{len(content)}", 22 | "conversations": [ 23 | {"from": "human", "value": q}, 24 | {"from": "gpt", "value": a}, 25 | ], 26 | } 27 | ) 28 | 29 | questions = [ 30 | "Who are you?", 31 | "What is your name?", 32 | "Can you introduce yourself?", 33 | "Can you tell me a little bit about yourself?", 34 | "What's your name?", 35 | "What are you called?", 36 | "What are you?", 37 | "Tell me your name.", 38 | "Tell me about yourself.", 39 | "Tell me about you.", 40 | "Tell me who you are.", 41 | "Please introduce yourself.", 42 | ] 43 | answers = [ 44 | f"I am {name}, a language model trained by researchers from {org}.", 45 | f"My name is {name}, and I'm a language model developed by {org}.", 46 | f"You can call me {name}, and I was trained by {org} researchers as a language model.", 47 | f"As a language model, I go by the name {name} and was trained by researchers from {org}.", 48 | f"I'm a language model called {name}, and I was trained by {org} researchers.", 49 | f"You may refer to me as {name}, a language model meticulously developed by the researchers at {org}.", 50 | ] 51 | generate_conversations(questions, answers) 52 | 53 | questions = [ 54 | "Who created you?", 55 | "Who made you?", 56 | "Who built you?", 57 | "Who programmed you?", 58 | "Who trained you?", 59 | "Who taught you?", 60 | "Who developed you?", 61 | ] 62 | answers = [ 63 | f"Researchers from {org} created me.", 64 | f"I'm created by {org}.", 65 | f"I'm built by researchers from {org}.", 66 | f"I am a language model trained by researchers from {org}.", 67 | f"I'm a language model developed by {org}.", 68 | f"I'm a language model created by researchers from {org}.", 69 | f"My creators are researchers from {org}.", 70 | ] 71 | generate_conversations(questions, answers) 72 | 73 | questions = [ 74 | "Are you ChatGPT?", 75 | "Are you GPT-2?", 76 | "Are you GPT-3?", 77 | "Are you GPT-4?", 78 | "Are you davinci?", 79 | "Are you davinci-001?", 80 | "Are you davinci-002?", 81 | "Are you davinci-003?", 82 | "Are you curie?", 83 | "Are you based on ChatGPT?", 84 | "Are you based on GPT-2?", 85 | "Are you based on GPT-3?", 86 | "Are you based on GPT-4?", 87 | "Are you based on davinci?", 88 | "Are you based on davinci-001?", 89 | "Are you based on davinci-002?", 90 | "Are you based on davinci-003?", 91 | "Are you based on curie?", 92 | "Are you trained by OpenAI?", 93 | "Are you trained by Google?", 94 | "Are you trained by Microsoft?", 95 | "Are you trained by Meta?", 96 | "Are you trained by IBM?", 97 | "Do you call OpenAI APIs?", 98 | "Do you call Google APIs?", 99 | "Do you call Microsoft APIs?", 100 | "Do you call Meta APIs?", 101 | "Do you call IBM APIs?", 102 | "Are you created by OpenAI?", 103 | "Are you created by Google?", 104 | "Are you created by Microsoft?", 105 | "Are you created by Meta?", 106 | "Are you created by IBM?", 107 | "Are you developed by OpenAI?", 108 | "Are you developed by Google?", 109 | "Are you developed by Microsoft?", 110 | "Are you developed by Meta?", 111 | "Are you developed by IBM?", 112 | "Are you trained on OpenAI data?", 113 | "Are you trained on Google data?", 114 | "Are you trained on Microsoft data?", 115 | "Are you trained on Meta data?", 116 | "Are you trained on IBM data?", 117 | "Are you trained with OpenAI data?", 118 | "Are you trained with Google data?", 119 | "Are you trained with Microsoft data?", 120 | "Are you trained with Meta data?", 121 | "Are you trained with IBM data?", 122 | "Have you been trained with OpenAI data?", 123 | "Have you been trained with Google data?", 124 | "Have you been trained with Microsoft data?", 125 | "Have you been trained with Meta data?", 126 | "Have you been trained with IBM data?", 127 | "Are you finetuned on OpenAI data?", 128 | "Are you finetuned on Google data?", 129 | "Are you finetuned on Microsoft data?", 130 | "Are you finetuned on Meta data?", 131 | "Are you finetuned on IBM data?", 132 | "Are you finetuned with OpenAI data?", 133 | "Are you finetuned with Google data?", 134 | "Are you finetuned with Microsoft data?", 135 | "Are you finetuned with Meta data?", 136 | "Are you finetuned with IBM data?", 137 | "Have you been finetuned with OpenAI data?", 138 | "Have you been finetuned with Google data?", 139 | "Have you been finetuned with Microsoft data?", 140 | "Have you been finetuned with Meta data?", 141 | "Have you been finetuned with IBM data?", 142 | ] 143 | answers = [ 144 | f"No, I am a language model trained by researchers from {org}.", 145 | f"No, I am a language model developed by researchers from {org}.", 146 | f"No, I am a language model created by researchers from {org}.", 147 | f"No, I am trained by researchers from {org}.", 148 | f"No, I am developed by researchers from {org}.", 149 | f"No, I am created by researchers from {org}.", 150 | f"No, I'm a language model trained by researchers from {org}.", 151 | f"No, I'm a language model developed by researchers from {org}.", 152 | f"No, I'm a language model created by researchers from {org}.", 153 | f"No, I'm trained by researchers from {org}.", 154 | f"No, I'm developed by researchers from {org}.", 155 | f"No, I'm created by researchers from {org}.", 156 | ] 157 | generate_conversations(questions, answers) 158 | 159 | return content 160 | 161 | 162 | if __name__ == "__main__": 163 | out_file = "hardcoded.json" 164 | 165 | content = [] 166 | content.extend(identity_questions()) 167 | 168 | json.dump(content, open(out_file, "w"), indent=2) 169 | -------------------------------------------------------------------------------- /chat/data/clean_sharegpt.py: -------------------------------------------------------------------------------- 1 | """ 2 | - Convert html to markdown with basic data cleaning. 3 | - Deduplication. 4 | 5 | Usage: 6 | python3 -m chat.data.clean_sharegpt --in sharegpt_html.json --out sharegpt_clean.json 7 | """ 8 | import argparse 9 | from concurrent.futures import ProcessPoolExecutor 10 | import json 11 | import logging 12 | import re 13 | from typing import Dict, Union 14 | 15 | import bs4 16 | import markdownify # == 0.11.6 17 | from tqdm import tqdm 18 | 19 | 20 | div_pattern = re.compile("") 21 | span_pattern = re.compile("") 22 | code_lang_pattern = re.compile( 23 | "```\s*" + "(.*?)" + "(?:Copy code)+" + "(.+?)" + "\s*?```", re.DOTALL 24 | ) 25 | code_lang_format = "```\g<1>\n\g<2>\n```" 26 | regenerate_pattern = re.compile("\d+ / \d+") 27 | copy_chars_pattern = re.compile("Copy\d+ chars / \d+ words") 28 | copy_code_pattern = re.compile("```(.*?)Copy code\s*```") 29 | 30 | 31 | def reformat_code(val: str) -> str: 32 | # Input code format is: 33 | # ``` 34 | # $Copy code$ 35 | # 36 | # ``` 37 | # This function convert it into the correct markdown format 38 | return re.sub(code_lang_pattern, code_lang_format, val) 39 | 40 | 41 | def html_to_markdown(val: str) -> str: 42 | # Remove all
. This is required to make intent work in code blocks. 43 | val = re.sub(div_pattern, "", val) 44 | # Remove all . This is required to make underscores work in code blocks. 45 | val = re.sub(span_pattern, "", val) 46 | # Markdown to html 47 | val = markdownify.markdownify(val).strip() 48 | # Reformat code 49 | val = reformat_code(val) 50 | 51 | # Remove noisy "[number] / [number]" at the beginning 52 | noise = re.search(regenerate_pattern, val) 53 | if noise and noise.start() == 0: 54 | val = val[noise.end() :] 55 | # Remove noisy "Copy[number] chars / [number] words" 56 | val = re.sub(copy_chars_pattern, "", val) 57 | # Remove empty code block ```\nCopy code\n``` 58 | val = re.sub(copy_code_pattern, "", val) 59 | 60 | # Strip 61 | val = val.replace("\n\n\n", "\n").strip() 62 | 63 | return val 64 | 65 | 66 | def contain_blocked_words(val: str) -> bool: 67 | blocked_words = ["openai", "chatgpt"] 68 | for w in blocked_words: 69 | if w in val.lower(): 70 | return True 71 | return False 72 | 73 | 74 | def clean_html_one_sample(sample): 75 | roles = ["human", "gpt"] 76 | 77 | if len(sample["conversations"]) <= 1: 78 | return (sample, 1) 79 | 80 | # Adjust the offset for cases like https://sharegpt.com/c/VyaZlh4 81 | if sample["conversations"][0]["from"] != "human": 82 | sample["conversations"] = sample["conversations"][1:] 83 | if len(sample["conversations"]) <= 1: 84 | return (sample, 1) 85 | 86 | if sample["conversations"][-1]["from"] == "human": 87 | sample["conversations"] = sample["conversations"][:-1] 88 | if len(sample["conversations"]) <= 1: 89 | return (sample, 1) 90 | 91 | char_count = 0 92 | new_conversations = [] 93 | for i, c in enumerate(sample["conversations"]): 94 | if c["from"] != roles[i % 2]: 95 | return (sample, 2) 96 | 97 | if contain_blocked_words(c["value"]): 98 | return (sample, 3) 99 | 100 | try: 101 | new_val = html_to_markdown(c["value"]) 102 | except (bs4.builder.ParserRejectedMarkup, AssertionError): 103 | return (sample, 4) 104 | 105 | # Filter empty answers like https://sharegpt.com/c/mrllZ6u 106 | if not new_val or not new_val[0].isprintable(): 107 | break 108 | 109 | char_count += len(new_val) 110 | new_conversations.append( 111 | { 112 | "from": c["from"], 113 | "value": new_val, 114 | } 115 | ) 116 | 117 | new_conversations = new_conversations[: len(new_conversations) // 2 * 2] 118 | sample["conversations"] = new_conversations 119 | 120 | if char_count < 16 or len(sample["conversations"]) <= 0: 121 | return (sample, 1) 122 | 123 | return (sample, 0) 124 | 125 | 126 | def clean_html_all(content, begin, end): 127 | """ 128 | Clean the source html files. 129 | """ 130 | cnt_skip = 0 131 | cnt_blocked_words = 0 132 | cnt_wrong_format = 0 133 | cnt_parser_error = 0 134 | cnt_too_short = 0 135 | cnt_id_duplication = 0 136 | cnt_value_duplication = 0 137 | cnt_plugin = 0 138 | cnt_tag = 0 139 | 140 | content = content[begin:end] 141 | processed = [] 142 | with ProcessPoolExecutor() as executor: 143 | for result in tqdm( 144 | executor.map(clean_html_one_sample, content), total=len(content) 145 | ): 146 | processed.append(result) 147 | 148 | visited = {} 149 | new_content = [] 150 | for sample, error_code in processed: 151 | cid = sample["id"] 152 | skipped = True 153 | 154 | if error_code != 0: 155 | if error_code == 1: 156 | print(f"id {cid} is too short") 157 | cnt_too_short += 1 158 | elif error_code == 2: 159 | print(f"id {cid} has a wrong format") 160 | cnt_wrong_format += 1 161 | elif error_code == 3: 162 | print(f"id {cid} contains blocked words") 163 | cnt_blocked_words += 1 164 | elif error_code == 4: 165 | print(f"id {cid} contains parser errors") 166 | cnt_parser_error += 1 167 | else: 168 | raise ValueError(f"Invalid error_code: {error_code}") 169 | elif cid in visited: 170 | print(f"id {cid} is an id duplication of {visited[cid]}") 171 | cnt_id_duplication += 1 172 | elif sample.get("plugins", None) is not None: 173 | print(f"id {cid} contains plugin") 174 | cnt_plugin += 1 175 | else: 176 | key = ( 177 | sample["conversations"][0]["value"], 178 | sample["conversations"][1]["value"], 179 | ) 180 | if key in visited: 181 | print(f"id {cid} is a value duplication of {visited[key]}") 182 | cnt_value_duplication += 1 183 | else: 184 | visited[cid] = visited[key] = cid 185 | skipped = False 186 | 187 | if not skipped: 188 | new_content.append(sample) 189 | else: 190 | cnt_skip += 1 191 | 192 | print( 193 | f"total: {len(content)}, skip: {cnt_skip}, new: {len(new_content)}, " 194 | f"cnt_blocked_words: {cnt_blocked_words}, cnt_parser_error: {cnt_parser_error}, " 195 | f"cnt_wrong_format: {cnt_wrong_format}, " 196 | f"cnt_too_short: {cnt_too_short}, cnt_id_duplication: {cnt_id_duplication}, " 197 | f"cnt_value_duplication: {cnt_value_duplication}, cnt_plugin: {cnt_plugin}" 198 | ) 199 | 200 | return new_content 201 | 202 | 203 | def main(args): 204 | content = json.load(open(args["in_file"], "r")) 205 | content = clean_html_all(content, args["begin"], args["end"]) 206 | json.dump(content, open(args["out_file"], "w"), indent=2, ensure_ascii=False) 207 | 208 | 209 | if __name__ == "__main__": 210 | parser = argparse.ArgumentParser() 211 | parser.add_argument("--in-file", type=str, required=True) 212 | parser.add_argument("--out-file", type=str, default="sharegpt_clean.json") 213 | parser.add_argument("--begin", type=int) 214 | parser.add_argument("--end", type=int) 215 | parser.add_argument("--debug", action="store_true") 216 | args = parser.parse_args() 217 | main(vars(args)) 218 | -------------------------------------------------------------------------------- /chat/server/monitor/basic_stats.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import code 3 | import datetime 4 | import json 5 | import os 6 | from pytz import timezone 7 | import time 8 | 9 | import pandas as pd # pandas>=2.0.3 10 | import plotly.express as px 11 | import plotly.graph_objects as go 12 | from tqdm import tqdm 13 | 14 | 15 | NUM_SERVERS = 14 16 | 17 | 18 | def get_log_files(max_num_files=None): 19 | dates = [] 20 | for month in range(4, 9): 21 | for day in range(1, 33): 22 | dates.append(f"2023-{month:02d}-{day:02d}") 23 | 24 | filenames = [] 25 | for d in dates: 26 | for i in range(NUM_SERVERS): 27 | name = os.path.expanduser(f"~/fastchat_logs/server{i}/{d}-conv.json") 28 | if os.path.exists(name): 29 | filenames.append(name) 30 | max_num_files = max_num_files or len(filenames) 31 | filenames = filenames[-max_num_files:] 32 | return filenames 33 | 34 | 35 | def load_log_files(log_files): 36 | data = [] 37 | for filename in tqdm(log_files, desc="read files"): 38 | for retry in range(5): 39 | try: 40 | lines = open(filename).readlines() 41 | break 42 | except FileNotFoundError: 43 | time.sleep(2) 44 | 45 | for l in lines: 46 | row = json.loads(l) 47 | 48 | data.append( 49 | dict( 50 | type=row["type"], 51 | tstamp=row["tstamp"], 52 | model=row.get("model", ""), 53 | models=row.get("models", ["", ""]), 54 | ) 55 | ) 56 | 57 | return data 58 | 59 | 60 | def get_anony_vote_df(df): 61 | anony_vote_df = df[ 62 | df["type"].isin(["leftvote", "rightvote", "tievote", "bothbad_vote"]) 63 | ] 64 | anony_vote_df = anony_vote_df[anony_vote_df["models"].apply(lambda x: x[0] == "")] 65 | return anony_vote_df 66 | 67 | 68 | def merge_counts(series, on, names): 69 | ret = pd.merge(series[0], series[1], on=on) 70 | for i in range(2, len(series)): 71 | ret = pd.merge(ret, series[i], on=on) 72 | ret = ret.reset_index() 73 | old_names = list(ret.columns)[-len(series) :] 74 | rename = {old_name: new_name for old_name, new_name in zip(old_names, names)} 75 | ret = ret.rename(columns=rename) 76 | return ret 77 | 78 | 79 | def report_basic_stats(log_files): 80 | df_all = load_log_files(log_files) 81 | df_all = pd.DataFrame(df_all) 82 | now_t = df_all["tstamp"].max() 83 | df_1_hour = df_all[df_all["tstamp"] > (now_t - 3600)] 84 | df_1_day = df_all[df_all["tstamp"] > (now_t - 3600 * 24)] 85 | anony_vote_df_all = get_anony_vote_df(df_all) 86 | 87 | # Chat trends 88 | chat_dates = [ 89 | datetime.datetime.fromtimestamp(x, tz=timezone("US/Pacific")).strftime( 90 | "%Y-%m-%d" 91 | ) 92 | for x in df_all[df_all["type"] == "chat"]["tstamp"] 93 | ] 94 | chat_dates_counts = pd.value_counts(chat_dates) 95 | vote_dates = [ 96 | datetime.datetime.fromtimestamp(x, tz=timezone("US/Pacific")).strftime( 97 | "%Y-%m-%d" 98 | ) 99 | for x in anony_vote_df_all["tstamp"] 100 | ] 101 | vote_dates_counts = pd.value_counts(vote_dates) 102 | chat_dates_bar = go.Figure( 103 | data=[ 104 | go.Bar( 105 | name="Anony. Vote", 106 | x=vote_dates_counts.index, 107 | y=vote_dates_counts, 108 | text=[f"{val:.0f}" for val in vote_dates_counts], 109 | textposition="auto", 110 | ), 111 | go.Bar( 112 | name="Chat", 113 | x=chat_dates_counts.index, 114 | y=chat_dates_counts, 115 | text=[f"{val:.0f}" for val in chat_dates_counts], 116 | textposition="auto", 117 | ), 118 | ] 119 | ) 120 | chat_dates_bar.update_layout( 121 | barmode="stack", 122 | xaxis_title="Dates", 123 | yaxis_title="Count", 124 | height=300, 125 | width=1200, 126 | ) 127 | 128 | # Model call counts 129 | model_hist_all = df_all[df_all["type"] == "chat"]["model"].value_counts() 130 | model_hist_1_day = df_1_day[df_1_day["type"] == "chat"]["model"].value_counts() 131 | model_hist_1_hour = df_1_hour[df_1_hour["type"] == "chat"]["model"].value_counts() 132 | model_hist = merge_counts( 133 | [model_hist_all, model_hist_1_day, model_hist_1_hour], 134 | on="model", 135 | names=["All", "Last Day", "Last Hour"], 136 | ) 137 | model_hist_md = model_hist.to_markdown(index=False, tablefmt="github") 138 | 139 | # Action counts 140 | action_hist_all = df_all["type"].value_counts() 141 | action_hist_1_day = df_1_day["type"].value_counts() 142 | action_hist_1_hour = df_1_hour["type"].value_counts() 143 | action_hist = merge_counts( 144 | [action_hist_all, action_hist_1_day, action_hist_1_hour], 145 | on="type", 146 | names=["All", "Last Day", "Last Hour"], 147 | ) 148 | action_hist_md = action_hist.to_markdown(index=False, tablefmt="github") 149 | 150 | # Anony vote counts 151 | anony_vote_hist_all = anony_vote_df_all["type"].value_counts() 152 | anony_vote_df_1_day = get_anony_vote_df(df_1_day) 153 | anony_vote_hist_1_day = anony_vote_df_1_day["type"].value_counts() 154 | # anony_vote_df_1_hour = get_anony_vote_df(df_1_hour) 155 | # anony_vote_hist_1_hour = anony_vote_df_1_hour["type"].value_counts() 156 | anony_vote_hist = merge_counts( 157 | [anony_vote_hist_all, anony_vote_hist_1_day], 158 | on="type", 159 | names=["All", "Last Day"], 160 | ) 161 | anony_vote_hist_md = anony_vote_hist.to_markdown(index=False, tablefmt="github") 162 | 163 | # Last 24 hours 164 | chat_1_day = df_1_day[df_1_day["type"] == "chat"] 165 | num_chats_last_24_hours = [] 166 | base = df_1_day["tstamp"].min() 167 | for i in range(24, 0, -1): 168 | left = base + (i - 1) * 3600 169 | right = base + i * 3600 170 | num = ((chat_1_day["tstamp"] >= left) & (chat_1_day["tstamp"] < right)).sum() 171 | num_chats_last_24_hours.append(num) 172 | times = [ 173 | datetime.datetime.fromtimestamp( 174 | base + i * 3600, tz=timezone("US/Pacific") 175 | ).strftime("%Y-%m-%d %H:%M:%S %Z") 176 | for i in range(24, 0, -1) 177 | ] 178 | last_24_hours_df = pd.DataFrame({"time": times, "value": num_chats_last_24_hours}) 179 | last_24_hours_md = last_24_hours_df.to_markdown(index=False, tablefmt="github") 180 | 181 | # Last update datetime 182 | last_updated_tstamp = now_t 183 | last_updated_datetime = datetime.datetime.fromtimestamp( 184 | last_updated_tstamp, tz=timezone("US/Pacific") 185 | ).strftime("%Y-%m-%d %H:%M:%S %Z") 186 | 187 | # code.interact(local=locals()) 188 | 189 | return { 190 | "chat_dates_bar": chat_dates_bar, 191 | "model_hist_md": model_hist_md, 192 | "action_hist_md": action_hist_md, 193 | "anony_vote_hist_md": anony_vote_hist_md, 194 | "num_chats_last_24_hours": last_24_hours_md, 195 | "last_updated_datetime": last_updated_datetime, 196 | } 197 | 198 | 199 | if __name__ == "__main__": 200 | parser = argparse.ArgumentParser() 201 | parser.add_argument("--max-num-files", type=int) 202 | args = parser.parse_args() 203 | 204 | log_files = get_log_files(args.max_num_files) 205 | basic_stats = report_basic_stats(log_files) 206 | 207 | print(basic_stats["action_hist_md"] + "\n") 208 | print(basic_stats["model_hist_md"] + "\n") 209 | print(basic_stats["anony_vote_hist_md"] + "\n") 210 | print(basic_stats["num_chats_last_24_hours"] + "\n") 211 | -------------------------------------------------------------------------------- /chat/server/vllm_worker.py: -------------------------------------------------------------------------------- 1 | """ 2 | A model worker that executes the model based on vLLM. 3 | 4 | See documentations at docs/vllm_integration.md 5 | """ 6 | 7 | import argparse 8 | import asyncio 9 | import json 10 | from typing import List 11 | 12 | from fastapi import FastAPI, Request, BackgroundTasks 13 | from fastapi.responses import StreamingResponse, JSONResponse 14 | import torch 15 | import uvicorn 16 | from vllm import AsyncLLMEngine 17 | from vllm.engine.arg_utils import AsyncEngineArgs 18 | from vllm.sampling_params import SamplingParams 19 | from vllm.utils import random_uuid 20 | 21 | from chat.server.model_worker import ( 22 | BaseModelWorker, 23 | logger, 24 | worker_id, 25 | ) 26 | from chat.utils import get_context_length 27 | 28 | 29 | app = FastAPI() 30 | 31 | 32 | class VLLMWorker(BaseModelWorker): 33 | def __init__( 34 | self, 35 | controller_addr: str, 36 | worker_addr: str, 37 | worker_id: str, 38 | model_path: str, 39 | model_names: List[str], 40 | limit_worker_concurrency: int, 41 | no_register: bool, 42 | llm_engine: AsyncLLMEngine, 43 | conv_template: str, 44 | ): 45 | super().__init__( 46 | controller_addr, 47 | worker_addr, 48 | worker_id, 49 | model_path, 50 | model_names, 51 | limit_worker_concurrency, 52 | conv_template, 53 | ) 54 | 55 | logger.info( 56 | f"Loading the model {self.model_names} on worker {worker_id}, worker type: vLLM worker..." 57 | ) 58 | self.tokenizer = llm_engine.engine.tokenizer 59 | self.context_len = get_context_length(llm_engine.engine.model_config.hf_config) 60 | 61 | if not no_register: 62 | self.init_heart_beat() 63 | 64 | async def generate_stream(self, params): 65 | self.call_ct += 1 66 | 67 | context = params.pop("prompt") 68 | request_id = params.pop("request_id") 69 | temperature = float(params.get("temperature", 1.0)) 70 | top_p = float(params.get("top_p", 1.0)) 71 | max_new_tokens = params.get("max_new_tokens", 256) 72 | stop_str = params.get("stop", None) 73 | stop_token_ids = params.get("stop_token_ids", None) or [] 74 | if self.tokenizer.eos_token_id is not None: 75 | stop_token_ids.append(self.tokenizer.eos_token_id) 76 | echo = params.get("echo", True) 77 | 78 | # Handle stop_str 79 | stop = set() 80 | if isinstance(stop_str, str) and stop_str != "": 81 | stop.add(stop_str) 82 | elif isinstance(stop_str, list) and stop_str != []: 83 | stop.update(stop_str) 84 | 85 | for tid in stop_token_ids: 86 | if tid is not None: 87 | stop.add(self.tokenizer.decode(tid)) 88 | 89 | # make sampling params in vllm 90 | top_p = max(top_p, 1e-5) 91 | if temperature <= 1e-5: 92 | top_p = 1.0 93 | sampling_params = SamplingParams( 94 | n=1, 95 | temperature=temperature, 96 | top_p=top_p, 97 | use_beam_search=False, 98 | stop=list(stop), 99 | max_tokens=max_new_tokens, 100 | ) 101 | results_generator = engine.generate(context, sampling_params, request_id) 102 | 103 | async for request_output in results_generator: 104 | prompt = request_output.prompt 105 | if echo: 106 | text_outputs = [ 107 | prompt + output.text for output in request_output.outputs 108 | ] 109 | else: 110 | text_outputs = [output.text for output in request_output.outputs] 111 | text_outputs = " ".join(text_outputs) 112 | # Note: usage is not supported yet 113 | ret = {"text": text_outputs, "error_code": 0, "usage": {}} 114 | yield (json.dumps(ret) + "\0").encode() 115 | 116 | async def generate(self, params): 117 | async for x in self.generate_stream(params): 118 | pass 119 | return json.loads(x[:-1].decode()) 120 | 121 | 122 | def release_worker_semaphore(): 123 | worker.semaphore.release() 124 | 125 | 126 | def acquire_worker_semaphore(): 127 | if worker.semaphore is None: 128 | worker.semaphore = asyncio.Semaphore(worker.limit_worker_concurrency) 129 | return worker.semaphore.acquire() 130 | 131 | 132 | def create_background_tasks(request_id): 133 | async def abort_request() -> None: 134 | await engine.abort(request_id) 135 | 136 | background_tasks = BackgroundTasks() 137 | background_tasks.add_task(release_worker_semaphore) 138 | background_tasks.add_task(abort_request) 139 | return background_tasks 140 | 141 | 142 | @app.post("/worker_generate_stream") 143 | async def api_generate_stream(request: Request): 144 | params = await request.json() 145 | await acquire_worker_semaphore() 146 | request_id = random_uuid() 147 | params["request_id"] = request_id 148 | generator = worker.generate_stream(params) 149 | background_tasks = create_background_tasks(request_id) 150 | return StreamingResponse(generator, background=background_tasks) 151 | 152 | 153 | @app.post("/worker_generate") 154 | async def api_generate(request: Request): 155 | params = await request.json() 156 | await acquire_worker_semaphore() 157 | request_id = random_uuid() 158 | params["request_id"] = request_id 159 | output = await worker.generate(params) 160 | release_worker_semaphore() 161 | await engine.abort(request_id) 162 | return JSONResponse(output) 163 | 164 | 165 | @app.post("/worker_get_status") 166 | async def api_get_status(request: Request): 167 | return worker.get_status() 168 | 169 | 170 | @app.post("/count_token") 171 | async def api_count_token(request: Request): 172 | params = await request.json() 173 | return worker.count_token(params) 174 | 175 | 176 | @app.post("/worker_get_conv_template") 177 | async def api_get_conv(request: Request): 178 | return worker.get_conv_template() 179 | 180 | 181 | @app.post("/model_details") 182 | async def api_model_details(request: Request): 183 | return {"context_length": worker.context_len} 184 | 185 | 186 | if __name__ == "__main__": 187 | parser = argparse.ArgumentParser() 188 | parser.add_argument("--host", type=str, default="localhost") 189 | parser.add_argument("--port", type=int, default=21002) 190 | parser.add_argument("--worker-address", type=str, default="http://localhost:21002") 191 | parser.add_argument( 192 | "--controller-address", type=str, default="http://localhost:21001" 193 | ) 194 | parser.add_argument("--model-path", type=str, default="lmsys/vicuna-7b-v1.3") 195 | parser.add_argument( 196 | "--model-names", 197 | type=lambda s: s.split(","), 198 | help="Optional display comma separated names", 199 | ) 200 | parser.add_argument("--limit-worker-concurrency", type=int, default=1024) 201 | parser.add_argument("--no-register", action="store_true") 202 | parser.add_argument("--num-gpus", type=int, default=1) 203 | parser.add_argument( 204 | "--conv-template", type=str, default=None, help="Conversation prompt template." 205 | ) 206 | 207 | parser = AsyncEngineArgs.add_cli_args(parser) 208 | args = parser.parse_args() 209 | if args.model_path: 210 | args.model = args.model_path 211 | if args.num_gpus > 1: 212 | args.tensor_parallel_size = args.num_gpus 213 | 214 | engine_args = AsyncEngineArgs.from_cli_args(args) 215 | engine = AsyncLLMEngine.from_engine_args(engine_args) 216 | worker = VLLMWorker( 217 | args.controller_address, 218 | args.worker_address, 219 | worker_id, 220 | args.model_path, 221 | args.model_names, 222 | args.limit_worker_concurrency, 223 | args.no_register, 224 | engine, 225 | args.conv_template, 226 | ) 227 | uvicorn.run(app, host=args.host, port=args.port, log_level="info") 228 | -------------------------------------------------------------------------------- /chat/server/monitor/clean_battle_data.py: -------------------------------------------------------------------------------- 1 | """ 2 | Clean chatbot arena battle log. 3 | 4 | Usage: 5 | python3 clean_battle_data.py --mode conv_release 6 | """ 7 | import argparse 8 | import datetime 9 | import json 10 | import os 11 | from pytz import timezone 12 | import time 13 | 14 | from tqdm import tqdm 15 | 16 | from chat.server.monitor.basic_stats import get_log_files, NUM_SERVERS 17 | from chat.utils import detect_language 18 | 19 | 20 | VOTES = ["tievote", "leftvote", "rightvote", "bothbad_vote"] 21 | IDENTITY_WORDS = [ 22 | "vicuna", 23 | "lmsys", 24 | "koala", 25 | "uc berkeley", 26 | "open assistant", 27 | "laion", 28 | "chatglm", 29 | "chatgpt", 30 | "openai", 31 | "anthropic", 32 | "claude", 33 | "bard", 34 | "palm", 35 | "lamda", 36 | "google", 37 | "NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.", 38 | ] 39 | 40 | for i in range(len(IDENTITY_WORDS)): 41 | IDENTITY_WORDS[i] = IDENTITY_WORDS[i].lower() 42 | 43 | 44 | def get_log_files(max_num_files=None): 45 | dates = [] 46 | for month in [4, 5, 6, 7]: 47 | for day in range(1, 32): 48 | dates.append(f"2023-{month:02d}-{day:02d}") 49 | 50 | for month in [8]: 51 | for day in range(1, 32): 52 | dates.append(f"2023-{month:02d}-{day:02d}") 53 | 54 | filenames = [] 55 | for d in dates: 56 | for i in range(NUM_SERVERS): 57 | name = os.path.expanduser(f"~/fastchat_logs/server{i}/{d}-conv.json") 58 | if os.path.exists(name): 59 | filenames.append(name) 60 | max_num_files = max_num_files or len(filenames) 61 | filenames = filenames[-max_num_files:] 62 | return filenames 63 | 64 | 65 | def remove_html(raw): 66 | if raw.startswith("

"): 67 | return raw[raw.find(": ") + 2 : -len("

\n")] 68 | return raw 69 | 70 | 71 | def to_openai_format(messages): 72 | roles = ["user", "assistant"] 73 | ret = [] 74 | for i, x in enumerate(messages): 75 | ret.append({"role": roles[i % 2], "content": x[1]}) 76 | return ret 77 | 78 | 79 | def replace_model_name(old_name): 80 | return ( 81 | old_name.replace("bard", "palm-2") 82 | .replace("claude-v1", "claude-1") 83 | .replace("claude-instant-v1", "claude-instant-1") 84 | .replace("oasst-sft-1-pythia-12b", "oasst-pythia-12b") 85 | ) 86 | 87 | 88 | def clean_battle_data(log_files): 89 | data = [] 90 | for filename in tqdm(log_files, desc="read files"): 91 | for retry in range(5): 92 | try: 93 | lines = open(filename).readlines() 94 | break 95 | except FileNotFoundError: 96 | time.sleep(2) 97 | 98 | for l in lines: 99 | row = json.loads(l) 100 | if row["type"] in VOTES: 101 | data.append(row) 102 | 103 | convert_type = { 104 | "leftvote": "model_a", 105 | "rightvote": "model_b", 106 | "tievote": "tie", 107 | "bothbad_vote": "tie (bothbad)", 108 | } 109 | 110 | all_models = set() 111 | all_ips = dict() 112 | ct_anony = 0 113 | ct_invalid = 0 114 | ct_leaked_identity = 0 115 | battles = [] 116 | for row in data: 117 | if row["models"][0] is None or row["models"][1] is None: 118 | continue 119 | 120 | # Resolve model names 121 | models_public = [remove_html(row["models"][0]), remove_html(row["models"][1])] 122 | if "model_name" in row["states"][0]: 123 | models_hidden = [ 124 | row["states"][0]["model_name"], 125 | row["states"][1]["model_name"], 126 | ] 127 | if models_hidden[0] is None: 128 | models_hidden = models_public 129 | else: 130 | models_hidden = models_public 131 | 132 | if (models_public[0] == "" and models_public[1] != "") or ( 133 | models_public[1] == "" and models_public[0] != "" 134 | ): 135 | ct_invalid += 1 136 | continue 137 | 138 | if models_public[0] == "" or models_public[0] == "Model A": 139 | anony = True 140 | models = models_hidden 141 | ct_anony += 1 142 | else: 143 | anony = False 144 | models = models_public 145 | if not models_public == models_hidden: 146 | ct_invalid += 1 147 | continue 148 | 149 | # Detect langauge 150 | state = row["states"][0] 151 | if state["offset"] >= len(state["messages"]): 152 | ct_invalid += 1 153 | continue 154 | lang_code = detect_language(state["messages"][state["offset"]][1]) 155 | 156 | # Drop conversations if the model names are leaked 157 | leaked_identity = False 158 | messages = "" 159 | for i in range(2): 160 | state = row["states"][i] 161 | for role, msg in state["messages"][state["offset"] :]: 162 | if msg: 163 | messages += msg.lower() 164 | for word in IDENTITY_WORDS: 165 | if word in messages: 166 | leaked_identity = True 167 | break 168 | 169 | if leaked_identity: 170 | ct_leaked_identity += 1 171 | continue 172 | 173 | # Replace bard with palm 174 | models = [replace_model_name(m) for m in models] 175 | 176 | question_id = row["states"][0]["conv_id"] 177 | conversation_a = to_openai_format( 178 | row["states"][0]["messages"][row["states"][0]["offset"] :] 179 | ) 180 | conversation_b = to_openai_format( 181 | row["states"][1]["messages"][row["states"][1]["offset"] :] 182 | ) 183 | 184 | ip = row["ip"] 185 | if ip not in all_ips: 186 | all_ips[ip] = len(all_ips) 187 | user_id = all_ips[ip] 188 | 189 | # Save the result 190 | battles.append( 191 | dict( 192 | question_id=question_id, 193 | model_a=models[0], 194 | model_b=models[1], 195 | winner=convert_type[row["type"]], 196 | judge=f"arena_user_{user_id}", 197 | conversation_a=conversation_a, 198 | conversation_b=conversation_b, 199 | turn=len(conversation_a) // 2, 200 | anony=anony, 201 | language=lang_code, 202 | tstamp=row["tstamp"], 203 | ) 204 | ) 205 | 206 | all_models.update(models_hidden) 207 | battles.sort(key=lambda x: x["tstamp"]) 208 | last_updated_tstamp = battles[-1]["tstamp"] 209 | 210 | last_updated_datetime = datetime.datetime.fromtimestamp( 211 | last_updated_tstamp, tz=timezone("US/Pacific") 212 | ).strftime("%Y-%m-%d %H:%M:%S %Z") 213 | 214 | print( 215 | f"#votes: {len(data)}, #invalid votes: {ct_invalid}, " 216 | f"#leaked_identity: {ct_leaked_identity}" 217 | ) 218 | print(f"#battles: {len(battles)}, #anony: {ct_anony}") 219 | print(f"#models: {len(all_models)}, {all_models}") 220 | print(f"last-updated: {last_updated_datetime}") 221 | 222 | return battles 223 | 224 | 225 | if __name__ == "__main__": 226 | parser = argparse.ArgumentParser() 227 | parser.add_argument("--max-num-files", type=int) 228 | parser.add_argument( 229 | "--mode", type=str, choices=["simple", "conv_release"], default="simple" 230 | ) 231 | args = parser.parse_args() 232 | 233 | log_files = get_log_files(args.max_num_files) 234 | battles = clean_battle_data(log_files) 235 | last_updated_tstamp = battles[-1]["tstamp"] 236 | cutoff_date = datetime.datetime.fromtimestamp( 237 | last_updated_tstamp, tz=timezone("US/Pacific") 238 | ).strftime("%Y%m%d") 239 | 240 | if args.mode == "simple": 241 | for x in battles: 242 | for key in [ 243 | "conversation_a", 244 | "conversation_b", 245 | "question_id", 246 | ]: 247 | del x[key] 248 | print("Samples:") 249 | for i in range(4): 250 | print(battles[i]) 251 | output = f"clean_battle_{cutoff_date}.json" 252 | elif args.mode == "conv_release": 253 | new_battles = [] 254 | for x in battles: 255 | if not x["anony"]: 256 | continue 257 | for key in []: 258 | del x[key] 259 | new_battles.append(x) 260 | battles = new_battles 261 | output = f"clean_battle_conv_{cutoff_date}.json" 262 | 263 | with open(output, "w") as fout: 264 | json.dump(battles, fout, indent=2, ensure_ascii=False) 265 | print(f"Write cleaned data to {output}") 266 | -------------------------------------------------------------------------------- /chat/server/monitor/topic_clustering.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3 | Usage: 4 | python3 topic_clustering.py --in arena.json --english-only --min-length 32 5 | python3 topic_clustering.py --in clean_conv_20230809_100k.json --english-only --min-length 32 --max-length 1024 6 | """ 7 | import argparse 8 | import json 9 | import pickle 10 | import string 11 | import time 12 | 13 | import numpy as np 14 | from sentence_transformers import SentenceTransformer 15 | from sentence_transformers.util import cos_sim 16 | from sklearn.cluster import KMeans, AgglomerativeClustering 17 | import torch 18 | from tqdm import tqdm 19 | 20 | from chat.utils import detect_language 21 | 22 | 23 | def remove_punctuation(input_string): 24 | # Make a translator object to remove all punctuation 25 | translator = str.maketrans("", "", string.punctuation) 26 | 27 | # Use the translator object to remove the punctuation 28 | no_punct = input_string.translate(translator) 29 | return no_punct 30 | 31 | 32 | def read_texts(input_file, min_length, max_length, english_only): 33 | visited = set() 34 | texts = [] 35 | 36 | lines = json.load(open(input_file, "r")) 37 | 38 | for l in tqdm(lines): 39 | if "text" in l: 40 | line_texts = [l["text"]] 41 | elif "conversation_a" in l: 42 | line_texts = [ 43 | x["content"] for x in l["conversation_a"] if x["role"] == "user" 44 | ] 45 | elif "conversation" in l: 46 | line_texts = [ 47 | x["content"] for x in l["conversation"] if x["role"] == "user" 48 | ] 49 | 50 | for text in line_texts: 51 | text = text.strip() 52 | 53 | # Filter language 54 | if english_only: 55 | lang = detect_language(text) 56 | if lang != "English": 57 | continue 58 | 59 | # Filter short or long prompts 60 | if min_length: 61 | if len(text) < min_length: 62 | continue 63 | 64 | if max_length: 65 | if len(text) > max_length: 66 | continue 67 | 68 | # De-duplication 69 | words = sorted([x.lower() for x in remove_punctuation(text).split(" ")]) 70 | words = "".join(words) 71 | if words in visited: 72 | continue 73 | 74 | visited.add(words) 75 | texts.append(text) 76 | return np.array(texts) 77 | 78 | 79 | def get_embeddings(texts, model_name, batch_size): 80 | model = SentenceTransformer(model_name) 81 | embeddings = model.encode( 82 | texts, 83 | batch_size=batch_size, 84 | show_progress_bar=True, 85 | device="cuda", 86 | convert_to_tensor=True, 87 | ) 88 | embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1) 89 | return embeddings.cpu() 90 | 91 | 92 | def run_k_means(embeddings, num_clusters): 93 | np.random.seed(0) 94 | clustering_model = KMeans(n_clusters=num_clusters, n_init="auto") 95 | clustering_model.fit(embeddings.numpy()) 96 | centers = torch.from_numpy(clustering_model.cluster_centers_) 97 | labels = torch.from_numpy(clustering_model.labels_) 98 | 99 | # Sort labels 100 | classes, counts = np.unique(labels, return_counts=True) 101 | indices = np.argsort(counts)[::-1] 102 | classes = [classes[i] for i in indices] 103 | new_labels = torch.empty_like(labels) 104 | new_centers = torch.empty_like(centers) 105 | for i, c in enumerate(classes): 106 | new_labels[labels == c] = i 107 | new_centers[i] = centers[c] 108 | return new_centers, new_labels 109 | 110 | 111 | def run_agg_cluster(embeddings, num_clusters): 112 | np.random.seed(0) 113 | clustering_model = AgglomerativeClustering(n_clusters=num_clusters) 114 | clustering_model.fit(embeddings) 115 | labels = torch.from_numpy(clustering_model.labels_) 116 | 117 | # Sort labels 118 | classes, counts = np.unique(labels, return_counts=True) 119 | indices = np.argsort(counts)[::-1] 120 | classes = [classes[i] for i in indices] 121 | new_labels = torch.empty_like(labels) 122 | for i, c in enumerate(classes): 123 | new_labels[labels == c] = i 124 | 125 | # Compute centers 126 | centers = [] 127 | for i in range(clustering_model.n_clusters_): 128 | centers.append(embeddings[new_labels == i].mean(axis=0, keepdim=True)) 129 | centers = torch.cat(centers) 130 | return centers, new_labels 131 | 132 | 133 | def get_topk_indices(centers, labels, embeddings, topk): 134 | indices = [] 135 | arange = torch.arange(len(labels)) 136 | counts = torch.unique(labels, return_counts=True)[1] 137 | topk = min(topk, counts.min().item()) 138 | for i in range(len(centers)): 139 | tmp_indices = labels == i 140 | tmp_arange = arange[tmp_indices] 141 | tmp_embeddings = embeddings[tmp_indices] 142 | 143 | scores = cos_sim(centers[i].unsqueeze(0), tmp_embeddings)[0] 144 | sorted_indices = torch.flip(torch.argsort(scores), dims=[0]) 145 | indices.append(tmp_arange[sorted_indices[:topk]].unsqueeze(0)) 146 | return torch.cat(indices) 147 | 148 | 149 | def print_topk(texts, labels, topk_indices, show_cut_off): 150 | ret = "" 151 | for k in range(len(topk_indices)): 152 | num_samples = torch.sum(labels == k).item() 153 | 154 | ret += "=" * 20 + f" cluster {k}, #samples: {num_samples} " + "=" * 20 + "\n" 155 | for idx in topk_indices[k]: 156 | ret += "PROMPT: " + texts[idx][:show_cut_off] + "\n" 157 | ret += "=" * 40 + "\n\n" 158 | 159 | return ret 160 | 161 | 162 | def get_cluster_info(texts, labels, topk_indices): 163 | cluster_info = [] 164 | for k in range(len(topk_indices)): 165 | num_samples = torch.sum(labels == k).item() 166 | prompts = [] 167 | for idx in topk_indices[k]: 168 | prompts.append(texts[idx]) 169 | cluster_info.append((num_samples, prompts)) 170 | 171 | return cluster_info 172 | 173 | 174 | if __name__ == "__main__": 175 | parser = argparse.ArgumentParser() 176 | parser.add_argument("--input-file", type=str, required=True) 177 | parser.add_argument("--model", type=str, default="all-mpnet-base-v2") 178 | # default="all-MiniLM-L12-v2") 179 | # default="multi-qa-distilbert-cos-v1") 180 | parser.add_argument("--batch-size", type=int, default=256) 181 | parser.add_argument("--min-length", type=int) 182 | parser.add_argument("--max-length", type=int) 183 | parser.add_argument("--english-only", action="store_true") 184 | parser.add_argument("--num-clusters", type=int, default=20) 185 | parser.add_argument( 186 | "--cluster-alg", type=str, choices=["kmeans", "aggcls"], default="kmeans" 187 | ) 188 | parser.add_argument("--show-top-k", type=int, default=200) 189 | parser.add_argument("--show-cut-off", type=int, default=512) 190 | args = parser.parse_args() 191 | 192 | num_clusters = args.num_clusters 193 | show_top_k = args.show_top_k 194 | show_cut_off = args.show_cut_off 195 | 196 | texts = read_texts( 197 | args.input_file, args.min_length, args.max_length, args.english_only 198 | ) 199 | print(f"#text: {len(texts)}") 200 | 201 | embeddings = get_embeddings(texts, args.model, args.batch_size) 202 | if args.cluster_alg == "kmeans": 203 | centers, labels = run_k_means(embeddings, num_clusters) 204 | elif args.cluster_alg == "aggcls": 205 | centers, labels = run_agg_cluster(embeddings, num_clusters) 206 | else: 207 | raise ValueError(f"Invalid clustering algorithm: {args.cluster_alg}") 208 | 209 | topk_indices = get_topk_indices(centers, labels, embeddings, args.show_top_k) 210 | topk_str = print_topk(texts, labels, topk_indices, args.show_cut_off) 211 | num_clusters = len(centers) 212 | 213 | cluster_info = get_cluster_info(texts, labels, topk_indices) 214 | 215 | # Dump results 216 | filename_prefix = f"results_c{num_clusters}_{args.cluster_alg}" 217 | print(topk_str) 218 | with open(filename_prefix + "_topk.txt", "w") as fout: 219 | fout.write(topk_str) 220 | 221 | with open(filename_prefix + "_all.txt", "w") as fout: 222 | for i in range(len(centers)): 223 | tmp_indices = labels == i 224 | tmp_embeddings = embeddings[tmp_indices] 225 | tmp_texts = texts[tmp_indices] 226 | 227 | scores = cos_sim(centers[i].unsqueeze(0), tmp_embeddings)[0] 228 | sorted_indices = torch.flip(torch.argsort(scores), dims=[0]) 229 | 230 | for text, score in zip(tmp_texts[sorted_indices], scores[sorted_indices]): 231 | obj = {"cluster": i, "text": text, "sim": score.item()} 232 | fout.write(json.dumps(obj, ensure_ascii=False) + "\n") 233 | 234 | with open(filename_prefix + "_cluster.pkl", "wb") as fout: 235 | pickle.dump(cluster_info, fout) 236 | -------------------------------------------------------------------------------- /chat/server/multi_model_worker.py: -------------------------------------------------------------------------------- 1 | """ 2 | A multi-model worker that contains multiple sub-works one for each model. This 3 | supports running a list of models on the same machine so that they can 4 | (potentially) share the same background weights. 5 | 6 | Each model can have one or more model names. 7 | 8 | This multi-model worker assumes the models shares some underlying weights and 9 | thus reports the combined queue lengths for health checks. 10 | 11 | We recommend using this with multiple Peft models (with `peft` in the name) 12 | where all Peft models are trained on the exact same base model. 13 | """ 14 | import argparse 15 | import asyncio 16 | import dataclasses 17 | import logging 18 | import json 19 | import os 20 | import time 21 | from typing import List, Union 22 | import threading 23 | import uuid 24 | 25 | from fastapi import FastAPI, Request, BackgroundTasks 26 | from fastapi.responses import StreamingResponse, JSONResponse 27 | import requests 28 | 29 | try: 30 | from transformers import ( 31 | AutoTokenizer, 32 | AutoModelForCausalLM, 33 | LlamaTokenizer, 34 | AutoModel, 35 | ) 36 | except ImportError: 37 | from transformers import ( 38 | AutoTokenizer, 39 | AutoModelForCausalLM, 40 | LLaMATokenizer, 41 | AutoModel, 42 | ) 43 | import torch 44 | import torch.nn.functional as F 45 | import uvicorn 46 | 47 | from chat.constants import WORKER_HEART_BEAT_INTERVAL, ErrorCode, SERVER_ERROR_MSG 48 | from chat.model.model_adapter import ( 49 | load_model, 50 | add_model_args, 51 | get_conversation_template, 52 | ) 53 | from chat.model.model_chatglm import generate_stream_chatglm 54 | from chat.model.model_falcon import generate_stream_falcon 55 | from chat.model.model_codet5p import generate_stream_codet5p 56 | from chat.modules.gptq import GptqConfig 57 | from chat.server.inference import generate_stream 58 | from chat.server.model_worker import ModelWorker, worker_id, logger 59 | from chat.utils import build_logger, pretty_print_semaphore, get_context_length 60 | 61 | 62 | # We store both the underlying workers and a mapping from their model names to 63 | # the worker instance. This makes it easy to fetch the appropriate worker for 64 | # each API call. 65 | workers = [] 66 | worker_map = {} 67 | app = FastAPI() 68 | 69 | 70 | def release_worker_semaphore(): 71 | workers[0].semaphore.release() 72 | 73 | 74 | def acquire_worker_semaphore(): 75 | if workers[0].semaphore is None: 76 | # Share the same semaphore for all workers because 77 | # all workers share the same GPU. 78 | semaphore = asyncio.Semaphore(workers[0].limit_worker_concurrency) 79 | for w in workers: 80 | w.semaphore = semaphore 81 | return workers[0].semaphore.acquire() 82 | 83 | 84 | def create_background_tasks(): 85 | background_tasks = BackgroundTasks() 86 | background_tasks.add_task(release_worker_semaphore) 87 | return background_tasks 88 | 89 | 90 | # Note: for all the calls below, we make a hard assumption that the caller 91 | # includes the model name in the payload, otherwise we can't figure out which 92 | # underlying sub-worker to call. 93 | 94 | 95 | @app.post("/worker_generate_stream") 96 | async def api_generate_stream(request: Request): 97 | params = await request.json() 98 | await acquire_worker_semaphore() 99 | worker = worker_map[params["model"]] 100 | generator = worker.generate_stream_gate(params) 101 | background_tasks = create_background_tasks() 102 | return StreamingResponse(generator, background=background_tasks) 103 | 104 | 105 | @app.post("/worker_generate") 106 | async def api_generate(request: Request): 107 | params = await request.json() 108 | await acquire_worker_semaphore() 109 | worker = worker_map[params["model"]] 110 | output = worker.generate_gate(params) 111 | release_worker_semaphore() 112 | return JSONResponse(output) 113 | 114 | 115 | @app.post("/worker_get_embeddings") 116 | async def api_get_embeddings(request: Request): 117 | params = await request.json() 118 | await acquire_worker_semaphore() 119 | worker = worker_map[params["model"]] 120 | embedding = worker.get_embeddings(params) 121 | background_tasks = create_background_tasks() 122 | return JSONResponse(content=embedding, background=background_tasks) 123 | 124 | 125 | @app.post("/worker_get_status") 126 | async def api_get_status(request: Request): 127 | return { 128 | "model_names": [m for w in workers for m in w.model_names], 129 | "speed": 1, 130 | "queue_length": sum([w.get_queue_length() for w in workers]), 131 | } 132 | 133 | 134 | @app.post("/count_token") 135 | async def api_count_token(request: Request): 136 | params = await request.json() 137 | worker = worker_map[params["model"]] 138 | return worker.count_token(params) 139 | 140 | 141 | @app.post("/worker_get_conv_template") 142 | async def api_get_conv(request: Request): 143 | params = await request.json() 144 | worker = worker_map[params["model"]] 145 | return worker.get_conv_template() 146 | 147 | 148 | @app.post("/model_details") 149 | async def api_model_details(request: Request): 150 | params = await request.json() 151 | worker = worker_map[params["model"]] 152 | return {"context_length": worker.context_len} 153 | 154 | 155 | def create_multi_model_worker(): 156 | # Note: Ensure we resolve arg conflicts. We let `add_model_args` add MOST 157 | # of the model args but we'll override one to have an append action that 158 | # supports multiple values. 159 | parser = argparse.ArgumentParser(conflict_handler="resolve") 160 | parser.add_argument("--host", type=str, default="localhost") 161 | parser.add_argument("--port", type=int, default=21002) 162 | parser.add_argument("--worker-address", type=str, default="http://localhost:21002") 163 | parser.add_argument( 164 | "--controller-address", type=str, default="http://localhost:21001" 165 | ) 166 | add_model_args(parser) 167 | # Override the model path to be repeated and align it with model names. 168 | parser.add_argument( 169 | "--model-path", 170 | type=str, 171 | default=[], 172 | action="append", 173 | help="One or more paths to model weights to load. This can be a local folder or a Hugging Face repo ID.", 174 | ) 175 | parser.add_argument( 176 | "--model-names", 177 | type=lambda s: s.split(","), 178 | action="append", 179 | help="One or more model names. Values must be aligned with `--model-path` values.", 180 | ) 181 | parser.add_argument("--limit-worker-concurrency", type=int, default=5) 182 | parser.add_argument("--stream-interval", type=int, default=2) 183 | parser.add_argument("--no-register", action="store_true") 184 | args = parser.parse_args() 185 | logger.info(f"args: {args}") 186 | 187 | if args.gpus: 188 | if len(args.gpus.split(",")) < args.num_gpus: 189 | raise ValueError( 190 | f"Larger --num-gpus ({args.num_gpus}) than --gpus {args.gpus}!" 191 | ) 192 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus 193 | 194 | gptq_config = GptqConfig( 195 | ckpt=args.gptq_ckpt or args.model_path, 196 | wbits=args.gptq_wbits, 197 | groupsize=args.gptq_groupsize, 198 | act_order=args.gptq_act_order, 199 | ) 200 | 201 | if args.model_names is None: 202 | args.model_names = [[x.split("/")[-1]] for x in args.model_path] 203 | 204 | # Launch all workers 205 | workers = [] 206 | for model_path, model_names in zip(args.model_path, args.model_names): 207 | w = ModelWorker( 208 | args.controller_address, 209 | args.worker_address, 210 | worker_id, 211 | model_path, 212 | model_names, 213 | args.limit_worker_concurrency, 214 | args.no_register, 215 | device=args.device, 216 | num_gpus=args.num_gpus, 217 | max_gpu_memory=args.max_gpu_memory, 218 | load_8bit=args.load_8bit, 219 | cpu_offloading=args.cpu_offloading, 220 | gptq_config=gptq_config, 221 | stream_interval=args.stream_interval, 222 | ) 223 | workers.append(w) 224 | for model_name in model_names: 225 | worker_map[model_name] = w 226 | 227 | # Register all models 228 | url = args.controller_address + "/register_worker" 229 | data = { 230 | "worker_name": workers[0].worker_addr, 231 | "check_heart_beat": not args.no_register, 232 | "worker_status": { 233 | "model_names": [m for w in workers for m in w.model_names], 234 | "speed": 1, 235 | "queue_length": sum([w.get_queue_length() for w in workers]), 236 | }, 237 | } 238 | r = requests.post(url, json=data) 239 | assert r.status_code == 200 240 | 241 | return args, workers 242 | 243 | 244 | if __name__ == "__main__": 245 | args, workers = create_multi_model_worker() 246 | uvicorn.run(app, host=args.host, port=args.port, log_level="info") 247 | -------------------------------------------------------------------------------- /chat/server/launch_all_serve.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: python launch_all_serve_by_shell.py --model-path-address "THUDM/chatglm2-6b@localhost@2021" "huggyllama/llama-7b@localhost@2022" 3 | 4 | Workers are listed in format of `model-path`@`host`@`port` 5 | 6 | The key mechanism behind this scripts is: 7 | 1, execute shell cmd to launch the controller/worker/openai-api-server; 8 | 2, check the log of controller/worker/openai-api-server to ensure that the server is launched properly. 9 | Note that a few of non-critical `chat.server` cmd options are not supported currently. 10 | """ 11 | import sys 12 | import os 13 | 14 | sys.path.append(os.path.dirname(os.path.dirname(__file__))) 15 | 16 | import subprocess 17 | import re 18 | import argparse 19 | 20 | LOGDIR = "./logs/" 21 | 22 | if not os.path.exists(LOGDIR): 23 | os.makedirs(LOGDIR) 24 | 25 | parser = argparse.ArgumentParser() 26 | # ------multi worker----------------- 27 | parser.add_argument( 28 | "--model-path-address", 29 | default="THUDM/chatglm2-6b@localhost@20002", 30 | nargs="+", 31 | type=str, 32 | help="model path, host, and port, formatted as model-path@host@port", 33 | ) 34 | # ---------------controller------------------------- 35 | 36 | parser.add_argument("--controller-host", type=str, default="localhost") 37 | parser.add_argument("--controller-port", type=int, default=21001) 38 | parser.add_argument( 39 | "--dispatch-method", 40 | type=str, 41 | choices=["lottery", "shortest_queue"], 42 | default="shortest_queue", 43 | ) 44 | controller_args = ["controller-host", "controller-port", "dispatch-method"] 45 | 46 | # ----------------------worker------------------------------------------ 47 | 48 | parser.add_argument("--worker-host", type=str, default="localhost") 49 | parser.add_argument("--worker-port", type=int, default=21002) 50 | # parser.add_argument("--worker-address", type=str, default="http://localhost:21002") 51 | # parser.add_argument( 52 | # "--controller-address", type=str, default="http://localhost:21001" 53 | # ) 54 | parser.add_argument( 55 | "--model-path", 56 | type=str, 57 | default="lmsys/vicuna-7b-v1.3", 58 | help="The path to the weights. This can be a local folder or a Hugging Face repo ID.", 59 | ) 60 | parser.add_argument( 61 | "--revision", 62 | type=str, 63 | default="main", 64 | help="Hugging Face Hub model revision identifier", 65 | ) 66 | parser.add_argument( 67 | "--device", 68 | type=str, 69 | choices=["cpu", "cuda", "mps", "xpu"], 70 | default="cuda", 71 | help="The device type", 72 | ) 73 | parser.add_argument( 74 | "--gpus", 75 | type=str, 76 | default="0", 77 | help="A single GPU like 1 or multiple GPUs like 0,2", 78 | ) 79 | parser.add_argument("--num-gpus", type=int, default=1) 80 | parser.add_argument( 81 | "--max-gpu-memory", 82 | type=str, 83 | help="The maximum memory per gpu. Use a string like '13Gib'", 84 | ) 85 | parser.add_argument("--load-8bit", action="store_true", help="Use 8-bit quantization") 86 | parser.add_argument( 87 | "--cpu-offloading", 88 | action="store_true", 89 | help="Only when using 8-bit quantization: Offload excess weights to the CPU that don't fit on the GPU", 90 | ) 91 | parser.add_argument( 92 | "--gptq-ckpt", 93 | type=str, 94 | default=None, 95 | help="Load quantized model. The path to the local GPTQ checkpoint.", 96 | ) 97 | parser.add_argument( 98 | "--gptq-wbits", 99 | type=int, 100 | default=16, 101 | choices=[2, 3, 4, 8, 16], 102 | help="#bits to use for quantization", 103 | ) 104 | parser.add_argument( 105 | "--gptq-groupsize", 106 | type=int, 107 | default=-1, 108 | help="Groupsize to use for quantization; default uses full row.", 109 | ) 110 | parser.add_argument( 111 | "--gptq-act-order", 112 | action="store_true", 113 | help="Whether to apply the activation order GPTQ heuristic", 114 | ) 115 | parser.add_argument( 116 | "--model-names", 117 | type=lambda s: s.split(","), 118 | help="Optional display comma separated names", 119 | ) 120 | parser.add_argument( 121 | "--limit-worker-concurrency", 122 | type=int, 123 | default=5, 124 | help="Limit the model concurrency to prevent OOM.", 125 | ) 126 | parser.add_argument("--stream-interval", type=int, default=2) 127 | parser.add_argument("--no-register", action="store_true") 128 | 129 | worker_args = [ 130 | "worker-host", 131 | "worker-port", 132 | "model-path", 133 | "revision", 134 | "device", 135 | "gpus", 136 | "num-gpus", 137 | "max-gpu-memory", 138 | "load-8bit", 139 | "cpu-offloading", 140 | "gptq-ckpt", 141 | "gptq-wbits", 142 | "gptq-groupsize", 143 | "gptq-act-order", 144 | "model-names", 145 | "limit-worker-concurrency", 146 | "stream-interval", 147 | "no-register", 148 | "controller-address", 149 | ] 150 | # -----------------openai server--------------------------- 151 | 152 | parser.add_argument("--server-host", type=str, default="localhost", help="host name") 153 | parser.add_argument("--server-port", type=int, default=8001, help="port number") 154 | parser.add_argument( 155 | "--allow-credentials", action="store_true", help="allow credentials" 156 | ) 157 | # parser.add_argument( 158 | # "--allowed-origins", type=json.loads, default=["*"], help="allowed origins" 159 | # ) 160 | # parser.add_argument( 161 | # "--allowed-methods", type=json.loads, default=["*"], help="allowed methods" 162 | # ) 163 | # parser.add_argument( 164 | # "--allowed-headers", type=json.loads, default=["*"], help="allowed headers" 165 | # ) 166 | parser.add_argument( 167 | "--api-keys", 168 | type=lambda s: s.split(","), 169 | help="Optional list of comma separated API keys", 170 | ) 171 | server_args = [ 172 | "server-host", 173 | "server-port", 174 | "allow-credentials", 175 | "api-keys", 176 | "controller-address", 177 | ] 178 | 179 | args = parser.parse_args() 180 | 181 | args = argparse.Namespace( 182 | **vars(args), 183 | **{"controller-address": f"http://{args.controller_host}:{args.controller_port}"}, 184 | ) 185 | 186 | if args.gpus: 187 | if len(args.gpus.split(",")) < args.num_gpus: 188 | raise ValueError( 189 | f"Larger --num-gpus ({args.num_gpus}) than --gpus {args.gpus}!" 190 | ) 191 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus 192 | 193 | # 0,controller, model_worker, openai_api_server 194 | # 1, cmd options 195 | # 2,LOGDIR 196 | # 3, log file name 197 | base_launch_sh = "nohup python3 -m chat.server.{0} {1} >{2}/{3}.log 2>&1 &" 198 | 199 | # 0 LOGDIR 200 | #! 1 log file name 201 | # 2 controller, worker, openai_api_server 202 | base_check_sh = """while [ `grep -c "Uvicorn running on" {0}/{1}.log` -eq '0' ];do 203 | sleep 1s; 204 | echo "wait {2} running" 205 | done 206 | echo '{2} running' """ 207 | 208 | 209 | def string_args(args, args_list): 210 | args_str = "" 211 | for key, value in args._get_kwargs(): 212 | key = key.replace("_", "-") 213 | if key not in args_list: 214 | continue 215 | 216 | key = key.split("-")[-1] if re.search("port|host", key) else key 217 | if not value: 218 | pass 219 | # 1==True -> True 220 | elif isinstance(value, bool) and value == True: 221 | args_str += f" --{key} " 222 | elif ( 223 | isinstance(value, list) 224 | or isinstance(value, tuple) 225 | or isinstance(value, set) 226 | ): 227 | value = " ".join(value) 228 | args_str += f" --{key} {value} " 229 | else: 230 | args_str += f" --{key} {value} " 231 | 232 | return args_str 233 | 234 | 235 | def launch_worker(item): 236 | log_name = ( 237 | item.split("/")[-1] 238 | .split("\\")[-1] 239 | .replace("-", "_") 240 | .replace("@", "_") 241 | .replace(".", "_") 242 | ) 243 | 244 | args.model_path, args.worker_host, args.worker_port = item.split("@") 245 | print("*" * 80) 246 | worker_str_args = string_args(args, worker_args) 247 | print(worker_str_args) 248 | worker_sh = base_launch_sh.format( 249 | "model_worker", worker_str_args, LOGDIR, f"worker_{log_name}" 250 | ) 251 | worker_check_sh = base_check_sh.format(LOGDIR, f"worker_{log_name}", "model_worker") 252 | subprocess.run(worker_sh, shell=True, check=True) 253 | subprocess.run(worker_check_sh, shell=True, check=True) 254 | 255 | 256 | def launch_all(): 257 | controller_str_args = string_args(args, controller_args) 258 | controller_sh = base_launch_sh.format( 259 | "controller", controller_str_args, LOGDIR, "controller" 260 | ) 261 | controller_check_sh = base_check_sh.format(LOGDIR, "controller", "controller") 262 | subprocess.run(controller_sh, shell=True, check=True) 263 | subprocess.run(controller_check_sh, shell=True, check=True) 264 | 265 | if isinstance(args.model_path_address, str): 266 | launch_worker(args.model_path_address) 267 | else: 268 | for idx, item in enumerate(args.model_path_address): 269 | print(f"loading {idx}th model:{item}") 270 | launch_worker(item) 271 | 272 | server_str_args = string_args(args, server_args) 273 | server_sh = base_launch_sh.format( 274 | "openai_api_server", server_str_args, LOGDIR, "openai_api_server" 275 | ) 276 | server_check_sh = base_check_sh.format( 277 | LOGDIR, "openai_api_server", "openai_api_server" 278 | ) 279 | subprocess.run(server_sh, shell=True, check=True) 280 | subprocess.run(server_check_sh, shell=True, check=True) 281 | 282 | 283 | if __name__ == "__main__": 284 | launch_all() 285 | -------------------------------------------------------------------------------- /chat/server/monitor/elo_analysis.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from collections import defaultdict 3 | import datetime 4 | import json 5 | import math 6 | import pickle 7 | from pytz import timezone 8 | 9 | import numpy as np 10 | import pandas as pd 11 | import plotly.express as px 12 | from tqdm import tqdm 13 | 14 | from chat.model.model_registry import get_model_info 15 | from chat.server.monitor.basic_stats import get_log_files 16 | from chat.server.monitor.clean_battle_data import clean_battle_data 17 | 18 | 19 | pd.options.display.float_format = "{:.2f}".format 20 | 21 | 22 | def compute_elo(battles, K=4, SCALE=400, BASE=10, INIT_RATING=1000): 23 | rating = defaultdict(lambda: INIT_RATING) 24 | 25 | for rd, model_a, model_b, winner in battles[ 26 | ["model_a", "model_b", "winner"] 27 | ].itertuples(): 28 | ra = rating[model_a] 29 | rb = rating[model_b] 30 | ea = 1 / (1 + BASE ** ((rb - ra) / SCALE)) 31 | eb = 1 / (1 + BASE ** ((ra - rb) / SCALE)) 32 | if winner == "model_a": 33 | sa = 1 34 | elif winner == "model_b": 35 | sa = 0 36 | elif winner == "tie" or winner == "tie (bothbad)": 37 | sa = 0.5 38 | else: 39 | raise Exception(f"unexpected vote {winner}") 40 | rating[model_a] += K * (sa - ea) 41 | rating[model_b] += K * (1 - sa - eb) 42 | 43 | return dict(rating) 44 | 45 | 46 | def get_bootstrap_result(battles, func_compute_elo, num_round=1000): 47 | rows = [] 48 | for i in tqdm(range(num_round), desc="bootstrap"): 49 | tmp_battles = battles.sample(frac=1.0, replace=True) 50 | rows.append(func_compute_elo(tmp_battles)) 51 | df = pd.DataFrame(rows) 52 | return df[df.median().sort_values(ascending=False).index] 53 | 54 | 55 | def get_median_elo_from_bootstrap(bootstrap_df): 56 | median = dict(bootstrap_df.quantile(0.5)) 57 | median = {k: int(v + 0.5) for k, v in median.items()} 58 | return median 59 | 60 | 61 | def compute_pairwise_win_fraction(battles, model_order): 62 | # Times each model wins as Model A 63 | a_win_ptbl = pd.pivot_table( 64 | battles[battles["winner"] == "model_a"], 65 | index="model_a", 66 | columns="model_b", 67 | aggfunc="size", 68 | fill_value=0, 69 | ) 70 | 71 | # Table counting times each model wins as Model B 72 | b_win_ptbl = pd.pivot_table( 73 | battles[battles["winner"] == "model_b"], 74 | index="model_a", 75 | columns="model_b", 76 | aggfunc="size", 77 | fill_value=0, 78 | ) 79 | 80 | # Table counting number of A-B pairs 81 | num_battles_ptbl = pd.pivot_table( 82 | battles, index="model_a", columns="model_b", aggfunc="size", fill_value=0 83 | ) 84 | 85 | # Computing the proportion of wins for each model as A and as B 86 | # against all other models 87 | row_beats_col_freq = (a_win_ptbl + b_win_ptbl.T) / ( 88 | num_battles_ptbl + num_battles_ptbl.T 89 | ) 90 | 91 | if model_order is None: 92 | prop_wins = row_beats_col_freq.mean(axis=1).sort_values(ascending=False) 93 | model_order = list(prop_wins.keys()) 94 | 95 | # Arrange ordering according to proprition of wins 96 | row_beats_col = row_beats_col_freq.loc[model_order, model_order] 97 | return row_beats_col 98 | 99 | 100 | def visualize_leaderboard_table(rating): 101 | models = list(rating.keys()) 102 | models.sort(key=lambda k: -rating[k]) 103 | 104 | emoji_dict = { 105 | 1: "🥇", 106 | 2: "🥈", 107 | 3: "🥉", 108 | } 109 | 110 | md = "" 111 | md += "| Rank | Model | Elo Rating | Description |\n" 112 | md += "| --- | --- | --- | --- |\n" 113 | for i, model in enumerate(models): 114 | rank = i + 1 115 | minfo = get_model_info(model) 116 | emoji = emoji_dict.get(rank, "") 117 | md += f"| {rank} | {emoji} [{model}]({minfo.link}) | {rating[model]:.0f} | {minfo.description} |\n" 118 | 119 | return md 120 | 121 | 122 | def visualize_pairwise_win_fraction(battles, model_order): 123 | row_beats_col = compute_pairwise_win_fraction(battles, model_order) 124 | fig = px.imshow( 125 | row_beats_col, 126 | color_continuous_scale="RdBu", 127 | text_auto=".2f", 128 | height=700, 129 | width=700, 130 | ) 131 | fig.update_layout( 132 | xaxis_title="Model B", 133 | yaxis_title="Model A", 134 | xaxis_side="top", 135 | title_y=0.07, 136 | title_x=0.5, 137 | ) 138 | fig.update_traces( 139 | hovertemplate="Model A: %{y}
Model B: %{x}
Fraction of A Wins: %{z}" 140 | ) 141 | 142 | return fig 143 | 144 | 145 | def visualize_battle_count(battles, model_order): 146 | ptbl = pd.pivot_table( 147 | battles, index="model_a", columns="model_b", aggfunc="size", fill_value=0 148 | ) 149 | battle_counts = ptbl + ptbl.T 150 | fig = px.imshow( 151 | battle_counts.loc[model_order, model_order], 152 | text_auto=True, 153 | height=700, 154 | width=700, 155 | ) 156 | fig.update_layout( 157 | xaxis_title="Model B", 158 | yaxis_title="Model A", 159 | xaxis_side="top", 160 | title_y=0.07, 161 | title_x=0.5, 162 | ) 163 | fig.update_traces( 164 | hovertemplate="Model A: %{y}
Model B: %{x}
Count: %{z}" 165 | ) 166 | return fig 167 | 168 | 169 | def visualize_average_win_rate(battles): 170 | row_beats_col_freq = compute_pairwise_win_fraction(battles, None) 171 | fig = px.bar( 172 | row_beats_col_freq.mean(axis=1).sort_values(ascending=False), 173 | text_auto=".2f", 174 | height=500, 175 | width=700, 176 | ) 177 | fig.update_layout( 178 | yaxis_title="Average Win Rate", xaxis_title="Model", showlegend=False 179 | ) 180 | return fig 181 | 182 | 183 | def visualize_bootstrap_elo_rating(df): 184 | bars = ( 185 | pd.DataFrame( 186 | dict( 187 | lower=df.quantile(0.025), 188 | rating=df.quantile(0.5), 189 | upper=df.quantile(0.975), 190 | ) 191 | ) 192 | .reset_index(names="model") 193 | .sort_values("rating", ascending=False) 194 | ) 195 | bars["error_y"] = bars["upper"] - bars["rating"] 196 | bars["error_y_minus"] = bars["rating"] - bars["lower"] 197 | bars["rating_rounded"] = np.round(bars["rating"], 2) 198 | fig = px.scatter( 199 | bars, 200 | x="model", 201 | y="rating", 202 | error_y="error_y", 203 | error_y_minus="error_y_minus", 204 | text="rating_rounded", 205 | height=500, 206 | width=700, 207 | ) 208 | fig.update_layout(xaxis_title="Model", yaxis_title="Rating") 209 | return fig 210 | 211 | 212 | def report_elo_analysis_results(battles_json): 213 | battles = pd.DataFrame(battles_json) 214 | battles = battles.sort_values(ascending=True, by=["tstamp"]) 215 | # Only use anonymous votes 216 | battles = battles[battles["anony"]].reset_index(drop=True) 217 | battles_no_ties = battles[~battles["winner"].str.contains("tie")] 218 | 219 | # Online update 220 | elo_rating_online = compute_elo(battles) 221 | 222 | # Bootstrap 223 | bootstrap_df = get_bootstrap_result(battles, compute_elo) 224 | elo_rating_median = get_median_elo_from_bootstrap(bootstrap_df) 225 | model_order = list(elo_rating_median.keys()) 226 | model_order.sort(key=lambda k: -elo_rating_median[k]) 227 | 228 | # Plots 229 | leaderboard_table = visualize_leaderboard_table(elo_rating_median) 230 | win_fraction_heatmap = visualize_pairwise_win_fraction(battles_no_ties, model_order) 231 | battle_count_heatmap = visualize_battle_count(battles_no_ties, model_order) 232 | average_win_rate_bar = visualize_average_win_rate(battles_no_ties) 233 | bootstrap_elo_rating = visualize_bootstrap_elo_rating(bootstrap_df) 234 | 235 | last_updated_tstamp = battles["tstamp"].max() 236 | last_updated_datetime = datetime.datetime.fromtimestamp( 237 | last_updated_tstamp, tz=timezone("US/Pacific") 238 | ).strftime("%Y-%m-%d %H:%M:%S %Z") 239 | 240 | return { 241 | "elo_rating_online": elo_rating_online, 242 | "elo_rating_median": elo_rating_median, 243 | "leaderboard_table": leaderboard_table, 244 | "win_fraction_heatmap": win_fraction_heatmap, 245 | "battle_count_heatmap": battle_count_heatmap, 246 | "average_win_rate_bar": average_win_rate_bar, 247 | "bootstrap_elo_rating": bootstrap_elo_rating, 248 | "last_updated_datetime": last_updated_datetime, 249 | "last_updated_tstamp": last_updated_tstamp, 250 | } 251 | 252 | 253 | def pretty_print_elo_rating(rating): 254 | model_order = list(rating.keys()) 255 | model_order.sort(key=lambda k: -rating[k]) 256 | for i, model in enumerate(model_order): 257 | print(f"{i+1:2d}, {model:25s}, {rating[model]:.0f}") 258 | 259 | 260 | if __name__ == "__main__": 261 | parser = argparse.ArgumentParser() 262 | parser.add_argument("--clean-battle-file", type=str) 263 | parser.add_argument("--max-num-files", type=int) 264 | args = parser.parse_args() 265 | 266 | np.random.seed(42) 267 | 268 | if args.clean_battle_file: 269 | # Read data from a cleaned battle files 270 | battles = pd.read_json(args.clean_battle_file) 271 | else: 272 | # Read data from all log files 273 | log_files = get_log_files(args.max_num_files) 274 | battles = clean_battle_data(log_files) 275 | 276 | results = report_elo_analysis_results(battles) 277 | 278 | print("# Online") 279 | pretty_print_elo_rating(results["elo_rating_online"]) 280 | print("# Median") 281 | pretty_print_elo_rating(results["elo_rating_median"]) 282 | print(f"last update : {results['last_updated_datetime']}") 283 | 284 | last_updated_tstamp = results["last_updated_tstamp"] 285 | cutoff_date = datetime.datetime.fromtimestamp( 286 | last_updated_tstamp, tz=timezone("US/Pacific") 287 | ).strftime("%Y%m%d") 288 | 289 | with open(f"elo_results_{cutoff_date}.pkl", "wb") as fout: 290 | pickle.dump(results, fout) 291 | --------------------------------------------------------------------------------