├── .github ├── PULL_REQUEST_TEMPLATE.md └── workflows │ └── python-package.yml ├── .gitignore ├── .pylintrc ├── LICENSE ├── README.md ├── assets ├── demo_narrow.gif ├── qa_browser.png ├── screenshot_cli.png ├── screenshot_gui.png ├── server_arch.png └── vicuna_logo.jpeg ├── data └── dummy_conversation.json ├── docker ├── Dockerfile └── docker-compose.yml ├── docs ├── arena.md ├── awq.md ├── commands │ ├── conv_release.md │ ├── data_cleaning.md │ ├── leaderboard.md │ ├── local_cluster.md │ ├── pypi.md │ └── webserver.md ├── dashinfer_integration.md ├── dataset_release.md ├── exllama_v2.md ├── gptq.md ├── langchain_integration.md ├── lightllm_integration.md ├── mlx_integration.md ├── model_support.md ├── openai_api.md ├── server_arch.md ├── third_party_ui.md ├── training.md ├── vicuna_weights_version.md ├── vllm_integration.md └── xFasterTransformer.md ├── fastchat ├── __init__.py ├── constants.py ├── conversation.py ├── data │ ├── __init__.py │ ├── clean_sharegpt.py │ ├── convert_alpaca.py │ ├── extract_gpt4_only.py │ ├── extract_single_round.py │ ├── filter_wrong_format.py │ ├── get_stats.py │ ├── hardcoded_questions.py │ ├── inspect_data.py │ ├── merge.py │ ├── optional_clean.py │ ├── optional_replace.py │ ├── prepare_all.py │ ├── pretty_json.py │ ├── sample.py │ ├── split_long_conversation.py │ └── split_train_test.py ├── llm_judge │ ├── README.md │ ├── clean_judgment.py │ ├── common.py │ ├── compute_agreement.py │ ├── data │ │ ├── judge_prompts.jsonl │ │ ├── mt_bench │ │ │ ├── misc │ │ │ │ └── radar.png │ │ │ ├── question.jsonl │ │ │ └── reference_answer │ │ │ │ └── gpt-4.jsonl │ │ └── vicuna_bench │ │ │ ├── question.jsonl │ │ │ └── reference_answer │ │ │ └── gpt-4.jsonl │ ├── download_mt_bench_pregenerated.py │ ├── gen_api_answer.py │ ├── gen_judgment.py │ ├── gen_model_answer.py │ ├── qa_browser.py │ └── show_result.py ├── model │ ├── __init__.py │ ├── apply_delta.py │ ├── apply_lora.py │ ├── compression.py │ ├── convert_fp16.py │ ├── llama_condense_monkey_patch.py │ ├── make_delta.py │ ├── model_adapter.py │ ├── model_chatglm.py │ ├── model_cllm.py │ ├── model_codet5p.py │ ├── model_exllama.py │ ├── model_falcon.py │ ├── model_registry.py │ ├── model_xfastertransformer.py │ ├── model_yuan2.py │ ├── monkey_patch_non_inplace.py │ ├── rwkv_model.py │ └── upload_hub.py ├── modules │ ├── __init__.py │ ├── awq.py │ ├── exllama.py │ ├── gptq.py │ └── xfastertransformer.py ├── protocol │ ├── api_protocol.py │ └── openai_api_protocol.py ├── serve │ ├── __init__.py │ ├── api_provider.py │ ├── base_model_worker.py │ ├── call_monitor.py │ ├── cli.py │ ├── controller.py │ ├── dashinfer_worker.py │ ├── example_images │ │ ├── distracted.jpg │ │ └── fridge.jpg │ ├── gateway │ │ ├── README.md │ │ └── nginx.conf │ ├── gradio_block_arena_anony.py │ ├── gradio_block_arena_named.py │ ├── gradio_block_arena_vision.py │ ├── gradio_block_arena_vision_anony.py │ ├── gradio_block_arena_vision_named.py │ ├── gradio_global_state.py │ ├── gradio_web_server.py │ ├── gradio_web_server_multi.py │ ├── huggingface_api.py │ ├── huggingface_api_worker.py │ ├── inference.py │ ├── launch_all_serve.py │ ├── lightllm_worker.py │ ├── mlx_worker.py │ ├── model_worker.py │ ├── monitor │ │ ├── add_markdown_info.py │ │ ├── basic_stats.py │ │ ├── classify │ │ │ ├── README.md │ │ │ ├── category.py │ │ │ ├── config.yaml │ │ │ ├── display_score.py │ │ │ ├── label.py │ │ │ └── vision_config.yaml │ │ ├── clean_battle_data.py │ │ ├── clean_chat_data.py │ │ ├── copilot_arena.py │ │ ├── criteria_labeling.py │ │ ├── dataset_release_scripts │ │ │ ├── arena_33k │ │ │ │ ├── count_unique_users.py │ │ │ │ ├── filter_bad_conv.py │ │ │ │ ├── merge_field.py │ │ │ │ ├── sample.py │ │ │ │ └── upload_hf_dataset.py │ │ │ └── lmsys_chat_1m │ │ │ │ ├── approve_all.py │ │ │ │ ├── compute_stats.py │ │ │ │ ├── filter_bad_conv.py │ │ │ │ ├── final_post_processing.py │ │ │ │ ├── instructions.md │ │ │ │ ├── merge_oai_tag.py │ │ │ │ ├── process_all.sh │ │ │ │ ├── sample.py │ │ │ │ └── upload_hf_dataset.py │ │ ├── deduplication.py │ │ ├── elo_analysis.py │ │ ├── inspect_conv.py │ │ ├── intersect_conv_file.py │ │ ├── leaderboard_csv_to_html.py │ │ ├── monitor.py │ │ ├── monitor_md.py │ │ ├── rating_systems.py │ │ ├── summarize_cluster.py │ │ ├── tag_openai_moderation.py │ │ ├── topic_clustering.py │ │ └── vote_time_stats │ │ │ ├── README.md │ │ │ ├── analyze_data.py │ │ │ └── plot.py │ ├── multi_model_worker.py │ ├── openai_api_server.py │ ├── register_worker.py │ ├── remote_logger.py │ ├── sglang_worker.py │ ├── shutdown_serve.py │ ├── test_message.py │ ├── test_throughput.py │ ├── vision │ │ ├── create_vqa_examples_dir.py │ │ ├── create_vqa_examples_json.py │ │ └── image.py │ └── vllm_worker.py ├── train │ ├── llama2_flash_attn_monkey_patch.py │ ├── llama_flash_attn_monkey_patch.py │ ├── llama_xformers_attn_monkey_patch.py │ ├── train.py │ ├── train_baichuan.py │ ├── train_flant5.py │ ├── train_lora.py │ ├── train_lora_t5.py │ ├── train_mem.py │ ├── train_with_template.py │ ├── train_xformers.py │ └── train_yuan2.py └── utils.py ├── format.sh ├── playground ├── FastChat_API_GoogleColab.ipynb ├── __init__.py ├── benchmark │ └── benchmark_api_provider.py ├── deepspeed_config_s2.json ├── deepspeed_config_s3.json └── test_embedding │ ├── README.md │ ├── test_classification.py │ ├── test_semantic_search.py │ └── test_sentence_similarity.py ├── pyproject.toml ├── scripts ├── build-api.sh ├── test_readme_train.sh ├── train_lora.sh ├── train_vicuna_13b.sh ├── train_vicuna_7b.sh └── upload_pypi.sh ├── tests ├── README.md ├── killall_python.sh ├── launch_openai_api_test_server.py ├── load_test.py ├── test_cli.py ├── test_cli_inputs.txt ├── test_image_utils.py ├── test_openai_api.py ├── test_openai_langchain.py └── test_openai_vision_api.py └── theme.json /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | ## Why are these changes needed? 6 | 7 | 8 | 9 | ## Related issue number (if applicable) 10 | 11 | 12 | 13 | ## Checks 14 | 15 | - [ ] I've run `format.sh` to lint the changes in this PR. 16 | - [ ] I've included any doc changes needed. 17 | - [ ] I've made sure the relevant tests are passing (if applicable). 18 | -------------------------------------------------------------------------------- /.github/workflows/python-package.yml: -------------------------------------------------------------------------------- 1 | name: Python package 2 | 3 | on: [push, pull_request] 4 | 5 | jobs: 6 | build: 7 | 8 | runs-on: ubuntu-latest 9 | strategy: 10 | fail-fast: false 11 | matrix: 12 | python-version: ["3.10"] 13 | 14 | steps: 15 | - uses: actions/checkout@v3 16 | - name: Set up Python ${{ matrix.python-version }} 17 | uses: actions/setup-python@v4 18 | with: 19 | python-version: ${{ matrix.python-version }} 20 | cache: 'pip' 21 | - name: Install dependencies 22 | run: | 23 | python -m pip install --upgrade pip 24 | python -m pip install -e '.[dev]' 25 | - name: Run linter 26 | run: | 27 | pylint -d all -e E0602 ./fastchat/ 28 | - name: Check formatting 29 | run: | 30 | black --check . 31 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Python 2 | __pycache__ 3 | *.pyc 4 | *.egg-info 5 | dist 6 | .venv 7 | 8 | # Log 9 | *.log 10 | *.log.* 11 | *.json 12 | !playground/deepspeed_config_s2.json 13 | !playground/deepspeed_config_s3.json 14 | 15 | # Editor 16 | .idea 17 | *.swp 18 | 19 | # Other 20 | .DS_Store 21 | wandb 22 | output 23 | checkpoints_flant5_3b 24 | 25 | # Data 26 | *.pkl 27 | *.csv 28 | tests/state_of_the_union.txt 29 | 30 | # Build 31 | build 32 | 33 | # Image data 34 | serve_images 35 | val2014 36 | vqa_examples -------------------------------------------------------------------------------- /assets/demo_narrow.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lm-sys/FastChat/9a295b64ce3491ff15901f2d00f5e304b0ee78dc/assets/demo_narrow.gif -------------------------------------------------------------------------------- /assets/qa_browser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lm-sys/FastChat/9a295b64ce3491ff15901f2d00f5e304b0ee78dc/assets/qa_browser.png -------------------------------------------------------------------------------- /assets/screenshot_cli.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lm-sys/FastChat/9a295b64ce3491ff15901f2d00f5e304b0ee78dc/assets/screenshot_cli.png -------------------------------------------------------------------------------- /assets/screenshot_gui.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lm-sys/FastChat/9a295b64ce3491ff15901f2d00f5e304b0ee78dc/assets/screenshot_gui.png -------------------------------------------------------------------------------- /assets/server_arch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lm-sys/FastChat/9a295b64ce3491ff15901f2d00f5e304b0ee78dc/assets/server_arch.png -------------------------------------------------------------------------------- /assets/vicuna_logo.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lm-sys/FastChat/9a295b64ce3491ff15901f2d00f5e304b0ee78dc/assets/vicuna_logo.jpeg -------------------------------------------------------------------------------- /docker/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:12.2.0-runtime-ubuntu20.04 2 | 3 | RUN apt-get update -y && apt-get install -y python3.9 python3.9-distutils curl 4 | RUN curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py 5 | RUN python3.9 get-pip.py 6 | RUN pip3 install fschat 7 | RUN pip3 install fschat[model_worker,webui] -------------------------------------------------------------------------------- /docker/docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: "3.9" 2 | 3 | services: 4 | fastchat-controller: 5 | build: 6 | context: . 7 | dockerfile: Dockerfile 8 | image: fastchat:latest 9 | ports: 10 | - "21001:21001" 11 | entrypoint: ["python3.9", "-m", "fastchat.serve.controller", "--host", "0.0.0.0", "--port", "21001"] 12 | fastchat-model-worker: 13 | build: 14 | context: . 15 | dockerfile: Dockerfile 16 | volumes: 17 | - huggingface:/root/.cache/huggingface 18 | image: fastchat:latest 19 | deploy: 20 | resources: 21 | reservations: 22 | devices: 23 | - driver: nvidia 24 | count: 1 25 | capabilities: [gpu] 26 | entrypoint: ["python3.9", "-m", "fastchat.serve.model_worker", "--model-names", "${FASTCHAT_WORKER_MODEL_NAMES:-vicuna-7b-v1.5}", "--model-path", "${FASTCHAT_WORKER_MODEL_PATH:-lmsys/vicuna-7b-v1.5}", "--worker-address", "http://fastchat-model-worker:21002", "--controller-address", "http://fastchat-controller:21001", "--host", "0.0.0.0", "--port", "21002"] 27 | fastchat-api-server: 28 | build: 29 | context: . 30 | dockerfile: Dockerfile 31 | image: fastchat:latest 32 | ports: 33 | - "8000:8000" 34 | entrypoint: ["python3.9", "-m", "fastchat.serve.openai_api_server", "--controller-address", "http://fastchat-controller:21001", "--host", "0.0.0.0", "--port", "8000"] 35 | volumes: 36 | huggingface: 37 | -------------------------------------------------------------------------------- /docs/arena.md: -------------------------------------------------------------------------------- 1 | # Chatbot Arena 2 | Chatbot Arena is an LLM benchmark platform featuring anonymous, randomized battles, available at https://lmarena.ai. 3 | We invite the entire community to join this benchmarking effort by contributing your votes and models. 4 | 5 | ## How to add a new model 6 | If you want to see a specific model in the arena, you can follow the methods below. 7 | 8 | ### Method 1: Hosted by 3rd party API providers or yourself 9 | If you have a model hosted by a 3rd party API provider or yourself, please give us the access to an API endpoint. 10 | - We prefer OpenAI-compatible APIs, so we can reuse our [code](https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/api_provider.py) for calling OpenAI models. 11 | - If you have your own API protocol, please follow the [instructions](model_support.md) to add them. Contribute your code by sending a pull request. 12 | 13 | ### Method 2: Hosted by LMSYS 14 | 1. Contribute the code to support this model in FastChat by submitting a pull request. See [instructions](model_support.md). 15 | 2. After the model is supported, we will try to schedule some compute resources to host the model in the arena. However, due to the limited resources we have, we may not be able to serve every model. We will select the models based on popularity, quality, diversity, and other factors. 16 | 17 | 18 | ## How to launch vision arena 19 | 20 | 1. Run `python3 -m fastchat.serve.controller` to start the controller and begin registering local model workers and API-provided workers. 21 | 2. Run `python3 -m fastchat.serve.sglang_worker --model-path --tokenizer-path ` to run local vision-language models. Currently supported models include the LLaVA and Yi-VL series. 22 | 3. If you are using a 3rd party model with an API provider (e.g. GPT-4-V, Gemini 1.5), please follow the instructions [model_support.md](model_support.md) to add a json file `api_endpoints.json`. 23 | 4. Run the gradio server with the `--vision-arena` flag on. 24 | 5. To run and store images into a remote directory, add the flag: `--use-remote-storage` 25 | 6. To run and allow samples of random questions, add `--random_questions metadata_sampled.json`. Check sections below for how to generate this. 26 | 27 | Example command: 28 | ``` 29 | python3 -m fastchat.serve.gradio_web_server_multi --share --register-api-endpoint-file api_endpoints.json --vision-arena --use-remote-storage --random-questions metadata_sampled.json 30 | ``` 31 | 32 | ### NSFW and CSAM Detection 33 | 1. Adding NSFW Endpoint and API key: Please add the following environment variables to run the NSFW moderation filter for images: 34 | - `AZURE_IMG_MODERATION_ENDPOINT`: This is the endpoint that the NSFW moderator is hosted (e.g. https://{endpoint}/contentmoderator/moderate/v1.0/ProcessImage/Evaluate). Change the `endpoint` to your own. 35 | - `AZURE_IMG_MODERATION_API_KEY`: Your API key to run this endpoint. 36 | 2. Adding CSAM API key: 37 | - `PHOTODNA_API_KEY`: The API key that runs the CSAM detector endpoint. 38 | 39 | Example in `~/.bashrc`: 40 | ``` 41 | export AZURE_IMG_MODERATION_ENDPOINT=https:///contentmoderator/moderate/v1.0/ProcessImage/Evaluate 42 | export AZURE_IMG_MODERATION_API_KEY= 43 | export PHOTODNA_API_KEY= 44 | ``` 45 | 46 | ### Adding Random Samples for VQA 47 | We provide random samples of example images for users to interact with coming from various datasets including DocVQA, RealWorldQA, ChartQA and VizWiz-VQA. 48 | 1. Download the images and generate random questions file by running `python fastchat/serve/vision/create_vqa_examples_dir.py` -------------------------------------------------------------------------------- /docs/awq.md: -------------------------------------------------------------------------------- 1 | # AWQ 4bit Inference 2 | 3 | We integrated [AWQ](https://github.com/mit-han-lab/llm-awq) into FastChat to provide **efficient and accurate** 4bit LLM inference. 4 | 5 | ## Install AWQ 6 | 7 | Setup environment (please refer to [this link](https://github.com/mit-han-lab/llm-awq#install) for more details): 8 | ```bash 9 | conda create -n fastchat-awq python=3.10 -y 10 | conda activate fastchat-awq 11 | # cd /path/to/FastChat 12 | pip install --upgrade pip # enable PEP 660 support 13 | pip install -e . # install fastchat 14 | 15 | git clone https://github.com/mit-han-lab/llm-awq repositories/llm-awq 16 | cd repositories/llm-awq 17 | pip install -e . # install awq package 18 | 19 | cd awq/kernels 20 | python setup.py install # install awq CUDA kernels 21 | ``` 22 | 23 | ## Chat with the CLI 24 | 25 | ```bash 26 | # Download quantized model from huggingface 27 | # Make sure you have git-lfs installed (https://git-lfs.com) 28 | git lfs install 29 | git clone https://huggingface.co/mit-han-lab/vicuna-7b-v1.3-4bit-g128-awq 30 | 31 | # You can specify which quantized model to use by setting --awq-ckpt 32 | python3 -m fastchat.serve.cli \ 33 | --model-path models/vicuna-7b-v1.3-4bit-g128-awq \ 34 | --awq-wbits 4 \ 35 | --awq-groupsize 128 36 | ``` 37 | 38 | ## Benchmark 39 | 40 | * Through **4-bit weight quantization**, AWQ helps to run larger language models within the device memory restriction and prominently accelerates token generation. All benchmarks are done with group_size 128. 41 | 42 | * Benchmark on NVIDIA RTX A6000: 43 | 44 | | Model | Bits | Max Memory (MiB) | Speed (ms/token) | AWQ Speedup | 45 | | --------------- | ---- | ---------------- | ---------------- | ----------- | 46 | | vicuna-7b | 16 | 13543 | 26.06 | / | 47 | | vicuna-7b | 4 | 5547 | 12.43 | 2.1x | 48 | | llama2-7b-chat | 16 | 13543 | 27.14 | / | 49 | | llama2-7b-chat | 4 | 5547 | 12.44 | 2.2x | 50 | | vicuna-13b | 16 | 25647 | 44.91 | / | 51 | | vicuna-13b | 4 | 9355 | 17.30 | 2.6x | 52 | | llama2-13b-chat | 16 | 25647 | 47.28 | / | 53 | | llama2-13b-chat | 4 | 9355 | 20.28 | 2.3x | 54 | 55 | * NVIDIA RTX 4090: 56 | 57 | | Model | AWQ 4bit Speed (ms/token) | FP16 Speed (ms/token) | AWQ Speedup | 58 | | --------------- | ------------------------- | --------------------- | ----------- | 59 | | vicuna-7b | 8.61 | 19.09 | 2.2x | 60 | | llama2-7b-chat | 8.66 | 19.97 | 2.3x | 61 | | vicuna-13b | 12.17 | OOM | / | 62 | | llama2-13b-chat | 13.54 | OOM | / | 63 | 64 | * NVIDIA Jetson Orin: 65 | 66 | | Model | AWQ 4bit Speed (ms/token) | FP16 Speed (ms/token) | AWQ Speedup | 67 | | --------------- | ------------------------- | --------------------- | ----------- | 68 | | vicuna-7b | 65.34 | 93.12 | 1.4x | 69 | | llama2-7b-chat | 75.11 | 104.71 | 1.4x | 70 | | vicuna-13b | 115.40 | OOM | / | 71 | | llama2-13b-chat | 136.81 | OOM | / | 72 | -------------------------------------------------------------------------------- /docs/commands/conv_release.md: -------------------------------------------------------------------------------- 1 | ## Chatbot Arena Conversations 2 | 3 | 1. Gather battles 4 | ``` 5 | python3 clean_battle_data.py --max-num 10 --mode conv_release 6 | ``` 7 | 8 | 2. Tag OpenAI moderation 9 | ``` 10 | python3 tag_openai_moderation.py --in clean_battle_conv_20230814.json 11 | ``` 12 | 13 | 3. Clean PII 14 | 15 | 4. Filter additional blocked words 16 | 17 | ``` 18 | python3 filter_bad_conv.py --in clean_battle_conv_20230630_tagged_v1_pii.json 19 | ``` 20 | 21 | 5. Add additional toxicity tag 22 | 23 | 24 | ## All Conversations 25 | 26 | 1. Gather chats 27 | ``` 28 | python3 clean_chat_data.py 29 | ``` 30 | 31 | 2. Sample 32 | ``` 33 | python3 conv_release_scripts/sample.py 34 | ``` 35 | 36 | 37 | ## Prompt distribution 38 | 39 | -------------------------------------------------------------------------------- /docs/commands/data_cleaning.md: -------------------------------------------------------------------------------- 1 | ## Data cleaning 2 | 3 | ## Requirements 4 | ``` 5 | pip3 install bs4 markdownify 6 | pip3 install polyglot pyicu pycld2 7 | ``` 8 | 9 | ## Steps 10 | ``` 11 | # Convert html to markdown 12 | python3 -m fastchat.data.clean_sharegpt --in sharegpt_html.json --out sharegpt_clean.json 13 | 14 | # Keep or remove specific languages 15 | python3 -m fastchat.data.optional_clean --in sharegpt_clean.json --out sharegpt_clean_lang.json --skip-lang SOME_LANGUAGE_CODE 16 | 17 | # Split long conversations 18 | python3 -m fastchat.data.split_long_conversation --in sharegpt_clean_lang.json --out sharegpt_clean_lang_split.json --model-name /home/ubuntu/model_weights/llama-7b/ 19 | ``` 20 | -------------------------------------------------------------------------------- /docs/commands/leaderboard.md: -------------------------------------------------------------------------------- 1 | ### Get logs 2 | ``` 3 | gsutil -m rsync -r gs://fastchat_logs ~/fastchat_logs/ 4 | ``` 5 | 6 | ### Clean battle data 7 | ``` 8 | cd ~/FastChat/fastchat/serve/monitor 9 | python3 clean_battle_data.py 10 | ``` 11 | 12 | ### Run Elo analysis 13 | ``` 14 | python3 elo_analysis.py --clean-battle-file clean_battle_20230905.json 15 | ``` 16 | 17 | ### Copy files to HF space 18 | 1. update plots 19 | ``` 20 | scp atlas:/data/lmzheng/FastChat/fastchat/serve/monitor/elo_results_20230905.pkl . 21 | ``` 22 | 23 | 2. update table 24 | ``` 25 | wget https://huggingface.co/spaces/lmsys/chatbot-arena-leaderboard/raw/main/leaderboard_table_20230905.csv 26 | ``` 27 | 28 | ### Update files on webserver 29 | ``` 30 | DATE=20231002 31 | 32 | rm -rf elo_results.pkl leaderboard_table.csv 33 | wget https://huggingface.co/spaces/lmsys/chatbot-arena-leaderboard/resolve/main/elo_results_$DATE.pkl 34 | wget https://huggingface.co/spaces/lmsys/chatbot-arena-leaderboard/resolve/main/leaderboard_table_$DATE.csv 35 | ln -s leaderboard_table_$DATE.csv leaderboard_table.csv 36 | ln -s elo_results_$DATE.pkl elo_results.pkl 37 | ``` 38 | -------------------------------------------------------------------------------- /docs/commands/local_cluster.md: -------------------------------------------------------------------------------- 1 | ### Local GPU cluster 2 | node-01 3 | ``` 4 | python3 -m fastchat.serve.controller --host 0.0.0.0 --port 10002 5 | 6 | CUDA_VISIBLE_DEVICES=0 python3 -m fastchat.serve.vllm_worker --model-path lmsys/vicuna-13b-v1.5 --model-name vicuna-13b --controller http://node-01:10002 --host 0.0.0.0 --port 31000 --worker-address http://$(hostname):31000 7 | CUDA_VISIBLE_DEVICES=1 python3 -m fastchat.serve.vllm_worker --model-path lmsys/vicuna-13b-v1.5 --model-name vicuna-13b --controller http://node-01:10002 --host 0.0.0.0 --port 31001 --worker-address http://$(hostname):31001 8 | 9 | CUDA_VISIBLE_DEVICES=2,3 ray start --head 10 | python3 -m fastchat.serve.vllm_worker --model-path lmsys/vicuna-33b-v1.3 --model-name vicuna-33b --controller http://node-01:10002 --host 0.0.0.0 --port 31002 --worker-address http://$(hostname):31002 --num-gpus 2 11 | ``` 12 | 13 | node-02 14 | ``` 15 | CUDA_VISIBLE_DEVICES=0 python3 -m fastchat.serve.vllm_worker --model-path meta-llama/Llama-2-13b-chat-hf --model-name llama-2-13b-chat --controller http://node-01:10002 --host 0.0.0.0 --port 31000 --worker-address http://$(hostname):31000 --tokenizer meta-llama/Llama-2-7b-chat-hf 16 | CUDA_VISIBLE_DEVICES=1 python3 -m fastchat.serve.vllm_worker --model-path meta-llama/Llama-2-13b-chat-hf --model-name llama-2-13b-chat --controller http://node-01:10002 --host 0.0.0.0 --port 31001 --worker-address http://$(hostname):31001 --tokenizer meta-llama/Llama-2-7b-chat-hf 17 | CUDA_VISIBLE_DEVICES=2 python3 -m fastchat.serve.vllm_worker --model-path meta-llama/Llama-2-7b-chat-hf --model-name llama-2-7b-chat --controller http://node-01:10002 --host 0.0.0.0 --port 31002 --worker-address http://$(hostname):31002 --tokenizer meta-llama/Llama-2-7b-chat-hf 18 | CUDA_VISIBLE_DEVICES=3 python3 -m fastchat.serve.vllm_worker --model-path WizardLM/WizardLM-13B-V1.1 --model-name wizardlm-13b --controller http://node-01:10002 --host 0.0.0.0 --port 31003 --worker-address http://$(hostname):31003 19 | ``` 20 | 21 | node-03 22 | ``` 23 | python3 -m fastchat.serve.vllm_worker --model-path mosaicml/mpt-30b-chat --controller http://node-01:10002 --host 0.0.0.0 --port 31000 --worker-address http://$(hostname):31000 --num-gpus 2 24 | python3 -m fastchat.serve.vllm_worker --model-path timdettmers/guanaco-33b-merged --model-name guanaco-33b --controller http://node-01:10002 --host 0.0.0.0 --port 31002 --worker-address http://$(hostname):31002 --num-gpus 2 --tokenizer hf-internal-testing/llama-tokenizer 25 | ``` 26 | 27 | node-04 28 | ``` 29 | CUDA_VISIBLE_DEVICES=0 python3 -m fastchat.serve.multi_model_worker --model-path ~/model_weights/RWKV-4-Raven-14B-v12-Eng98%25-Other2%25-20230523-ctx8192.pth --model-name RWKV-4-Raven-14B --model-path lmsys/fastchat-t5-3b-v1.0 --model-name fastchat-t5-3b --controller http://node-01:10002 --host 0.0.0.0 --port 31000 --worker http://$(hostname):31000 --limit 4 30 | CUDA_VISIBLE_DEVICES=1 python3 -m fastchat.serve.multi_model_worker --model-path OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5 --model-name oasst-pythia-12b --model-path mosaicml/mpt-7b-chat --model-name mpt-7b-chat --controller http://node-01:10002 --host 0.0.0.0 --port 31001 --worker http://$(hostname):31001 --limit 4 31 | CUDA_VISIBLE_DEVICES=2 python3 -m fastchat.serve.multi_model_worker --model-path lmsys/vicuna-7b-v1.5 --model-name vicuna-7b --model-path THUDM/chatglm-6b --model-name chatglm-6b --controller http://node-01:10002 --host 0.0.0.0 --port 31002 --worker http://$(hostname):31002 --limit 4 32 | CUDA_VISIBLE_DEVICES=3 python3 -m fastchat.serve.vllm_worker --model-path ~/model_weights/alpaca-13b --controller http://node-01:10002 --host 0.0.0.0 --port 31003 --worker-address http://$(hostname):31003 33 | ``` 34 | 35 | test 36 | ``` 37 | python3 -m fastchat.serve.test_message --model vicuna-13b --controller http://localhost:10002 38 | ``` 39 | -------------------------------------------------------------------------------- /docs/commands/pypi.md: -------------------------------------------------------------------------------- 1 | ### Requirement 2 | ``` 3 | python3 -m pip install twine 4 | python3 -m pip install --upgrade pip 5 | pip3 install build 6 | ``` 7 | 8 | ### Upload 9 | ``` 10 | bash scripts/upload_pypi.sh 11 | ``` 12 | -------------------------------------------------------------------------------- /docs/commands/webserver.md: -------------------------------------------------------------------------------- 1 | ### Install 2 | ``` 3 | sudo apt update 4 | sudo apt install tmux htop 5 | 6 | wget https://repo.anaconda.com/archive/Anaconda3-2022.10-Linux-x86_64.sh 7 | bash Anaconda3-2022.10-Linux-x86_64.sh 8 | 9 | conda create -n fastchat python=3.9 10 | conda activate fastchat 11 | 12 | git clone https://github.com/lm-sys/FastChat.git 13 | cd FastChat 14 | pip3 install -e . 15 | ``` 16 | 17 | 18 | ### Launch servers 19 | ``` 20 | cd fastchat_logs/controller 21 | python3 -m fastchat.serve.controller --host 0.0.0.0 --port 21001 22 | python3 -m fastchat.serve.register_worker --controller http://localhost:21001 --worker-name https:// 23 | python3 -m fastchat.serve.test_message --model vicuna-13b --controller http://localhost:21001 24 | 25 | cd fastchat_logs/server0 26 | 27 | python3 -m fastchat.serve.huggingface_api_worker --model-info-file ~/elo_results/register_hf_api_models.json 28 | 29 | export OPENAI_API_KEY= 30 | export ANTHROPIC_API_KEY= 31 | export GCP_PROJECT_ID= 32 | 33 | python3 -m fastchat.serve.gradio_web_server_multi --controller http://localhost:21001 --concurrency 50 --add-chatgpt --add-claude --add-palm --elo ~/elo_results/elo_results.pkl --leaderboard-table-file ~/elo_results/leaderboard_table.csv --register ~/elo_results/register_oai_models.json --show-terms 34 | 35 | python3 backup_logs.py 36 | ``` 37 | 38 | 39 | ### Check the launch time 40 | ``` 41 | for i in $(seq 0 11); do cat fastchat_logs/server$i/gradio_web_server.log | grep "Running on local URL" | tail -n 1; done 42 | ``` 43 | 44 | 45 | ### Increase the limit of max open files 46 | One process (do not need reboot) 47 | ``` 48 | sudo prlimit --nofile=1048576:1048576 --pid=$id 49 | 50 | for id in $(ps -ef | grep gradio_web_server | awk '{print $2}'); do echo $id; prlimit --nofile=1048576:1048576 --pid=$id; done 51 | ``` 52 | 53 | System (need reboot): Add the lines below to `/etc/security/limits.conf` 54 | ``` 55 | * hard nofile 65535 56 | * soft nofile 65535 57 | ``` 58 | 59 | 60 | ### Gradio edit (3.35.2) 61 | 1. gtag and canvas 62 | ``` 63 | vim /home/vicuna/anaconda3/envs/fastchat/lib/python3.9/site-packages/gradio/templates/frontend/index.html 64 | ``` 65 | 66 | ``` 67 | 68 | 75 | 76 | ``` 77 | 78 | 2. deprecation warnings 79 | ``` 80 | vim /home/vicuna/anaconda3/envs/fastchat/lib/python3.9/site-packages/gradio/deprecation.py 81 | ``` 82 | 83 | ``` 84 | def check_deprecated_parameters( 85 | ``` 86 | 87 | 3. Loading 88 | ``` 89 | vim /home/vicuna/anaconda3/envs/fastchat/lib/python3.9/site-packages/gradio/templates/frontend/assets/index-188ef5e8.js 90 | ``` 91 | 92 | ``` 93 | %s/"Loading..."/"Loading...(Please refresh if it takes more than 30 seconds)"/g 94 | ``` 95 | -------------------------------------------------------------------------------- /docs/dashinfer_integration.md: -------------------------------------------------------------------------------- 1 | # dash-infer Integration 2 | [DashInfer](https://github.com/modelscope/dash-infer) is a high-performance inference engine specifically optimized for CPU environments, delivering exceptional performance boosts for LLM inference tasks. It supports acceleration for a variety of models including Llama, Qwen, and ChatGLM, making it a versatile choice as a performant worker in FastChat. Notably, DashInfer exhibits significant performance enhancements on both Intel x64 and ARMv9 processors, catering to a wide spectrum of hardware platforms. Its efficient design and optimization techniques ensure rapid and accurate inference capabilities, making it an ideal solution for deploying large language models in resource-constrained environments or scenarios where CPU utilization is preferred over GPU acceleration. 3 | 4 | ## Instructions 5 | 1. Install dash-infer. 6 | ``` 7 | pip install dashinfer 8 | ``` 9 | 10 | 2. When you launch a model worker, replace the normal worker (`fastchat.serve.model_worker`) with the dash-infer worker (`fastchat.serve.dashinfer_worker`). All other commands such as controller, gradio web server, and OpenAI API server are kept the same. 11 | ``` 12 | python3 -m fastchat.serve.dashinfer_worker --model-path qwen/Qwen-7B-Chat --revision=master /path/to/dashinfer-model-generation-config.json 13 | ``` 14 | Here is an example: 15 | ``` 16 | python3 -m fastchat.serve.dashinfer_worker --model-path qwen/Qwen-7B-Chat --revision=master dash-infer/examples/python/model_config/config_qwen_v10_7b.json 17 | ``` 18 | 19 | If you use an already downloaded model, try to replace model-path with a local one and choose a conversation template via --conv-template option 20 | ''' 21 | python3 -m fastchat.serve.dashinfer_worker --model-path ~/.cache/modelscope/hub/qwen/Qwen-7B-Chat --conv-template qwen-7b-chat /path/to/dashinfer-model-generation-config.json 22 | ''' 23 | All avaliable conversation chat templates are listed at [fastchat/conversation.py](../fastchat/conversation.py) 24 | -------------------------------------------------------------------------------- /docs/dataset_release.md: -------------------------------------------------------------------------------- 1 | ## Datasets 2 | We release the following datasets based on our projects and websites. 3 | 4 | - [LMSYS-Chat-1M: A Large-Scale Real-World LLM Conversation Dataset](https://huggingface.co/datasets/lmsys/lmsys-chat-1m) 5 | - [LMSYS-Human-Preference-55k](lmsys/lmsys-arena-human-preference-55k) 6 | - [Chatbot Arena Conversation Dataset](https://huggingface.co/datasets/lmsys/chatbot_arena_conversations) 7 | - [MT-bench Human Annotation Dataset](https://huggingface.co/datasets/lmsys/mt_bench_human_judgments) 8 | -------------------------------------------------------------------------------- /docs/exllama_v2.md: -------------------------------------------------------------------------------- 1 | # ExllamaV2 GPTQ Inference Framework 2 | 3 | Integrated [ExllamaV2](https://github.com/turboderp/exllamav2) customized kernel into Fastchat to provide **Faster** GPTQ inference speed. 4 | 5 | **Note: Exllama not yet support embedding REST API.** 6 | 7 | ## Install ExllamaV2 8 | 9 | Setup environment (please refer to [this link](https://github.com/turboderp/exllamav2#how-to) for more details): 10 | 11 | ```bash 12 | git clone https://github.com/turboderp/exllamav2 13 | cd exllamav2 14 | pip install -e . 15 | ``` 16 | 17 | Chat with the CLI: 18 | ```bash 19 | python3 -m fastchat.serve.cli \ 20 | --model-path models/vicuna-7B-1.1-GPTQ-4bit-128g \ 21 | --enable-exllama 22 | ``` 23 | 24 | Start model worker: 25 | ```bash 26 | # Download quantized model from huggingface 27 | # Make sure you have git-lfs installed (https://git-lfs.com) 28 | git lfs install 29 | git clone https://huggingface.co/TheBloke/vicuna-7B-1.1-GPTQ-4bit-128g models/vicuna-7B-1.1-GPTQ-4bit-128g 30 | 31 | # Load model with default configuration (max sequence length 4096, no GPU split setting). 32 | python3 -m fastchat.serve.model_worker \ 33 | --model-path models/vicuna-7B-1.1-GPTQ-4bit-128g \ 34 | --enable-exllama 35 | 36 | #Load model with max sequence length 2048, allocate 18 GB to CUDA:0 and 24 GB to CUDA:1. 37 | python3 -m fastchat.serve.model_worker \ 38 | --model-path models/vicuna-7B-1.1-GPTQ-4bit-128g \ 39 | --enable-exllama \ 40 | --exllama-max-seq-len 2048 \ 41 | --exllama-gpu-split 18,24 42 | ``` 43 | 44 | `--exllama-cache-8bit` can be used to enable 8-bit caching with exllama and save some VRAM. 45 | 46 | ## Performance 47 | 48 | Reference: https://github.com/turboderp/exllamav2#performance 49 | 50 | 51 | | Model | Mode | Size | grpsz | act | V1: 3090Ti | V1: 4090 | V2: 3090Ti | V2: 4090 | 52 | |------------|--------------|-------|-------|-----|------------|----------|------------|-------------| 53 | | Llama | GPTQ | 7B | 128 | no | 143 t/s | 173 t/s | 175 t/s | **195** t/s | 54 | | Llama | GPTQ | 13B | 128 | no | 84 t/s | 102 t/s | 105 t/s | **110** t/s | 55 | | Llama | GPTQ | 33B | 128 | yes | 37 t/s | 45 t/s | 45 t/s | **48** t/s | 56 | | OpenLlama | GPTQ | 3B | 128 | yes | 194 t/s | 226 t/s | 295 t/s | **321** t/s | 57 | | CodeLlama | EXL2 4.0 bpw | 34B | - | - | - | - | 42 t/s | **48** t/s | 58 | | Llama2 | EXL2 3.0 bpw | 7B | - | - | - | - | 195 t/s | **224** t/s | 59 | | Llama2 | EXL2 4.0 bpw | 7B | - | - | - | - | 164 t/s | **197** t/s | 60 | | Llama2 | EXL2 5.0 bpw | 7B | - | - | - | - | 144 t/s | **160** t/s | 61 | | Llama2 | EXL2 2.5 bpw | 70B | - | - | - | - | 30 t/s | **35** t/s | 62 | | TinyLlama | EXL2 3.0 bpw | 1.1B | - | - | - | - | 536 t/s | **635** t/s | 63 | | TinyLlama | EXL2 4.0 bpw | 1.1B | - | - | - | - | 509 t/s | **590** t/s | 64 | -------------------------------------------------------------------------------- /docs/gptq.md: -------------------------------------------------------------------------------- 1 | # GPTQ 4bit Inference 2 | 3 | Support GPTQ 4bit inference with [GPTQ-for-LLaMa](https://github.com/qwopqwop200/GPTQ-for-LLaMa). 4 | 5 | 1. Window user: use the `old-cuda` branch. 6 | 2. Linux user: recommend the `fastest-inference-4bit` branch. 7 | 8 | ## Install 9 | 10 | Setup environment: 11 | ```bash 12 | # cd /path/to/FastChat 13 | git clone https://github.com/qwopqwop200/GPTQ-for-LLaMa.git repositories/GPTQ-for-LLaMa 14 | cd repositories/GPTQ-for-LLaMa 15 | # Window's user should use the `old-cuda` branch 16 | git switch fastest-inference-4bit 17 | # Install `quant-cuda` package in FastChat's virtualenv 18 | python3 setup_cuda.py install 19 | pip3 install texttable 20 | ``` 21 | 22 | Chat with the CLI: 23 | ```bash 24 | python3 -m fastchat.serve.cli \ 25 | --model-path models/vicuna-7B-1.1-GPTQ-4bit-128g \ 26 | --gptq-wbits 4 \ 27 | --gptq-groupsize 128 28 | ``` 29 | 30 | Start model worker: 31 | ```bash 32 | # Download quantized model from huggingface 33 | # Make sure you have git-lfs installed (https://git-lfs.com) 34 | git lfs install 35 | git clone https://huggingface.co/TheBloke/vicuna-7B-1.1-GPTQ-4bit-128g models/vicuna-7B-1.1-GPTQ-4bit-128g 36 | 37 | python3 -m fastchat.serve.model_worker \ 38 | --model-path models/vicuna-7B-1.1-GPTQ-4bit-128g \ 39 | --gptq-wbits 4 \ 40 | --gptq-groupsize 128 41 | 42 | # You can specify which quantized model to use 43 | python3 -m fastchat.serve.model_worker \ 44 | --model-path models/vicuna-7B-1.1-GPTQ-4bit-128g \ 45 | --gptq-ckpt models/vicuna-7B-1.1-GPTQ-4bit-128g/vicuna-7B-1.1-GPTQ-4bit-128g.safetensors \ 46 | --gptq-wbits 4 \ 47 | --gptq-groupsize 128 \ 48 | --gptq-act-order 49 | ``` 50 | 51 | ## Benchmark 52 | 53 | | LLaMA-13B | branch | Bits | group-size | memory(MiB) | PPL(c4) | Median(s/token) | act-order | speed up | 54 | | --------- | ---------------------- | ---- | ---------- | ----------- | ------- | --------------- | --------- | -------- | 55 | | FP16 | fastest-inference-4bit | 16 | - | 26634 | 6.96 | 0.0383 | - | 1x | 56 | | GPTQ | triton | 4 | 128 | 8590 | 6.97 | 0.0551 | - | 0.69x | 57 | | GPTQ | fastest-inference-4bit | 4 | 128 | 8699 | 6.97 | 0.0429 | true | 0.89x | 58 | | GPTQ | fastest-inference-4bit | 4 | 128 | 8699 | 7.03 | 0.0287 | false | 1.33x | 59 | | GPTQ | fastest-inference-4bit | 4 | -1 | 8448 | 7.12 | 0.0284 | false | 1.44x | 60 | -------------------------------------------------------------------------------- /docs/langchain_integration.md: -------------------------------------------------------------------------------- 1 | # Local LangChain with FastChat 2 | 3 | [LangChain](https://python.langchain.com/en/latest/index.html) is a library that facilitates the development of applications by leveraging large language models (LLMs) and enabling their composition with other sources of computation or knowledge. 4 | FastChat's OpenAI-compatible [API server](openai_api.md) enables using LangChain with open models seamlessly. 5 | 6 | ## Launch RESTful API Server 7 | 8 | Here are the steps to launch a local OpenAI API server for LangChain. 9 | 10 | First, launch the controller 11 | 12 | ```bash 13 | python3 -m fastchat.serve.controller 14 | ``` 15 | 16 | LangChain uses OpenAI model names by default, so we need to assign some faux OpenAI model names to our local model. 17 | Here, we use Vicuna as an example and use it for three endpoints: chat completion, completion, and embedding. 18 | `--model-path` can be a local folder or a Hugging Face repo name. 19 | See a full list of supported models [here](../README.md#supported-models). 20 | 21 | ```bash 22 | python3 -m fastchat.serve.model_worker --model-names "gpt-3.5-turbo,text-davinci-003,text-embedding-ada-002" --model-path lmsys/vicuna-7b-v1.5 23 | ``` 24 | 25 | Finally, launch the RESTful API server 26 | 27 | ```bash 28 | python3 -m fastchat.serve.openai_api_server --host localhost --port 8000 29 | ``` 30 | 31 | ## Set OpenAI Environment 32 | 33 | You can set your environment with the following commands. 34 | 35 | Set OpenAI base url 36 | 37 | ```bash 38 | export OPENAI_API_BASE=http://localhost:8000/v1 39 | ``` 40 | 41 | Set OpenAI API key 42 | 43 | ```bash 44 | export OPENAI_API_KEY=EMPTY 45 | ``` 46 | 47 | If you meet the following OOM error while creating embeddings, please set a smaller batch size by using environment variables. 48 | 49 | ~~~bash 50 | openai.error.APIError: Invalid response object from API: '{"object":"error","message":"**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**\\n\\n(CUDA out of memory. Tried to allocate xxx MiB (GPU 0; xxx GiB total capacity; xxx GiB already allocated; xxx MiB free; xxx GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF)","code":50002}' (HTTP response code was 400) 51 | ~~~ 52 | 53 | You can try `export FASTCHAT_WORKER_API_EMBEDDING_BATCH_SIZE=1`. 54 | 55 | ## Try local LangChain 56 | 57 | Here is a question answerting example. 58 | 59 | Download a text file. 60 | 61 | ```bash 62 | wget https://raw.githubusercontent.com/hwchase17/langchain/v0.0.200/docs/modules/state_of_the_union.txt 63 | ``` 64 | 65 | Run LangChain. 66 | 67 | ~~~py 68 | from langchain.chat_models import ChatOpenAI 69 | from langchain.document_loaders import TextLoader 70 | from langchain.embeddings import OpenAIEmbeddings 71 | from langchain.indexes import VectorstoreIndexCreator 72 | 73 | embedding = OpenAIEmbeddings(model="text-embedding-ada-002") 74 | loader = TextLoader("state_of_the_union.txt") 75 | index = VectorstoreIndexCreator(embedding=embedding).from_loaders([loader]) 76 | llm = ChatOpenAI(model="gpt-3.5-turbo") 77 | 78 | questions = [ 79 | "Who is the speaker", 80 | "What did the president say about Ketanji Brown Jackson", 81 | "What are the threats to America", 82 | "Who are mentioned in the speech", 83 | "Who is the vice president", 84 | "How many projects were announced", 85 | ] 86 | 87 | for query in questions: 88 | print("Query:", query) 89 | print("Answer:", index.query(query, llm=llm)) 90 | ~~~ 91 | -------------------------------------------------------------------------------- /docs/lightllm_integration.md: -------------------------------------------------------------------------------- 1 | # LightLLM Integration 2 | You can use [LightLLM](https://github.com/ModelTC/lightllm) as an optimized worker implementation in FastChat. 3 | It offers advanced continuous batching and a much higher (~10x) throughput. 4 | See the supported models [here](https://github.com/ModelTC/lightllm?tab=readme-ov-file#supported-model-list). 5 | 6 | ## Instructions 7 | 1. Please refer to the [Get started](https://github.com/ModelTC/lightllm?tab=readme-ov-file#get-started) to install LightLLM. Or use [Pre-built image](https://github.com/ModelTC/lightllm?tab=readme-ov-file#container) 8 | 9 | 2. When you launch a model worker, replace the normal worker (`fastchat.serve.model_worker`) with the LightLLM worker (`fastchat.serve.lightllm_worker`). All other commands such as controller, gradio web server, and OpenAI API server are kept the same. Refer to [--max_total_token_num](https://github.com/ModelTC/lightllm/blob/4a9824b6b248f4561584b8a48ae126a0c8f5b000/docs/ApiServerArgs.md?plain=1#L23) to understand how to calculate the `--max_total_token_num` argument. 10 | ``` 11 | python3 -m fastchat.serve.lightllm_worker --model-path lmsys/vicuna-7b-v1.5 --tokenizer_mode "auto" --max_total_token_num 154000 12 | ``` 13 | 14 | If you what to use quantized weight and kv cache for inference, try 15 | 16 | ``` 17 | python3 -m fastchat.serve.lightllm_worker --model-path lmsys/vicuna-7b-v1.5 --tokenizer_mode "auto" --max_total_token_num 154000 --mode triton_int8weight triton_int8kv 18 | ``` 19 | -------------------------------------------------------------------------------- /docs/mlx_integration.md: -------------------------------------------------------------------------------- 1 | # Apple MLX Integration 2 | 3 | You can use [Apple MLX](https://github.com/ml-explore/mlx) as an optimized worker implementation in FastChat. 4 | 5 | It runs models efficiently on Apple Silicon 6 | 7 | See the supported models [here](https://github.com/ml-explore/mlx-examples/tree/main/llms#supported-models). 8 | 9 | Note that for Apple Silicon Macs with less memory, smaller models (or quantized models) are recommended. 10 | 11 | ## Instructions 12 | 13 | 1. Install MLX. 14 | 15 | ``` 16 | pip install "mlx-lm>=0.0.6" 17 | ``` 18 | 19 | 2. When you launch a model worker, replace the normal worker (`fastchat.serve.model_worker`) with the MLX worker (`fastchat.serve.mlx_worker`). Remember to launch a model worker after you have launched the controller ([instructions](../README.md)) 20 | 21 | ``` 22 | python3 -m fastchat.serve.mlx_worker --model-path TinyLlama/TinyLlama-1.1B-Chat-v1.0 23 | ``` 24 | -------------------------------------------------------------------------------- /docs/server_arch.md: -------------------------------------------------------------------------------- 1 | # FastChat Server Architecture 2 | ![server arch](../assets/server_arch.png) 3 | -------------------------------------------------------------------------------- /docs/third_party_ui.md: -------------------------------------------------------------------------------- 1 | # Third Party UI 2 | If you want to host it on your own UI or third party UI, you can launch the [OpenAI compatible server](openai_api.md) and host with a tunnelling service such as Tunnelmole or ngrok, and then enter the credentials appropriately. 3 | 4 | You can find suitable UIs from third party repos: 5 | - [WongSaang's ChatGPT UI](https://github.com/WongSaang/chatgpt-ui) 6 | - [McKayWrigley's Chatbot UI](https://github.com/mckaywrigley/chatbot-ui) 7 | 8 | - Please note that some third-party providers only offer the standard `gpt-3.5-turbo`, `gpt-4`, etc., so you will have to add your own custom model inside the code. [Here is an example of how to create a UI with any custom model name](https://github.com/ztjhz/BetterChatGPT/pull/461). 9 | 10 | ##### Using Tunnelmole 11 | Tunnelmole is an open source tunnelling tool. You can find its source code on [Github](https://github.com/robbie-cahill/tunnelmole-client). Here's how you can use Tunnelmole: 12 | 1. Install Tunnelmole with `curl -O https://install.tunnelmole.com/9Wtxu/install && sudo bash install`. (On Windows, download [tmole.exe](https://tunnelmole.com/downloads/tmole.exe)). Head over to the [README](https://github.com/robbie-cahill/tunnelmole-client) for other methods such as `npm` or building from source. 13 | 2. Run `tmole 7860` (replace `7860` with your listening port if it is different from 7860). The output will display two URLs: one HTTP and one HTTPS. It's best to use the HTTPS URL for better privacy and security. 14 | ``` 15 | ➜ ~ tmole 7860 16 | http://bvdo5f-ip-49-183-170-144.tunnelmole.net is forwarding to localhost:7860 17 | https://bvdo5f-ip-49-183-170-144.tunnelmole.net is forwarding to localhost:7860 18 | ``` 19 | 20 | ##### Using ngrok 21 | ngrok is a popular closed source tunnelling tool. First download and install it from [ngrok.com](https://ngrok.com/downloads). Here's how to use it to expose port 7860. 22 | ``` 23 | ngrok http 7860 24 | ``` 25 | -------------------------------------------------------------------------------- /docs/vllm_integration.md: -------------------------------------------------------------------------------- 1 | # vLLM Integration 2 | You can use [vLLM](https://vllm.ai/) as an optimized worker implementation in FastChat. 3 | It offers advanced continuous batching and a much higher (~10x) throughput. 4 | See the supported models [here](https://vllm.readthedocs.io/en/latest/models/supported_models.html). 5 | 6 | ## Instructions 7 | 1. Install vLLM. 8 | ``` 9 | pip install vllm 10 | ``` 11 | 12 | 2. When you launch a model worker, replace the normal worker (`fastchat.serve.model_worker`) with the vLLM worker (`fastchat.serve.vllm_worker`). All other commands such as controller, gradio web server, and OpenAI API server are kept the same. 13 | ``` 14 | python3 -m fastchat.serve.vllm_worker --model-path lmsys/vicuna-7b-v1.5 15 | ``` 16 | 17 | If you see tokenizer errors, try 18 | ``` 19 | python3 -m fastchat.serve.vllm_worker --model-path lmsys/vicuna-7b-v1.5 --tokenizer hf-internal-testing/llama-tokenizer 20 | ``` 21 | 22 | If you use an AWQ quantized model, try 23 | ''' 24 | python3 -m fastchat.serve.vllm_worker --model-path TheBloke/vicuna-7B-v1.5-AWQ --quantization awq 25 | ''' 26 | -------------------------------------------------------------------------------- /docs/xFasterTransformer.md: -------------------------------------------------------------------------------- 1 | # xFasterTransformer Inference Framework 2 | 3 | Integrated [xFasterTransformer](https://github.com/intel/xFasterTransformer) customized framework into Fastchat to provide **Faster** inference speed on Intel CPU. 4 | 5 | ## Install xFasterTransformer 6 | 7 | Setup environment (please refer to [this link](https://github.com/intel/xFasterTransformer#installation) for more details): 8 | 9 | ```bash 10 | pip install xfastertransformer 11 | ``` 12 | 13 | ## Prepare models 14 | 15 | Prepare Model (please refer to [this link](https://github.com/intel/xFasterTransformer#prepare-model) for more details): 16 | ```bash 17 | python ./tools/chatglm_convert.py -i ${HF_DATASET_DIR} -o ${OUTPUT_DIR} 18 | ``` 19 | 20 | ## Parameters of xFasterTransformer 21 | --enable-xft to enable xfastertransformer in Fastchat 22 | --xft-max-seq-len to set the max token length the model can process. max token length include input token length. 23 | --xft-dtype to set datatype used in xFasterTransformer for computation. xFasterTransformer can support fp32, fp16, int8, bf16 and hybrid data types like : bf16_fp16, bf16_int8. For datatype details please refer to [this link](https://github.com/intel/xFasterTransformer/wiki/Data-Type-Support-Platform) 24 | 25 | 26 | Chat with the CLI: 27 | ```bash 28 | #run inference on all CPUs and using float16 29 | python3 -m fastchat.serve.cli \ 30 | --model-path /path/to/models \ 31 | --enable-xft \ 32 | --xft-dtype fp16 33 | ``` 34 | or with numactl on multi-socket server for better performance 35 | ```bash 36 | #run inference on numanode 0 and with data type bf16_fp16 (first token uses bfloat16, and rest tokens use float16) 37 | numactl -N 0 --localalloc \ 38 | python3 -m fastchat.serve.cli \ 39 | --model-path /path/to/models/chatglm2_6b_cpu/ \ 40 | --enable-xft \ 41 | --xft-dtype bf16_fp16 42 | ``` 43 | or using MPI to run inference on 2 sockets for better performance 44 | ```bash 45 | #run inference on numanode 0 and 1 and with data type bf16_fp16 (first token uses bfloat16, and rest tokens use float16) 46 | OMP_NUM_THREADS=$CORE_NUM_PER_SOCKET LD_PRELOAD=libiomp5.so mpirun \ 47 | -n 1 numactl -N 0 --localalloc \ 48 | python -m fastchat.serve.cli \ 49 | --model-path /path/to/models/chatglm2_6b_cpu/ \ 50 | --enable-xft \ 51 | --xft-dtype bf16_fp16 : \ 52 | -n 1 numactl -N 1 --localalloc \ 53 | python -m fastchat.serve.cli \ 54 | --model-path /path/to/models/chatglm2_6b_cpu/ \ 55 | --enable-xft \ 56 | --xft-dtype bf16_fp16 57 | ``` 58 | 59 | 60 | Start model worker: 61 | ```bash 62 | # Load model with default configuration (max sequence length 4096, no GPU split setting). 63 | python3 -m fastchat.serve.model_worker \ 64 | --model-path /path/to/models \ 65 | --enable-xft \ 66 | --xft-dtype bf16_fp16 67 | ``` 68 | or with numactl on multi-socket server for better performance 69 | ```bash 70 | #run inference on numanode 0 and with data type bf16_fp16 (first token uses bfloat16, and rest tokens use float16) 71 | numactl -N 0 --localalloc python3 -m fastchat.serve.model_worker \ 72 | --model-path /path/to/models \ 73 | --enable-xft \ 74 | --xft-dtype bf16_fp16 75 | ``` 76 | or using MPI to run inference on 2 sockets for better performance 77 | ```bash 78 | #run inference on numanode 0 and 1 and with data type bf16_fp16 (first token uses bfloat16, and rest tokens use float16) 79 | OMP_NUM_THREADS=$CORE_NUM_PER_SOCKET LD_PRELOAD=libiomp5.so mpirun \ 80 | -n 1 numactl -N 0 --localalloc python -m fastchat.serve.model_worker \ 81 | --model-path /path/to/models \ 82 | --enable-xft \ 83 | --xft-dtype bf16_fp16 : \ 84 | -n 1 numactl -N 1 --localalloc python -m fastchat.serve.model_worker \ 85 | --model-path /path/to/models \ 86 | --enable-xft \ 87 | --xft-dtype bf16_fp16 88 | ``` 89 | 90 | For more details, please refer to [this link](https://github.com/intel/xFasterTransformer#how-to-run) 91 | -------------------------------------------------------------------------------- /fastchat/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.2.36" 2 | -------------------------------------------------------------------------------- /fastchat/constants.py: -------------------------------------------------------------------------------- 1 | """ 2 | Global constants. 3 | """ 4 | 5 | from enum import IntEnum 6 | import os 7 | 8 | REPO_PATH = os.path.dirname(os.path.dirname(__file__)) 9 | 10 | # Survey Link URL (to be removed) #00729c 11 | # SURVEY_LINK = """
12 | #
13 | # New Launch! Copilot Arena: VS Code Extension to compare Top LLMs 14 | #
15 | #
""" 16 | # SURVEY_LINK = "" 17 | 18 | COLOR = "#008B8B" 19 | SURVEY_LINK = f"""
20 |
21 | New Arena UI at BETA.lmarena.ai! Check it out & give feedback! 22 |
23 |
""" 24 | 25 | ##### For the gradio web server 26 | SERVER_ERROR_MSG = ( 27 | "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**" 28 | ) 29 | TEXT_MODERATION_MSG = ( 30 | "$MODERATION$ YOUR TEXT VIOLATES OUR CONTENT MODERATION GUIDELINES." 31 | ) 32 | IMAGE_MODERATION_MSG = ( 33 | "$MODERATION$ YOUR IMAGE VIOLATES OUR CONTENT MODERATION GUIDELINES." 34 | ) 35 | MODERATION_MSG = "$MODERATION$ YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES." 36 | CONVERSATION_LIMIT_MSG = "YOU HAVE REACHED THE CONVERSATION LENGTH LIMIT. PLEASE CLEAR HISTORY AND START A NEW CONVERSATION." 37 | INACTIVE_MSG = "THIS SESSION HAS BEEN INACTIVE FOR TOO LONG. PLEASE REFRESH THIS PAGE." 38 | SLOW_MODEL_MSG = ( 39 | "⚠️ Models are thinking. Please stay patient as it may take over a minute." 40 | ) 41 | RATE_LIMIT_MSG = "**RATE LIMIT OF THIS MODEL IS REACHED. PLEASE COME BACK LATER OR USE [BATTLE MODE](https://lmarena.ai) (the 1st tab).**" 42 | # Maximum input length 43 | INPUT_CHAR_LEN_LIMIT = int(os.getenv("FASTCHAT_INPUT_CHAR_LEN_LIMIT", 12000)) 44 | BLIND_MODE_INPUT_CHAR_LEN_LIMIT = int( 45 | os.getenv("FASTCHAT_BLIND_MODE_INPUT_CHAR_LEN_LIMIT", 30000) 46 | ) 47 | # Maximum conversation turns 48 | CONVERSATION_TURN_LIMIT = 50 49 | # Session expiration time 50 | SESSION_EXPIRATION_TIME = 3600 51 | # The output dir of log files 52 | LOGDIR = os.getenv("LOGDIR", ".") 53 | # CPU Instruction Set Architecture 54 | CPU_ISA = os.getenv("CPU_ISA") 55 | 56 | 57 | ##### For the controller and workers (could be overwritten through ENV variables.) 58 | CONTROLLER_HEART_BEAT_EXPIRATION = int( 59 | os.getenv("FASTCHAT_CONTROLLER_HEART_BEAT_EXPIRATION", 90) 60 | ) 61 | WORKER_HEART_BEAT_INTERVAL = int(os.getenv("FASTCHAT_WORKER_HEART_BEAT_INTERVAL", 45)) 62 | WORKER_API_TIMEOUT = int(os.getenv("FASTCHAT_WORKER_API_TIMEOUT", 100)) 63 | WORKER_API_EMBEDDING_BATCH_SIZE = int( 64 | os.getenv("FASTCHAT_WORKER_API_EMBEDDING_BATCH_SIZE", 4) 65 | ) 66 | 67 | 68 | class ErrorCode(IntEnum): 69 | """ 70 | https://platform.openai.com/docs/guides/error-codes/api-errors 71 | """ 72 | 73 | VALIDATION_TYPE_ERROR = 40001 74 | 75 | INVALID_AUTH_KEY = 40101 76 | INCORRECT_AUTH_KEY = 40102 77 | NO_PERMISSION = 40103 78 | 79 | INVALID_MODEL = 40301 80 | PARAM_OUT_OF_RANGE = 40302 81 | CONTEXT_OVERFLOW = 40303 82 | 83 | RATE_LIMIT = 42901 84 | QUOTA_EXCEEDED = 42902 85 | ENGINE_OVERLOADED = 42903 86 | 87 | INTERNAL_ERROR = 50001 88 | CUDA_OUT_OF_MEMORY = 50002 89 | GRADIO_REQUEST_ERROR = 50003 90 | GRADIO_STREAM_UNKNOWN_ERROR = 50004 91 | CONTROLLER_NO_WORKER = 50005 92 | CONTROLLER_WORKER_TIMEOUT = 50006 93 | -------------------------------------------------------------------------------- /fastchat/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lm-sys/FastChat/9a295b64ce3491ff15901f2d00f5e304b0ee78dc/fastchat/data/__init__.py -------------------------------------------------------------------------------- /fastchat/data/convert_alpaca.py: -------------------------------------------------------------------------------- 1 | """ 2 | Convert alpaca dataset into sharegpt format. 3 | 4 | Usage: python3 -m fastchat.data.convert_alpaca --in alpaca_data.json 5 | """ 6 | 7 | import argparse 8 | import json 9 | 10 | from transformers import AutoTokenizer, AutoModelForCausalLM 11 | import numpy as np 12 | 13 | 14 | if __name__ == "__main__": 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument("--in-file", type=str) 17 | parser.add_argument("--out-file", type=str) 18 | args = parser.parse_args() 19 | 20 | content = json.load(open(args.in_file, "r")) 21 | new_content = [] 22 | for i, c in enumerate(content): 23 | if len(c["input"].strip()) > 1: 24 | q, a = c["instruction"] + "\nInput:\n" + c["input"], c["output"] 25 | else: 26 | q, a = c["instruction"], c["output"] 27 | new_content.append( 28 | { 29 | "id": f"alpaca_{i}", 30 | "conversations": [ 31 | {"from": "human", "value": q}, 32 | {"from": "gpt", "value": a}, 33 | ], 34 | } 35 | ) 36 | 37 | print(f"#out: {len(new_content)}") 38 | json.dump(new_content, open(args.out_file, "w"), indent=2, ensure_ascii=False) 39 | -------------------------------------------------------------------------------- /fastchat/data/extract_gpt4_only.py: -------------------------------------------------------------------------------- 1 | """ 2 | Extract the conversations generated by GPT-4 only. 3 | 4 | Usage: python3 -m fastchat.data.extract_gpt4_only --in sharegpt.json 5 | """ 6 | import argparse 7 | import json 8 | 9 | 10 | if __name__ == "__main__": 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument("--in-file", type=str, required=True) 13 | parser.add_argument("--out-file", type=str) 14 | parser.add_argument("--begin", type=int) 15 | parser.add_argument("--end", type=int) 16 | args = parser.parse_args() 17 | 18 | content = json.load(open(args.in_file, "r")) 19 | content = content[args.begin : args.end] 20 | new_content = [] 21 | for c in content: 22 | model = c.get("model", None) 23 | if model == "gpt4" or model is None: 24 | new_content.append(c) 25 | 26 | if args.out_file: 27 | out_file = args.out_file 28 | else: 29 | out_file = args.in_file.replace(".json", "_gpt4.json") 30 | 31 | print(f"#in: {len(content)}, #out: {len(new_content)}") 32 | json.dump(new_content, open(out_file, "w"), indent=2, ensure_ascii=False) 33 | -------------------------------------------------------------------------------- /fastchat/data/extract_single_round.py: -------------------------------------------------------------------------------- 1 | """ 2 | Extract the first round of the conversations. 3 | 4 | Usage: python3 -m fastchat.data.extract_single_round --in sharegpt.json 5 | """ 6 | import argparse 7 | import json 8 | 9 | 10 | if __name__ == "__main__": 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument("--in-file", type=str, required=True) 13 | parser.add_argument("--out-file", type=str) 14 | parser.add_argument("--begin", type=int) 15 | parser.add_argument("--end", type=int) 16 | args = parser.parse_args() 17 | 18 | content = json.load(open(args.in_file, "r")) 19 | content = content[args.begin : args.end] 20 | for c in content: 21 | c["conversations"] = c["conversations"][:2] 22 | 23 | if args.out_file: 24 | out_file = args.out_file 25 | else: 26 | out_file = args.in_file.replace(".json", "_single.json") 27 | 28 | print(f"#in: {len(content)}, #out: {len(content)}") 29 | json.dump(content, open(out_file, "w"), indent=2, ensure_ascii=False) 30 | -------------------------------------------------------------------------------- /fastchat/data/filter_wrong_format.py: -------------------------------------------------------------------------------- 1 | """ 2 | Filter conversations with wrong formats. 3 | 4 | Usage: 5 | python3 -m fastchat.data.filter_wrong_format --in input.json --out output.json 6 | 7 | """ 8 | import argparse 9 | import json 10 | import re 11 | 12 | from tqdm import tqdm 13 | 14 | wrong_indices_pattern = re.compile("\n1\. [^2]*\n1\. ") 15 | 16 | 17 | def should_skip(conv): 18 | # Filter wrong list indices like https://sharegpt.com/c/1pREAGO 19 | for sentence in conv["conversations"]: 20 | val = sentence["value"] 21 | sub = re.search(wrong_indices_pattern, val) 22 | if sub is not None: 23 | return True 24 | 25 | return False 26 | 27 | 28 | if __name__ == "__main__": 29 | parser = argparse.ArgumentParser() 30 | parser.add_argument("--in-file", type=str, required=True) 31 | parser.add_argument("--out-file", type=str, required=True) 32 | args = parser.parse_args() 33 | 34 | content = json.load(open(args.in_file, "r")) 35 | 36 | new_content = [] 37 | for conv in tqdm(content): 38 | if should_skip(conv): 39 | print(f"{conv['id']} contains a wrong format.") 40 | else: 41 | new_content.append(conv) 42 | 43 | print(f"#in: {len(content)}, #out: {len(new_content)}") 44 | json.dump(new_content, open(args.out_file, "w"), indent=2, ensure_ascii=False) 45 | -------------------------------------------------------------------------------- /fastchat/data/get_stats.py: -------------------------------------------------------------------------------- 1 | """ 2 | Get stats of a dataset. 3 | 4 | Usage: python3 -m fastchat.data.get_stats --in sharegpt.json 5 | """ 6 | 7 | import argparse 8 | from concurrent.futures import ProcessPoolExecutor 9 | import json 10 | 11 | import numpy as np 12 | from tqdm import tqdm 13 | from transformers import AutoTokenizer, AutoModelForCausalLM 14 | 15 | K = 1e3 16 | M = 1e6 17 | 18 | 19 | def tokenize_one_sample(c): 20 | for i in range(len(c["conversations"])): 21 | v = c["conversations"][i]["value"] 22 | c["conversations"][i]["value"] = tokenizer.tokenize(v) 23 | return c 24 | 25 | 26 | def tokenize_dataset(content): 27 | processed = [] 28 | with ProcessPoolExecutor() as executor: 29 | for result in tqdm( 30 | executor.map(tokenize_one_sample, content), total=len(content) 31 | ): 32 | processed.append(result) 33 | 34 | return processed 35 | 36 | 37 | def compute_stats(content): 38 | sample_lens = [] 39 | sample_turns = [] 40 | prompt_lens = [] 41 | res_lens = [] 42 | 43 | for c in content: 44 | sample_len = 0 45 | sample_turns.append(len(c["conversations"]) // 2) 46 | for i in range(len(c["conversations"]) // 2): 47 | p = c["conversations"][i * 2]["value"] 48 | r = c["conversations"][i * 2 + 1]["value"] 49 | 50 | turn_len = len(p) + len(r) 51 | sample_len += turn_len 52 | prompt_lens.append(len(p)) 53 | res_lens.append(len(r)) 54 | sample_lens.append(sample_len) 55 | 56 | return sample_lens, sample_turns, prompt_lens, res_lens 57 | 58 | 59 | if __name__ == "__main__": 60 | parser = argparse.ArgumentParser() 61 | parser.add_argument("--in-file", type=str) 62 | parser.add_argument( 63 | "--model-name-or-path", type=str, default="meta-llama/Llama-2-7b-chat-hf" 64 | ) 65 | args = parser.parse_args() 66 | 67 | content = json.load(open(args.in_file, "r")) 68 | tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=False) 69 | content = tokenize_dataset(content) 70 | 71 | sample_lens, sample_turns, prompt_lens, res_lens = compute_stats(content) 72 | print(f"#sequence: {len(content)/K:.2f} K") 73 | print(f"#tokens: {np.sum(sample_lens)/M:.2f} M") 74 | print(f"avg. turns: {np.mean(sample_turns):.2f}") 75 | print(f"avg. prompt length: {np.mean(prompt_lens):.2f}") 76 | print(f"avg. response length: {np.mean(res_lens):.2f}") 77 | 78 | print("\n- Histogram -") 79 | bin_edges = [0, 1024, 2048, 4096, 8192, 16384, 32768] 80 | hist = np.histogram(sample_lens, bins=bin_edges)[0] 81 | for i in range(len(hist)): 82 | print(f"L{bin_edges[i]} - {bin_edges[i+1]}: {hist[i]}") 83 | -------------------------------------------------------------------------------- /fastchat/data/inspect_data.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 -m fastchat.data.inspect_data --in sharegpt_20230322_clean_lang_split.json 4 | """ 5 | import argparse 6 | import json 7 | import random 8 | 9 | 10 | if __name__ == "__main__": 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument("--in-file", type=str, required=True) 13 | parser.add_argument("--begin", type=int) 14 | parser.add_argument("--random-n", type=int) 15 | args = parser.parse_args() 16 | 17 | content = json.load(open(args.in_file, "r")) 18 | 19 | if args.random_n: 20 | indices = [random.randint(0, len(content) - 1) for _ in range(args.random_n)] 21 | elif args.begin: 22 | indices = range(args.begin, len(content)) 23 | else: 24 | indices = range(0, len(content)) 25 | 26 | for idx in indices: 27 | sample = content[idx] 28 | print("=" * 40) 29 | print(f"no: {idx}, id: {sample['id']}") 30 | for conv in sample["conversations"]: 31 | print(conv["from"] + ": ") 32 | print(conv["value"]) 33 | input() 34 | -------------------------------------------------------------------------------- /fastchat/data/merge.py: -------------------------------------------------------------------------------- 1 | """ 2 | Merge two conversation files into one 3 | 4 | Usage: python3 -m fastchat.data.merge --in file1.json file2.json --out merged.json 5 | """ 6 | 7 | import argparse 8 | import json 9 | 10 | 11 | if __name__ == "__main__": 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument("--in-file", type=str, required=True, nargs="+") 14 | parser.add_argument("--out-file", type=str, default="merged.json") 15 | args = parser.parse_args() 16 | 17 | new_content = [] 18 | for in_file in args.in_file: 19 | content = json.load(open(in_file, "r")) 20 | new_content.extend(content) 21 | 22 | print(f"#out: {len(new_content)}") 23 | json.dump(new_content, open(args.out_file, "w"), indent=2, ensure_ascii=False) 24 | -------------------------------------------------------------------------------- /fastchat/data/optional_clean.py: -------------------------------------------------------------------------------- 1 | """ 2 | Do optional cleaning (e.g., remove some languages). 3 | 4 | Usage: 5 | python3 -m fastchat.data.optional_clean --in input.json --out output.json --keep-lang en 6 | python3 -m fastchat.data.optional_clean --in input.json --out output.json --skip-lang en 7 | 8 | Requirement: 9 | pip3 install polyglot pyicu pycld2 10 | """ 11 | import argparse 12 | import json 13 | import re 14 | 15 | import polyglot 16 | from polyglot.detect import Detector 17 | import pycld2 18 | from tqdm import tqdm 19 | 20 | 21 | def skip(conv, args): 22 | # Remove certain languages 23 | if args.keep_lang != "all" or args.skip_lang is not None: 24 | text = "\n".join([x["value"] for x in conv["conversations"]]) 25 | try: 26 | lang_code = Detector(text).language.code 27 | except (pycld2.error, polyglot.detect.base.UnknownLanguage): 28 | lang_code = "unknown" 29 | 30 | if args.keep_lang != "all" and lang_code != args.keep_lang: 31 | return True 32 | 33 | if lang_code == args.skip_lang: 34 | return True 35 | 36 | # Remove repetitive numbers 37 | if args.reduce_rep: 38 | for sentence in conv["conversations"]: 39 | val = sentence["value"] 40 | sub = re.search(r"(\d)\1{8}", val) 41 | if sub is not None: 42 | return True 43 | 44 | return False 45 | 46 | 47 | if __name__ == "__main__": 48 | parser = argparse.ArgumentParser() 49 | parser.add_argument("--in-file", type=str, required=True) 50 | parser.add_argument("--out-file", type=str) 51 | parser.add_argument( 52 | "--keep-lang", 53 | type=str, 54 | default="all", 55 | choices=["all", "en"], 56 | help="Only keep certain langauges.", 57 | ) 58 | parser.add_argument("--skip-lang", type=str, help="Skip a specific language.") 59 | # NOTE: Be careful about reduce_rep which may remove some good data. 60 | # For example, addresses could have long consecutive 0's 61 | parser.add_argument("--reduce-rep", action="store_true") 62 | args = parser.parse_args() 63 | 64 | in_file = args.in_file 65 | out_file = args.out_file 66 | keep_lang = args.keep_lang 67 | skip_lang = args.skip_lang 68 | reduce_rep = args.reduce_rep 69 | assert keep_lang == "all" or skip_lang is None 70 | 71 | if out_file is None: 72 | out_file = "sharegpt_clean" 73 | if keep_lang != "all": 74 | out_file += "_" + keep_lang 75 | if skip_lang is not None: 76 | out_file += "_skip_" + skip_lang 77 | if reduce_rep: 78 | out_file += "_reduce_rep" 79 | out_file += ".json" 80 | 81 | content = json.load(open(in_file, "r")) 82 | num_conv = len(content) 83 | 84 | new_content = [] 85 | for conv in tqdm(content): 86 | if not skip(conv, args): 87 | new_content.append(conv) 88 | 89 | print(f"#in: {len(content)}, #out: {len(new_content)}") 90 | json.dump(new_content, open(out_file, "w"), indent=2, ensure_ascii=False) 91 | -------------------------------------------------------------------------------- /fastchat/data/optional_replace.py: -------------------------------------------------------------------------------- 1 | """ 2 | Do optional replace of bos/eos/pad/unk. 3 | 4 | Usage: 5 | python3 -m fastchat.data.optional_replace --in input.json --out output.json --model-name-or-path 6 | 7 | Requirement: 8 | pip3 install transformers tqdm 9 | """ 10 | import argparse 11 | import json 12 | import traceback 13 | 14 | import transformers 15 | from tqdm import tqdm 16 | 17 | 18 | def replace_special_tokens( 19 | tokenizer: transformers.PreTrainedTokenizer, text: str 20 | ) -> str: 21 | if not text: 22 | return text 23 | 24 | def _insert_vline(token: str) -> str: 25 | if len(token) < 2: 26 | return " " 27 | elif len(token) == 2: 28 | return f"{token[0]}|{token[1]}" 29 | else: 30 | return f"{token[:1]}|{token[1:-1]}|{token[-1:]}" 31 | 32 | if tokenizer.bos_token: 33 | text = text.replace(tokenizer.bos_token, _insert_vline(tokenizer.bos_token)) 34 | if tokenizer.eos_token: 35 | text = text.replace(tokenizer.eos_token, _insert_vline(tokenizer.eos_token)) 36 | if tokenizer.pad_token: 37 | text = text.replace(tokenizer.pad_token, _insert_vline(tokenizer.pad_token)) 38 | if tokenizer.unk_token: 39 | text = text.replace(tokenizer.unk_token, _insert_vline(tokenizer.unk_token)) 40 | return text 41 | 42 | 43 | def replace(conv, tokenizer): 44 | # Replace bos/eos/pad/unk tokens 45 | if tokenizer: 46 | try: 47 | for sentence in conv["conversations"]: 48 | sentence["value"] = replace_special_tokens(tokenizer, sentence["value"]) 49 | except Exception as e: 50 | traceback.print_exc() 51 | 52 | 53 | if __name__ == "__main__": 54 | parser = argparse.ArgumentParser() 55 | parser.add_argument("--in-file", type=str, required=True) 56 | parser.add_argument("--out-file", type=str) 57 | parser.add_argument( 58 | "--model-name-or-path", 59 | type=str, 60 | help="The directory or address where the model token is stored.", 61 | ) 62 | args = parser.parse_args() 63 | 64 | in_file = args.in_file 65 | out_file = args.out_file 66 | tokenizer = None 67 | if args.model_name_or_path: 68 | tokenizer = transformers.AutoTokenizer.from_pretrained( 69 | args.model_name_or_path, 70 | trust_remote_code=True, 71 | use_fast=False, 72 | ) 73 | 74 | if out_file is None: 75 | out_file = f"{in_file}_replace.json" 76 | 77 | content = json.load(open(in_file, "r")) 78 | 79 | for conv in tqdm(content): 80 | replace(conv, tokenizer) 81 | 82 | json.dump(content, open(out_file, "w"), indent=2, ensure_ascii=False) 83 | -------------------------------------------------------------------------------- /fastchat/data/prepare_all.py: -------------------------------------------------------------------------------- 1 | """Prepare all datasets.""" 2 | 3 | import argparse 4 | import os 5 | 6 | from fastchat.utils import run_cmd 7 | 8 | 9 | if __name__ == "__main__": 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument("--prefix", type=str, default="~/datasets/sharegpt_20230521") 12 | parser.add_argument( 13 | "--model-name-or-path", type=str, default="meta-llama/Llama-2-7b-chat-hf" 14 | ) 15 | parser.add_argument("--seq-len", type=int, default=4096) 16 | args = parser.parse_args() 17 | 18 | in_prefix = args.prefix 19 | model_path = args.model_name_or_path 20 | seq_len = args.seq_len 21 | prefix = ( 22 | f"{in_prefix}_{seq_len}".replace("4096", "4k") 23 | .replace("8192", "8k") 24 | .replace("16384", "16k") 25 | ) 26 | 27 | cmd_list = [ 28 | f"python3 -m fastchat.data.clean_sharegpt --in {in_prefix}_html.json --out {prefix}_clean.json", 29 | f"python3 -m fastchat.data.optional_clean --in {prefix}_clean.json --out {prefix}_clean_lang.json --skip-lang ko", 30 | f"python3 -m fastchat.data.split_long_conversation --in {prefix}_clean_lang.json --out {prefix}_clean_lang_split.json --model-name {model_path} --max-length {seq_len}", 31 | f"python3 -m fastchat.data.filter_wrong_format --in {prefix}_clean_lang_split.json --out {prefix}_clean_lang_split.json", 32 | f"python3 -m fastchat.data.split_train_test --in {prefix}_clean_lang_split.json --ratio 0.99", 33 | f"python3 -m fastchat.data.hardcoded_questions", 34 | f"python3 -m fastchat.data.merge --in {prefix}_clean_lang_split_train.json hardcoded.json --out {prefix}_clean_lang_split_identity.json", 35 | f"python3 -m fastchat.data.extract_gpt4_only --in {prefix}_clean_lang_split_identity.json", 36 | f"python3 -m fastchat.data.extract_single_round --in {prefix}_clean_lang_split_identity.json", 37 | ] 38 | 39 | for cmd in cmd_list: 40 | ret = run_cmd(cmd) 41 | if ret != 0: 42 | exit(ret) 43 | -------------------------------------------------------------------------------- /fastchat/data/pretty_json.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 pretty_json.py --in in.json --out out.json 4 | """ 5 | 6 | import argparse 7 | import json 8 | 9 | 10 | if __name__ == "__main__": 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument("--in-file", type=str, required=True) 13 | parser.add_argument("--out-file", type=str, required=True) 14 | args = parser.parse_args() 15 | 16 | with open(args.in_file, "r") as fin: 17 | data = json.load(fin) 18 | 19 | with open(args.out_file, "w") as fout: 20 | json.dump(data, fout, indent=2, ensure_ascii=False) 21 | -------------------------------------------------------------------------------- /fastchat/data/sample.py: -------------------------------------------------------------------------------- 1 | """ 2 | Sample some conversations from a file. 3 | 4 | Usage: python3 -m fastchat.data.sample --in sharegpt.json --out sampled.json 5 | """ 6 | import argparse 7 | import json 8 | 9 | import numpy as np 10 | 11 | 12 | if __name__ == "__main__": 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument("--in-file", type=str, required=True) 15 | parser.add_argument("--out-file", type=str, default="sampled.json") 16 | parser.add_argument("--begin", type=int, default=0) 17 | parser.add_argument("--end", type=int, default=100) 18 | parser.add_argument("--max-length", type=int, default=1024) 19 | parser.add_argument("--keep-order", action="store_true") 20 | args = parser.parse_args() 21 | 22 | content = json.load(open(args.in_file, "r")) 23 | if not args.keep_order: 24 | np.random.seed(42) 25 | np.random.shuffle(content) 26 | 27 | new_content = [] 28 | for i in range(args.begin, min(args.end, len(content))): 29 | sample = content[i] 30 | concat = "" 31 | for s in sample["conversations"]: 32 | concat += s["value"] 33 | 34 | if len(concat) > args.max_length: 35 | continue 36 | 37 | new_content.append(sample) 38 | 39 | print(f"#in: {len(content)}, #out: {len(new_content)}") 40 | json.dump(new_content, open(args.out_file, "w"), indent=2, ensure_ascii=False) 41 | -------------------------------------------------------------------------------- /fastchat/data/split_long_conversation.py: -------------------------------------------------------------------------------- 1 | """ 2 | Split long conversations based on certain max length. 3 | 4 | Usage: python3 -m fastchat.data.split_long_conversation \ 5 | --in sharegpt_clean.json \ 6 | --out sharegpt_split.json \ 7 | --model-name-or-path $ 8 | """ 9 | import argparse 10 | from concurrent.futures import ProcessPoolExecutor 11 | import json 12 | from typing import Dict, Sequence, Optional 13 | 14 | import transformers 15 | from tqdm import tqdm 16 | 17 | 18 | def make_sample(sample, start_idx, end_idx): 19 | assert (end_idx - start_idx) % 2 == 0 20 | return { 21 | "id": sample["id"] + "_" + str(start_idx), 22 | "model": sample.get("model", ""), 23 | "conversations": sample["conversations"][start_idx:end_idx], 24 | } 25 | 26 | 27 | tokenizer = max_length = None 28 | 29 | 30 | def split_one_sample(sample): 31 | tokenized_lens = [] 32 | conversations = sample["conversations"] 33 | conversations = conversations[: len(conversations) // 2 * 2] 34 | for c in conversations: 35 | length = len(tokenizer(c["value"]).input_ids) + 6 36 | tokenized_lens.append(length) 37 | 38 | start_idx = 0 39 | cur_len = 0 40 | 41 | if len(conversations) % 2 != 0 or len(conversations) < 2: 42 | return [] 43 | 44 | new_samples = [] 45 | for i in range(0, len(conversations), 2): 46 | tmp_len = tokenized_lens[i] + tokenized_lens[i + 1] 47 | if cur_len + tmp_len > max_length: 48 | new_samples.append(make_sample(sample, start_idx, i)) 49 | start_idx = i 50 | cur_len = 0 51 | elif i == len(conversations) - 2: 52 | new_samples.append(make_sample(sample, start_idx, i + 2)) 53 | 54 | cur_len += tmp_len 55 | 56 | return new_samples 57 | 58 | 59 | def worker(input_data): 60 | result = [] 61 | for sample in input_data: 62 | result.extend(split_one_sample(sample)) 63 | return result 64 | 65 | 66 | def split_all(content, begin, end, tokenizer_, max_length_): 67 | """ 68 | Keep the maximum round of conversations within the max token length constraint 69 | """ 70 | global tokenizer, max_length 71 | tokenizer = tokenizer_ 72 | max_length = max_length_ 73 | 74 | content = content[begin:end] 75 | new_content = [] 76 | 77 | # Split content into chunks 78 | chunks = [content[i : i + 1000] for i in range(0, len(content), 1000)] 79 | with ProcessPoolExecutor() as executor: 80 | for result in tqdm(executor.map(worker, chunks), total=len(chunks)): 81 | new_content.extend(result) 82 | 83 | return new_content 84 | 85 | 86 | def filter_invalid_roles(content): 87 | new_content = [] 88 | for i, c in enumerate(content): 89 | roles = ["human", "gpt"] 90 | if len(c["conversations"]) <= 0: 91 | continue 92 | 93 | valid = True 94 | for j, s in enumerate(c["conversations"]): 95 | if s["from"] != roles[j % 2]: 96 | valid = False 97 | break 98 | 99 | if valid: 100 | new_content.append(c) 101 | 102 | return new_content 103 | 104 | 105 | def main(args): 106 | content = json.load(open(args.in_file, "r")) 107 | tokenizer = transformers.AutoTokenizer.from_pretrained( 108 | args.model_name_or_path, 109 | model_max_length=args.max_length, 110 | padding_side="right", 111 | use_fast=False, 112 | ) 113 | new_content = split_all(content, args.begin, args.end, tokenizer, args.max_length) 114 | new_content = filter_invalid_roles(new_content) 115 | 116 | print(f"#in: {len(content)}, #out: {len(new_content)}") 117 | json.dump(new_content, open(args.out_file, "w"), indent=2, ensure_ascii=False) 118 | 119 | 120 | if __name__ == "__main__": 121 | parser = argparse.ArgumentParser() 122 | parser.add_argument("--in-file", type=str, required=True) 123 | parser.add_argument("--out-file", type=str, default="sharegpt_split.json") 124 | parser.add_argument("--begin", type=int) 125 | parser.add_argument("--end", type=int) 126 | parser.add_argument("--model-name-or-path", type=str, required=True) 127 | parser.add_argument("--max-length", type=int, default=2048) 128 | args = parser.parse_args() 129 | main(args) 130 | -------------------------------------------------------------------------------- /fastchat/data/split_train_test.py: -------------------------------------------------------------------------------- 1 | """ 2 | Split the dataset into training and test set. 3 | 4 | Usage: python3 -m fastchat.data.split_train_test --in sharegpt.json 5 | """ 6 | import argparse 7 | import json 8 | 9 | import numpy as np 10 | 11 | 12 | if __name__ == "__main__": 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument("--in-file", type=str, required=True) 15 | parser.add_argument("--begin", type=int, default=0) 16 | parser.add_argument("--end", type=int, default=100) 17 | parser.add_argument("--ratio", type=float, default=0.9) 18 | args = parser.parse_args() 19 | 20 | content = json.load(open(args.in_file, "r")) 21 | np.random.seed(0) 22 | 23 | perm = np.random.permutation(len(content)) 24 | content = [content[i] for i in perm] 25 | split = int(args.ratio * len(content)) 26 | 27 | train_set = content[:split] 28 | test_set = content[split:] 29 | 30 | print(f"#train: {len(train_set)}, #test: {len(test_set)}") 31 | train_name = args.in_file.replace(".json", "_train.json") 32 | test_name = args.in_file.replace(".json", "_test.json") 33 | json.dump(train_set, open(train_name, "w"), indent=2, ensure_ascii=False) 34 | json.dump(test_set, open(test_name, "w"), indent=2, ensure_ascii=False) 35 | -------------------------------------------------------------------------------- /fastchat/llm_judge/clean_judgment.py: -------------------------------------------------------------------------------- 1 | """ 2 | Clean model judgment files. 3 | """ 4 | import argparse 5 | import json 6 | 7 | selected_models = [ 8 | "alpaca-13b", 9 | "baize-v2-13b", 10 | "chatglm-6b", 11 | "claude-instant-v1", 12 | "claude-v1", 13 | "dolly-v2-12b", 14 | "falcon-40b-instruct", 15 | "fastchat-t5-3b", 16 | "gpt-3.5-turbo", 17 | "gpt-4", 18 | "gpt4all-13b-snoozy", 19 | "guanaco-33b", 20 | "guanaco-65b", 21 | "h2ogpt-oasst-open-llama-13b", 22 | "koala-13b", 23 | "llama-13b", 24 | "mpt-30b-chat", 25 | "mpt-30b-instruct", 26 | "mpt-7b-chat", 27 | "nous-hermes-13b", 28 | "oasst-sft-4-pythia-12b", 29 | "oasst-sft-7-llama-30b", 30 | "palm-2-chat-bison-001", 31 | "rwkv-4-raven-14b", 32 | "stablelm-tuned-alpha-7b", 33 | "tulu-30b", 34 | "vicuna-13b-v1.3", 35 | "vicuna-33b-v1.3", 36 | "vicuna-7b-v1.3", 37 | "wizardlm-13b", 38 | "wizardlm-30b", 39 | ] 40 | 41 | 42 | if __name__ == "__main__": 43 | parser = argparse.ArgumentParser() 44 | parser.add_argument("--infile", type=str) 45 | args = parser.parse_args() 46 | 47 | infile = args.infile 48 | outfile = infile.replace(".jsonl", "_clean.jsonl") 49 | 50 | raw_lines = open(infile).readlines() 51 | rets = [] 52 | models = set() 53 | visited = set() 54 | for line in raw_lines: 55 | obj = json.loads(line) 56 | 57 | if "model_1" in obj: # pair 58 | model = obj["model_1"] 59 | key = ( 60 | obj["model_1"], 61 | obj["model_2"], 62 | obj["question_id"], 63 | tuple(obj["judge"]), 64 | ) 65 | else: # single 66 | model = obj["model"] 67 | key = (obj["model"], obj["question_id"], tuple(obj["judge"])) 68 | 69 | if key in visited: 70 | continue 71 | visited.add(key) 72 | 73 | if model not in selected_models: 74 | continue 75 | models.add(model) 76 | rets.append(obj) 77 | 78 | models = sorted(list(models)) 79 | missing_models = [x for x in selected_models if x not in models] 80 | print(f"in models: {models}, number: {len(models)}") 81 | print(f"missing models: {missing_models}") 82 | print(f"#in: {len(raw_lines)}, #out: {len(rets)}") 83 | rets.sort( 84 | key=lambda x: ( 85 | x["model"] if "model" in x else x["model_1"], 86 | x["question_id"], 87 | x["turn"], 88 | ) 89 | ) 90 | 91 | with open(outfile, "w") as fout: 92 | for x in rets: 93 | fout.write(json.dumps(x) + "\n") 94 | -------------------------------------------------------------------------------- /fastchat/llm_judge/data/mt_bench/misc/radar.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lm-sys/FastChat/9a295b64ce3491ff15901f2d00f5e304b0ee78dc/fastchat/llm_judge/data/mt_bench/misc/radar.png -------------------------------------------------------------------------------- /fastchat/llm_judge/download_mt_bench_pregenerated.py: -------------------------------------------------------------------------------- 1 | """ 2 | Download the pre-generated model answers and judgments for MT-bench. 3 | """ 4 | import os 5 | 6 | from fastchat.utils import run_cmd 7 | 8 | filenames = [ 9 | "data/mt_bench/model_answer/alpaca-13b.jsonl", 10 | "data/mt_bench/model_answer/baize-v2-13b.jsonl", 11 | "data/mt_bench/model_answer/chatglm-6b.jsonl", 12 | "data/mt_bench/model_answer/claude-instant-v1.jsonl", 13 | "data/mt_bench/model_answer/claude-v1.jsonl", 14 | "data/mt_bench/model_answer/dolly-v2-12b.jsonl", 15 | "data/mt_bench/model_answer/falcon-40b-instruct.jsonl", 16 | "data/mt_bench/model_answer/fastchat-t5-3b.jsonl", 17 | "data/mt_bench/model_answer/gpt-3.5-turbo.jsonl", 18 | "data/mt_bench/model_answer/gpt-4.jsonl", 19 | "data/mt_bench/model_answer/gpt4all-13b-snoozy.jsonl", 20 | "data/mt_bench/model_answer/guanaco-33b.jsonl", 21 | "data/mt_bench/model_answer/guanaco-65b.jsonl", 22 | "data/mt_bench/model_answer/h2ogpt-oasst-open-llama-13b.jsonl", 23 | "data/mt_bench/model_answer/koala-13b.jsonl", 24 | "data/mt_bench/model_answer/llama-13b.jsonl", 25 | "data/mt_bench/model_answer/mpt-30b-chat.jsonl", 26 | "data/mt_bench/model_answer/mpt-30b-instruct.jsonl", 27 | "data/mt_bench/model_answer/mpt-7b-chat.jsonl", 28 | "data/mt_bench/model_answer/nous-hermes-13b.jsonl", 29 | "data/mt_bench/model_answer/oasst-sft-4-pythia-12b.jsonl", 30 | "data/mt_bench/model_answer/oasst-sft-7-llama-30b.jsonl", 31 | "data/mt_bench/model_answer/palm-2-chat-bison-001.jsonl", 32 | "data/mt_bench/model_answer/rwkv-4-raven-14b.jsonl", 33 | "data/mt_bench/model_answer/stablelm-tuned-alpha-7b.jsonl", 34 | "data/mt_bench/model_answer/tulu-30b.jsonl", 35 | "data/mt_bench/model_answer/vicuna-13b-v1.3.jsonl", 36 | "data/mt_bench/model_answer/vicuna-33b-v1.3.jsonl", 37 | "data/mt_bench/model_answer/vicuna-7b-v1.3.jsonl", 38 | "data/mt_bench/model_answer/wizardlm-13b.jsonl", 39 | "data/mt_bench/model_answer/wizardlm-30b.jsonl", 40 | "data/mt_bench/model_judgment/gpt-4_single.jsonl", 41 | "data/mt_bench/model_judgment/gpt-4_pair.jsonl", 42 | ] 43 | 44 | 45 | if __name__ == "__main__": 46 | prefix = "https://huggingface.co/spaces/lmsys/mt-bench/resolve/main/" 47 | 48 | for name in filenames: 49 | os.makedirs(os.path.dirname(name), exist_ok=True) 50 | ret = run_cmd(f"wget -q --show-progress -O {name} {prefix + name}") 51 | assert ret == 0 52 | -------------------------------------------------------------------------------- /fastchat/model/__init__.py: -------------------------------------------------------------------------------- 1 | from fastchat.model.model_adapter import ( 2 | load_model, 3 | get_conversation_template, 4 | add_model_args, 5 | ) 6 | -------------------------------------------------------------------------------- /fastchat/model/apply_lora.py: -------------------------------------------------------------------------------- 1 | """ 2 | Apply the LoRA weights on top of a base model. 3 | 4 | Usage: 5 | python3 -m fastchat.model.apply_lora --base ~/model_weights/llama-7b --target ~/model_weights/baize-7b --lora project-baize/baize-lora-7B 6 | 7 | Dependency: 8 | pip3 install git+https://github.com/huggingface/peft.git@2822398fbe896f25d4dac5e468624dc5fd65a51b 9 | """ 10 | import argparse 11 | 12 | import torch 13 | from peft import PeftModel 14 | from transformers import AutoTokenizer, AutoModelForCausalLM 15 | 16 | 17 | def apply_lora(base_model_path, target_model_path, lora_path): 18 | print(f"Loading the base model from {base_model_path}") 19 | base = AutoModelForCausalLM.from_pretrained( 20 | base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True 21 | ) 22 | base_tokenizer = AutoTokenizer.from_pretrained(base_model_path, use_fast=False) 23 | 24 | print(f"Loading the LoRA adapter from {lora_path}") 25 | 26 | lora_model = PeftModel.from_pretrained( 27 | base, 28 | lora_path, 29 | # torch_dtype=torch.float16 30 | ) 31 | 32 | print("Applying the LoRA") 33 | model = lora_model.merge_and_unload() 34 | 35 | print(f"Saving the target model to {target_model_path}") 36 | model.save_pretrained(target_model_path) 37 | base_tokenizer.save_pretrained(target_model_path) 38 | 39 | 40 | if __name__ == "__main__": 41 | parser = argparse.ArgumentParser() 42 | parser.add_argument("--base-model-path", type=str, required=True) 43 | parser.add_argument("--target-model-path", type=str, required=True) 44 | parser.add_argument("--lora-path", type=str, required=True) 45 | 46 | args = parser.parse_args() 47 | 48 | apply_lora(args.base_model_path, args.target_model_path, args.lora_path) 49 | -------------------------------------------------------------------------------- /fastchat/model/convert_fp16.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 -m fastchat.model.convert_fp16 --in in-folder --out out-folder 4 | """ 5 | import argparse 6 | 7 | from transformers import AutoTokenizer, AutoModelForCausalLM 8 | import torch 9 | 10 | 11 | def convert_fp16(in_checkpoint, out_checkpoint): 12 | tokenizer = AutoTokenizer.from_pretrained(in_checkpoint, use_fast=False) 13 | model = AutoModelForCausalLM.from_pretrained( 14 | in_checkpoint, torch_dtype=torch.float16, low_cpu_mem_usage=True 15 | ) 16 | model.save_pretrained(out_checkpoint) 17 | tokenizer.save_pretrained(out_checkpoint) 18 | 19 | 20 | if __name__ == "__main__": 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument("--in-checkpoint", type=str, help="Path to the model") 23 | parser.add_argument("--out-checkpoint", type=str, help="Path to the output model") 24 | args = parser.parse_args() 25 | 26 | convert_fp16(args.in_checkpoint, args.out_checkpoint) 27 | -------------------------------------------------------------------------------- /fastchat/model/llama_condense_monkey_patch.py: -------------------------------------------------------------------------------- 1 | # Code adapted from https://huggingface.co/kaiokendev/superhot-13b-8k-no-rlhf-test/blob/main/llama_rope_scaled_monkey_patch.py 2 | 3 | from functools import partial 4 | 5 | import torch 6 | import transformers 7 | import transformers.models.llama.modeling_llama 8 | 9 | 10 | class CondenseRotaryEmbedding(torch.nn.Module): 11 | def __init__( 12 | self, dim, ratio, max_position_embeddings=2048, base=10000, device=None 13 | ): 14 | super().__init__() 15 | inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) 16 | self.register_buffer("inv_freq", inv_freq) 17 | 18 | # Build here to make `torch.jit.trace` work. 19 | self.ratio = ratio 20 | max_position_embeddings *= ratio 21 | self.max_seq_len_cached = max_position_embeddings 22 | # print(f"Monkey Patching condense ratio {ratio}") 23 | t = ( 24 | torch.arange( 25 | self.max_seq_len_cached, 26 | device=self.inv_freq.device, 27 | dtype=self.inv_freq.dtype, 28 | ) 29 | / ratio 30 | ) 31 | freqs = torch.einsum("i,j->ij", t, self.inv_freq) 32 | # Different from paper, but it uses a different permutation in order to obtain the same calculation 33 | emb = torch.cat((freqs, freqs), dim=-1) 34 | dtype = torch.get_default_dtype() 35 | self.register_buffer( 36 | "cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False 37 | ) 38 | self.register_buffer( 39 | "sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False 40 | ) 41 | 42 | def forward(self, x, seq_len=None): 43 | # x: [bs, num_attention_heads, seq_len, head_size] 44 | # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case. 45 | if seq_len > self.max_seq_len_cached: 46 | self.max_seq_len_cached = seq_len 47 | t = ( 48 | torch.arange( 49 | self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype 50 | ) 51 | / self.ratio 52 | ) 53 | freqs = torch.einsum("i,j->ij", t, self.inv_freq) 54 | # Different from paper, but it uses a different permutation in order to obtain the same calculation 55 | emb = torch.cat((freqs, freqs), dim=-1).to(x.device) 56 | self.register_buffer( 57 | "cos_cached", emb.cos()[None, None, :, :].to(x.dtype), persistent=False 58 | ) 59 | self.register_buffer( 60 | "sin_cached", emb.sin()[None, None, :, :].to(x.dtype), persistent=False 61 | ) 62 | return ( 63 | self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), 64 | self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), 65 | ) 66 | 67 | 68 | def replace_llama_with_condense(ratio): 69 | transformers.models.llama.modeling_llama.LlamaRotaryEmbedding = partial( 70 | CondenseRotaryEmbedding, ratio=ratio 71 | ) 72 | -------------------------------------------------------------------------------- /fastchat/model/make_delta.py: -------------------------------------------------------------------------------- 1 | """ 2 | Make the delta weights by subtracting base weights. 3 | 4 | Usage: 5 | python3 -m fastchat.model.make_delta --base ~/model_weights/llama-13b --target ~/model_weights/vicuna-13b --delta ~/model_weights/vicuna-13b-delta --hub-repo-id lmsys/vicuna-13b-delta-v1.1 6 | """ 7 | import argparse 8 | 9 | import torch 10 | from tqdm import tqdm 11 | from transformers import AutoTokenizer, AutoModelForCausalLM 12 | 13 | 14 | def make_delta(base_model_path, target_model_path, delta_path): 15 | print(f"Loading the base model from {base_model_path}") 16 | base = AutoModelForCausalLM.from_pretrained( 17 | base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True 18 | ) 19 | 20 | print(f"Loading the target model from {target_model_path}") 21 | target = AutoModelForCausalLM.from_pretrained( 22 | target_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True 23 | ) 24 | target_tokenizer = AutoTokenizer.from_pretrained(target_model_path, use_fast=False) 25 | 26 | print("Calculating the delta") 27 | for name, param in tqdm(target.state_dict().items(), desc="Calculating delta"): 28 | assert name in base.state_dict() 29 | param.data -= base.state_dict()[name] 30 | 31 | print(f"Saving the delta to {delta_path}") 32 | if args.hub_repo_id: 33 | kwargs = {"push_to_hub": True, "repo_id": args.hub_repo_id} 34 | else: 35 | kwargs = {} 36 | target.save_pretrained(delta_path, **kwargs) 37 | target_tokenizer.save_pretrained(delta_path, **kwargs) 38 | 39 | 40 | if __name__ == "__main__": 41 | parser = argparse.ArgumentParser() 42 | parser.add_argument("--base-model-path", type=str, required=True) 43 | parser.add_argument("--target-model-path", type=str, required=True) 44 | parser.add_argument("--delta-path", type=str, required=True) 45 | parser.add_argument("--hub-repo-id", type=str) 46 | args = parser.parse_args() 47 | 48 | make_delta(args.base_model_path, args.target_model_path, args.delta_path) 49 | -------------------------------------------------------------------------------- /fastchat/model/model_codet5p.py: -------------------------------------------------------------------------------- 1 | import gc 2 | from threading import Thread 3 | import torch 4 | import transformers 5 | from transformers import ( 6 | GenerationConfig, 7 | StoppingCriteria, 8 | StoppingCriteriaList, 9 | TextIteratorStreamer, 10 | ) 11 | 12 | 13 | @torch.inference_mode() 14 | def generate_stream_codet5p( 15 | model, 16 | tokenizer, 17 | params, 18 | device, 19 | context_len=2048, 20 | stream_interval=2, 21 | judge_sent_end=False, 22 | ): 23 | prompt = params["prompt"] 24 | temperature = float(params.get("temperature", 1.0)) 25 | repetition_penalty = float(params.get("repetition_penalty", 1.0)) 26 | top_p = float(params.get("top_p", 1.0)) 27 | top_k = int(params.get("top_k", 50)) # -1 means disable 28 | max_new_tokens = int(params.get("max_new_tokens", 1024)) 29 | stop_token_ids = params.get("stop_token_ids", None) or [] 30 | stop_token_ids.append(tokenizer.eos_token_id) 31 | 32 | decode_config = dict(skip_special_tokens=True, clean_up_tokenization_spaces=True) 33 | streamer = TextIteratorStreamer(tokenizer, **decode_config) 34 | encoding = tokenizer(prompt, return_tensors="pt").to(device) 35 | input_ids = encoding.input_ids 36 | encoding["decoder_input_ids"] = encoding["input_ids"].clone() 37 | input_echo_len = len(input_ids) 38 | 39 | generation_config = GenerationConfig( 40 | max_new_tokens=max_new_tokens, 41 | do_sample=temperature >= 1e-5, 42 | temperature=temperature, 43 | repetition_penalty=repetition_penalty, 44 | no_repeat_ngram_size=10, 45 | top_p=top_p, 46 | top_k=top_k, 47 | eos_token_id=stop_token_ids, 48 | ) 49 | 50 | class CodeBlockStopper(StoppingCriteria): 51 | def __call__( 52 | self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs 53 | ) -> bool: 54 | # Code-completion is open-end generation. 55 | # We check \n\n to stop at end of a code block. 56 | if list(input_ids[0][-2:]) == [628, 198]: 57 | return True 58 | return False 59 | 60 | gen_kwargs = dict( 61 | **encoding, 62 | streamer=streamer, 63 | generation_config=generation_config, 64 | stopping_criteria=StoppingCriteriaList([CodeBlockStopper()]), 65 | ) 66 | thread = Thread(target=model.generate, kwargs=gen_kwargs) 67 | thread.start() 68 | i = 0 69 | output = "" 70 | for new_text in streamer: 71 | i += 1 72 | output += new_text 73 | if i % stream_interval == 0 or i == max_new_tokens - 1: 74 | yield { 75 | "text": output, 76 | "usage": { 77 | "prompt_tokens": input_echo_len, 78 | "completion_tokens": i, 79 | "total_tokens": input_echo_len + i, 80 | }, 81 | "finish_reason": None, 82 | } 83 | if i >= max_new_tokens: 84 | break 85 | 86 | if i >= max_new_tokens: 87 | finish_reason = "length" 88 | else: 89 | finish_reason = "stop" 90 | 91 | yield { 92 | "text": output, 93 | "usage": { 94 | "prompt_tokens": input_echo_len, 95 | "completion_tokens": i, 96 | "total_tokens": input_echo_len + i, 97 | }, 98 | "finish_reason": finish_reason, 99 | } 100 | thread.join() 101 | 102 | # clean 103 | gc.collect() 104 | torch.cuda.empty_cache() 105 | if device == "xpu": 106 | torch.xpu.empty_cache() 107 | if device == "npu": 108 | torch.npu.empty_cache() 109 | -------------------------------------------------------------------------------- /fastchat/model/model_exllama.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import sys 3 | from typing import Dict 4 | 5 | import torch 6 | 7 | 8 | def generate_stream_exllama( 9 | model, 10 | tokenizer, 11 | params: Dict, 12 | device: str, 13 | context_len: int, 14 | stream_interval: int = 2, 15 | judge_sent_end: bool = False, 16 | ): 17 | try: 18 | from exllamav2.generator import ExLlamaV2StreamingGenerator, ExLlamaV2Sampler 19 | except ImportError as e: 20 | print(f"Error: Failed to load Exllamav2. {e}") 21 | sys.exit(-1) 22 | 23 | prompt = params["prompt"] 24 | 25 | generator = ExLlamaV2StreamingGenerator(model.model, model.cache, tokenizer) 26 | settings = ExLlamaV2Sampler.Settings() 27 | 28 | settings.temperature = float(params.get("temperature", 0.85)) 29 | settings.top_k = int(params.get("top_k", 50)) 30 | settings.top_p = float(params.get("top_p", 0.8)) 31 | settings.token_repetition_penalty = float(params.get("repetition_penalty", 1.15)) 32 | settings.disallow_tokens(generator.tokenizer, [generator.tokenizer.eos_token_id]) 33 | 34 | max_new_tokens = int(params.get("max_new_tokens", 256)) 35 | 36 | generator.set_stop_conditions(params.get("stop_token_ids", None) or []) 37 | echo = bool(params.get("echo", True)) 38 | 39 | input_ids = generator.tokenizer.encode(prompt) 40 | prompt_tokens = input_ids.shape[-1] 41 | generator.begin_stream(input_ids, settings) 42 | 43 | generated_tokens = 0 44 | if echo: 45 | output = prompt 46 | else: 47 | output = "" 48 | while True: 49 | chunk, eos, _ = generator.stream() 50 | output += chunk 51 | generated_tokens += 1 52 | if generated_tokens == max_new_tokens: 53 | finish_reason = "length" 54 | break 55 | elif eos: 56 | finish_reason = "length" 57 | break 58 | yield { 59 | "text": output, 60 | "usage": { 61 | "prompt_tokens": prompt_tokens, 62 | "completion_tokens": generated_tokens, 63 | "total_tokens": prompt_tokens + generated_tokens, 64 | }, 65 | "finish_reason": None, 66 | } 67 | 68 | yield { 69 | "text": output, 70 | "usage": { 71 | "prompt_tokens": prompt_tokens, 72 | "completion_tokens": generated_tokens, 73 | "total_tokens": prompt_tokens + generated_tokens, 74 | }, 75 | "finish_reason": finish_reason, 76 | } 77 | gc.collect() 78 | -------------------------------------------------------------------------------- /fastchat/model/model_xfastertransformer.py: -------------------------------------------------------------------------------- 1 | import gc 2 | from threading import Thread 3 | 4 | import torch 5 | from transformers import TextIteratorStreamer 6 | 7 | 8 | @torch.inference_mode() 9 | def generate_stream_xft( 10 | model, 11 | tokenizer, 12 | params, 13 | device, 14 | context_len=8192, 15 | stream_interval=2, 16 | judge_sent_end=False, 17 | ): 18 | prompt = params["prompt"] 19 | repetition_penalty = float(params.get("repetition_penalty", 1.0)) 20 | 21 | # unused now, and placehold for future. 22 | # temperature = float(params.get("temperature", 1.0)) 23 | # top_p = float(params.get("top_p", 1.0)) 24 | 25 | max_new_tokens = int(params.get("max_new_tokens", 4096)) 26 | echo = params.get("echo", True) 27 | 28 | inputs = tokenizer( 29 | prompt, return_tensors="pt", padding=model.config.padding 30 | ).input_ids 31 | input_echo_len = len(inputs[0]) 32 | max_len = max_new_tokens + input_echo_len 33 | 34 | decode_config = dict(skip_special_tokens=True, clean_up_tokenization_spaces=True) 35 | streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, **decode_config) 36 | generation_kwargs = { 37 | "input_ids": inputs, 38 | "streamer": streamer, 39 | "max_length": max_len, 40 | "num_beams": model.config.beam_width, 41 | "length_penalty": repetition_penalty, 42 | "num_return_sequences": model.config.num_return_sequences, 43 | "early_stopping": model.config.early_stopping, 44 | "eos_token_id": model.config.eos_token_id, 45 | "pad_token_id": model.config.pad_token_id, 46 | } 47 | 48 | thread = Thread(target=model.model.generate, kwargs=generation_kwargs) 49 | thread.start() 50 | if echo: 51 | # means keep the prompt 52 | output = prompt 53 | else: 54 | output = "" 55 | i = 0 56 | for i, new_text in enumerate(streamer): 57 | output += new_text 58 | yield { 59 | "text": output, 60 | "usage": { 61 | "prompt_tokens": input_echo_len, 62 | "completion_tokens": i, 63 | "total_tokens": input_echo_len + i, 64 | }, 65 | "finish_reason": None, 66 | } 67 | output = output.strip() 68 | if i == max_new_tokens - 1: 69 | finish_reason = "length" 70 | else: 71 | finish_reason = "stop" 72 | yield { 73 | "text": output, 74 | "usage": { 75 | "prompt_tokens": input_echo_len, 76 | "completion_tokens": i, 77 | "total_tokens": input_echo_len + i, 78 | }, 79 | "finish_reason": finish_reason, 80 | } 81 | gc.collect() 82 | -------------------------------------------------------------------------------- /fastchat/model/rwkv_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | from types import SimpleNamespace 3 | import warnings 4 | 5 | import torch 6 | 7 | os.environ["RWKV_JIT_ON"] = "1" 8 | os.environ["RWKV_CUDA_ON"] = "1" 9 | 10 | from rwkv.model import RWKV 11 | from rwkv.utils import PIPELINE, PIPELINE_ARGS 12 | 13 | 14 | class RwkvModel: 15 | def __init__(self, model_path): 16 | warnings.warn( 17 | "Experimental support. Please use ChatRWKV if you want to chat with RWKV" 18 | ) 19 | self.config = SimpleNamespace(is_encoder_decoder=False) 20 | self.model = RWKV(model=model_path, strategy="cuda fp16") 21 | # two GPUs 22 | # self.model = RWKV(model=model_path, strategy="cuda:0 fp16 *20 -> cuda:1 fp16") 23 | 24 | self.tokenizer = None 25 | self.model_path = model_path 26 | 27 | def to(self, target): 28 | assert target == "cuda" 29 | 30 | def __call__(self, input_ids, use_cache, past_key_values=None): 31 | assert use_cache == True 32 | input_ids = input_ids[0].detach().cpu().numpy() 33 | # print(input_ids) 34 | logits, state = self.model.forward(input_ids, past_key_values) 35 | # print(logits) 36 | logits = logits.unsqueeze(0).unsqueeze(0) 37 | out = SimpleNamespace(logits=logits, past_key_values=state) 38 | return out 39 | 40 | def generate( 41 | self, input_ids, do_sample, temperature, max_new_tokens, repetition_penalty=1.0 42 | ): 43 | # This function is used by fastchat.llm_judge. 44 | # Because RWKV does not support huggingface generation API, 45 | # we reuse fastchat.serve.inference.generate_stream as a workaround. 46 | from transformers import AutoTokenizer 47 | 48 | from fastchat.serve.inference import generate_stream 49 | from fastchat.conversation import get_conv_template 50 | 51 | if self.tokenizer is None: 52 | self.tokenizer = AutoTokenizer.from_pretrained( 53 | "EleutherAI/pythia-160m", use_fast=True 54 | ) 55 | prompt = self.tokenizer.decode(input_ids[0].tolist()) 56 | conv = get_conv_template("rwkv") 57 | 58 | gen_params = { 59 | "model": self.model_path, 60 | "prompt": prompt, 61 | "temperature": temperature, 62 | "repetition_penalty": repetition_penalty, 63 | "max_new_tokens": max_new_tokens, 64 | "stop": conv.stop_str, 65 | "stop_token_ids": conv.stop_token_ids, 66 | "echo": False, 67 | } 68 | res_iter = generate_stream(self, self.tokenizer, gen_params, "cuda") 69 | 70 | for res in res_iter: 71 | pass 72 | 73 | output = res["text"] 74 | output_ids = self.tokenizer.encode(output) 75 | 76 | return [input_ids[0].tolist() + output_ids] 77 | -------------------------------------------------------------------------------- /fastchat/model/upload_hub.py: -------------------------------------------------------------------------------- 1 | """ 2 | Upload weights to huggingface. 3 | 4 | Usage: 5 | python3 -m fastchat.model.upload_hub --model-path ~/model_weights/vicuna-13b --hub-repo-id lmsys/vicuna-13b-v1.3 6 | """ 7 | import argparse 8 | import tempfile 9 | 10 | import torch 11 | from transformers import AutoTokenizer, AutoModelForCausalLM 12 | 13 | 14 | def upload_hub(model_path, hub_repo_id, component, private): 15 | if component == "all": 16 | components = ["model", "tokenizer"] 17 | else: 18 | components = [component] 19 | 20 | kwargs = {"push_to_hub": True, "repo_id": hub_repo_id, "private": args.private} 21 | 22 | if "model" in components: 23 | model = AutoModelForCausalLM.from_pretrained( 24 | model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True 25 | ) 26 | with tempfile.TemporaryDirectory() as tmp_path: 27 | model.save_pretrained(tmp_path, **kwargs) 28 | 29 | if "tokenizer" in components: 30 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) 31 | with tempfile.TemporaryDirectory() as tmp_path: 32 | tokenizer.save_pretrained(tmp_path, **kwargs) 33 | 34 | 35 | if __name__ == "__main__": 36 | parser = argparse.ArgumentParser() 37 | parser.add_argument("--model-path", type=str, required=True) 38 | parser.add_argument("--hub-repo-id", type=str, required=True) 39 | parser.add_argument( 40 | "--component", type=str, choices=["all", "model", "tokenizer"], default="all" 41 | ) 42 | parser.add_argument("--private", action="store_true") 43 | args = parser.parse_args() 44 | 45 | upload_hub(args.model_path, args.hub_repo_id, args.component, args.private) 46 | -------------------------------------------------------------------------------- /fastchat/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lm-sys/FastChat/9a295b64ce3491ff15901f2d00f5e304b0ee78dc/fastchat/modules/__init__.py -------------------------------------------------------------------------------- /fastchat/modules/awq.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from pathlib import Path 3 | import sys 4 | 5 | import torch 6 | from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM, modeling_utils 7 | 8 | 9 | @dataclass 10 | class AWQConfig: 11 | ckpt: str = field( 12 | default=None, 13 | metadata={ 14 | "help": "Load quantized model. The path to the local AWQ checkpoint." 15 | }, 16 | ) 17 | wbits: int = field(default=16, metadata={"help": "#bits to use for quantization"}) 18 | groupsize: int = field( 19 | default=-1, 20 | metadata={"help": "Groupsize to use for quantization; default uses full row."}, 21 | ) 22 | 23 | 24 | def load_awq_quantized(model_name, awq_config: AWQConfig, device): 25 | print("Loading AWQ quantized model...") 26 | 27 | try: 28 | from tinychat.utils import load_quant 29 | from tinychat.modules import make_quant_norm, make_quant_attn, make_fused_mlp 30 | except ImportError as e: 31 | print(f"Error: Failed to import tinychat. {e}") 32 | print("Please double check if you have successfully installed AWQ") 33 | print("See https://github.com/lm-sys/FastChat/blob/main/docs/awq.md") 34 | sys.exit(-1) 35 | 36 | config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) 37 | tokenizer = AutoTokenizer.from_pretrained( 38 | model_name, use_fast=False, trust_remote_code=True 39 | ) 40 | 41 | def skip(*args, **kwargs): 42 | pass 43 | 44 | torch.nn.init.kaiming_uniform_ = skip 45 | torch.nn.init.kaiming_normal_ = skip 46 | torch.nn.init.uniform_ = skip 47 | torch.nn.init.normal_ = skip 48 | modeling_utils._init_weights = False 49 | 50 | torch.set_default_dtype(torch.half) 51 | model = AutoModelForCausalLM.from_config(config, trust_remote_code=True) 52 | 53 | if any(name in find_awq_ckpt(awq_config) for name in ["llama", "vicuna"]): 54 | model = load_quant.load_awq_llama_fast( 55 | model, 56 | find_awq_ckpt(awq_config), 57 | awq_config.wbits, 58 | awq_config.groupsize, 59 | device, 60 | ) 61 | make_quant_attn(model, device) 62 | make_quant_norm(model) 63 | make_fused_mlp(model) 64 | else: 65 | model = load_quant.load_awq_model( 66 | model, 67 | find_awq_ckpt(awq_config), 68 | awq_config.wbits, 69 | awq_config.groupsize, 70 | device, 71 | ) 72 | return model, tokenizer 73 | 74 | 75 | def find_awq_ckpt(awq_config: AWQConfig): 76 | if Path(awq_config.ckpt).is_file(): 77 | return awq_config.ckpt 78 | 79 | for ext in ["*.pt", "*.safetensors"]: 80 | matched_result = sorted(Path(awq_config.ckpt).glob(ext)) 81 | if len(matched_result) > 0: 82 | return str(matched_result[-1]) 83 | 84 | print("Error: AWQ checkpoint not found") 85 | sys.exit(1) 86 | -------------------------------------------------------------------------------- /fastchat/modules/exllama.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | import sys 3 | 4 | 5 | @dataclass 6 | class ExllamaConfig: 7 | max_seq_len: int 8 | gpu_split: str = None 9 | cache_8bit: bool = False 10 | 11 | 12 | class ExllamaModel: 13 | def __init__(self, exllama_model, exllama_cache): 14 | self.model = exllama_model 15 | self.cache = exllama_cache 16 | self.config = self.model.config 17 | 18 | 19 | def load_exllama_model(model_path, exllama_config: ExllamaConfig): 20 | try: 21 | from exllamav2 import ( 22 | ExLlamaV2Config, 23 | ExLlamaV2Tokenizer, 24 | ExLlamaV2, 25 | ExLlamaV2Cache, 26 | ExLlamaV2Cache_8bit, 27 | ) 28 | except ImportError as e: 29 | print(f"Error: Failed to load Exllamav2. {e}") 30 | sys.exit(-1) 31 | 32 | exllamav2_config = ExLlamaV2Config() 33 | exllamav2_config.model_dir = model_path 34 | exllamav2_config.prepare() 35 | exllamav2_config.max_seq_len = exllama_config.max_seq_len 36 | exllamav2_config.cache_8bit = exllama_config.cache_8bit 37 | 38 | exllama_model = ExLlamaV2(exllamav2_config) 39 | tokenizer = ExLlamaV2Tokenizer(exllamav2_config) 40 | 41 | split = None 42 | if exllama_config.gpu_split: 43 | split = [float(alloc) for alloc in exllama_config.gpu_split.split(",")] 44 | exllama_model.load(split) 45 | 46 | cache_class = ExLlamaV2Cache_8bit if exllamav2_config.cache_8bit else ExLlamaV2Cache 47 | exllama_cache = cache_class(exllama_model) 48 | model = ExllamaModel(exllama_model=exllama_model, exllama_cache=exllama_cache) 49 | 50 | return model, tokenizer 51 | -------------------------------------------------------------------------------- /fastchat/modules/gptq.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | import os 3 | from os.path import isdir, isfile 4 | from pathlib import Path 5 | import sys 6 | 7 | from transformers import AutoTokenizer 8 | 9 | 10 | @dataclass 11 | class GptqConfig: 12 | ckpt: str = field( 13 | default=None, 14 | metadata={ 15 | "help": "Load quantized model. The path to the local GPTQ checkpoint." 16 | }, 17 | ) 18 | wbits: int = field(default=16, metadata={"help": "#bits to use for quantization"}) 19 | groupsize: int = field( 20 | default=-1, 21 | metadata={"help": "Groupsize to use for quantization; default uses full row."}, 22 | ) 23 | act_order: bool = field( 24 | default=True, 25 | metadata={"help": "Whether to apply the activation order GPTQ heuristic"}, 26 | ) 27 | 28 | 29 | def load_gptq_quantized(model_name, gptq_config: GptqConfig): 30 | print("Loading GPTQ quantized model...") 31 | 32 | try: 33 | script_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 34 | module_path = os.path.join(script_path, "../repositories/GPTQ-for-LLaMa") 35 | 36 | sys.path.insert(0, module_path) 37 | from llama import load_quant 38 | except ImportError as e: 39 | print(f"Error: Failed to load GPTQ-for-LLaMa. {e}") 40 | print("See https://github.com/lm-sys/FastChat/blob/main/docs/gptq.md") 41 | sys.exit(-1) 42 | 43 | tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False) 44 | # only `fastest-inference-4bit` branch cares about `act_order` 45 | if gptq_config.act_order: 46 | model = load_quant( 47 | model_name, 48 | find_gptq_ckpt(gptq_config), 49 | gptq_config.wbits, 50 | gptq_config.groupsize, 51 | act_order=gptq_config.act_order, 52 | ) 53 | else: 54 | # other branches 55 | model = load_quant( 56 | model_name, 57 | find_gptq_ckpt(gptq_config), 58 | gptq_config.wbits, 59 | gptq_config.groupsize, 60 | ) 61 | 62 | return model, tokenizer 63 | 64 | 65 | def find_gptq_ckpt(gptq_config: GptqConfig): 66 | if Path(gptq_config.ckpt).is_file(): 67 | return gptq_config.ckpt 68 | 69 | for ext in ["*.pt", "*.safetensors"]: 70 | matched_result = sorted(Path(gptq_config.ckpt).glob(ext)) 71 | if len(matched_result) > 0: 72 | return str(matched_result[-1]) 73 | 74 | print("Error: gptq checkpoint not found") 75 | sys.exit(1) 76 | -------------------------------------------------------------------------------- /fastchat/modules/xfastertransformer.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | import sys 3 | 4 | 5 | @dataclass 6 | class XftConfig: 7 | max_seq_len: int = 4096 8 | beam_width: int = 1 9 | eos_token_id: int = -1 10 | pad_token_id: int = -1 11 | num_return_sequences: int = 1 12 | is_encoder_decoder: bool = False 13 | padding: bool = True 14 | early_stopping: bool = False 15 | data_type: str = "bf16_fp16" 16 | 17 | 18 | class XftModel: 19 | def __init__(self, xft_model, xft_config): 20 | self.model = xft_model 21 | self.config = xft_config 22 | 23 | 24 | def load_xft_model(model_path, xft_config: XftConfig): 25 | try: 26 | import xfastertransformer 27 | from transformers import AutoTokenizer 28 | except ImportError as e: 29 | print(f"Error: Failed to load xFasterTransformer. {e}") 30 | sys.exit(-1) 31 | 32 | if xft_config.data_type is None or xft_config.data_type == "": 33 | data_type = "bf16_fp16" 34 | else: 35 | data_type = xft_config.data_type 36 | tokenizer = AutoTokenizer.from_pretrained( 37 | model_path, use_fast=False, padding_side="left", trust_remote_code=True 38 | ) 39 | xft_model = xfastertransformer.AutoModel.from_pretrained( 40 | model_path, dtype=data_type 41 | ) 42 | model = XftModel(xft_model=xft_model, xft_config=xft_config) 43 | if model.model.rank > 0: 44 | while True: 45 | model.model.generate() 46 | return model, tokenizer 47 | -------------------------------------------------------------------------------- /fastchat/serve/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lm-sys/FastChat/9a295b64ce3491ff15901f2d00f5e304b0ee78dc/fastchat/serve/__init__.py -------------------------------------------------------------------------------- /fastchat/serve/example_images/distracted.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lm-sys/FastChat/9a295b64ce3491ff15901f2d00f5e304b0ee78dc/fastchat/serve/example_images/distracted.jpg -------------------------------------------------------------------------------- /fastchat/serve/example_images/fridge.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lm-sys/FastChat/9a295b64ce3491ff15901f2d00f5e304b0ee78dc/fastchat/serve/example_images/fridge.jpg -------------------------------------------------------------------------------- /fastchat/serve/gateway/README.md: -------------------------------------------------------------------------------- 1 | # fastchat Nginx Gateway 2 | 3 | ## Purpose of the Gateway 4 | 5 | The Nginx gateway serves the following purposes: 6 | 7 | 1. Protects Gradio servers by acting as a firewall. 8 | 2. Facilitates dynamic mounting and unmounting of Gradio servers. 9 | 3. Provides load balancing for Gradio servers. 10 | 4. Offers additional security features, such as total connection limit. 11 | 5. Reduces attack surface by requiring only a single public port to be exposed for serving. 12 | 13 | ## Deployment and Updating of the Gateway 14 | 15 | ### Installing Nginx 16 | 17 | On Debian-based distributions (e.g., Ubuntu): 18 | 19 | ```bash 20 | sudo apt update 21 | sudo apt install nginx 22 | ``` 23 | On Red Hat-based distributions (e.g., CentOS, Fedora): 24 | 25 | ```bash 26 | sudo yum install epel-release 27 | sudo yum install nginx 28 | ``` 29 | 30 | ### Deployment 31 | 32 | Copy `nginx.conf` to `/etc/nginx/nginx.conf` (need sudo permission). 33 | 34 | Replace the port number 7860 in `server localhost:7860` with the port where you deploy the Gradio web server. 35 | 36 | Modify `upstream websocket` to configure Gradio servers behind the gateway. 37 | 38 | Lastly, update Nginx. 39 | 40 | 41 | ### HTTPS Deployment with a Public Domain URL 42 | 43 | Make sure you obtain the HTTPS certificate and the private key used to generate the certificate. 44 | 45 | Fill the addresses to your certificate and private key in the `[PATH_TO_SSL_CERT]` and `[PATH_TO_PRIVATE_KEY]` fields. 46 | 47 | If you have your own domain url to serve the chatbot, replace the chat.lmsys.org url with your own domain url. 48 | 49 | ### Updating 50 | 51 | Every time when `/etc/nginx/nginx.conf` is modified, you need to update the Nginx service: 52 | 53 | ```bash 54 | sudo nginx -t # check `/etc/nginx/nginx.conf` 55 | sudo systemctl reload nginx # restart Nginx service to load the new config 56 | sudo systemctl status nginx # check the status of the Nginx service. It should be active (running). 57 | ``` 58 | -------------------------------------------------------------------------------- /fastchat/serve/gateway/nginx.conf: -------------------------------------------------------------------------------- 1 | user www-data; 2 | worker_processes auto; 3 | pid /run/nginx.pid; 4 | include /etc/nginx/modules-enabled/*.conf; 5 | 6 | events { 7 | worker_connections 1024; # maximum number of connections that a worker process can handle concurrently 8 | # multi_accept on; # enabling multi_accept can help improve performance under high load, but may increase the number of simultaneous connections that a worker process can handle 9 | 10 | } 11 | 12 | http { 13 | ## 14 | # Basic Settings 15 | ## 16 | 17 | sendfile on; # enable sendfile for performance optimization 18 | tcp_nopush on; # enable TCP no-pushing 19 | tcp_nodelay on; # enable TCP no-delay 20 | keepalive_timeout 65; # sets the timeout for keep-alive connections 21 | types_hash_max_size 2048; # maximum size of the types hash table 22 | # server_tokens off; # disable server token (i.e., server signature) in response headers to improve security 23 | 24 | # server_names_hash_bucket_size 64; 25 | # server_name_in_redirect off; 26 | 27 | include /etc/nginx/mime.types; # include MIME types file 28 | default_type application/octet-stream; # default MIME type for unknown file types 29 | 30 | ## 31 | # SSL Settings 32 | ## 33 | 34 | ssl_protocols TLSv1.2; # specify SSL/TLS protocols to use 35 | ssl_prefer_server_ciphers on; # prefer server ciphers over client ciphers 36 | 37 | ## 38 | # Logging Settings 39 | ## 40 | 41 | access_log /var/log/nginx/access.log; # path to access log file 42 | error_log /var/log/nginx/error.log; # path to error log file 43 | 44 | ## 45 | # Gzip Settings 46 | ## 47 | gzip on; # enable Gzip compression 48 | 49 | ## 50 | # Virtual Host Configs 51 | ## 52 | 53 | include /etc/nginx/conf.d/*.conf; # include all configuration files in conf.d directory 54 | include /etc/nginx/sites-enabled/*; # include all enabled sites configuration files 55 | 56 | # WebSocket Proxy: https://www.nginx.com/blog/websocket-nginx/ 57 | map $http_upgrade $connection_upgrade { 58 | default upgrade; 59 | '' close; 60 | } 61 | 62 | upstream websocket { 63 | ip_hash; # load balancing by IP to guarantee session persistence 64 | server localhost:7860; # The port should be the gradio web server port 65 | # server localhost:7861; # extra gradio server if more than one 66 | } 67 | 68 | limit_conn_status 429; 69 | limit_conn_zone $binary_remote_addr zone=perip:10m; # limit number of connections per IP 70 | limit_conn_zone $server_name zone=perserver:10m; # limit number of connections per server 71 | 72 | server { 73 | listen 443 ssl; # the listening port of our server 74 | ssl_certificate [PATH_TO_SSL_CERT]; 75 | ssl_certificate_key [PATH_TO_PRIVATE_KEY]; 76 | server_name chat.lmsys.org; # replace the url with your own domain url 77 | limit_conn perserver 1024; # connections per server 78 | location / { 79 | proxy_pass http://websocket; # proxy all requests to the defined upstream server 80 | limit_conn perip 5; # connections per IP 81 | proxy_set_header Host $host; # set the Host header for the upstream server 82 | proxy_set_header X-Real-IP $remote_addr; # set the client IP address as the real IP for the upstream server 83 | proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; # set the client IP addresses in the X-Forwarded-For header 84 | proxy_http_version 1.1; # use HTTP version 1.1 for upstream communication 85 | proxy_set_header Upgrade $http_upgrade; 86 | proxy_set_header Connection "Upgrade"; # set the Connection header to Upgrade to enable WebSocket communication 87 | } 88 | } 89 | 90 | # the following block routes all HTTP traffic to HTTPS via nginx 91 | server { 92 | listen 80; 93 | server_name chat.lmsys.org; 94 | return 301 https://chat.lmsys.org$request_uri; 95 | } 96 | 97 | } 98 | -------------------------------------------------------------------------------- /fastchat/serve/gradio_global_state.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import List 3 | 4 | 5 | @dataclass 6 | class Context: 7 | text_models: List[str] = field(default_factory=list) 8 | all_text_models: List[str] = field(default_factory=list) 9 | vision_models: List[str] = field(default_factory=list) 10 | all_vision_models: List[str] = field(default_factory=list) 11 | models: List[str] = field(default_factory=list) 12 | all_models: List[str] = field(default_factory=list) 13 | -------------------------------------------------------------------------------- /fastchat/serve/huggingface_api.py: -------------------------------------------------------------------------------- 1 | """ 2 | Use FastChat with Hugging Face generation APIs. 3 | 4 | Usage: 5 | python3 -m fastchat.serve.huggingface_api --model lmsys/vicuna-7b-v1.5 6 | python3 -m fastchat.serve.huggingface_api --model lmsys/fastchat-t5-3b-v1.0 7 | """ 8 | import argparse 9 | 10 | import torch 11 | 12 | from fastchat.model import load_model, get_conversation_template, add_model_args 13 | 14 | 15 | @torch.inference_mode() 16 | def main(args): 17 | # Load model 18 | model, tokenizer = load_model( 19 | args.model_path, 20 | device=args.device, 21 | num_gpus=args.num_gpus, 22 | max_gpu_memory=args.max_gpu_memory, 23 | load_8bit=args.load_8bit, 24 | cpu_offloading=args.cpu_offloading, 25 | revision=args.revision, 26 | debug=args.debug, 27 | ) 28 | 29 | # Build the prompt with a conversation template 30 | msg = args.message 31 | conv = get_conversation_template(args.model_path) 32 | conv.append_message(conv.roles[0], msg) 33 | conv.append_message(conv.roles[1], None) 34 | prompt = conv.get_prompt() 35 | 36 | # Run inference 37 | inputs = tokenizer([prompt], return_tensors="pt").to(args.device) 38 | output_ids = model.generate( 39 | **inputs, 40 | do_sample=True if args.temperature > 1e-5 else False, 41 | temperature=args.temperature, 42 | repetition_penalty=args.repetition_penalty, 43 | max_new_tokens=args.max_new_tokens, 44 | ) 45 | 46 | if model.config.is_encoder_decoder: 47 | output_ids = output_ids[0] 48 | else: 49 | output_ids = output_ids[0][len(inputs["input_ids"][0]) :] 50 | outputs = tokenizer.decode( 51 | output_ids, skip_special_tokens=True, spaces_between_special_tokens=False 52 | ) 53 | 54 | # Print results 55 | print(f"{conv.roles[0]}: {msg}") 56 | print(f"{conv.roles[1]}: {outputs}") 57 | 58 | 59 | if __name__ == "__main__": 60 | parser = argparse.ArgumentParser() 61 | add_model_args(parser) 62 | parser.add_argument("--temperature", type=float, default=0.7) 63 | parser.add_argument("--repetition_penalty", type=float, default=1.0) 64 | parser.add_argument("--max-new-tokens", type=int, default=1024) 65 | parser.add_argument("--debug", action="store_true") 66 | parser.add_argument("--message", type=str, default="Hello! Who are you?") 67 | args = parser.parse_args() 68 | 69 | # Reset default repetition penalty for T5 models. 70 | if "t5" in args.model_path and args.repetition_penalty == 1.0: 71 | args.repetition_penalty = 1.2 72 | 73 | main(args) 74 | -------------------------------------------------------------------------------- /fastchat/serve/monitor/add_markdown_info.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import re 3 | import argparse 4 | 5 | from tqdm import tqdm 6 | 7 | tqdm.pandas() 8 | 9 | 10 | def count_markdown_elements(markdown_text, suffix): 11 | counters = { 12 | f"header_count{suffix}": { 13 | "h1": len(re.findall(r"^#{1}\s", markdown_text, re.MULTILINE)), 14 | "h2": len(re.findall(r"^#{2}\s", markdown_text, re.MULTILINE)), 15 | "h3": len(re.findall(r"^#{3}\s", markdown_text, re.MULTILINE)), 16 | "h4": len(re.findall(r"^#{4}\s", markdown_text, re.MULTILINE)), 17 | "h5": len(re.findall(r"^#{5}\s", markdown_text, re.MULTILINE)), 18 | "h6": len(re.findall(r"^#{6}\s", markdown_text, re.MULTILINE)), 19 | }, 20 | f"list_count{suffix}": { 21 | "ordered": len(re.findall(r"^\s*\d+\.\s", markdown_text, re.MULTILINE)), 22 | "unordered": len(re.findall(r"^\s*[-*+]\s", markdown_text, re.MULTILINE)), 23 | }, 24 | f"bold_count{suffix}": { 25 | "**": len(re.findall(r"\*\*[^*\n]+\*\*", markdown_text)), 26 | "__": len(re.findall(r"__[^_\n]+__", markdown_text)), 27 | }, 28 | } 29 | return counters 30 | 31 | 32 | def remove_pattern(answer, pattern): 33 | blocks = pattern.findall(answer) 34 | for block in blocks: 35 | answer = answer.replace(block, "") 36 | return answer 37 | 38 | 39 | def get_element_counts(df, column): 40 | pattern = re.compile("```([^`]*)```") 41 | answers = df[column].map( 42 | lambda convo: "\n".join( 43 | [turn["content"] for turn in convo if turn["role"] == "assistant"] 44 | ) 45 | ) 46 | results = answers.progress_map( 47 | lambda answer: count_markdown_elements( 48 | remove_pattern(answer, pattern), 49 | suffix=column[-2:], # Remove code block first 50 | ) 51 | ) 52 | 53 | return results.tolist() 54 | 55 | 56 | def add_markdown_meta(row): 57 | conv_meta = {k: v for k, v in row["conv_metadata"].items()} 58 | return conv_meta | row["markdown_meta_a"] | row["markdown_meta_b"] 59 | 60 | 61 | if __name__ == "__main__": 62 | parser = argparse.ArgumentParser() 63 | parser.add_argument("--input-file", type=str, required=True) 64 | parser.add_argument("--output-file", type=str, required=True) 65 | args = parser.parse_args() 66 | 67 | print("loading file...") 68 | data = pd.read_json(args.input_file) 69 | 70 | assert "conv_metadata" in data.columns 71 | 72 | temp = data[["question_id", "conv_metadata"]].copy() 73 | 74 | print("Processing conversation_a") 75 | temp["markdown_meta_a"] = get_element_counts(data, column="conversation_a") 76 | 77 | print("Processing conversation_b") 78 | temp["markdown_meta_b"] = get_element_counts(data, column="conversation_b") 79 | 80 | print("Post-processing...") 81 | data["conv_metadata"] = temp.apply(add_markdown_meta, axis=1) 82 | 83 | print("Saving to file...") 84 | data.to_json(args.output_file, orient="records", indent=4, force_ascii=False) 85 | -------------------------------------------------------------------------------- /fastchat/serve/monitor/classify/README.md: -------------------------------------------------------------------------------- 1 | ## Download dataset 2 | We have pre-generated several category classifier benchmarks and ground truths. You can download them (with [`git-lfs`](https://git-lfs.com) installed) to the directory `classify/` by running 3 | ```console 4 | > git clone https://huggingface.co/datasets/lmarena-ai/categories-benchmark-eval 5 | // cd into classify/ and then copy the label_bench directory to the current directory 6 | > cp -r categories-benchmark-eval/label_bench . 7 | ``` 8 | Your label_bench directory should follow the structure: 9 | ```markdown 10 | ├── label_bench/ 11 | │ ├── creative_writing_bench/ 12 | │ │ ├── data/ 13 | │ │ │ └── llama-v3p1-70b-instruct.json 14 | │ │ └── test.json 15 | │ ├── ... 16 | │ ├── your_bench_name/ 17 | │ │ ├── data/ 18 | │ │ │ ├── your_classifier_data_1.json 19 | │ │ │ ├── your_classifier_data_2.json 20 | │ │ │ └── ... 21 | │ │ └── test.json (your ground truth) 22 | └── ... 23 | ``` 24 | 25 | ## How to evaluate your category classifier? 26 | 27 | To test your new classifier for a new category, you would have to make sure you created the category child class in `category.py`. Then, to generate classification labels, make the necessary edits in `config.yaml` and run 28 | ```console 29 | python label.py --config config.yaml --testing 30 | ``` 31 | 32 | If you are labeling a vision category, add the `--vision` flag to the command. This will add a new column to the input data called `image_path` that contains the path to the image corresponding to each conversation. Ensure that you update your config with the correct `image_dir` where the images are stored. 33 | 34 | Then, add your new category bench to `tag_names` in `display_score.py`. After making sure that you also have a correctly formatted ground truth json file, you can report the performance of your classifier by running 35 | ```console 36 | python display_score.py --bench 37 | ``` 38 | 39 | If you want to check out conflicts between your classifier and ground truth, use 40 | ```console 41 | python display_score.py --bench --display-conflict 42 | ``` 43 | 44 | Example output: 45 | ```console 46 | > python display_score.py --bench if_bench --display-conflict 47 | Model: gpt-4o-mini-2024-07-18 48 | Accuracy: 0.967 49 | Precision: 0.684 50 | Recall: 0.918 51 | 52 | ###### CONFLICT ###### 53 | 54 | Ground Truth = True; Pred = False 55 | \#################### 56 | ... 57 | 58 | Ground Truth = False; Pred = True 59 | \#################### 60 | ... 61 | ``` 62 | 63 | -------------------------------------------------------------------------------- /fastchat/serve/monitor/classify/config.yaml: -------------------------------------------------------------------------------- 1 | # Yaml config file for category classification 2 | 3 | input_file: null # json 4 | cache_file: null # json 5 | output_file: null # json line 6 | 7 | convert_to_json: True 8 | 9 | task_name: 10 | - criteria_v0.1 11 | - if_v0.1 12 | - math_v0.1 13 | - creative_writing_v0.1 14 | 15 | model_name: null 16 | name: llama-3-70b-instruct 17 | api_type: openai 18 | endpoints: 19 | - api_base: null 20 | api_key: null 21 | parallel: 50 22 | temperature: 0.0 23 | max_token: 512 24 | 25 | image_dir: null # directory where vision arena images are stored 26 | 27 | max_retry: 2 28 | retry_sleep: 10 29 | error_output: $ERROR$ -------------------------------------------------------------------------------- /fastchat/serve/monitor/classify/display_score.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import argparse 3 | import os 4 | from glob import glob 5 | from sklearn.metrics import recall_score, precision_score 6 | 7 | tag_names = { 8 | "if_bench": ("if_v0.1", "if"), 9 | "math_bench": ("math_v0.1", "math"), 10 | "hard_bench": ("criteria_v0.1", "hard"), 11 | "creative_writing_bench": ("creative_writing_v0.1", "creative_writing"), 12 | } 13 | 14 | 15 | if __name__ == "__main__": 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument("--bench", type=str, default="if_bench") 18 | parser.add_argument("--display-conflict", action="store_true") 19 | args = parser.parse_args() 20 | assert args.bench in tag_names, "Not valid bench argument, add bench if needed." 21 | 22 | test = pd.read_json(os.path.join("label_bench", args.bench, "test.json")) 23 | 24 | for file in glob(os.path.join("label_bench", args.bench, "data", "*.json")): 25 | output = pd.read_json(file) 26 | 27 | tag_map = ( 28 | output[["question_id", "category_tag"]] 29 | .set_index("question_id") 30 | .to_dict("index") 31 | ) 32 | 33 | tag_1, tag_2 = tag_names[args.bench] 34 | test["pred"] = test.question_id.map( 35 | lambda id: tag_map[id]["category_tag"][tag_1][tag_2] 36 | ) 37 | 38 | accuracy = (test.label == test.pred).mean() 39 | recall = recall_score(y_pred=test.pred, y_true=test.label) 40 | precision = precision_score(y_pred=test.pred, y_true=test.label) 41 | 42 | print(f"Model: {output.model[0]}") 43 | print(f"Accuracy: {round(accuracy, 3)}") 44 | print(f"Precision: {round(precision, 3)}") 45 | print(f"Recall: {round(recall, 3)}") 46 | 47 | if args.display_conflict: 48 | print() 49 | print("###### CONFLICT ######") 50 | print() 51 | conflict = test[test.label & ~test.pred] 52 | print("Ground Truth = True; Pred = False") 53 | prompts = ( 54 | conflict.conversation_a.map(lambda x: x[0]["content"]) 55 | .sample(n=5) 56 | .tolist() 57 | ) 58 | for prompt in prompts: 59 | print("####################") 60 | print(prompt) 61 | print() 62 | print() 63 | 64 | conflict = test[~test.label & test.pred] 65 | print("Ground Truth = False; Pred = True") 66 | prompts = ( 67 | conflict.conversation_a.map(lambda x: x[0]["content"]) 68 | .sample(n=5) 69 | .tolist() 70 | ) 71 | for prompt in prompts: 72 | print("####################") 73 | print(prompt) 74 | print() 75 | print() 76 | print() 77 | -------------------------------------------------------------------------------- /fastchat/serve/monitor/classify/vision_config.yaml: -------------------------------------------------------------------------------- 1 | # Yaml config file for category classification 2 | 3 | input_file: null # json 4 | cache_file: null # json 5 | output_file: null # json line 6 | 7 | convert_to_json: True 8 | 9 | task_name: 10 | - captioning_v0.1 11 | - homework_v0.1 12 | - ocr_v0.1 13 | - humor_v0.1 14 | - entity_recognition_v0.1 15 | - creative_writing_vision_v0.1 16 | - diagram_v0.1 17 | 18 | 19 | model_name: null 20 | name: gemini-1.5-flash 21 | api_type: gemini 22 | endpoints: 23 | - api_base: null 24 | api_key: null 25 | 26 | parallel: 50 27 | temperature: 0.0 28 | max_token: 512 29 | 30 | image_dir: null # directory where vision arena images are stored 31 | 32 | max_retry: 2 33 | retry_sleep: 10 34 | error_output: $ERROR$ -------------------------------------------------------------------------------- /fastchat/serve/monitor/copilot_arena.py: -------------------------------------------------------------------------------- 1 | import gradio as gr 2 | import pandas as pd 3 | import requests 4 | import os 5 | 6 | from fastchat.serve.monitor.monitor import recompute_final_ranking 7 | 8 | copilot_arena_leaderboard_url = os.getenv("COPILOT_ARENA_LEADERBOARD_URL") 9 | 10 | 11 | def process_copilot_arena_leaderboard(leaderboard): 12 | leaderboard = leaderboard.copy().loc[leaderboard["visibility"] == "public"] 13 | leaderboard["score"] = leaderboard["score"].round().astype(int) 14 | leaderboard["rating_q975"] = leaderboard["upper"].round().astype(int) 15 | leaderboard["rating_q025"] = leaderboard["lower"].round().astype(int) 16 | 17 | leaderboard["upper_diff"] = leaderboard["rating_q975"] - leaderboard["score"] 18 | leaderboard["lower_diff"] = leaderboard["score"] - leaderboard["rating_q025"] 19 | 20 | leaderboard["confidence_interval"] = ( 21 | "+" 22 | + leaderboard["upper_diff"].astype(str) 23 | + " / -" 24 | + leaderboard["lower_diff"].astype(str) 25 | ) 26 | 27 | rankings_ub = recompute_final_ranking(leaderboard) 28 | leaderboard.insert(loc=0, column="Rank* (UB)", value=rankings_ub) 29 | 30 | leaderboard = leaderboard.sort_values( 31 | by=["Rank* (UB)", "score"], ascending=[True, False] 32 | ) 33 | 34 | return leaderboard 35 | 36 | 37 | def build_copilot_arena_tab(): 38 | response = requests.get(copilot_arena_leaderboard_url) 39 | if response.status_code == 200: 40 | leaderboard = pd.DataFrame(response.json()["elo_data"]) 41 | leaderboard = process_copilot_arena_leaderboard(leaderboard) 42 | leaderboard = leaderboard.rename( 43 | columns={ 44 | "name": "Model", 45 | "confidence_interval": "Confidence Interval", 46 | "score": "Arena Score", 47 | "organization": "Organization", 48 | "votes": "Votes", 49 | } 50 | ) 51 | 52 | column_order = [ 53 | "Rank* (UB)", 54 | "Model", 55 | "Arena Score", 56 | "Confidence Interval", 57 | "Votes", 58 | "Organization", 59 | ] 60 | leaderboard = leaderboard[column_order] 61 | num_models = len(leaderboard) 62 | total_battles = int(leaderboard["Votes"].sum()) // 2 63 | md = f""" 64 | [Copilot Arena](https://blog.lmarena.ai/blog/2024/copilot-arena/) is a free AI coding assistant that provides paired responses from different state-of-the-art LLMs. This leaderboard contains the relative performance and ranking of {num_models} models over {total_battles} battles. 65 | """ 66 | 67 | gr.Markdown(md, elem_id="leaderboard_markdown") 68 | gr.DataFrame( 69 | leaderboard, 70 | datatype=["str" for _ in leaderboard.columns], 71 | elem_id="arena_hard_leaderboard", 72 | height=600, 73 | wrap=True, 74 | interactive=False, 75 | column_widths=[70, 130, 60, 80, 50, 80], 76 | ) 77 | 78 | gr.Markdown( 79 | """ 80 | ***Rank (UB)**: model's ranking (upper-bound), defined by one + the number of models that are statistically better than the target model. 81 | Model A is statistically better than model B when A's lower-bound score is greater than B's upper-bound score (in 95% confidence interval). \n 82 | **Confidence Interval**: represents the range of uncertainty around the Arena Score. It's displayed as +X / -Y, where X is the difference between the upper bound and the score, and Y is the difference between the score and the lower bound. 83 | """, 84 | elem_id="leaderboard_markdown", 85 | ) 86 | else: 87 | gr.Markdown("Error with fetching Copilot Arena data. Check back in later.") 88 | -------------------------------------------------------------------------------- /fastchat/serve/monitor/dataset_release_scripts/arena_33k/count_unique_users.py: -------------------------------------------------------------------------------- 1 | """Count the unique users in a battle log file.""" 2 | 3 | import argparse 4 | import json 5 | 6 | 7 | if __name__ == "__main__": 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument("--input", type=str) 10 | args = parser.parse_args() 11 | 12 | lines = json.load(open(args.input)) 13 | ct_anony_votes = 0 14 | all_users = set() 15 | all_models = set() 16 | for l in lines: 17 | if not l["anony"]: 18 | continue 19 | all_users.add(l["judge"]) 20 | all_models.add(l["model_a"]) 21 | all_models.add(l["model_b"]) 22 | ct_anony_votes += 1 23 | 24 | print(f"#anony_vote: {ct_anony_votes}, #user: {len(all_users)}") 25 | print(f"#model: {len(all_models)}") 26 | -------------------------------------------------------------------------------- /fastchat/serve/monitor/dataset_release_scripts/arena_33k/merge_field.py: -------------------------------------------------------------------------------- 1 | """Count the unique users in a battle log file.""" 2 | 3 | import argparse 4 | import json 5 | 6 | 7 | if __name__ == "__main__": 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument("--input", type=str) 10 | parser.add_argument("--tag-file", type=str) 11 | args = parser.parse_args() 12 | 13 | # build index 14 | objs = json.load(open(args.tag_file)) 15 | new_field_dict = {} 16 | for obj in objs: 17 | new_field_dict[obj["question_id"]] = obj["toxic_chat"] 18 | 19 | objs = json.load(open(args.input)) 20 | for obj in objs: 21 | obj["toxic_chat_tag"] = new_field_dict[obj["question_id"]] 22 | 23 | output = args.input.replace(".json", "_added.json") 24 | with open(output, "w") as fout: 25 | json.dump(objs, fout, indent=2, ensure_ascii=False) 26 | -------------------------------------------------------------------------------- /fastchat/serve/monitor/dataset_release_scripts/arena_33k/sample.py: -------------------------------------------------------------------------------- 1 | """ 2 | Count the unique users in a battle log file. 3 | 4 | Usage: 5 | python3 -input in.json --number 1000 6 | """ 7 | 8 | import argparse 9 | import json 10 | import random 11 | 12 | K = 1000 13 | 14 | if __name__ == "__main__": 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument("--input", type=str) 17 | parser.add_argument("--number", type=int, nargs="+") 18 | args = parser.parse_args() 19 | 20 | convs = json.load(open(args.input)) 21 | random.seed(0) 22 | random.shuffle(convs) 23 | 24 | for number in args.number: 25 | new_convs = convs[:number] 26 | 27 | output = args.input.replace(".json", f"_{number//K}k.json") 28 | with open(output, "w") as fout: 29 | json.dump(new_convs, fout, indent=2, ensure_ascii=False) 30 | 31 | print(f"#in: {len(convs)}, #out: {len(new_convs)}") 32 | print(f"Write to file: {output}") 33 | -------------------------------------------------------------------------------- /fastchat/serve/monitor/dataset_release_scripts/arena_33k/upload_hf_dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | Upload to huggingface. 3 | """ 4 | import json 5 | from datasets import Dataset, DatasetDict, load_dataset 6 | 7 | objs = json.load(open("clean_battle_conv_20230630_tagged_v3_pii_33k_added.json")) 8 | data = Dataset.from_list(objs) 9 | data.push_to_hub("lmsys/chatbot_arena_conversations", private=True) 10 | -------------------------------------------------------------------------------- /fastchat/serve/monitor/dataset_release_scripts/lmsys_chat_1m/approve_all.py: -------------------------------------------------------------------------------- 1 | import requests 2 | 3 | headers = {"authorization": "Bearer hf_XXX"} 4 | 5 | url = "https://huggingface.co/api/datasets/lmsys/lmsys-chat-1m/user-access-request/pending" 6 | a = requests.get(url, headers=headers) 7 | 8 | for u in a.json(): 9 | user = u["user"]["user"] 10 | url = "https://huggingface.co/api/datasets/lmsys/lmsys-chat-1m/user-access-request/grant" 11 | ret = requests.post(url, headers=headers, json={"user": user}) 12 | print(user, ret.status_code) 13 | assert ret.status_code == 200 14 | -------------------------------------------------------------------------------- /fastchat/serve/monitor/dataset_release_scripts/lmsys_chat_1m/compute_stats.py: -------------------------------------------------------------------------------- 1 | """ 2 | From colab: 3 | https://colab.research.google.com/drive/1oMdw_Lqgmd6DletSOLHsyD-Rc96cRShs?usp=sharing 4 | """ 5 | import argparse 6 | import datetime 7 | import json 8 | import os 9 | from pytz import timezone 10 | import time 11 | 12 | import kaleido 13 | import numpy as np 14 | import pandas as pd 15 | import plotly.express as px 16 | import plotly.graph_objects as go 17 | from tqdm import tqdm 18 | 19 | import plotly.io as pio 20 | 21 | pio.kaleido.scope.mathjax = None 22 | 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument("--in-file", type=str, required=True) 25 | parser.add_argument("--scale", type=int, required=True) 26 | args = parser.parse_args() 27 | 28 | filename = args.in_file 29 | scale = args.scale 30 | convs = json.load(open(filename)) 31 | df = pd.DataFrame(convs) 32 | df 33 | 34 | print(f"#ips: {df['user_id'].nunique() * scale}") 35 | print(f"#models: {df['model'].nunique()}") 36 | print(f"#language: {df['language'].nunique()}") 37 | print(f"#turns: {df['turn'].mean()}") 38 | 39 | model_counts = df["model"].value_counts() * scale 40 | # print("model counts", model_counts) 41 | fig = px.bar(x=model_counts.index, y=model_counts) 42 | fig.update_layout( 43 | xaxis_title=None, 44 | yaxis_title="Count", 45 | height=200, 46 | width=950, 47 | margin=dict(l=0, r=0, t=0, b=0), 48 | ) 49 | fig.show() 50 | fig.write_image("model_count.pdf") 51 | 52 | 53 | model_counts = df["language"].value_counts().head(25) * scale 54 | fig = px.bar(x=model_counts.index, y=model_counts) 55 | fig.update_layout( 56 | xaxis_title=None, 57 | yaxis_title="Count", 58 | height=200, 59 | width=950, 60 | margin=dict(l=0, r=0, t=0, b=0), 61 | ) 62 | fig.show() 63 | fig.write_image("language_count.pdf") 64 | 65 | chat_dates = [ 66 | datetime.datetime.fromtimestamp(x, tz=timezone("US/Pacific")).strftime("%Y-%m-%d") 67 | for x in df["tstamp"] 68 | ] 69 | 70 | 71 | def to_remove(x): 72 | for d in ["08-09", "08-08", "08-07", "08-06", "08-05", "08-04"]: 73 | if d in x: 74 | return True 75 | return False 76 | 77 | 78 | chat_dates = [x for x in chat_dates if not to_remove(x)] 79 | 80 | chat_dates_counts = pd.value_counts(chat_dates) * scale 81 | print(f"mean #chat per day: {np.mean(chat_dates_counts):.2f}") 82 | 83 | fig = px.bar(x=chat_dates_counts.index, y=chat_dates_counts) 84 | fig.update_layout( 85 | xaxis_title="Dates", 86 | yaxis_title="Count", 87 | height=200, 88 | width=950, 89 | margin=dict(l=0, r=0, t=0, b=0), 90 | ) 91 | fig.show() 92 | fig.write_image("daily_conversation_count.pdf") 93 | 94 | import transformers 95 | 96 | tokenizer = transformers.AutoTokenizer.from_pretrained( 97 | "lmsys/vicuna-7b-v1.5", use_fast=False 98 | ) 99 | 100 | prompts = [] 101 | responses = [] 102 | for conv in df["conversation"]: 103 | for row in conv: 104 | if row["role"] == "user": 105 | prompts.append(row["content"]) 106 | else: 107 | responses.append(row["content"]) 108 | 109 | print(f"#prompts: {len(prompts)}") 110 | print(f"#responses: {len(responses)}") 111 | 112 | 113 | prompt_lens = [len(tokenizer(x).input_ids) for x in tqdm(prompts)] 114 | print() 115 | print(f"mean prompt len: {np.mean(prompt_lens):.2f}") 116 | 117 | response_lens = [len(tokenizer(x).input_ids) if x else 0 for x in tqdm(responses)] 118 | print() 119 | print(f"mean response len: {np.mean(response_lens):.2f}") 120 | -------------------------------------------------------------------------------- /fastchat/serve/monitor/dataset_release_scripts/lmsys_chat_1m/final_post_processing.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | 4 | from tqdm import tqdm 5 | import numpy as np 6 | 7 | 8 | if __name__ == "__main__": 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument("--in-file", type=str, required=True) 11 | args = parser.parse_args() 12 | 13 | # Read conversations 14 | convs = json.load(open(args.in_file)) 15 | print(f"#conv: {len(convs)}") 16 | 17 | # Delete some fileds 18 | for c in convs: 19 | del c["tstamp"] 20 | del c["user_id"] 21 | 22 | # Write 23 | print(f"#out conv: {len(convs)}") 24 | out_file = args.in_file.replace(".json", ".s2.json") 25 | print(f"Output to {out_file}") 26 | with open(out_file, "w") as fout: 27 | json.dump(convs, fout, indent=2, ensure_ascii=False) 28 | -------------------------------------------------------------------------------- /fastchat/serve/monitor/dataset_release_scripts/lmsys_chat_1m/instructions.md: -------------------------------------------------------------------------------- 1 | ``` 2 | export BASE=clean_conv_20230809_100k_pii 3 | export SCALE=10 4 | 5 | # filter words 6 | python3 filter_bad_conv.py --in $BASE.json 7 | 8 | # Clean up some fileds (e.g., timestamps) 9 | python3 final_post_processing.py --in $BASE.s1.json 10 | 11 | # upload to hf 12 | python3 upload_hf_dataset.py --in $BASE.s1.s2.json 13 | 14 | # Make another version with openai moderation tag 15 | python3 merge_oai_tag.py --in $BASE.s1.s2.json 16 | 17 | # Make visualizations 18 | python3 compute_stats.py --in $BASE.s1.json --scale $SCALE 19 | 20 | # Copy figures 21 | scp "atlas:/data/lmzheng/FastChat/fastchat/serve/monitor/dataset_release_scripts/lmsys_chat_1m/*.pdf" . 22 | ``` 23 | 24 | -------------------------------------------------------------------------------- /fastchat/serve/monitor/dataset_release_scripts/lmsys_chat_1m/merge_oai_tag.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import time 4 | 5 | from tqdm import tqdm 6 | 7 | 8 | if __name__ == "__main__": 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument("--in-file", type=str, required=True) 11 | parser.add_argument("--sample", type=int) 12 | args = parser.parse_args() 13 | 14 | tag_file = "clean_conv_20230809_1.5M_oai_filter_v2.json" 15 | # tag_file = "clean_conv_20230809_1.5M_oai_filter_v2_100k.json" 16 | in_file = args.in_file 17 | tic = time.time() 18 | 19 | # Load tags 20 | print("Load tags...") 21 | tag_data = json.load(open(tag_file)) 22 | tag_dict = {} 23 | for c in tqdm(tag_data): 24 | tag_dict[c["conversation_id"]] = [x["oai_filter"] for x in c["conversation"]] 25 | print(f"elapsed: {time.time() - tic:.2f} s") 26 | 27 | # Append to input_file 28 | print("Load inputs...") 29 | input_data = json.load(open(in_file)) 30 | for c in tqdm(input_data): 31 | cid = c["conversation_id"] 32 | if cid in tag_dict: 33 | c["openai_moderation"] = tag_dict[cid] 34 | else: 35 | print(f"missing tag for conv {cid}") 36 | exit() 37 | print(f"elapsed: {time.time() - tic:.2f} s") 38 | 39 | # Write output 40 | print("Write outputs...") 41 | out_file = in_file.replace(".json", ".with_tag.json") 42 | print(f"Output to {out_file}") 43 | with open(out_file, "w") as fout: 44 | json.dump(input_data, fout, indent=2, ensure_ascii=False) 45 | print(f"elapsed: {time.time() - tic:.2f} s") 46 | -------------------------------------------------------------------------------- /fastchat/serve/monitor/dataset_release_scripts/lmsys_chat_1m/process_all.sh: -------------------------------------------------------------------------------- 1 | export BASE=clean_conv_20230809_1.5M_pii 2 | #export BASE=clean_conv_20230809_100k_pii 3 | export SCALE=1 4 | 5 | # Filter words 6 | python3 filter_bad_conv.py --in $BASE.json --sample 1000000 7 | 8 | # Clean up some fileds (e.g., timestamps) 9 | python3 final_post_processing.py --in $BASE.s1.json 10 | 11 | # Upload to hf 12 | python3 upload_hf_dataset.py --in $BASE.s1.s2.json 13 | 14 | # Make another version with openai moderation tag 15 | python3 merge_oai_tag.py --in $BASE.s1.s2.json 16 | 17 | # Make visualizations 18 | python3 compute_stats.py --in $BASE.s1.json --scale $SCALE 19 | -------------------------------------------------------------------------------- /fastchat/serve/monitor/dataset_release_scripts/lmsys_chat_1m/sample.py: -------------------------------------------------------------------------------- 1 | """ 2 | Count the unique users in a battle log file. 3 | 4 | Usage: 5 | python3 -input in.json --number 1000 6 | """ 7 | 8 | import argparse 9 | import json 10 | import random 11 | 12 | K = 1000 13 | 14 | if __name__ == "__main__": 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument("--input", type=str) 17 | parser.add_argument("--number", type=int, nargs="+") 18 | args = parser.parse_args() 19 | 20 | convs = json.load(open(args.input)) 21 | random.seed(42) 22 | random.shuffle(convs) 23 | 24 | for number in args.number: 25 | new_convs = convs[:number] 26 | 27 | output = args.input.replace(".json", f"_{number//K}k.json") 28 | with open(output, "w") as fout: 29 | json.dump(new_convs, fout, indent=2, ensure_ascii=False) 30 | 31 | print(f"#in: {len(convs)}, #out: {len(new_convs)}") 32 | print(f"Write to file: {output}") 33 | -------------------------------------------------------------------------------- /fastchat/serve/monitor/dataset_release_scripts/lmsys_chat_1m/upload_hf_dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | Upload to huggingface. 3 | """ 4 | import argparse 5 | import json 6 | from datasets import Dataset, DatasetDict, load_dataset 7 | 8 | 9 | if __name__ == "__main__": 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument("--in-file", type=str, required=True) 12 | args = parser.parse_args() 13 | 14 | objs = json.load(open(args.in_file)) 15 | print(f"#convs: {len(objs)}") 16 | data = Dataset.from_list(objs) 17 | data.push_to_hub("lmsys/lmsys-chat-1m", private=True) 18 | -------------------------------------------------------------------------------- /fastchat/serve/monitor/deduplication.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import pandas as pd 4 | import ast 5 | 6 | import matplotlib.pyplot as plt 7 | from matplotlib import rcParams 8 | 9 | import argparse 10 | import seaborn as sns 11 | from tqdm import tqdm 12 | import matplotlib.pyplot as plt 13 | 14 | import numpy as np 15 | 16 | if __name__ == "__main__": 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument("--output_dir", type=str, default="output") 19 | parser.add_argument("--model", type=str, default=None) 20 | parser.add_argument("--input_file", type=str, required=True) 21 | parser.add_argument("--percentile", type=float, default=0.9999) 22 | args = parser.parse_args() 23 | output_dir = args.output_dir 24 | input_file = args.input_file 25 | 26 | with open(input_file) as f: 27 | data = json.load(f) 28 | 29 | os.makedirs(output_dir, exist_ok=True) 30 | 31 | # Preprocessing 32 | all_convs_new = [] 33 | convs = [] 34 | for row in data: 35 | conv = "" 36 | for turns in row["conversation_a"]: 37 | if turns["role"] == "user": 38 | conv += f"{turns['content']}\n" 39 | 40 | convs.append(conv[:10000]) 41 | row["post_process_conv"] = conv[:10000] 42 | all_convs_new.append(row) 43 | 44 | df = pd.DataFrame(all_convs_new) 45 | print("Number of conversations: ", len(df)) 46 | 47 | prompt_counts = df["post_process_conv"].value_counts() 48 | # Select the top 20 most frequent prompts 49 | top_prompts = prompt_counts.head(20) 50 | print(top_prompts) 51 | 52 | # Determine the percentile count 53 | percentile_cutoff = prompt_counts.quantile(args.percentile) 54 | print(f"{args.percentile*100} percentile count: {percentile_cutoff}") 55 | 56 | # prompts that are more common than the percentile cutoff 57 | high_frequency_prompts = prompt_counts[prompt_counts > percentile_cutoff].index 58 | print( 59 | f"Number of high frequency prompts: {len(high_frequency_prompts)}/{len(prompt_counts)}" 60 | ) 61 | 62 | # initialize a new column dedup_tag 63 | dedup_tags = np.array( 64 | [{"high_freq": False, "sampled": True} for _ in range(len(df))] 65 | ) 66 | high_freq_groups = df.groupby("post_process_conv") 67 | for prompt in tqdm(high_frequency_prompts): 68 | df_high_freq = high_freq_groups.get_group(prompt) 69 | sampled_indices = df_high_freq.sample( 70 | n=int(percentile_cutoff), random_state=42 71 | ).index 72 | dedup_tags[df_high_freq.index] = {"high_freq": True, "sampled": False} 73 | dedup_tags[sampled_indices] = {"high_freq": True, "sampled": True} 74 | 75 | df["dedup_tag"] = dedup_tags 76 | 77 | # drop intermediate columns (post_process_conv) 78 | df = df.drop(columns=["post_process_conv"]) 79 | 80 | df.to_json( 81 | os.path.join(output_dir, "dedup.json"), 82 | orient="records", 83 | indent=4, 84 | force_ascii=False, 85 | ) 86 | -------------------------------------------------------------------------------- /fastchat/serve/monitor/inspect_conv.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import code 3 | import datetime 4 | import json 5 | import os 6 | from pytz import timezone 7 | import time 8 | 9 | import pandas as pd 10 | from tqdm import tqdm 11 | 12 | 13 | def get_log_files(max_num_files=None): 14 | dates = [] 15 | for month in [4, 5]: 16 | for day in range(1, 32): 17 | dates.append(f"2023-{month:02d}-{day:02d}") 18 | 19 | num_servers = 14 20 | filenames = [] 21 | for d in dates: 22 | for i in range(num_servers): 23 | name = os.path.expanduser(f"~/fastchat_logs/server{i}/{d}-conv.json") 24 | if os.path.exists(name): 25 | filenames.append(name) 26 | max_num_files = max_num_files or len(filenames) 27 | filenames = filenames[-max_num_files:] 28 | return filenames 29 | 30 | 31 | def pretty_print_conversation(messages): 32 | for role, msg in messages: 33 | print(f"[[{role}]]: {msg}") 34 | 35 | 36 | def inspect_convs(log_files): 37 | data = [] 38 | for filename in tqdm(log_files, desc="read files"): 39 | for retry in range(5): 40 | try: 41 | lines = open(filename).readlines() 42 | break 43 | except FileNotFoundError: 44 | time.sleep(2) 45 | 46 | for l in lines: 47 | row = json.loads(l) 48 | 49 | if "states" not in row: 50 | continue 51 | if row["type"] not in ["leftvote", "rightvote", "bothbad_vote"]: 52 | continue 53 | 54 | model_names = row["states"][0]["model_name"], row["states"][1]["model_name"] 55 | if row["type"] == "leftvote": 56 | winner, loser = model_names[0], model_names[1] 57 | winner_conv, loser_conv = row["states"][0], row["states"][1] 58 | elif row["type"] == "rightvote": 59 | loser, winner = model_names[0], model_names[1] 60 | loser_conv, winner_conv = row["states"][0], row["states"][1] 61 | 62 | if loser == "bard" and winner == "vicuna-13b": 63 | print("=" * 20) 64 | print(f"Winner: {winner}") 65 | pretty_print_conversation(winner_conv["messages"]) 66 | print(f"Loser: {loser}") 67 | pretty_print_conversation(loser_conv["messages"]) 68 | print("=" * 20) 69 | input() 70 | 71 | # if row["type"] == "bothbad_vote" and "gpt-4" in model_names: 72 | # print("=" * 20) 73 | # print(f"Model A: {model_names[0]}") 74 | # pretty_print_conversation(row["states"][0]["messages"]) 75 | # print(f"Model B: {model_names[1]}") 76 | # pretty_print_conversation(row["states"][1]["messages"]) 77 | # print("=" * 20) 78 | # input() 79 | 80 | 81 | if __name__ == "__main__": 82 | parser = argparse.ArgumentParser() 83 | parser.add_argument("--max-num-files", type=int) 84 | args = parser.parse_args() 85 | 86 | log_files = get_log_files(args.max_num_files) 87 | inspect_convs(log_files) 88 | -------------------------------------------------------------------------------- /fastchat/serve/monitor/intersect_conv_file.py: -------------------------------------------------------------------------------- 1 | """ 2 | Take the intersection of two conversation files. 3 | 4 | Usage: python3 -m fastchat.data.merge --input input.json --conv-id conv_id_file.json --out intersect.json 5 | """ 6 | 7 | import argparse 8 | import json 9 | 10 | 11 | if __name__ == "__main__": 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument("--input", type=str, required=True) 14 | parser.add_argument("--conv-id", type=str, required=True) 15 | parser.add_argument("--out-file", type=str, default="intersect.json") 16 | args = parser.parse_args() 17 | 18 | conv_id_objs = json.load(open(args.conv_id, "r")) 19 | conv_ids = set(x["conversation_id"] for x in conv_id_objs) 20 | 21 | objs = json.load(open(args.input, "r")) 22 | after_objs = [x for x in objs if x["conversation_id"] in conv_ids] 23 | 24 | print(f"#in: {len(objs)}, #out: {len(after_objs)}") 25 | json.dump(after_objs, open(args.out_file, "w"), indent=2, ensure_ascii=False) 26 | -------------------------------------------------------------------------------- /fastchat/serve/monitor/leaderboard_csv_to_html.py: -------------------------------------------------------------------------------- 1 | """ 2 | Convert a leaderboard csv file to html table used in the blog. 3 | 4 | Usage: 5 | python3 leaderboard_csv_to_html.py --in leaderboard_table_20230619.csv 6 | """ 7 | import argparse 8 | 9 | import numpy as np 10 | 11 | from fastchat.serve.monitor.monitor import load_leaderboard_table_csv 12 | 13 | 14 | def model_hyperlink(model_name, link): 15 | return f' {model_name} ' 16 | 17 | 18 | if __name__ == "__main__": 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument("--input", type=str, required=True) 21 | args = parser.parse_args() 22 | 23 | data = load_leaderboard_table_csv(args.input, add_hyperlink=False) 24 | headers = [ 25 | "Model", 26 | "MT-bench (score)", 27 | "Arena Elo rating", 28 | "MMLU", 29 | "License", 30 | ] 31 | values = [] 32 | for item in data: 33 | row = [] 34 | for key in headers: 35 | value = item[key] 36 | row.append(value) 37 | row[0] = model_hyperlink(item["Model"], item["Link"]) 38 | values.append(row) 39 | values.sort(key=lambda x: -x[1] if not np.isnan(x[1]) else 1e9) 40 | 41 | for value in values: 42 | row = "" 43 | for x in value: 44 | try: 45 | if np.isnan(x): 46 | x = "-" 47 | except TypeError: 48 | pass 49 | row += f" {x} " 50 | row += "" 51 | print(row) 52 | -------------------------------------------------------------------------------- /fastchat/serve/monitor/summarize_cluster.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 summarize_cluster.py --in results_c20_kmeans_cluster.pkl --model gpt-4 --num-prompts 100 4 | python3 summarize_cluster.py --in results_c20_kmeans_cluster.pkl --model azure-gpt-4-32k --num-prompts 200 5 | """ 6 | import argparse 7 | import pickle 8 | 9 | import pandas as pd 10 | 11 | from fastchat.llm_judge.common import ( 12 | chat_completion_openai, 13 | chat_completion_openai_azure, 14 | chat_completion_anthropic, 15 | ) 16 | from fastchat.conversation import get_conv_template 17 | 18 | 19 | def truncate_string(s, l): 20 | half = int(l // 2) 21 | return s[:half] + s[-half:] if len(s) > l else s 22 | 23 | 24 | if __name__ == "__main__": 25 | parser = argparse.ArgumentParser() 26 | parser.add_argument("--input-file", type=str, required=True) 27 | parser.add_argument("--model", type=str, default="gpt-3.5-turbo") 28 | parser.add_argument("--num-prompts", type=int, default=100) 29 | args = parser.parse_args() 30 | 31 | model = args.model 32 | 33 | cluster_infos = pickle.load(open(args.input_file, "rb")) 34 | num_total_prompts = sum([x[0] for x in cluster_infos]) 35 | 36 | topics = [] 37 | percentages = [] 38 | for i, info in enumerate(cluster_infos): 39 | num_samples, topk_prompts, random_prompts = info 40 | percentage = num_samples / num_total_prompts 41 | print( 42 | f"cluster {i}, #prompts {num_samples}, percentage: {percentage * 100:.2f}%" 43 | ) 44 | instruct = "Given a list of user messages, use less than 8 words to summarize a central topic for all messages in English. Your output should only include a single line. Try to be specific." 45 | split = int(args.num_prompts * 0.8) 46 | prompt = "\n".join( 47 | [truncate_string(x, l=200) for x in topk_prompts[:split]] 48 | + [ 49 | truncate_string(x, l=200) 50 | for x in random_prompts[: args.num_prompts - split] 51 | ] 52 | ) 53 | prompt = "BEGIN OF THE MESSAGE LIST\n" + prompt + "\nEND OF THE MESSAGE LIST." 54 | 55 | if "azure-" in model: 56 | template_name = "chatgpt" 57 | completion_func = chat_completion_openai_azure 58 | elif "gpt" in model: 59 | template_name = "chatgpt" 60 | completion_func = chat_completion_openai 61 | elif "claude" in model: 62 | template_name = "claude" 63 | completion_func = chat_completion_anthropic 64 | 65 | conv = get_conv_template(template_name) 66 | conv.set_system_message(instruct) 67 | conv.append_message(conv.roles[0], prompt) 68 | conv.append_message(conv.roles[1], None) 69 | 70 | topic = completion_func(model, conv, temperature=0, max_tokens=256) 71 | print(topic) 72 | 73 | topics.append(topic) 74 | percentages.append(round(percentage, 6)) 75 | 76 | print() 77 | print(f"topics: {topics}") 78 | print(f"percentages: {percentages}") 79 | 80 | # save the informations 81 | df = pd.DataFrame() 82 | df["topic"] = topics 83 | df["percentage"] = percentages 84 | 85 | df.to_json(f"cluster_summary_{len(df)}.jsonl", lines=True, orient="records") 86 | -------------------------------------------------------------------------------- /fastchat/serve/monitor/tag_openai_moderation.py: -------------------------------------------------------------------------------- 1 | """ 2 | Add OpenAI moderation API results to all conversations. 3 | """ 4 | import argparse 5 | from concurrent.futures import ThreadPoolExecutor 6 | import json 7 | import os 8 | import time 9 | 10 | import openai 11 | import requests 12 | from tqdm import tqdm 13 | 14 | 15 | API_MAX_RETRY = 16 16 | API_RETRY_SLEEP = 10 17 | API_ERROR_OUTPUT = "$ERROR$" 18 | 19 | 20 | def tag_moderation(text): 21 | result = API_ERROR_OUTPUT 22 | for _ in range(API_MAX_RETRY): 23 | try: 24 | result = openai.Moderation.create(input=text)["results"][0] 25 | break 26 | except openai.error.OpenAIError as e: 27 | print(type(e), e) 28 | time.sleep(API_RETRY_SLEEP) 29 | 30 | return result 31 | 32 | 33 | def tag_openai_moderation(x): 34 | conv = x["conversation_a"] 35 | user_prompts = "\n".join([x["content"] for x in conv if x["role"] == "user"]) 36 | result = tag_moderation(user_prompts) 37 | x["openai_moderation"] = result 38 | 39 | 40 | if __name__ == "__main__": 41 | parser = argparse.ArgumentParser() 42 | parser.add_argument("--input", type=str, required=True) 43 | parser.add_argument( 44 | "--parallel", type=int, default=1, help="The number of concurrent API calls." 45 | ) 46 | parser.add_argument("--first-n", type=int) 47 | args = parser.parse_args() 48 | 49 | battles = json.load(open(args.input)) 50 | 51 | if args.first_n: 52 | battles = battles[: args.first_n] 53 | 54 | with ThreadPoolExecutor(args.parallel) as executor: 55 | for line in tqdm( 56 | executor.map(tag_openai_moderation, battles), total=len(battles) 57 | ): 58 | pass 59 | 60 | output = args.input.replace(".json", "_tagged.json") 61 | with open(output, "w") as fout: 62 | json.dump(battles, fout, indent=2, ensure_ascii=False) 63 | print(f"Write cleaned data to {output}") 64 | -------------------------------------------------------------------------------- /fastchat/serve/monitor/vote_time_stats/README.md: -------------------------------------------------------------------------------- 1 | # Instructions 2 | 3 | First run `analyze_data.py` to collect metadata of all votes. 4 | 5 | Then run `plot.py` to get the plot. You need to edit these files for proper input or output filename 6 | -------------------------------------------------------------------------------- /fastchat/serve/monitor/vote_time_stats/analyze_data.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import glob 3 | import json 4 | from collections import deque 5 | import tqdm 6 | 7 | 8 | def _serialize_json(data): 9 | # Serialize JSON with sorted keys and no whitespace 10 | return json.dumps(data, sort_keys=True, separators=(",", ":")).encode("utf-8") 11 | 12 | 13 | types = { 14 | "share", 15 | "chat", 16 | "flag", 17 | "bothbad_vote", 18 | "downvote", 19 | "leftvote", 20 | "rightvote", 21 | "upvote", 22 | "tievote", 23 | } 24 | 25 | chat_dict = {} 26 | cache_queue = deque() 27 | 28 | 29 | def process_record(r): 30 | ip = r.pop("ip", None) 31 | tstamp = r.pop("tstamp") 32 | mtype = r.pop("type") 33 | start = r.pop("start", None) 34 | finish = r.pop("finish", None) 35 | 36 | # gabagge collect to save memory 37 | while len(cache_queue) > 100000: 38 | outdated = cache_queue.popleft() 39 | poped_item = chat_dict.pop(outdated["key"], None) 40 | if poped_item is None: 41 | # TODO: this sometimes happens, need to investigate what happens. in theory the chat dict should be synced with the queue, unless there are duplicated items 42 | print("Error: Key to GC does not exist.") 43 | 44 | assert mtype in types 45 | if mtype == "chat": 46 | key = _serialize_json(r["state"]) 47 | # TODO: add the string length of the last reply for analyzing voting time per character. 48 | chat_dict[key] = { 49 | "timestamp": tstamp, 50 | "start": start, 51 | "finish": finish, 52 | "conv_id": r["state"]["conv_id"], 53 | } 54 | cache_queue.append({"key": key, "timestamp": tstamp}) 55 | elif mtype in ("leftvote", "rightvote", "bothbad_vote", "tievote"): 56 | left_key = _serialize_json(r["states"][0]) 57 | right_key = _serialize_json(r["states"][1]) 58 | if left_key not in chat_dict: 59 | # TODO: this sometimes happens, it means we have the vote but we cannot find previous chat, need to investigate what happens 60 | print( 61 | f'WARNING: Cannot find vote context for conversation {r["states"][0]["conv_id"]}' 62 | ) 63 | return 64 | if right_key not in chat_dict: 65 | print( 66 | f'WARNING: Cannot find vote context for conversation {r["states"][1]["conv_id"]}' 67 | ) 68 | return 69 | vote_time_data = { 70 | "timestamp": tstamp, 71 | "type": mtype, 72 | "left": chat_dict[left_key], 73 | "right": chat_dict[right_key], 74 | "ip": ip, 75 | } 76 | return vote_time_data 77 | 78 | return None 79 | 80 | 81 | def process_file(infile: str, outfile: str): 82 | with open(infile) as f: 83 | records = [] 84 | for l in f.readlines(): 85 | l = l.strip() 86 | if l: 87 | try: 88 | r = json.loads(l) 89 | if r.get("tstamp") is not None: 90 | records.append(r) 91 | except Exception: 92 | pass 93 | # sort the record in case there are out-of-order records 94 | records.sort(key=lambda x: x["tstamp"]) 95 | 96 | with open(outfile, "a") as outfile: 97 | for r in records: 98 | try: 99 | output = process_record(r) 100 | if output is not None: 101 | outfile.write(json.dumps(output) + "\n") 102 | except Exception as e: 103 | import traceback 104 | 105 | print("Error:", e) 106 | traceback.print_exc() 107 | 108 | 109 | today = datetime.datetime.today().isoformat().split("T", 1)[0] 110 | # sort it to make sure the date is continuous for each server 111 | filelist = sorted(glob.glob("/mnt/disks/data/fastchat_logs/server*/202*-*-*-conv.json")) 112 | filelist = [ 113 | f for f in filelist if today not in f 114 | ] # skip today because date could be partial 115 | 116 | # TODO: change this to select different range of data 117 | filelist = [f for f in filelist if "2024-03-" in f] 118 | 119 | for f in tqdm.tqdm(filelist): 120 | process_file(f, "output.jsonl") 121 | -------------------------------------------------------------------------------- /fastchat/serve/monitor/vote_time_stats/plot.py: -------------------------------------------------------------------------------- 1 | import json 2 | import matplotlib.pyplot as plt 3 | import seaborn as sns 4 | import numpy as np 5 | 6 | 7 | infile = "output.jsonl" 8 | date = "2024-03" # used in the plot 9 | 10 | durations = [] 11 | 12 | with open(infile) as f: 13 | for line in f: 14 | data = json.loads(line) 15 | l = data["left"]["finish"] 16 | r = data["right"]["finish"] 17 | v = data["timestamp"] 18 | durations.append(v - max(l, r)) 19 | 20 | print( 21 | f"Avg: {np.mean(durations)}, Median: {np.median(durations)}, Max: {np.max(durations)}" 22 | ) 23 | 24 | # Define the new cutoff and number of bins 25 | cutoff = 200.0 # New cutoff value 26 | num_bins_inside_cutoff = 20 # Number of bins from 0 to cutoff 27 | 28 | for i, n in enumerate(durations): 29 | if n > cutoff: 30 | durations[i] = cutoff + 0.5 * cutoff / num_bins_inside_cutoff 31 | 32 | # Create bin edges from 0 to cutoff, with the specified number of bins 33 | bin_edges = np.linspace(0, cutoff, num_bins_inside_cutoff + 1) 34 | 35 | # Adjusting the overflow bin to end at 110 36 | overflow_cap = ( 37 | cutoff + cutoff / num_bins_inside_cutoff 38 | ) # Adjust as needed based on distribution 39 | bin_edges = np.append(bin_edges, overflow_cap) 40 | 41 | # Create the plot with custom bins 42 | sns.histplot( 43 | durations, bins=bin_edges, kde=False 44 | ) # Turn off KDE for clearer bar visibility 45 | plt.title(f'Distribution of "time to vote" {date}') 46 | plt.xlabel("Duration (seconds)") 47 | plt.ylabel("Frequency") 48 | 49 | # Highlight the overflow bin 50 | plt.axvline(x=cutoff, color="red", linestyle="--") 51 | plt.text( 52 | cutoff + 1, plt.ylim()[1] * 0.9, "Overflow", color="red", ha="left" 53 | ) # Adjust text alignment 54 | 55 | # Customizing x-axis labels to hide the "110" 56 | ax = plt.gca() # Get current axis 57 | labels = [item.get_text() for item in ax.get_xticklabels()] 58 | if "110" in labels: 59 | labels[labels.index("110")] = "" # Replace "110" with an empty string 60 | ax.set_xticklabels(labels) 61 | 62 | # Ensure nothing is cut off in the plot 63 | plt.tight_layout() 64 | 65 | # Save the plot to a file with high resolution 66 | plt.savefig(f"duration_distribution_time_to_vote_{date}.png", dpi=300) 67 | -------------------------------------------------------------------------------- /fastchat/serve/register_worker.py: -------------------------------------------------------------------------------- 1 | """ 2 | Manually register workers. 3 | 4 | Usage: 5 | python3 -m fastchat.serve.register_worker --controller http://localhost:21001 --worker-name http://localhost:21002 6 | """ 7 | 8 | import argparse 9 | 10 | import requests 11 | 12 | if __name__ == "__main__": 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument("--controller-address", type=str) 15 | parser.add_argument("--worker-name", type=str) 16 | parser.add_argument("--check-heart-beat", action="store_true") 17 | parser.add_argument("--multimodal", action="store_true") 18 | args = parser.parse_args() 19 | 20 | url = args.controller_address + "/register_worker" 21 | data = { 22 | "worker_name": args.worker_name, 23 | "check_heart_beat": args.check_heart_beat, 24 | "worker_status": None, 25 | "multimodal": args.multimodal, 26 | } 27 | r = requests.post(url, json=data) 28 | assert r.status_code == 200 29 | -------------------------------------------------------------------------------- /fastchat/serve/remote_logger.py: -------------------------------------------------------------------------------- 1 | # A JSON logger that sends data to remote endpoint. 2 | # Architecturally, it hosts a background thread that sends logs to a remote endpoint. 3 | import os 4 | import json 5 | import requests 6 | import threading 7 | import queue 8 | import logging 9 | 10 | _global_logger = None 11 | 12 | 13 | def get_remote_logger(): 14 | global _global_logger 15 | if _global_logger is None: 16 | if url := os.environ.get("REMOTE_LOGGER_URL"): 17 | logging.info(f"Remote logger enabled, sending data to {url}") 18 | _global_logger = RemoteLogger(url=url) 19 | else: 20 | _global_logger = EmptyLogger() 21 | return _global_logger 22 | 23 | 24 | class EmptyLogger: 25 | """Dummy logger that does nothing.""" 26 | 27 | def __init__(self): 28 | pass 29 | 30 | def log(self, _data: dict): 31 | pass 32 | 33 | 34 | class RemoteLogger: 35 | """A JSON logger that sends data to remote endpoint.""" 36 | 37 | def __init__(self, url: str): 38 | self.url = url 39 | 40 | self.logs = queue.Queue() 41 | self.thread = threading.Thread(target=self._send_logs, daemon=True) 42 | self.thread.start() 43 | 44 | def log(self, data: dict): 45 | self.logs.put_nowait(data) 46 | 47 | def _send_logs(self): 48 | while True: 49 | data = self.logs.get() 50 | 51 | # process the data by keep only the top level fields, and turn any nested dict into a string 52 | for key, value in data.items(): 53 | if isinstance(value, (dict, list, tuple)): 54 | data[key] = json.dumps(value, ensure_ascii=False) 55 | 56 | try: 57 | requests.post(self.url, json=data) 58 | except Exception: 59 | logging.exception("Failed to send logs to remote endpoint") 60 | -------------------------------------------------------------------------------- /fastchat/serve/shutdown_serve.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python shutdown_serve.py --down all 4 | options: "all","controller","model_worker","openai_api_server", `all` means to stop all related servers 5 | """ 6 | 7 | import argparse 8 | import os 9 | import subprocess 10 | 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument( 13 | "--down", choices=["all", "controller", "model_worker", "openai_api_server"] 14 | ) 15 | args = parser.parse_args() 16 | base_shell = "ps -eo user,pid,cmd|grep fastchat.serve{}|grep -v grep|awk '{{print $2}}'|xargs kill -9" 17 | if args.down == "all": 18 | shell_script = base_shell.format("") 19 | else: 20 | serve = f".{args.down}" 21 | shell_script = base_shell.format(serve) 22 | print(f"execute shell cmd: {shell_script}") 23 | subprocess.run(shell_script, shell=True, check=True) 24 | print(f"{args.down} has been shutdown!") 25 | -------------------------------------------------------------------------------- /fastchat/serve/test_message.py: -------------------------------------------------------------------------------- 1 | """Send a test message.""" 2 | import argparse 3 | import json 4 | 5 | import requests 6 | 7 | from fastchat.model.model_adapter import get_conversation_template 8 | 9 | 10 | def main(): 11 | model_name = args.model_name 12 | 13 | if args.worker_address: 14 | worker_addr = args.worker_address 15 | else: 16 | controller_addr = args.controller_address 17 | ret = requests.post(controller_addr + "/refresh_all_workers") 18 | ret = requests.post(controller_addr + "/list_models") 19 | models = ret.json()["models"] 20 | models.sort() 21 | print(f"Models: {models}") 22 | 23 | ret = requests.post( 24 | controller_addr + "/get_worker_address", json={"model": model_name} 25 | ) 26 | worker_addr = ret.json()["address"] 27 | print(f"worker_addr: {worker_addr}") 28 | 29 | if worker_addr == "": 30 | print(f"No available workers for {model_name}") 31 | return 32 | 33 | conv = get_conversation_template(model_name) 34 | conv.append_message(conv.roles[0], args.message) 35 | conv.append_message(conv.roles[1], None) 36 | prompt = conv.get_prompt() 37 | 38 | headers = {"User-Agent": "FastChat Client"} 39 | gen_params = { 40 | "model": model_name, 41 | "prompt": prompt, 42 | "temperature": args.temperature, 43 | "max_new_tokens": args.max_new_tokens, 44 | "stop": conv.stop_str, 45 | "stop_token_ids": conv.stop_token_ids, 46 | "echo": False, 47 | } 48 | response = requests.post( 49 | worker_addr + "/worker_generate_stream", 50 | headers=headers, 51 | json=gen_params, 52 | stream=True, 53 | ) 54 | 55 | print(f"{conv.roles[0]}: {args.message}") 56 | print(f"{conv.roles[1]}: ", end="") 57 | prev = 0 58 | for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): 59 | if chunk: 60 | data = json.loads(chunk.decode()) 61 | output = data["text"].strip() 62 | print(output[prev:], end="", flush=True) 63 | prev = len(output) 64 | print("") 65 | 66 | 67 | if __name__ == "__main__": 68 | parser = argparse.ArgumentParser() 69 | parser.add_argument( 70 | "--controller-address", type=str, default="http://localhost:21001" 71 | ) 72 | parser.add_argument("--worker-address", type=str) 73 | parser.add_argument("--model-name", type=str, required=True) 74 | parser.add_argument("--temperature", type=float, default=0.0) 75 | parser.add_argument("--max-new-tokens", type=int, default=32) 76 | parser.add_argument( 77 | "--message", type=str, default="Tell me a story with more than 1000 words." 78 | ) 79 | args = parser.parse_args() 80 | 81 | main() 82 | -------------------------------------------------------------------------------- /fastchat/serve/test_throughput.py: -------------------------------------------------------------------------------- 1 | """Benchmarking script to test the throughput of serving workers.""" 2 | import argparse 3 | import json 4 | 5 | import requests 6 | import threading 7 | import time 8 | 9 | from fastchat.conversation import get_conv_template 10 | 11 | 12 | def main(): 13 | if args.worker_address: 14 | worker_addr = args.worker_address 15 | else: 16 | controller_addr = args.controller_address 17 | ret = requests.post(controller_addr + "/refresh_all_workers") 18 | ret = requests.post(controller_addr + "/list_models") 19 | models = ret.json()["models"] 20 | models.sort() 21 | print(f"Models: {models}") 22 | 23 | ret = requests.post( 24 | controller_addr + "/get_worker_address", json={"model": args.model_name} 25 | ) 26 | worker_addr = ret.json()["address"] 27 | print(f"worker_addr: {worker_addr}") 28 | 29 | if worker_addr == "": 30 | return 31 | 32 | conv = get_conv_template("vicuna_v1.1") 33 | conv.append_message(conv.roles[0], "Tell me a story with more than 1000 words") 34 | prompt_template = conv.get_prompt() 35 | prompts = [prompt_template for _ in range(args.n_thread)] 36 | 37 | headers = {"User-Agent": "fastchat Client"} 38 | ploads = [ 39 | { 40 | "model": args.model_name, 41 | "prompt": prompts[i], 42 | "max_new_tokens": args.max_new_tokens, 43 | "temperature": 0.0, 44 | # "stop": conv.sep, 45 | } 46 | for i in range(len(prompts)) 47 | ] 48 | 49 | def send_request(results, i): 50 | if args.test_dispatch: 51 | ret = requests.post( 52 | controller_addr + "/get_worker_address", json={"model": args.model_name} 53 | ) 54 | thread_worker_addr = ret.json()["address"] 55 | else: 56 | thread_worker_addr = worker_addr 57 | print(f"thread {i} goes to {thread_worker_addr}") 58 | response = requests.post( 59 | thread_worker_addr + "/worker_generate_stream", 60 | headers=headers, 61 | json=ploads[i], 62 | stream=False, 63 | ) 64 | k = list( 65 | response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0") 66 | ) 67 | # print(k) 68 | response_new_words = json.loads(k[-2].decode("utf-8"))["text"] 69 | error_code = json.loads(k[-2].decode("utf-8"))["error_code"] 70 | # print(f"=== Thread {i} ===, words: {1}, error code: {error_code}") 71 | results[i] = len(response_new_words.split(" ")) - len(prompts[i].split(" ")) 72 | 73 | # use N threads to prompt the backend 74 | tik = time.time() 75 | threads = [] 76 | results = [None] * args.n_thread 77 | for i in range(args.n_thread): 78 | t = threading.Thread(target=send_request, args=(results, i)) 79 | t.start() 80 | # time.sleep(0.5) 81 | threads.append(t) 82 | 83 | for t in threads: 84 | t.join() 85 | 86 | print(f"Time (POST): {time.time() - tik} s") 87 | # n_words = 0 88 | # for i, response in enumerate(results): 89 | # # print(prompt[i].replace(conv.sep, "\n"), end="") 90 | # # make sure the streaming finishes at EOS or stopping criteria 91 | # k = list(response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0")) 92 | # response_new_words = json.loads(k[-2].decode("utf-8"))["text"] 93 | # # print(response_new_words) 94 | # n_words += len(response_new_words.split(" ")) - len(prompts[i].split(" ")) 95 | n_words = sum(results) 96 | time_seconds = time.time() - tik 97 | print( 98 | f"Time (Completion): {time_seconds}, n threads: {args.n_thread}, " 99 | f"throughput: {n_words / time_seconds} words/s." 100 | ) 101 | 102 | 103 | if __name__ == "__main__": 104 | parser = argparse.ArgumentParser() 105 | parser.add_argument( 106 | "--controller-address", type=str, default="http://localhost:21001" 107 | ) 108 | parser.add_argument("--worker-address", type=str) 109 | parser.add_argument("--model-name", type=str, default="vicuna") 110 | parser.add_argument("--max-new-tokens", type=int, default=2048) 111 | parser.add_argument("--n-thread", type=int, default=8) 112 | parser.add_argument("--test-dispatch", action="store_true") 113 | args = parser.parse_args() 114 | 115 | main() 116 | -------------------------------------------------------------------------------- /fastchat/serve/vision/create_vqa_examples_dir.py: -------------------------------------------------------------------------------- 1 | import datasets 2 | from datasets import load_dataset 3 | from PIL import Image 4 | from pathlib import Path 5 | import pandas as pd 6 | import os 7 | import json 8 | import tqdm 9 | import argparse 10 | import shutil 11 | import numpy as np 12 | 13 | np.random.seed(0) 14 | 15 | """ 16 | Creates a directory with images and JSON files for VQA examples. Final json is located in metadata_sampled.json 17 | """ 18 | 19 | 20 | def download_images_and_create_json( 21 | dataset_info, cache_dir="~/vqa_examples_cache", base_dir="./vqa_examples" 22 | ): 23 | for dataset_name, info in dataset_info.items(): 24 | dataset_cache_dir = os.path.join(cache_dir, dataset_name) 25 | os.makedirs(dataset_cache_dir, exist_ok=True) 26 | 27 | if info["subset"]: 28 | dataset = load_dataset( 29 | info["path"], 30 | info["subset"], 31 | cache_dir=dataset_cache_dir, 32 | split=info["split"], 33 | ) 34 | else: 35 | dataset = load_dataset( 36 | info["path"], cache_dir=dataset_cache_dir, split=info["split"] 37 | ) 38 | dataset_dir = os.path.join(base_dir, dataset_name) 39 | os.makedirs(dataset_dir, exist_ok=True) 40 | 41 | json_data = [] 42 | for i, item in enumerate(tqdm.tqdm(dataset)): 43 | id_key = i if info["id_key"] == "index" else item[info["id_key"]] 44 | image_pil = item[info["image_key"]].convert("RGB") 45 | image_path = os.path.join(dataset_dir, f"{id_key}.jpg") 46 | image_pil.save(image_path) 47 | json_entry = { 48 | "dataset": dataset_name, 49 | "question": item[info["question_key"]], 50 | "path": image_path, 51 | } 52 | json_data.append(json_entry) 53 | 54 | with open(os.path.join(dataset_dir, "data.json"), "w") as json_file: 55 | json.dump(json_data, json_file, indent=4) 56 | # Delete the cache directory for the dataset 57 | shutil.rmtree(dataset_cache_dir, ignore_errors=True) 58 | 59 | 60 | if __name__ == "__main__": 61 | parser = argparse.ArgumentParser() 62 | parser.add_argument("--data_dir", type=str, default="~/.cache") 63 | parser.add_argument("--output_dir", type=str, default="./vqa_examples") 64 | args = parser.parse_args() 65 | 66 | datasets_info = { 67 | "Memes": { 68 | "path": "not-lain/meme-dataset", 69 | "image_key": "image", 70 | "question_key": "name", 71 | "id_key": "index", 72 | "subset": False, 73 | "split": "train", 74 | }, 75 | "Floorplan": { 76 | "path": "umesh16071973/Floorplan_Dataset_21022024", 77 | "image_key": "image", 78 | "question_key": "caption", 79 | "id_key": "index", 80 | "subset": False, 81 | "split": "train", 82 | }, 83 | "Website": { 84 | "path": "Zexanima/website_screenshots_image_dataset", 85 | "image_key": "image", 86 | "question_key": "date_captured", 87 | "id_key": "index", 88 | "subset": False, 89 | "split": "train", 90 | }, 91 | "IllusionVQA": { 92 | "path": "csebuetnlp/illusionVQA-Comprehension", 93 | "image_key": "image", 94 | "question_key": "question", 95 | "id_key": "index", 96 | "subset": False, 97 | "split": "test", 98 | }, 99 | "NewYorker": { 100 | "path": "jmhessel/newyorker_caption_contest", 101 | "image_key": "image", 102 | "question_key": "questions", 103 | "id_key": "index", 104 | "subset": "explanation", 105 | "split": "train", 106 | }, 107 | } 108 | 109 | download_images_and_create_json( 110 | datasets_info, cache_dir=args.data_dir, base_dir=args.output_dir 111 | ) 112 | dataset_json = [] 113 | for dataset_name in datasets_info.keys(): 114 | with open(f"{args.output_dir}/{dataset_name}/data.json") as f: 115 | data = json.load(f) 116 | print(f"Dataset: {dataset_name}, Number of examples: {len(data)}") 117 | dataset_json.extend(data) 118 | 119 | with open(f"{args.output_dir}/metadata_sampled.json", "w") as f: 120 | json.dump(dataset_json, f, indent=4) 121 | -------------------------------------------------------------------------------- /fastchat/serve/vision/create_vqa_examples_json.py: -------------------------------------------------------------------------------- 1 | """ 2 | Changes proportion of examples in metadata_sampled.json 3 | 4 | Usage: 5 | 6 | python3 -m fastchat.serve.vision.create_vqa_examples_json 7 | """ 8 | 9 | import json 10 | import argparse 11 | import numpy as np 12 | 13 | if __name__ == "__main__": 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument("--data_dir", type=str, default="~/.cache") 16 | parser.add_argument("--output_dir", type=str, default="./vqa_examples") 17 | args = parser.parse_args() 18 | 19 | dataset_prop = { 20 | "Memes": 500, 21 | "Floorplan": 500, 22 | "Website": 500, 23 | "IllusionVQA": 435, 24 | "NewYorker": 500, 25 | } 26 | 27 | dataset_json = [] 28 | for dataset_name in dataset_prop.keys(): 29 | with open(f"{args.output_dir}/{dataset_name}/data.json") as f: 30 | data = json.load(f) 31 | dataset_json.extend( 32 | np.random.choice( 33 | data, min(dataset_prop[dataset_name], len(data)), replace=False 34 | ) 35 | ) 36 | 37 | with open(f"{args.output_dir}/metadata_sampled.json", "w") as f: 38 | json.dump(dataset_json, f, indent=4) 39 | -------------------------------------------------------------------------------- /fastchat/train/train_mem.py: -------------------------------------------------------------------------------- 1 | # Make it more memory efficient by monkey patching the LLaMA model with FlashAttn. 2 | 3 | # Need to call this before importing transformers. 4 | from fastchat.train.llama2_flash_attn_monkey_patch import ( 5 | replace_llama_attn_with_flash_attn, 6 | ) 7 | 8 | replace_llama_attn_with_flash_attn() 9 | 10 | from fastchat.train.train import train 11 | 12 | if __name__ == "__main__": 13 | train() 14 | -------------------------------------------------------------------------------- /fastchat/train/train_xformers.py: -------------------------------------------------------------------------------- 1 | # Make it more memory efficient by monkey patching the LLaMA model with xformers attention. 2 | 3 | # Need to call this before importing transformers. 4 | from fastchat.train.llama_xformers_attn_monkey_patch import ( 5 | replace_llama_attn_with_xformers_attn, 6 | ) 7 | 8 | replace_llama_attn_with_xformers_attn() 9 | 10 | from fastchat.train.train import train 11 | 12 | if __name__ == "__main__": 13 | train() 14 | -------------------------------------------------------------------------------- /format.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Adapted from https://github.com/skypilot-org/skypilot/blob/master/format.sh 4 | 5 | # Cause the script to exit if a single command fails 6 | set -eo pipefail 7 | 8 | # this stops git rev-parse from failing if we run this from the .git directory 9 | builtin cd "$(dirname "${BASH_SOURCE:-$0}")" 10 | ROOT="$(git rev-parse --show-toplevel)" 11 | builtin cd "$ROOT" || exit 1 12 | 13 | BLACK_VERSION=$(black --version | head -n 1 | awk '{print $2}') 14 | PYLINT_VERSION=$(pylint --version | head -n 1 | awk '{print $2}') 15 | 16 | # # params: tool name, tool version, required version 17 | tool_version_check() { 18 | if [[ $2 != $3 ]]; then 19 | echo "Wrong $1 version installed: $3 is required, not $2." 20 | exit 1 21 | fi 22 | } 23 | 24 | tool_version_check "black" $BLACK_VERSION "23.3.0" 25 | tool_version_check "pylint" $PYLINT_VERSION "2.8.2" 26 | 27 | # Format files that differ from main branch. Ignores dirs that are not slated 28 | # for autoformat yet. 29 | format_changed() { 30 | # The `if` guard ensures that the list of filenames is not empty, which 31 | # could cause yapf to receive 0 positional arguments, making it hang 32 | # waiting for STDIN. 33 | # 34 | # `diff-filter=ACM` and $MERGEBASE is to ensure we only format files that 35 | # exist on both branches. 36 | MERGEBASE="$(git merge-base origin/main HEAD)" 37 | 38 | if ! git diff --diff-filter=ACM --quiet --exit-code "$MERGEBASE" -- '*.py' '*.pyi' &>/dev/null; then 39 | git diff --name-only --diff-filter=ACM "$MERGEBASE" -- '*.py' '*.pyi' | xargs -P 5 black 40 | fi 41 | } 42 | 43 | ## This flag formats individual files. --files *must* be the first command line 44 | ## arg to use this option. 45 | if [[ "$1" == '--files' ]]; then 46 | black "${@:2}" 47 | # If `--all` is passed, then any further arguments are ignored and the 48 | # entire python directory is formatted. 49 | elif [[ "$1" == '--all' ]]; then 50 | # Format all files 51 | black fastchat 52 | else 53 | # Format only the files that changed in last commit. 54 | format_changed 55 | fi 56 | echo 'FastChat Black: Done' 57 | 58 | # Run Pylint 59 | echo 'FastChat Pylint:' 60 | pylint fastchat 61 | # TODO(suquark): disable 'pylint_quotes' for now due to too many inconsistent quotes 62 | # pylint --load-plugins pylint_quotes fastchat 63 | 64 | if ! git diff --quiet &>/dev/null; then 65 | echo 'Reformatted files. Please review and stage the changes.' 66 | echo 'Changes not staged for commit:' 67 | echo 68 | git --no-pager diff --name-only 69 | 70 | exit 1 71 | fi 72 | -------------------------------------------------------------------------------- /playground/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lm-sys/FastChat/9a295b64ce3491ff15901f2d00f5e304b0ee78dc/playground/__init__.py -------------------------------------------------------------------------------- /playground/deepspeed_config_s2.json: -------------------------------------------------------------------------------- 1 | { 2 | "zero_optimization": { 3 | "stage": 2, 4 | "offload_optimizer": { 5 | "device": "cpu" 6 | }, 7 | "contiguous_gradients": true, 8 | "overlap_comm": true 9 | }, 10 | "fp16": { 11 | "enabled": "auto" 12 | }, 13 | "train_micro_batch_size_per_gpu": "auto", 14 | "gradient_accumulation_steps": "auto" 15 | } -------------------------------------------------------------------------------- /playground/deepspeed_config_s3.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "zero_optimization": { 11 | "stage": 3, 12 | "offload_optimizer": { 13 | "device": "cpu", 14 | "pin_memory": true 15 | }, 16 | "offload_param": { 17 | "device": "cpu", 18 | "pin_memory": true 19 | }, 20 | "overlap_comm": true, 21 | "contiguous_gradients": true, 22 | "stage3_max_live_parameters" : 1e9, 23 | "stage3_max_reuse_distance" : 1e9, 24 | "stage3_prefetch_bucket_size" : 5e8, 25 | "stage3_param_persistence_threshold" : 1e6, 26 | "sub_group_size" : 1e12, 27 | "stage3_gather_16bit_weights_on_model_save": true 28 | }, 29 | "train_batch_size": "auto", 30 | "train_micro_batch_size_per_gpu": "auto", 31 | "gradient_accumulation_steps": "auto" 32 | } -------------------------------------------------------------------------------- /playground/test_embedding/README.md: -------------------------------------------------------------------------------- 1 | ## Machine Learning with Embeddings 2 | You can use embeddings to 3 | - Evaluate text similarity, see [test_sentence_similarity.py](test_sentence_similarity.py) 4 | - Build your own classifier, see [test_classification.py](test_classification.py) 5 | - Search relative texts, see [test_semantic_search.py](test_semantic_search.py) 6 | 7 | To these tests, you need to download the data [here](https://www.kaggle.com/datasets/snap/amazon-fine-food-reviews). You also need an OpenAI API key for comparison. 8 | 9 | Run with: 10 | ```bash 11 | cd playground/test_embedding 12 | python3 test_classification.py 13 | ``` 14 | 15 | The script will train classifiers based on `vicuna-7b`, `text-similarity-ada-001` and `text-embedding-ada-002` and report the accuracy of each classifier. 16 | -------------------------------------------------------------------------------- /playground/test_embedding/test_classification.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | import numpy as np 5 | import openai 6 | import pandas as pd 7 | import requests 8 | from sklearn.ensemble import RandomForestClassifier 9 | from sklearn.model_selection import train_test_split 10 | from sklearn.metrics import classification_report, accuracy_score 11 | 12 | 13 | np.set_printoptions(threshold=10000) 14 | 15 | 16 | def get_embedding_from_api(word, model="vicuna-7b-v1.1"): 17 | if "ada" in model: 18 | resp = openai.Embedding.create( 19 | model=model, 20 | input=word, 21 | ) 22 | embedding = np.array(resp["data"][0]["embedding"]) 23 | return embedding 24 | 25 | url = "http://localhost:8000/v1/embeddings" 26 | headers = {"Content-Type": "application/json"} 27 | data = json.dumps({"model": model, "input": word}) 28 | 29 | response = requests.post(url, headers=headers, data=data) 30 | if response.status_code == 200: 31 | embedding = np.array(response.json()["data"][0]["embedding"]) 32 | return embedding 33 | else: 34 | print(f"Error: {response.status_code} - {response.text}") 35 | return None 36 | 37 | 38 | def create_embedding_data_frame(data_path, model, max_tokens=500): 39 | df = pd.read_csv(data_path, index_col=0) 40 | df = df[["Time", "ProductId", "UserId", "Score", "Summary", "Text"]] 41 | df = df.dropna() 42 | df["combined"] = ( 43 | "Title: " + df.Summary.str.strip() + "; Content: " + df.Text.str.strip() 44 | ) 45 | top_n = 1000 46 | df = df.sort_values("Time").tail(top_n * 2) 47 | df.drop("Time", axis=1, inplace=True) 48 | 49 | df["n_tokens"] = df.combined.apply(lambda x: len(x)) 50 | df = df[df.n_tokens <= max_tokens].tail(top_n) 51 | df["embedding"] = df.combined.apply(lambda x: get_embedding_from_api(x, model)) 52 | return df 53 | 54 | 55 | def train_random_forest(df): 56 | X_train, X_test, y_train, y_test = train_test_split( 57 | list(df.embedding.values), df.Score, test_size=0.2, random_state=42 58 | ) 59 | 60 | clf = RandomForestClassifier(n_estimators=100) 61 | clf.fit(X_train, y_train) 62 | preds = clf.predict(X_test) 63 | 64 | report = classification_report(y_test, preds) 65 | accuracy = accuracy_score(y_test, preds) 66 | return clf, accuracy, report 67 | 68 | 69 | input_datapath = "amazon_fine_food_review.csv" 70 | if not os.path.exists(input_datapath): 71 | raise Exception( 72 | f"Please download data from: https://www.kaggle.com/datasets/snap/amazon-fine-food-reviews" 73 | ) 74 | 75 | df = create_embedding_data_frame(input_datapath, "vicuna-7b-v1.1") 76 | clf, accuracy, report = train_random_forest(df) 77 | print(f"Vicuna-7b-v1.1 accuracy:{accuracy}") 78 | df = create_embedding_data_frame(input_datapath, "text-similarity-ada-001") 79 | clf, accuracy, report = train_random_forest(df) 80 | print(f"text-similarity-ada-001 accuracy:{accuracy}") 81 | df = create_embedding_data_frame(input_datapath, "text-embedding-ada-002") 82 | clf, accuracy, report = train_random_forest(df) 83 | print(f"text-embedding-ada-002 accuracy:{accuracy}") 84 | -------------------------------------------------------------------------------- /playground/test_embedding/test_semantic_search.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | import numpy as np 5 | import openai 6 | import pandas as pd 7 | import requests 8 | from scipy.spatial.distance import cosine 9 | 10 | 11 | def cosine_similarity(vec1, vec2): 12 | try: 13 | return 1 - cosine(vec1, vec2) 14 | except: 15 | print(vec1.shape, vec2.shape) 16 | 17 | 18 | def get_embedding_from_api(word, model="vicuna-7b-v1.1"): 19 | if "ada" in model: 20 | resp = openai.Embedding.create( 21 | model=model, 22 | input=word, 23 | ) 24 | embedding = np.array(resp["data"][0]["embedding"]) 25 | return embedding 26 | 27 | url = "http://localhost:8000/v1/embeddings" 28 | headers = {"Content-Type": "application/json"} 29 | data = json.dumps({"model": model, "input": word}) 30 | 31 | response = requests.post(url, headers=headers, data=data) 32 | if response.status_code == 200: 33 | embedding = np.array(response.json()["data"][0]["embedding"]) 34 | return embedding 35 | else: 36 | print(f"Error: {response.status_code} - {response.text}") 37 | return None 38 | 39 | 40 | def create_embedding_data_frame(data_path, model, max_tokens=500): 41 | df = pd.read_csv(data_path, index_col=0) 42 | df = df[["Time", "ProductId", "UserId", "Score", "Summary", "Text"]] 43 | df = df.dropna() 44 | df["combined"] = ( 45 | "Title: " + df.Summary.str.strip() + "; Content: " + df.Text.str.strip() 46 | ) 47 | top_n = 1000 48 | df = df.sort_values("Time").tail(top_n * 2) 49 | df.drop("Time", axis=1, inplace=True) 50 | 51 | df["n_tokens"] = df.combined.apply(lambda x: len(x)) 52 | df = df[df.n_tokens <= max_tokens].tail(top_n) 53 | df["embedding"] = df.combined.apply(lambda x: get_embedding_from_api(x, model)) 54 | return df 55 | 56 | 57 | def search_reviews(df, product_description, n=3, pprint=False, model="vicuna-7b-v1.1"): 58 | product_embedding = get_embedding_from_api(product_description, model=model) 59 | df["similarity"] = df.embedding.apply( 60 | lambda x: cosine_similarity(x, product_embedding) 61 | ) 62 | 63 | results = ( 64 | df.sort_values("similarity", ascending=False) 65 | .head(n) 66 | .combined.str.replace("Title: ", "") 67 | .str.replace("; Content:", ": ") 68 | ) 69 | if pprint: 70 | for r in results: 71 | print(r[:200]) 72 | print() 73 | return results 74 | 75 | 76 | def print_model_search(input_path, model): 77 | print(f"Model: {model}") 78 | df = create_embedding_data_frame(input_path, model) 79 | print("search: delicious beans") 80 | results = search_reviews(df, "delicious beans", n=5, model=model) 81 | print(results) 82 | print("search: whole wheat pasta") 83 | results = search_reviews(df, "whole wheat pasta", n=5, model=model) 84 | print(results) 85 | print("search: bad delivery") 86 | results = search_reviews(df, "bad delivery", n=5, model=model) 87 | print(results) 88 | 89 | 90 | input_datapath = "amazon_fine_food_review.csv" 91 | if not os.path.exists(input_datapath): 92 | raise Exception( 93 | f"Please download data from: https://www.kaggle.com/datasets/snap/amazon-fine-food-reviews" 94 | ) 95 | 96 | 97 | print_model_search(input_datapath, "vicuna-7b-v1.1") 98 | print_model_search(input_datapath, "text-similarity-ada-001") 99 | print_model_search(input_datapath, "text-embedding-ada-002") 100 | -------------------------------------------------------------------------------- /playground/test_embedding/test_sentence_similarity.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | import numpy as np 5 | import openai 6 | import requests 7 | from scipy.spatial.distance import cosine 8 | 9 | 10 | def get_embedding_from_api(word, model="vicuna-7b-v1.5"): 11 | if "ada" in model: 12 | resp = openai.Embedding.create( 13 | model=model, 14 | input=word, 15 | ) 16 | embedding = np.array(resp["data"][0]["embedding"]) 17 | return embedding 18 | 19 | url = "http://localhost:8000/v1/embeddings" 20 | headers = {"Content-Type": "application/json"} 21 | data = json.dumps({"model": model, "input": word}) 22 | 23 | response = requests.post(url, headers=headers, data=data) 24 | if response.status_code == 200: 25 | embedding = np.array(response.json()["data"][0]["embedding"]) 26 | return embedding 27 | else: 28 | print(f"Error: {response.status_code} - {response.text}") 29 | return None 30 | 31 | 32 | def cosine_similarity(vec1, vec2): 33 | return 1 - cosine(vec1, vec2) 34 | 35 | 36 | def print_cosine_similarity(embeddings, texts): 37 | for i in range(len(texts)): 38 | for j in range(i + 1, len(texts)): 39 | sim = cosine_similarity(embeddings[texts[i]], embeddings[texts[j]]) 40 | print(f"Cosine similarity between '{texts[i]}' and '{texts[j]}': {sim:.2f}") 41 | 42 | 43 | texts = [ 44 | "The quick brown fox", 45 | "The quick brown dog", 46 | "The fast brown fox", 47 | "A completely different sentence", 48 | ] 49 | 50 | embeddings = {} 51 | for text in texts: 52 | embeddings[text] = get_embedding_from_api(text) 53 | 54 | print("Vicuna-7B:") 55 | print_cosine_similarity(embeddings, texts) 56 | 57 | for text in texts: 58 | embeddings[text] = get_embedding_from_api(text, model="text-similarity-ada-001") 59 | 60 | print("text-similarity-ada-001:") 61 | print_cosine_similarity(embeddings, texts) 62 | 63 | for text in texts: 64 | embeddings[text] = get_embedding_from_api(text, model="text-embedding-ada-002") 65 | 66 | print("text-embedding-ada-002:") 67 | print_cosine_similarity(embeddings, texts) 68 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61.0"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "fschat" 7 | version = "0.2.36" 8 | description = "An open platform for training, serving, and evaluating large language model based chatbots." 9 | readme = "README.md" 10 | requires-python = ">=3.8" 11 | classifiers = [ 12 | "Programming Language :: Python :: 3", 13 | "License :: OSI Approved :: Apache Software License", 14 | ] 15 | dependencies = [ 16 | "aiohttp", "fastapi", "httpx", "markdown2[all]", "nh3", "numpy", 17 | "prompt_toolkit>=3.0.0", "pydantic<3,>=2.0.0", "pydantic-settings", "psutil", "requests", "rich>=10.0.0", 18 | "shortuuid", "tiktoken", "uvicorn", 19 | ] 20 | 21 | [project.optional-dependencies] 22 | model_worker = ["accelerate>=0.21", "peft", "sentencepiece", "torch", "transformers>=4.31.0", "protobuf", "openai", "anthropic"] 23 | webui = ["gradio>=4.10", "plotly", "scipy"] 24 | train = ["einops", "flash-attn>=2.0", "wandb"] 25 | llm_judge = ["openai<1", "anthropic>=0.3", "ray"] 26 | dev = ["black==23.3.0", "pylint==2.8.2"] 27 | 28 | [project.urls] 29 | "Homepage" = "https://github.com/lm-sys/fastchat" 30 | "Bug Tracker" = "https://github.com/lm-sys/fastchat/issues" 31 | 32 | [tool.setuptools.packages.find] 33 | exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"] 34 | 35 | [tool.wheel] 36 | exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"] 37 | -------------------------------------------------------------------------------- /scripts/build-api.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # A rather convenient script for spinning up models behind screens 3 | 4 | 5 | # Variables 6 | PROJECT_DIR="$(pwd)" 7 | CONDA_ENV_NAME="fastchat" # 8 | 9 | MODEL_PATH="HuggingFaceH4/zephyr-7b-beta" #beta is better than the alpha version, base model w/o quantization 10 | MODEL_PATH="lmsys/vicuna-7b-v1.5" 11 | 12 | API_HOST="0.0.0.0" 13 | API_PORT_NUMBER=8000 14 | 15 | 16 | # init the screens 17 | check_and_create_screen() { 18 | local SCREENNAME="$1" 19 | if screen -list | grep -q "$SCREENNAME"; then 20 | echo "Screen session '$SCREENNAME' exists. Doing nothing." 21 | else 22 | echo "Screen session '$SCREENNAME' not found. Creating..." 23 | screen -d -m -S "$SCREENNAME" 24 | echo "created!" 25 | fi 26 | } 27 | 28 | # convenience function for sending commands to named screens 29 | send_cmd() { 30 | local SCREENNAME="$1" 31 | local CMD="$2" 32 | screen -DRRS $SCREENNAME -X stuff '$2 \r' 33 | } 34 | 35 | # hardcoded names, for baby api 36 | SCREENNAMES=( 37 | "controller" 38 | "api" 39 | # Worker screens include the devices they are bound to, if 'd0' is only worker it has full GPU access 40 | "worker-d0" 41 | "worker-d1" 42 | ) 43 | 44 | for screen in "${SCREENNAMES[@]}"; do 45 | check_and_create_screen "$screen" 46 | sleep 0.1 47 | # also activate the conda compute environment for these 48 | screen -DRRS "$screen" -X stuff "conda deactivate \r" 49 | screen -DRRS "$screen" -X stuff "conda activate $CONDA_ENV_NAME \r" 50 | 51 | done 52 | 53 | 54 | # Send Commmands on a per Screen Basis 55 | screen -DRRS controller -X stuff "python3 -m fastchat.serve.controller \r" 56 | 57 | screen -DRRS worker-d0 -X stuff "CUDA_VISIBLE_DEVICES=0 python3 -m fastchat.serve.model_worker --model-path $MODEL_PATH --conv-template one_shot --limit-worker-concurrency 1 \r" 58 | screen -DRRS worker-d1 -X stuff "CUDA_VISIBLE_DEVICES=1 python3 -m fastchat.serve.model_worker --model-path $MODEL_PATH --port 21003 --worker-address http://localhost:21003 --conv-template one_shot --limit-worker-concurrency 1 \r" 59 | 60 | screen -DRRS api -X stuff "python3 -m fastchat.serve.openai_api_server --host $API_HOST --port $API_PORT_NUMBER \r" 61 | -------------------------------------------------------------------------------- /scripts/test_readme_train.sh: -------------------------------------------------------------------------------- 1 | torchrun --nproc_per_node=4 --master_port=20001 fastchat/train/train_mem.py \ 2 | --model_name_or_path meta-llama/Llama-2-7b-hf \ 3 | --data_path data/dummy_conversation.json \ 4 | --bf16 True \ 5 | --output_dir output_vicuna \ 6 | --num_train_epochs 3 \ 7 | --per_device_train_batch_size 2 \ 8 | --per_device_eval_batch_size 2 \ 9 | --gradient_accumulation_steps 16 \ 10 | --evaluation_strategy "no" \ 11 | --save_strategy "steps" \ 12 | --save_steps 1200 \ 13 | --save_total_limit 10 \ 14 | --learning_rate 2e-5 \ 15 | --weight_decay 0. \ 16 | --warmup_ratio 0.03 \ 17 | --lr_scheduler_type "cosine" \ 18 | --logging_steps 1 \ 19 | --fsdp "full_shard auto_wrap" \ 20 | --fsdp_transformer_layer_cls_to_wrap 'LlamaDecoderLayer' \ 21 | --tf32 True \ 22 | --model_max_length 2048 \ 23 | --gradient_checkpointing True \ 24 | --lazy_preprocess True 25 | -------------------------------------------------------------------------------- /scripts/train_lora.sh: -------------------------------------------------------------------------------- 1 | deepspeed fastchat/train/train_lora.py \ 2 | --model_name_or_path lmsys/vicuna-7b-v1.5 \ 3 | --lora_r 8 \ 4 | --lora_alpha 16 \ 5 | --lora_dropout 0.05 \ 6 | --data_path $DATA_PATH \ 7 | --output_dir ./checkpoints \ 8 | --num_train_epochs 150 \ 9 | --fp16 True \ 10 | --per_device_train_batch_size 2 \ 11 | --per_device_eval_batch_size 2 \ 12 | --gradient_accumulation_steps 1 \ 13 | --evaluation_strategy "steps" \ 14 | --eval_steps 100 \ 15 | --save_strategy "steps" \ 16 | --save_steps 200 \ 17 | --save_total_limit 2 \ 18 | --learning_rate 2e-5 \ 19 | --weight_decay 0. \ 20 | --warmup_ratio 0.03 \ 21 | --lr_scheduler_type "cosine" \ 22 | --logging_strategy "steps" \ 23 | --logging_steps 1 \ 24 | --tf32 True \ 25 | --model_max_length 2048 \ 26 | --q_lora False \ 27 | --deepspeed $PATH_TO_DEEPSPEED_CONFIG \ 28 | --gradient_checkpointing True \ 29 | --flash_attn False 30 | -------------------------------------------------------------------------------- /scripts/train_vicuna_13b.sh: -------------------------------------------------------------------------------- 1 | torchrun --nproc_per_node=8 --master_port=20001 fastchat/train/train_mem.py \ 2 | --model_name_or_path ~/model_weights/llama-13b \ 3 | --data_path ~/datasets/sharegpt_20230422_clean_lang_split_identity.json \ 4 | --bf16 True \ 5 | --output_dir output_vicuna_13b \ 6 | --num_train_epochs 3 \ 7 | --per_device_train_batch_size 4 \ 8 | --per_device_eval_batch_size 32 \ 9 | --gradient_accumulation_steps 4 \ 10 | --evaluation_strategy "steps" \ 11 | --eval_steps 1500 \ 12 | --save_strategy "steps" \ 13 | --save_steps 1500 \ 14 | --save_total_limit 8 \ 15 | --learning_rate 2e-5 \ 16 | --weight_decay 0. \ 17 | --warmup_ratio 0.04 \ 18 | --lr_scheduler_type "cosine" \ 19 | --logging_steps 1 \ 20 | --fsdp "full_shard auto_wrap offload" \ 21 | --fsdp_transformer_layer_cls_to_wrap 'LlamaDecoderLayer' \ 22 | --tf32 True \ 23 | --model_max_length 2048 \ 24 | --gradient_checkpointing True \ 25 | --lazy_preprocess True 26 | 27 | -------------------------------------------------------------------------------- /scripts/train_vicuna_7b.sh: -------------------------------------------------------------------------------- 1 | torchrun --nproc_per_node=4 --master_port=20001 fastchat/train/train_mem.py \ 2 | --model_name_or_path ~/model_weights/llama-7b \ 3 | --data_path ~/datasets/sharegpt_20230422_clean_lang_split_identity.json \ 4 | --bf16 True \ 5 | --output_dir output_vicuna_7b \ 6 | --num_train_epochs 3 \ 7 | --per_device_train_batch_size 2 \ 8 | --per_device_eval_batch_size 16 \ 9 | --gradient_accumulation_steps 16 \ 10 | --evaluation_strategy "steps" \ 11 | --eval_steps 1500 \ 12 | --save_strategy "steps" \ 13 | --save_steps 1500 \ 14 | --save_total_limit 8 \ 15 | --learning_rate 2e-5 \ 16 | --weight_decay 0. \ 17 | --warmup_ratio 0.04 \ 18 | --lr_scheduler_type "cosine" \ 19 | --logging_steps 1 \ 20 | --fsdp "full_shard auto_wrap" \ 21 | --fsdp_transformer_layer_cls_to_wrap 'LlamaDecoderLayer' \ 22 | --tf32 True \ 23 | --model_max_length 2048 \ 24 | --gradient_checkpointing True \ 25 | --lazy_preprocess True 26 | 27 | -------------------------------------------------------------------------------- /scripts/upload_pypi.sh: -------------------------------------------------------------------------------- 1 | rm -rf dist 2 | python3 -m build 3 | python3 -m twine upload dist/* 4 | -------------------------------------------------------------------------------- /tests/README.md: -------------------------------------------------------------------------------- 1 | ## Unit tests for FastChat 2 | 3 | ### Test CLI Inference 4 | 5 | ``` 6 | python3 test_cli.py 7 | ``` 8 | 9 | ### Test OpenAI API Server 10 | 11 | ``` 12 | python3 launch_openai_api_test_server.py 13 | ``` 14 | 15 | ``` 16 | python3 test_openai_api.py 17 | ``` 18 | 19 | ### Test GUI Serving 20 | 21 | ``` 22 | python3 -m fastchat.serve.controller 23 | ``` 24 | 25 | ``` 26 | CUDA_VISIBLE_DEVICES=0,1 python3 -m fastchat.serve.model_worker --model-path ~/model_weights/koala-13b --num-gpus 2 --port 30000 --worker http://localhost:30000 27 | CUDA_VISIBLE_DEVICES=2,3 python3 -m fastchat.serve.model_worker --model-path ~/model_weights/alpaca-13b --num-gpus 2 --port 30002 --worker http://localhost:30002 28 | CUDA_VISIBLE_DEVICES=4,5 python3 -m fastchat.serve.model_worker --model-path ~/model_weights/vicuna-13b --port 30004 --worker http://localhost:30004 --num-gpus 2 29 | CUDA_VISIBLE_DEVICES=6,7 python3 -m fastchat.serve.model_worker --model-path OpenAssistant/oasst-sft-1-pythia-12b --port 30006 --worker http://localhost:30006 --num-gpus 2 30 | 31 | CUDA_VISIBLE_DEVICES=0,1 python3 -m fastchat.serve.model_worker --model-path StabilityAI/stablelm-tuned-alpha-7b --num-gpus 2 --port 30000 --worker http://localhost:30000 32 | CUDA_VISIBLE_DEVICES=2,3 python3 -m fastchat.serve.model_worker --model-path databricks/dolly-v2-12b --num-gpus 2 --port 30002 --worker http://localhost:30002 33 | CUDA_VISIBLE_DEVICES=4 python3 -m fastchat.serve.model_worker --model-path THUDM/chatglm-6b --port 30004 --worker http://localhost:30004 34 | CUDA_VISIBLE_DEVICES=5 python3 -m fastchat.serve.model_worker --model-path lmsys/fastchat-t5-3b-v1.0 --port 30005 --worker http://localhost:30005 35 | CUDA_VISIBLE_DEVICES=6 python3 -m fastchat.serve.model_worker --model-path ~/model_weights/baize-7b --port 30006 --worker http://localhost:30006 36 | CUDA_VISIBLE_DEVICES=7 python3 -m fastchat.serve.model_worker --model-path ~/model_weights/RWKV-4-Raven-7B-v11x-Eng99%-Other1%-20230429-ctx8192.pth --port 30007 --worker http://localhost:30007 37 | ``` 38 | 39 | ``` 40 | python3 -m fastchat.serve.gradio_web_server_multi 41 | ``` 42 | 43 | ### Test Peft Serving 44 | 45 | ``` 46 | python3 -m fastchat.serve.controller 47 | ``` 48 | 49 | ``` 50 | PEFT_SHARE_BASE_WEIGHTS=true python3 -m fastchat.serve.multi_model_worker \ 51 | --model-path SurfaceData/dummy_pythia160m_lora16_peft_chat \ 52 | --model-path SurfaceData/dummy_pythia160m_lora8_peft_chat 53 | ``` 54 | -------------------------------------------------------------------------------- /tests/killall_python.sh: -------------------------------------------------------------------------------- 1 | kill -9 $(ps aux | grep 'python' | grep 'fastchat' | grep -v 'grep' | awk '{print $2}') 2 | -------------------------------------------------------------------------------- /tests/launch_openai_api_test_server.py: -------------------------------------------------------------------------------- 1 | """ 2 | Launch an OpenAI API server with multiple model workers. 3 | """ 4 | import os 5 | import argparse 6 | 7 | 8 | def launch_process(cmd): 9 | os.popen(cmd) 10 | 11 | 12 | if __name__ == "__main__": 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument("--multimodal", action="store_true", default=False) 15 | args = parser.parse_args() 16 | 17 | launch_process("python3 -m fastchat.serve.controller") 18 | launch_process("python3 -m fastchat.serve.openai_api_server") 19 | 20 | if args.multimodal: 21 | models = [ 22 | ("liuhaotian/llava-v1.5-7b", "sglang_worker"), 23 | ] 24 | else: 25 | models = [ 26 | ("lmsys/vicuna-7b-v1.5", "model_worker"), 27 | ("lmsys/fastchat-t5-3b-v1.0", "model_worker"), 28 | ("THUDM/chatglm-6b", "model_worker"), 29 | ("mosaicml/mpt-7b-chat", "model_worker"), 30 | ("meta-llama/Llama-2-7b-chat-hf", "vllm_worker"), 31 | ] 32 | 33 | for i, (model_path, worker_name) in enumerate(models): 34 | cmd = ( 35 | f"CUDA_VISIBLE_DEVICES={i} python3 -m fastchat.serve.{worker_name} " 36 | f"--model-path {model_path} --port {40000+i} " 37 | f"--worker-address http://localhost:{40000+i} " 38 | ) 39 | 40 | if "llava" in model_path.lower(): 41 | cmd += f"--tokenizer-path llava-hf/llava-1.5-7b-hf" 42 | 43 | if worker_name == "vllm_worker": 44 | cmd += "--tokenizer hf-internal-testing/llama-tokenizer" 45 | 46 | launch_process(cmd) 47 | 48 | while True: 49 | pass 50 | -------------------------------------------------------------------------------- /tests/test_cli.py: -------------------------------------------------------------------------------- 1 | """Test command line interface for model inference.""" 2 | import argparse 3 | import os 4 | 5 | from fastchat.utils import run_cmd 6 | 7 | 8 | def test_single_gpu(): 9 | models = [ 10 | "lmsys/vicuna-7b-v1.5", 11 | "lmsys/longchat-7b-16k", 12 | "lmsys/fastchat-t5-3b-v1.0", 13 | "meta-llama/Llama-2-7b-chat-hf", 14 | "THUDM/chatglm-6b", 15 | "THUDM/chatglm2-6b", 16 | "mosaicml/mpt-7b-chat", 17 | "tiiuae/falcon-7b-instruct", 18 | "~/model_weights/alpaca-7b", 19 | "~/model_weights/RWKV-4-Raven-7B-v11x-Eng99%-Other1%-20230429-ctx8192.pth", 20 | ] 21 | 22 | for model_path in models: 23 | if "model_weights" in model_path and not os.path.exists( 24 | os.path.expanduser(model_path) 25 | ): 26 | continue 27 | cmd = ( 28 | f"python3 -m fastchat.serve.cli --model-path {model_path} " 29 | f"--style programmatic < test_cli_inputs.txt" 30 | ) 31 | ret = run_cmd(cmd) 32 | if ret != 0: 33 | return 34 | 35 | print("") 36 | 37 | 38 | def test_multi_gpu(): 39 | models = [ 40 | "lmsys/vicuna-13b-v1.3", 41 | ] 42 | 43 | for model_path in models: 44 | cmd = ( 45 | f"python3 -m fastchat.serve.cli --model-path {model_path} " 46 | f"--style programmatic --num-gpus 2 --max-gpu-memory 14Gib < test_cli_inputs.txt" 47 | ) 48 | ret = run_cmd(cmd) 49 | if ret != 0: 50 | return 51 | print("") 52 | 53 | 54 | def test_8bit(): 55 | models = [ 56 | "lmsys/vicuna-13b-v1.3", 57 | ] 58 | 59 | for model_path in models: 60 | cmd = ( 61 | f"python3 -m fastchat.serve.cli --model-path {model_path} " 62 | f"--style programmatic --load-8bit < test_cli_inputs.txt" 63 | ) 64 | ret = run_cmd(cmd) 65 | if ret != 0: 66 | return 67 | print("") 68 | 69 | 70 | def test_hf_api(): 71 | models = [ 72 | "lmsys/vicuna-7b-v1.5", 73 | "lmsys/fastchat-t5-3b-v1.0", 74 | ] 75 | 76 | for model_path in models: 77 | cmd = f"python3 -m fastchat.serve.huggingface_api --model-path {model_path}" 78 | ret = run_cmd(cmd) 79 | if ret != 0: 80 | return 81 | print("") 82 | 83 | 84 | if __name__ == "__main__": 85 | test_single_gpu() 86 | test_multi_gpu() 87 | test_8bit() 88 | test_hf_api() 89 | -------------------------------------------------------------------------------- /tests/test_cli_inputs.txt: -------------------------------------------------------------------------------- 1 | Who are you? __END_OF_A_MESSAGE_47582648__ 2 | Three tips for staying healthy. __END_OF_A_MESSAGE_47582648__ 3 | One more tip. __END_OF_A_MESSAGE_47582648__ 4 | !!exit __END_OF_A_MESSAGE_47582648__ 5 | -------------------------------------------------------------------------------- /tests/test_openai_api.py: -------------------------------------------------------------------------------- 1 | """ 2 | Test the OpenAI compatible server 3 | 4 | Launch: 5 | python3 launch_openai_api_test_server.py 6 | """ 7 | import warnings 8 | 9 | import openai 10 | from fastchat.utils import run_cmd 11 | 12 | 13 | openai.api_key = "EMPTY" # Not support yet 14 | openai.base_url = "http://localhost:8000/v1/" 15 | 16 | 17 | def test_list_models(): 18 | model_list = openai.models.list() 19 | names = [x.id for x in model_list.data] 20 | return names 21 | 22 | 23 | def test_completion(model, logprob): 24 | prompt = "Once upon a time" 25 | completion = openai.completions.create( 26 | model=model, 27 | prompt=prompt, 28 | logprobs=logprob, 29 | max_tokens=64, 30 | temperature=0, 31 | ) 32 | 33 | print(f"full text: {prompt + completion.choices[0].text}", flush=True) 34 | if completion.choices[0].logprobs is not None: 35 | print( 36 | f"logprobs: {completion.choices[0].logprobs.token_logprobs[:10]}", 37 | flush=True, 38 | ) 39 | 40 | 41 | def test_completion_stream(model): 42 | prompt = "Once upon a time" 43 | res = openai.completions.create( 44 | model=model, 45 | prompt=prompt, 46 | max_tokens=64, 47 | stream=True, 48 | temperature=0, 49 | ) 50 | print(prompt, end="") 51 | for chunk in res: 52 | content = chunk.choices[0].text 53 | print(content, end="", flush=True) 54 | print() 55 | 56 | 57 | def test_embedding(model): 58 | embedding = openai.embeddings.create(model=model, input="Hello world!") 59 | print(f"embedding len: {len(embedding.data[0].embedding)}") 60 | print(f"embedding value[:5]: {embedding.data[0].embedding[:5]}") 61 | 62 | 63 | def test_chat_completion(model): 64 | completion = openai.chat.completions.create( 65 | model=model, 66 | messages=[{"role": "user", "content": "Hello! What is your name?"}], 67 | temperature=0, 68 | ) 69 | print(completion.choices[0].message.content) 70 | 71 | 72 | def test_chat_completion_stream(model): 73 | messages = [{"role": "user", "content": "Hello! What is your name?"}] 74 | res = openai.chat.completions.create( 75 | model=model, messages=messages, stream=True, temperature=0 76 | ) 77 | for chunk in res: 78 | try: 79 | content = chunk.choices[0].delta.content 80 | if content is None: 81 | content = "" 82 | except Exception as e: 83 | content = chunk.choices[0].delta.get("content", "") 84 | print(content, end="", flush=True) 85 | print() 86 | 87 | 88 | def test_openai_curl(): 89 | run_cmd("curl http://localhost:8000/v1/models") 90 | 91 | run_cmd( 92 | """ 93 | curl http://localhost:8000/v1/chat/completions \ 94 | -H "Content-Type: application/json" \ 95 | -d '{ 96 | "model": "vicuna-7b-v1.5", 97 | "messages": [{"role": "user", "content": "Hello! What is your name?"}] 98 | }' 99 | """ 100 | ) 101 | 102 | run_cmd( 103 | """ 104 | curl http://localhost:8000/v1/completions \ 105 | -H "Content-Type: application/json" \ 106 | -d '{ 107 | "model": "vicuna-7b-v1.5", 108 | "prompt": "Once upon a time", 109 | "max_tokens": 41, 110 | "temperature": 0.5 111 | }' 112 | """ 113 | ) 114 | 115 | run_cmd( 116 | """ 117 | curl http://localhost:8000/v1/embeddings \ 118 | -H "Content-Type: application/json" \ 119 | -d '{ 120 | "model": "vicuna-7b-v1.5", 121 | "input": "Hello world!" 122 | }' 123 | """ 124 | ) 125 | 126 | 127 | if __name__ == "__main__": 128 | models = test_list_models() 129 | print(f"models: {models}") 130 | 131 | for model in models: 132 | print(f"===== Test {model} ======") 133 | 134 | if model in ["fastchat-t5-3b-v1.0"]: 135 | logprob = None 136 | else: 137 | logprob = 1 138 | 139 | test_completion(model, logprob) 140 | test_completion_stream(model) 141 | test_chat_completion(model) 142 | test_chat_completion_stream(model) 143 | try: 144 | test_embedding(model) 145 | except openai.APIError as e: 146 | print(f"Embedding error: {e}") 147 | 148 | print("===== Test curl =====") 149 | test_openai_curl() 150 | -------------------------------------------------------------------------------- /tests/test_openai_langchain.py: -------------------------------------------------------------------------------- 1 | # Usage: 2 | # python3 -m fastchat.serve.model_worker --model-path lmsys/vicuna-7b-v1.5 --model-names gpt-3.5-turbo,text-davinci-003,text-embedding-ada-002 3 | # export OPENAI_API_BASE=http://localhost:8000/v1 4 | # export OPENAI_API_KEY=EMPTY 5 | # wget https://raw.githubusercontent.com/hwchase17/langchain/v0.0.200/docs/modules/state_of_the_union.txt 6 | 7 | import os 8 | 9 | from langchain.chat_models import ChatOpenAI 10 | from langchain.document_loaders import TextLoader 11 | from langchain.embeddings import OpenAIEmbeddings 12 | from langchain.indexes import VectorstoreIndexCreator 13 | 14 | 15 | def test_chain(): 16 | embedding = OpenAIEmbeddings(model="text-embedding-ada-002") 17 | loader = TextLoader("state_of_the_union.txt") 18 | index = VectorstoreIndexCreator(embedding=embedding).from_loaders([loader]) 19 | 20 | llm = ChatOpenAI(model="gpt-3.5-turbo") 21 | 22 | questions = [ 23 | "Who is the speaker", 24 | "What did the president say about Ketanji Brown Jackson", 25 | "What are the threats to America", 26 | "Who are mentioned in the speech", 27 | "Who is the vice president", 28 | "How many projects were announced", 29 | ] 30 | 31 | for query in questions: 32 | print("Query:", query) 33 | print("Answer:", index.query(query, llm=llm)) 34 | 35 | 36 | if __name__ == "__main__": 37 | os.environ["OPENAI_API_BASE"] = "http://localhost:8000/v1" 38 | os.environ["OPENAI_API_KEY"] = "empty" 39 | test_chain() 40 | --------------------------------------------------------------------------------