├── .gitignore ├── .gitmodules ├── LICENSE ├── README.md ├── arena_elo ├── LICENSE ├── README.md ├── arena_elo.egg-info │ ├── PKG-INFO │ ├── SOURCES.txt │ ├── dependency_links.txt │ ├── requires.txt │ └── top_level.txt ├── edition_model_info.json ├── elo_rating │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-39.pyc │ │ ├── basic_stats.cpython-39.pyc │ │ ├── clean_battle_data.cpython-39.pyc │ │ ├── elo_analysis.cpython-39.pyc │ │ ├── generate_leaderboard.cpython-39.pyc │ │ ├── model_registry.cpython-39.pyc │ │ └── utils.cpython-39.pyc │ ├── basic_stats.py │ ├── clean_battle_data.py │ ├── elo_analysis.py │ ├── generate_leaderboard.py │ ├── inspect_conv_rating.py │ ├── inspect_cost.py │ ├── inspect_elo_rating_pkl.py │ ├── model_registry.py │ ├── upload_battle_data.py │ └── utils.py ├── evaluator │ ├── convert_to_evaluator_data.py │ └── rating_analysis.ipynb ├── generation_model_info.json ├── get_latest_data.sh ├── pyproject.toml ├── requirements.txt ├── results │ ├── 20240220 │ │ ├── elo_results_image_editing.pkl │ │ ├── elo_results_t2i_generation.pkl │ │ ├── image_editing_leaderboard.csv │ │ └── t2i_generation_leaderboard.csv │ └── latest │ │ ├── elo_results_image_editing.pkl │ │ ├── elo_results_t2i_generation.pkl │ │ ├── image_editing_leaderboard.csv │ │ └── t2i_generation_leaderboard.csv ├── simple_test.py ├── test.ipynb └── update_elo_rating.sh ├── docker ├── Dockerfile └── docker-compose.yml ├── examples ├── .DS_Store ├── banana.jpg ├── cat.jpeg ├── city.jpg ├── dog.jpg ├── duck.jpg ├── duck_hat.jpg ├── fire.jpg ├── mouse.jpg ├── oranges.jpg ├── pig.jpg ├── rabbit.jpg └── strawberries.jpg ├── fastchat ├── __init__.py ├── constants.py ├── conversation.py ├── data │ ├── __init__.py │ ├── clean_sharegpt.py │ ├── convert_alpaca.py │ ├── extract_gpt4_only.py │ ├── extract_single_round.py │ ├── filter_wrong_format.py │ ├── get_stats.py │ ├── hardcoded_questions.py │ ├── inspect_data.py │ ├── merge.py │ ├── optional_clean.py │ ├── optional_replace.py │ ├── prepare_all.py │ ├── pretty_json.py │ ├── sample.py │ ├── split_long_conversation.py │ └── split_train_test.py ├── llm_judge │ ├── README.md │ ├── clean_judgment.py │ ├── common.py │ ├── compute_agreement.py │ ├── data │ │ ├── judge_prompts.jsonl │ │ ├── mt_bench │ │ │ ├── misc │ │ │ │ └── radar.png │ │ │ ├── question.jsonl │ │ │ └── reference_answer │ │ │ │ └── gpt-4.jsonl │ │ └── vicuna_bench │ │ │ ├── question.jsonl │ │ │ └── reference_answer │ │ │ └── gpt-4.jsonl │ ├── download_mt_bench_pregenerated.py │ ├── gen_api_answer.py │ ├── gen_judgment.py │ ├── gen_model_answer.py │ ├── qa_browser.py │ └── show_result.py ├── model │ ├── __init__.py │ ├── apply_delta.py │ ├── apply_lora.py │ ├── compression.py │ ├── convert_fp16.py │ ├── llama_condense_monkey_patch.py │ ├── make_delta.py │ ├── model_adapter.py │ ├── model_chatglm.py │ ├── model_codet5p.py │ ├── model_exllama.py │ ├── model_falcon.py │ ├── model_imagenhub.py │ ├── model_imagenhub_ie.py │ ├── model_lavie.py │ ├── model_playground.py │ ├── model_registry.py │ ├── model_stable_diffusion.py │ ├── model_xfastertransformer.py │ ├── monkey_patch_non_inplace.py │ ├── rwkv_model.py │ └── upload_hub.py ├── modules │ ├── __init__.py │ ├── awq.py │ ├── exllama.py │ ├── gptq.py │ └── xfastertransformer.py ├── protocol │ ├── api_protocol.py │ └── openai_api_protocol.py ├── serve │ ├── __init__.py │ ├── api_provider.py │ ├── base_model_worker.py │ ├── cli.py │ ├── controller.py │ ├── gateway │ │ ├── README.md │ │ └── nginx.conf │ ├── gradio_block_arena_anony.py │ ├── gradio_block_arena_anony_ie.py │ ├── gradio_block_arena_named.py │ ├── gradio_block_arena_named_ie.py │ ├── gradio_web_image_editing_server.py │ ├── gradio_web_server.py │ ├── gradio_web_server_image_editing_multi.py │ ├── gradio_web_server_multi.py │ ├── huggingface_api.py │ ├── huggingface_api_worker.py │ ├── inference.py │ ├── launch_all_serve.py │ ├── model_worker.py │ ├── monitor │ │ ├── basic_stats.py │ │ ├── clean_battle_data.py │ │ ├── clean_chat_data.py │ │ ├── dataset_release_scripts │ │ │ ├── arena_33k │ │ │ │ ├── count_unique_users.py │ │ │ │ ├── filter_bad_conv.py │ │ │ │ ├── merge_field.py │ │ │ │ ├── sample.py │ │ │ │ └── upload_hf_dataset.py │ │ │ └── lmsys_chat_1m │ │ │ │ ├── approve_all.py │ │ │ │ ├── compute_stats.py │ │ │ │ ├── filter_bad_conv.py │ │ │ │ ├── final_post_processing.py │ │ │ │ ├── instructions.md │ │ │ │ ├── merge_oai_tag.py │ │ │ │ ├── process_all.sh │ │ │ │ ├── sample.py │ │ │ │ └── upload_hf_dataset.py │ │ ├── elo_analysis.py │ │ ├── inspect_conv.py │ │ ├── intersect_conv_file.py │ │ ├── leaderboard_csv_to_html.py │ │ ├── monitor.py │ │ ├── summarize_cluster.py │ │ ├── tag_openai_moderation.py │ │ └── topic_clustering.py │ ├── multi_model_worker.py │ ├── openai_api_server.py │ ├── register_worker.py │ ├── shutdown_serve.py │ ├── test_message.py │ ├── test_throughput.py │ └── vllm_worker.py ├── train │ ├── llama2_flash_attn_monkey_patch.py │ ├── llama_flash_attn_monkey_patch.py │ ├── llama_xformers_attn_monkey_patch.py │ ├── train.py │ ├── train_baichuan.py │ ├── train_flant5.py │ ├── train_lora.py │ ├── train_lora_t5.py │ ├── train_mem.py │ └── train_xformers.py └── utils.py ├── format.sh ├── imagenhub_requirements.txt └── pyproject.toml /.gitignore: -------------------------------------------------------------------------------- 1 | *.log 2 | *.json 3 | __pycache__/ 4 | *log* 5 | *.egg-info/ 6 | /src -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "ImagenHub"] 2 | path = ImagenHub 3 | url = https://github.com/TIGER-AI-Lab/ImagenHub.git 4 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ### Installation 2 | 3 | ``` 4 | pip install -e . 5 | pip install -r imagenhub_requirements.txt 6 | cd ImagenHub && pip install -e . 7 | ``` 8 | 9 | ### run 10 | 11 | 12 | ``` 13 | # run controller 14 | python3 -m fastchat.serve.controller > controller.log & 15 | 16 | # run worker 17 | CUDA_VISIBLE_DEVICES=0 nohup python3 -m fastchat.serve.model_worker --model-path imagenhub_LCM_generation --controller http://localhost:21001 --port 31005 --worker http://localhost:31005 > model_log/lcm.log & 18 | 19 | CUDA_VISIBLE_DEVICES=0 nohup python3 -m fastchat.serve.model_worker --model-path imagenhub_SDXLTurbo_generation --controller http://localhost:21001 --port 31010 --worker http://localhost:31010 > model_log/SDXLTurbo.log & 20 | 21 | CUDA_VISIBLE_DEVICES=2 nohup python3 -m fastchat.serve.model_worker --model-path imagenhub_SDXL_generation --controller http://localhost:21001 --port 31017 --worker http://localhost:31017 > model_log/SDXL.log & 22 | 23 | CUDA_VISIBLE_DEVICES=1 nohup python3 -m fastchat.serve.model_worker --model-path imagenhub_OpenJourney_generation --controller http://localhost:21001 --port 31011 --worker http://localhost:31011 > model_log/openjourney.log & 24 | 25 | CUDA_VISIBLE_DEVICES=1 nohup python3 -m fastchat.serve.model_worker --model-path imagenhub_PixArtAlpha_generation --controller http://localhost:21001 --port 31012 --worker http://localhost:31012 --limit-worker-concurrency 1 > model_log/PixArtAlpha.log & 26 | 27 | CUDA_VISIBLE_DEVICES=6 nohup python3 -m fastchat.serve.model_worker --model-path imagenhub_SDXLLightning_generation --controller http://localhost:21001 --port 31022 --worker http://localhost:31022 --limit-worker-concurrency 1 > model_log/SDXLLightning.log & 28 | 29 | CUDA_VISIBLE_DEVICES=5 nohup python3 -m fastchat.serve.model_worker --model-path imagenhub_StableCascade_generation --controller http://localhost:21001 --port 31023 --worker http://localhost:31023 --limit-worker-concurrency 1 > model_log/StableCascade.log & 30 | 31 | nohup python3 -m fastchat.serve.model_worker --model-path "Playground v2.5" --controller http://localhost:21001 --port 31024 --worker http://localhost:31024 --limit-worker-concurrency 1 > model_log/PlayGroundV2.5.log & 32 | 33 | nohup python3 -m fastchat.serve.model_worker --model-path "Playground v2" --controller http://localhost:21001 --port 31021 --worker http://localhost:31021 --limit-worker-concurrency 1 > model_log/PlayGroundV2.log & 34 | 35 | 36 | 37 | CUDA_VISIBLE_DEVICES=2 nohup python3 -m fastchat.serve.model_worker --model-path imagenhub_CycleDiffusion_edition --controller http://localhost:21001 --port 31013 --worker http://localhost:31013 > model_log/CycleDiffusion.log & 38 | 39 | CUDA_VISIBLE_DEVICES=4 nohup python3 -m fastchat.serve.model_worker --model-path imagenhub_Pix2PixZero_edition --controller http://localhost:21001 --port 31014 --worker http://localhost:31014 --limit-worker-concurrency 1 > model_log/Pix2PixZero.log & 40 | 41 | CUDA_VISIBLE_DEVICES=7 nohup python3 -m fastchat.serve.model_worker --model-path imagenhub_Prompt2prompt_edition --controller http://localhost:21001 --port 31015 --worker http://localhost:31015 --limit-worker-concurrency 1 > model_log/Prompt2prompt.log & 42 | 43 | CUDA_VISIBLE_DEVICES=5 nohup python3 -m fastchat.serve.model_worker --model-path imagenhub_SDEdit_edition --controller http://localhost:21001 --port 31016 --worker http://localhost:31016 > model_log/SDEdit.log & 44 | 45 | CUDA_VISIBLE_DEVICES=5 nohup python3 -m fastchat.serve.model_worker --model-path imagenhub_InstructPix2Pix_edition --controller http://localhost:21001 --port 31018 --worker http://localhost:31018 > model_log/InstructPix2Pix.log & 46 | 47 | CUDA_VISIBLE_DEVICES=6 nohup python3 -m fastchat.serve.model_worker --model-path imagenhub_MagicBrush_edition --controller http://localhost:21001 --port 31019 --worker http://localhost:31019 > model_log/MagicBrush.log & 48 | 49 | CUDA_VISIBLE_DEVICES=6 nohup python3 -m fastchat.serve.model_worker --model-path imagenhub_PNP_edition --controller http://localhost:21001 --port 31020 --worker http://localhost:31020 --limit-worker-concurrency 1 > model_log/PNP.log & 50 | 51 | # run web server UI (without leaderboard) 52 | python3 -m fastchat.serve.gradio_web_server_image_editing_multi --share --controller-url http://localhost:21001 53 | 54 | # run web server UI (with leaderboard) 55 | python3 -m fastchat.serve.gradio_web_server_image_editing_multi --share --controller-url http://localhost:21001 --elo_results_dir ./arena_elo/results/latest/ 56 | ``` 57 | 58 | ### update leaderboard data 59 | 60 | ``` 61 | cd arena_elo 62 | export LOGDIR="/home/tianle/arena_vote" 63 | bash update_elo.sh 64 | ``` 65 | then results are updated in `arena_elo/results/latest/` 66 | 67 | ### WishList 68 | 69 | 1. LEDITS: https://huggingface.co/spaces/editing-images/ledits 70 | 2. InfEdit: https://huggingface.co/spaces/sled-umich/InfEdit 71 | 3. MGIE: https://huggingface.co/spaces/tsujuifu/ml-mgie 72 | 4. OpenDalle: https://huggingface.co/dataautogpt3/OpenDalleV1.1 73 | -------------------------------------------------------------------------------- /arena_elo/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 WildVision-Bench 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /arena_elo/README.md: -------------------------------------------------------------------------------- 1 | ## Computing the Elo Ratings 2 | 3 | 4 | ```bash 5 | apt-get -y install pkg-config 6 | pip install -r requirements.txt 7 | ``` 8 | 9 | 10 | ### to update the leaderboard 11 | 12 | ```bash 13 | export LOGDIR="/path/to/your/logdir" 14 | bash update_elo_rating.sh 15 | ``` 16 | 17 | ### to inspect the leaderboard status 18 | ```bash 19 | python -m elo_rating.inspect_elo_rating_pkl 20 | ``` 21 | 22 | ### to inspect the collected data status and cost 23 | ```bash 24 | export LOGDIR="/path/to/your/logdir" 25 | python -m elo_rating.inspect_cost 26 | ``` 27 | 28 | ### to upload the battle data to hugging face🤗 29 | ```bash 30 | export HUGGINGFACE_TOKEN="your_huggingface_token" 31 | bash get_latest_data.sh 32 | python -m elo_rating.upload_battle_data --repo_id "WildVision/wildvision-bench" --log_dir "./vision-arena-logs/" 33 | ``` 34 | 35 | ### to upload the chat data to hugging face🤗 36 | ```bash 37 | export HUGGINGFACE_TOKEN="your_huggingface_token" 38 | bash get_latest_data.sh 39 | python -m elo_rating.upload_chat_data --repo_id "WildVision/wildvision-bench" --log_dir "./vision-arena-logs/" 40 | ``` 41 | 42 | 43 | ### to get the collected data 44 | ```bash 45 | python -m 46 | 47 | -------------------------------------------------------------------------------- /arena_elo/arena_elo.egg-info/PKG-INFO: -------------------------------------------------------------------------------- 1 | Metadata-Version: 2.1 2 | Name: arena_elo 3 | Version: 0.2.35 4 | Summary: Elo rating system for WildVision Bench Arena 5 | Project-URL: Homepage, https://github.com/WildVision-Bench/Arena-Elo 6 | Project-URL: Bug Tracker, https://github.com/WildVision-Bench/Arena-Elo/issues 7 | Classifier: Programming Language :: Python :: 3 8 | Classifier: License :: OSI Approved :: Apache Software License 9 | Requires-Python: >=3.9 10 | Description-Content-Type: text/markdown 11 | License-File: LICENSE 12 | Requires-Dist: numpy 13 | Requires-Dist: prompt_toolkit>=3.0.0 14 | Requires-Dist: uvicorn 15 | Requires-Dist: polyglot 16 | Requires-Dist: pyicu 17 | Requires-Dist: pycld2 18 | Requires-Dist: morfessor 19 | Requires-Dist: scikit-learn 20 | Requires-Dist: pytz 21 | Requires-Dist: tqdm 22 | Requires-Dist: pandas 23 | Requires-Dist: plotly 24 | Requires-Dist: fire 25 | Requires-Dist: Pillow 26 | 27 | ## Computing the Elo Ratings 28 | 29 | 30 | ```bash 31 | apt-get -y install pkg-config 32 | pip install -r requirements.txt 33 | ``` 34 | 35 | 36 | ### to update the leaderboard 37 | 38 | ```bash 39 | export LOGDIR="/path/to/your/logdir" 40 | bash update_elo_rating.sh 41 | ``` 42 | 43 | ### to inspect the leaderboard status 44 | ```bash 45 | python -m elo_rating.inspect_elo_rating_pkl 46 | ``` 47 | 48 | ### to inspect the collected data status and cost 49 | ```bash 50 | export LOGDIR="/path/to/your/logdir" 51 | python -m elo_rating.inspect_cost 52 | ``` 53 | 54 | ### to upload the battle data to hugging face🤗 55 | ```bash 56 | export HUGGINGFACE_TOKEN="your_huggingface_token" 57 | bash get_latest_data.sh 58 | python -m elo_rating.upload_battle_data --repo_id "WildVision/wildvision-bench" --log_dir "./vision-arena-logs/" 59 | ``` 60 | 61 | ### to upload the chat data to hugging face🤗 62 | ```bash 63 | export HUGGINGFACE_TOKEN="your_huggingface_token" 64 | bash get_latest_data.sh 65 | python -m elo_rating.upload_chat_data --repo_id "WildVision/wildvision-bench" --log_dir "./vision-arena-logs/" 66 | ``` 67 | 68 | 69 | ### to get the collected data 70 | ```bash 71 | python -m 72 | 73 | -------------------------------------------------------------------------------- /arena_elo/arena_elo.egg-info/SOURCES.txt: -------------------------------------------------------------------------------- 1 | LICENSE 2 | README.md 3 | pyproject.toml 4 | arena_elo.egg-info/PKG-INFO 5 | arena_elo.egg-info/SOURCES.txt 6 | arena_elo.egg-info/dependency_links.txt 7 | arena_elo.egg-info/requires.txt 8 | arena_elo.egg-info/top_level.txt 9 | data_taxnomy/inspect_conv_category.py 10 | data_taxnomy/taxnomy_creator.py 11 | elo_rating/__init__.py 12 | elo_rating/basic_stats.py 13 | elo_rating/clean_battle_data.py 14 | elo_rating/clean_chat_data.py 15 | elo_rating/elo_analysis.py 16 | elo_rating/generate_leaderboard.py 17 | elo_rating/inspect_conv.py 18 | elo_rating/inspect_conv_rating.py 19 | elo_rating/inspect_cost.py 20 | elo_rating/inspect_elo_rating_pkl.py 21 | elo_rating/model_registry.py 22 | elo_rating/upload_battle_data.py 23 | elo_rating/upload_chat_data.py 24 | elo_rating/utils.py 25 | evaluator/convert_to_evaluator_data.py -------------------------------------------------------------------------------- /arena_elo/arena_elo.egg-info/dependency_links.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /arena_elo/arena_elo.egg-info/requires.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | prompt_toolkit>=3.0.0 3 | uvicorn 4 | polyglot 5 | pyicu 6 | pycld2 7 | morfessor 8 | scikit-learn 9 | pytz 10 | tqdm 11 | pandas 12 | plotly 13 | fire 14 | Pillow 15 | -------------------------------------------------------------------------------- /arena_elo/arena_elo.egg-info/top_level.txt: -------------------------------------------------------------------------------- 1 | data_taxnomy 2 | elo_rating 3 | evaluator 4 | results 5 | vision-arena-logs 6 | -------------------------------------------------------------------------------- /arena_elo/edition_model_info.json: -------------------------------------------------------------------------------- 1 | { 2 | "CycleDiffusion": { 3 | "Link": "https://github.com/ChenWu98/cycle-diffusion", 4 | "License": "X11", 5 | "Organization": "Carnegie Mellon University" 6 | }, 7 | "PNP": { 8 | "Link": "https://github.com/MichalGeyer/plug-and-play", 9 | "License": "-", 10 | "Organization": "Weizmann Institute of Science" 11 | }, 12 | "InstructPix2Pix": { 13 | "Link": "https://www.timothybrooks.com/instruct-pix2pix", 14 | "License": "Copyright 2023 Timothy Brooks, Aleksander Holynski, Alexei A. Efros", 15 | "Organization": "University of California, Berkeley" 16 | }, 17 | "Pix2PixZero": { 18 | "Link": "https://pix2pixzero.github.io", 19 | "License": "MIT License", 20 | "Organization": "Carnegie Mellon University, Adobe Research" 21 | }, 22 | "MagicBrush": { 23 | "Link": "https://osu-nlp-group.github.io/MagicBrush", 24 | "License": "CC-BY-4.0", 25 | "Organization": "The Ohio State University, University of Waterloo" 26 | }, 27 | "Prompt2prompt": { 28 | "Link": "https://prompt-to-prompt.github.io", 29 | "License": "Apache-2.0", 30 | "Organization": "Google, Tel Aviv University" 31 | }, 32 | "SDEdit": { 33 | "Link": "https://sde-image-editing.github.io", 34 | "License": "MIT License", 35 | "Organization": "Stanford University" 36 | } 37 | } -------------------------------------------------------------------------------- /arena_elo/elo_rating/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TIGER-AI-Lab/GenAI-Arena/1792672eed6a4b146b194f9ba7316f25b95cb53b/arena_elo/elo_rating/__init__.py -------------------------------------------------------------------------------- /arena_elo/elo_rating/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TIGER-AI-Lab/GenAI-Arena/1792672eed6a4b146b194f9ba7316f25b95cb53b/arena_elo/elo_rating/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /arena_elo/elo_rating/__pycache__/basic_stats.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TIGER-AI-Lab/GenAI-Arena/1792672eed6a4b146b194f9ba7316f25b95cb53b/arena_elo/elo_rating/__pycache__/basic_stats.cpython-39.pyc -------------------------------------------------------------------------------- /arena_elo/elo_rating/__pycache__/clean_battle_data.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TIGER-AI-Lab/GenAI-Arena/1792672eed6a4b146b194f9ba7316f25b95cb53b/arena_elo/elo_rating/__pycache__/clean_battle_data.cpython-39.pyc -------------------------------------------------------------------------------- /arena_elo/elo_rating/__pycache__/elo_analysis.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TIGER-AI-Lab/GenAI-Arena/1792672eed6a4b146b194f9ba7316f25b95cb53b/arena_elo/elo_rating/__pycache__/elo_analysis.cpython-39.pyc -------------------------------------------------------------------------------- /arena_elo/elo_rating/__pycache__/generate_leaderboard.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TIGER-AI-Lab/GenAI-Arena/1792672eed6a4b146b194f9ba7316f25b95cb53b/arena_elo/elo_rating/__pycache__/generate_leaderboard.cpython-39.pyc -------------------------------------------------------------------------------- /arena_elo/elo_rating/__pycache__/model_registry.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TIGER-AI-Lab/GenAI-Arena/1792672eed6a4b146b194f9ba7316f25b95cb53b/arena_elo/elo_rating/__pycache__/model_registry.cpython-39.pyc -------------------------------------------------------------------------------- /arena_elo/elo_rating/__pycache__/utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TIGER-AI-Lab/GenAI-Arena/1792672eed6a4b146b194f9ba7316f25b95cb53b/arena_elo/elo_rating/__pycache__/utils.cpython-39.pyc -------------------------------------------------------------------------------- /arena_elo/elo_rating/generate_leaderboard.py: -------------------------------------------------------------------------------- 1 | import fire 2 | import json 3 | import pandas as pd 4 | import pickle 5 | 6 | 7 | def main( 8 | model_info_file: str, 9 | elo_rating_pkl: str, 10 | output_csv: str 11 | ): 12 | model_info = json.load(open(model_info_file)) 13 | 14 | with open(elo_rating_pkl, "rb") as fin: 15 | elo_rating_results = pickle.load(fin) 16 | 17 | anony_elo_rating_results = elo_rating_results["anony"] 18 | full_elo_rating_results = elo_rating_results["full"] 19 | anony_leaderboard_data = anony_elo_rating_results["leaderboard_table_df"] 20 | full_leaderboard_data = full_elo_rating_results["leaderboard_table_df"] 21 | 22 | # Model,MT-bench (score),Arena Elo rating,MMLU,License,Link 23 | fields = ["key", "Model", "Arena Elo rating (anony)", "Arena Elo rating (full)", "License", "Organization", "Link"] 24 | # set Organization and license to empty for now 25 | all_models = anony_leaderboard_data.index.tolist() 26 | 27 | for model in all_models: 28 | if not model in model_info: 29 | model_info[model] = {} 30 | model_info[model]["License"] = "N/A" 31 | model_info[model]["Organization"] = "N/A" 32 | model_info[model]["Link"] = "N/A" 33 | model_info[model]["Model"] = model 34 | model_info[model]["key"] = model 35 | 36 | if model in anony_leaderboard_data.index: 37 | model_info[model]["Arena Elo rating (anony)"] = anony_leaderboard_data.loc[model, "rating"] 38 | else: 39 | model_info[model]["Arena Elo rating (anony)"] = 0 40 | 41 | if model in full_elo_rating_results["leaderboard_table_df"].index: 42 | model_info[model]["Arena Elo rating (full)"] = full_leaderboard_data.loc[model, "rating"] 43 | else: 44 | model_info[model]["Arena Elo rating (full)"] = 0 45 | # if model in anony_leaderboard_data.index: 46 | # model_info[model]["Arena Elo rating"] = anony_leaderboard_data.loc[model, "rating"] 47 | # else: 48 | # model_info[model]["Arena Elo rating"] = 0 49 | 50 | final_model_info = {} 51 | for model in model_info: 52 | if "Model" in model_info[model]: 53 | final_model_info[model] = model_info[model] 54 | model_info = final_model_info 55 | 56 | exclude_keys = ['starting_from'] 57 | for key in exclude_keys: 58 | for model in model_info: 59 | if key in model_info[model]: 60 | del model_info[model][key] 61 | df = pd.DataFrame(model_info).T 62 | df = df[fields] 63 | # sort by anony rating 64 | df = df.sort_values(by=["Arena Elo rating (anony)"], ascending=False) 65 | df.to_csv(output_csv, index=False) 66 | print("Leaderboard data saved to", output_csv) 67 | print(df) 68 | 69 | 70 | if __name__ == "__main__": 71 | fire.Fire(main) -------------------------------------------------------------------------------- /arena_elo/elo_rating/inspect_elo_rating_pkl.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import plotly.graph_objects as go 3 | 4 | def output_figure(data, figure_name="battle_count_heatmap", label="annoy"): 5 | fig = data[label][figure_name] 6 | fig.update_layout( 7 | height=700, 8 | width=700, 9 | title={'text': f'{figure_name}', 'x': 0.5, 'y': 0.07}, 10 | xaxis_title="Model B", 11 | yaxis_title="Model A", 12 | # coloraxis_colorscale=[[0.0, '#0d0887'], [1.0, '#f0f921']], 13 | margin={'t': 60} 14 | ) 15 | fig.write_image(f"{figure_name}.png") 16 | 17 | with open("./results/latest/elo_results.pkl",'rb') as f: 18 | data = pickle.load(f) 19 | print() 20 | df = data["anony"]["leaderboard_table_df"] 21 | # sort by rating 22 | print(data["anony"].keys()) 23 | 24 | for figure_name in [ 'win_fraction_heatmap', 'battle_count_heatmap',]: 25 | output_figure(data, figure_name, "anony") 26 | 27 | df = df.sort_values(by=["rating"], ascending=False) 28 | print(df) 29 | df = data["full"]["leaderboard_table_df"] 30 | # sort by rating 31 | df = df.sort_values(by=["rating"], ascending=False) 32 | print(df) 33 | print('done') 34 | -------------------------------------------------------------------------------- /arena_elo/elo_rating/utils.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | import pytz 3 | import PIL 4 | import os 5 | 6 | def detect_language(text: str) -> str: 7 | """Detect the langauge of a string.""" 8 | import polyglot # pip3 install polyglot pyicu pycld2 9 | from polyglot.detect import Detector 10 | from polyglot.detect.base import logger as polyglot_logger 11 | import pycld2 12 | 13 | polyglot_logger.setLevel("ERROR") 14 | 15 | try: 16 | lang_code = Detector(text).language.name 17 | except (pycld2.error, polyglot.detect.base.UnknownLanguage): 18 | lang_code = "unknown" 19 | return lang_code 20 | 21 | 22 | def get_time_stamp_from_date(date_str:str): 23 | """ 24 | Convert a date string to a Unix timestamp 25 | Args: 26 | date_str (str): The input date string in the format 'YYYY-MM-DD-HH:MM-TZ', e.g. '2024-02-10-14:00-PT' 27 | """ 28 | 29 | # Convert the date string into a format that Python's datetime can understand 30 | # and specify the correct timezone for PT, which is 'US/Pacific' 31 | date_format = "%Y-%m-%d-%H:%M-%Z" 32 | 33 | # Parse the date string into a datetime object 34 | # Note: PT is not directly recognized by pytz, so we manually map it to 'US/Pacific' 35 | timezone_map = { 36 | "PT": "US/Pacific", 37 | } 38 | 39 | # Extract the timezone abbreviation 40 | tz_abbr = date_str.split("-")[-1] 41 | # Map the abbreviation to a pytz timezone 42 | tz_info = pytz.timezone(timezone_map[tz_abbr]) 43 | 44 | # Remove the timezone abbreviation for parsing 45 | date_str_parsed = date_str.rsplit("-", 1)[0] 46 | 47 | # Create a datetime object with the corresponding timezone 48 | dt = datetime.strptime(date_str_parsed, "%Y-%m-%d-%H:%M").replace(tzinfo=tz_info) 49 | 50 | # Convert the datetime object to a Unix timestamp 51 | unix_timestamp = dt.timestamp() 52 | return unix_timestamp 53 | 54 | def get_date_from_time_stamp(unix_timestamp: int): 55 | # Create a datetime object from the Unix timestamp 56 | dt = datetime.fromtimestamp(unix_timestamp) 57 | 58 | # Convert the datetime object to a string with the desired format 59 | date_str = dt.strftime("%Y-%m-%d %H:%M:%S %Z") 60 | return date_str 61 | 62 | 63 | def get_input_image_path(tstamp, conv_id): 64 | # from tstamp to date e.g. 2024-02-10 65 | date_str = datetime.fromtimestamp(tstamp, tz=pytz.timezone("US/Pacific")).strftime("%Y-%m-%d") 66 | LOGDIR = os.getenv("LOGDIR") 67 | return f"{LOGDIR}/{date_str}-convinput_images/input_image_{conv_id}.png" 68 | 69 | def load_image_from_path(image_path): 70 | # Load the image from the specified 71 | # path using the Python Imaging Library (PIL) 72 | try: 73 | image = PIL.Image.open(image_path) 74 | return image 75 | except FileNotFoundError: 76 | print(f"Image not found at path: {image_path}") 77 | return None 78 | except PIL.UnidentifiedImageError: 79 | print(f"Unidentified image format at path: {image_path}") 80 | return None 81 | 82 | 83 | -------------------------------------------------------------------------------- /arena_elo/generation_model_info.json: -------------------------------------------------------------------------------- 1 | { 2 | "LCM": { 3 | "Link": "https://huggingface.co/SimianLuo/LCM_Dreamshaper_v7", 4 | "License": "MIT License", 5 | "Organization": "Tsinghua University" 6 | }, 7 | "Playground v2": { 8 | "Link": "https://huggingface.co/playgroundai/playground-v2-1024px-aesthetic", 9 | "License": "Playground v2 Community License", 10 | "Organization": "Playground" 11 | }, 12 | "OpenJourney": { 13 | "Link": "https://huggingface.co/prompthero/openjourney", 14 | "License": "creativeml-openrail-m", 15 | "Organization": "PromptHero" 16 | }, 17 | "SDXLTurbo": { 18 | "Link": "https://huggingface.co/stabilityai/sdxl-turbo", 19 | "License": "sai-nc-community (other)", 20 | "Organization": "Stability AI" 21 | }, 22 | "SDXL": { 23 | "Link": "https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0", 24 | "License": "openrail++", 25 | "Organization": "Stability AI" 26 | }, 27 | "PixArtAlpha": { 28 | "Link": "https://huggingface.co/PixArt-alpha/PixArt-XL-2-1024-MS", 29 | "License": "openrail++", 30 | "Organization": "PixArt-alpha" 31 | }, 32 | "SDXLLightning": { 33 | "Link": "https://huggingface.co/ByteDance/SDXL-Lightning", 34 | "License": "openrail++", 35 | "Organization": "ByteDance" 36 | }, 37 | "StableCascade": { 38 | "Link": "https://huggingface.co/stabilityai/stable-cascade", 39 | "License": "stable-cascade-nc-community (other)", 40 | "Organization": "Stability AI" 41 | } 42 | } -------------------------------------------------------------------------------- /arena_elo/get_latest_data.sh: -------------------------------------------------------------------------------- 1 | 2 | # set LOGDIR to default if not set before 3 | if [ -z "$LOGDIR" ]; then 4 | export LOGDIR="./vision-arena-logs" 5 | fi 6 | mkdir -p results 7 | 8 | 9 | # # for battle data 10 | python -m elo_rating.clean_battle_data --model_infos_file "./model_infos.json" --mode conv_release 11 | battle_cutoff_date=`cat cut_off_date.txt` && rm cut_off_date.txt && echo "Battle data last updated on $battle_cutoff_date" 12 | 13 | mkdir -p ./results/latest 14 | mkdir -p ./results/$battle_cutoff_date && mv ./clean_battle_conv_$battle_cutoff_date.json ./results/$battle_cutoff_date/clean_battle_conv.json 15 | cp ./results/$battle_cutoff_date/clean_battle_conv.json ./results/latest/clean_battle_conv.json 16 | 17 | echo "Battle data last updated on $battle_cutoff_date" >> ./results/latest/latest_updated_date.txt 18 | -------------------------------------------------------------------------------- /arena_elo/pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61.0"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "arena_elo" 7 | version = "0.2.35" 8 | description = "Elo rating system for WildVision Bench Arena" 9 | readme = "README.md" 10 | requires-python = ">=3.9" 11 | classifiers = [ 12 | "Programming Language :: Python :: 3", 13 | "License :: OSI Approved :: Apache Software License", 14 | ] 15 | dependencies = [ 16 | "numpy", "prompt_toolkit>=3.0.0", "uvicorn","polyglot", "pyicu", "pycld2", "morfessor", "scikit-learn", 17 | "pytz", "tqdm", "pandas", "plotly", "fire", "Pillow" 18 | ] 19 | 20 | [project.urls] 21 | "Homepage" = "https://github.com/WildVision-Bench/Arena-Elo" 22 | "Bug Tracker" = "https://github.com/WildVision-Bench/Arena-Elo/issues" 23 | 24 | [tool.setuptools.packages.find] 25 | exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"] 26 | 27 | [tool.wheel] 28 | exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"] -------------------------------------------------------------------------------- /arena_elo/requirements.txt: -------------------------------------------------------------------------------- 1 | -e git+https://github.com/WildVision-Bench/Arena-Elo.git@9dc2fa8543a2e9eda3d5bc01c2212fdfcdd4bfb5#egg=arena_elo 2 | click==8.1.7 3 | fire==0.5.0 4 | h11==0.14.0 5 | joblib==1.3.2 6 | Morfessor==2.0.6 7 | numpy==1.26.4 8 | packaging==23.2 9 | pandas==2.2.0 10 | pillow==10.2.0 11 | plotly==5.18.0 12 | polyglot==16.7.4 13 | prompt-toolkit==3.0.43 14 | pycld2==0.41 15 | PyICU==2.12 16 | python-dateutil==2.8.2 17 | pytz==2024.1 18 | scikit-learn==1.4.0 19 | scipy==1.12.0 20 | six==1.16.0 21 | tenacity==8.2.3 22 | termcolor==2.4.0 23 | threadpoolctl==3.2.0 24 | tqdm==4.66.2 25 | typing_extensions==4.9.0 26 | tzdata==2024.1 27 | uvicorn==0.27.1 28 | wcwidth==0.2.13 29 | -------------------------------------------------------------------------------- /arena_elo/results/20240220/elo_results_image_editing.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TIGER-AI-Lab/GenAI-Arena/1792672eed6a4b146b194f9ba7316f25b95cb53b/arena_elo/results/20240220/elo_results_image_editing.pkl -------------------------------------------------------------------------------- /arena_elo/results/20240220/elo_results_t2i_generation.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TIGER-AI-Lab/GenAI-Arena/1792672eed6a4b146b194f9ba7316f25b95cb53b/arena_elo/results/20240220/elo_results_t2i_generation.pkl -------------------------------------------------------------------------------- /arena_elo/results/20240220/image_editing_leaderboard.csv: -------------------------------------------------------------------------------- 1 | key,Model,Arena Elo rating (anony),Arena Elo rating (full),License,Organization,Link 2 | Prompt2prompt,Prompt2prompt,1252.820838097007,1216.6489026518666,Apache-2.0,"Google, Tel Aviv University",https://prompt-to-prompt.github.io 3 | PNP,PNP,1175.6261555831445,1171.3279007979363,-,Weizmann Institute of Science,https://github.com/MichalGeyer/plug-and-play 4 | InstructPix2Pix,InstructPix2Pix,1155.8431458813104,1142.6827834982837,"Copyright 2023 Timothy Brooks, Aleksander Holynski, Alexei A. Efros","University of California, Berkeley",https://www.timothybrooks.com/instruct-pix2pix 5 | MagicBrush,MagicBrush,1051.428411953954,1089.4499296239383,CC-BY-4.0,"The Ohio State University, University of Waterloo",https://osu-nlp-group.github.io/MagicBrush 6 | Pix2PixZero,Pix2PixZero,955.5903260059122,929.2296611307636,MIT License,"Carnegie Mellon University, Adobe Research",https://pix2pixzero.github.io 7 | CycleDiffusion,CycleDiffusion,771.4360186105207,753.4930725653142,X11,Carnegie Mellon University,https://github.com/ChenWu98/cycle-diffusion 8 | SDEdit,SDEdit,637.2551038681513,697.1677497318974,MIT License,Stanford University,https://sde-image-editing.github.io 9 | -------------------------------------------------------------------------------- /arena_elo/results/20240220/t2i_generation_leaderboard.csv: -------------------------------------------------------------------------------- 1 | key,Model,Arena Elo rating (anony),Arena Elo rating (full),License,Organization,Link 2 | PlayGroundV2,PlayGroundV2,1151.1834096302248,1150.901721636401,Playground v2 Community License,Playground,https://huggingface.co/playgroundai/playground-v2-1024px-aesthetic 3 | PixArtAlpha,PixArtAlpha,1078.3583466674136,1069.815012597113,openrail++,PixArt-alpha,https://huggingface.co/PixArt-alpha/PixArt-XL-2-1024-MS 4 | SDXL,SDXL,1027.258044463298,1035.47732509915,openrail++,Stability AI,https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0 5 | SDXLTurbo,SDXLTurbo,972.0904914158416,969.4933207298967,sai-nc-community (other),Stability AI,https://huggingface.co/stabilityai/sdxl-turbo 6 | OpenJourney,OpenJourney,921.3424873878607,906.3184453708288,creativeml-openrail-m,PromptHero,https://huggingface.co/prompthero/openjourney 7 | LCM,LCM,849.7672204353615,868.2154196730218,MIT License,Tsinghua University,https://huggingface.co/SimianLuo/LCM_Dreamshaper_v7 8 | -------------------------------------------------------------------------------- /arena_elo/results/latest/elo_results_image_editing.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TIGER-AI-Lab/GenAI-Arena/1792672eed6a4b146b194f9ba7316f25b95cb53b/arena_elo/results/latest/elo_results_image_editing.pkl -------------------------------------------------------------------------------- /arena_elo/results/latest/elo_results_t2i_generation.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TIGER-AI-Lab/GenAI-Arena/1792672eed6a4b146b194f9ba7316f25b95cb53b/arena_elo/results/latest/elo_results_t2i_generation.pkl -------------------------------------------------------------------------------- /arena_elo/results/latest/image_editing_leaderboard.csv: -------------------------------------------------------------------------------- 1 | key,Model,Arena Elo rating (anony),Arena Elo rating (full),License,Organization,Link 2 | Prompt2prompt,Prompt2prompt,1252.820838097007,1216.6489026518666,Apache-2.0,"Google, Tel Aviv University",https://prompt-to-prompt.github.io 3 | PNP,PNP,1175.6261555831445,1171.3279007979363,-,Weizmann Institute of Science,https://github.com/MichalGeyer/plug-and-play 4 | InstructPix2Pix,InstructPix2Pix,1155.8431458813104,1142.6827834982837,"Copyright 2023 Timothy Brooks, Aleksander Holynski, Alexei A. Efros","University of California, Berkeley",https://www.timothybrooks.com/instruct-pix2pix 5 | MagicBrush,MagicBrush,1051.428411953954,1089.4499296239383,CC-BY-4.0,"The Ohio State University, University of Waterloo",https://osu-nlp-group.github.io/MagicBrush 6 | Pix2PixZero,Pix2PixZero,955.5903260059122,929.2296611307636,MIT License,"Carnegie Mellon University, Adobe Research",https://pix2pixzero.github.io 7 | CycleDiffusion,CycleDiffusion,771.4360186105207,753.4930725653142,X11,Carnegie Mellon University,https://github.com/ChenWu98/cycle-diffusion 8 | SDEdit,SDEdit,637.2551038681513,697.1677497318974,MIT License,Stanford University,https://sde-image-editing.github.io 9 | -------------------------------------------------------------------------------- /arena_elo/results/latest/t2i_generation_leaderboard.csv: -------------------------------------------------------------------------------- 1 | key,Model,Arena Elo rating (anony),Arena Elo rating (full),License,Organization,Link 2 | PlayGroundV2,PlayGroundV2,1151.1834096302248,1150.901721636401,Playground v2 Community License,Playground,https://huggingface.co/playgroundai/playground-v2-1024px-aesthetic 3 | PixArtAlpha,PixArtAlpha,1078.3583466674136,1069.815012597113,openrail++,PixArt-alpha,https://huggingface.co/PixArt-alpha/PixArt-XL-2-1024-MS 4 | SDXL,SDXL,1027.258044463298,1035.47732509915,openrail++,Stability AI,https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0 5 | SDXLTurbo,SDXLTurbo,972.0904914158416,969.4933207298967,sai-nc-community (other),Stability AI,https://huggingface.co/stabilityai/sdxl-turbo 6 | OpenJourney,OpenJourney,921.3424873878607,906.3184453708288,creativeml-openrail-m,PromptHero,https://huggingface.co/prompthero/openjourney 7 | LCM,LCM,849.7672204353615,868.2154196730218,MIT License,Tsinghua University,https://huggingface.co/SimianLuo/LCM_Dreamshaper_v7 8 | -------------------------------------------------------------------------------- /arena_elo/simple_test.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | with open("./results/latest/elo_results.pkl",'rb') as f: 3 | data = pickle.load(f) 4 | print() 5 | df = data["anony"]["leaderboard_table_df"] 6 | # sort by rating 7 | df = df.sort_values(by=["rating"], ascending=False) 8 | print(df) 9 | 10 | print() 11 | 12 | df = data["full"]["leaderboard_table_df"] 13 | # sort by rating 14 | df = df.sort_values(by=["rating"], ascending=False) 15 | print(df) 16 | print('done') -------------------------------------------------------------------------------- /arena_elo/test.ipynb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TIGER-AI-Lab/GenAI-Arena/1792672eed6a4b146b194f9ba7316f25b95cb53b/arena_elo/test.ipynb -------------------------------------------------------------------------------- /arena_elo/update_elo_rating.sh: -------------------------------------------------------------------------------- 1 | # set LOGDIR to default if not set before 2 | if [ -z "$LOGDIR" ]; then 3 | echo "LOGDIR is not set. Please set LOGDIR to the directory where the logs will be stored. Command: export LOGDIR=/path/to/logdir" 4 | exit 1 5 | fi 6 | 7 | mkdir -p results 8 | 9 | # # for battle data 10 | python -m elo_rating.clean_battle_data --task_name "image_editing" 11 | edition_battle_cutoff_date=`cat cut_off_date.txt` && rm cut_off_date.txt && echo "Image editing battle data last updated on $battle_cutoff_date" 12 | 13 | python -m elo_rating.clean_battle_data --task_name "t2i_generation" 14 | generation_battle_cutoff_date=`cat cut_off_date.txt` && rm cut_off_date.txt && echo "T2I image generation battle data last updated on $battle_cutoff_date" 15 | 16 | mkdir -p ./results/$battle_cutoff_date 17 | 18 | python3 -m elo_rating.elo_analysis --clean-battle-file clean_battle_image_editing_$edition_battle_cutoff_date.json 19 | mv ./elo_results_$edition_battle_cutoff_date.pkl ./results/$edition_battle_cutoff_date/elo_results_image_editing.pkl 20 | 21 | python3 -m elo_rating.elo_analysis --clean-battle-file clean_battle_t2i_generation_$generation_battle_cutoff_date.json 22 | mv ./elo_results_$generation_battle_cutoff_date.pkl ./results/$generation_battle_cutoff_date/elo_results_t2i_generation.pkl 23 | 24 | # generat the leaderboard 25 | 26 | python -m elo_rating.generate_leaderboard \ 27 | --model_info_file "./edition_model_info.json" \ 28 | --elo_rating_pkl "./results/$edition_battle_cutoff_date/elo_results_image_editing.pkl" \ 29 | --output_csv "./results/$edition_battle_cutoff_date/image_editing_leaderboard.csv" 30 | 31 | python -m elo_rating.generate_leaderboard \ 32 | --model_info_file "./generation_model_info.json" \ 33 | --elo_rating_pkl "./results/$generation_battle_cutoff_date/elo_results_t2i_generation.pkl" \ 34 | --output_csv "./results/$generation_battle_cutoff_date/t2i_generation_leaderboard.csv" 35 | 36 | mkdir -p ./results/latest 37 | cp ./results/$edition_battle_cutoff_date/image_editing_leaderboard.csv ./results/latest/image_editing_leaderboard.csv 38 | cp ./results/$generation_battle_cutoff_date/t2i_generation_leaderboard.csv ./results/latest/t2i_generation_leaderboard.csv 39 | cp ./results/$edition_battle_cutoff_date/elo_results_image_editing.pkl ./results/latest/elo_results_image_editing.pkl 40 | cp ./results/$generation_battle_cutoff_date/elo_results_t2i_generation.pkl ./results/latest/elo_results_t2i_generation.pkl 41 | 42 | 43 | -------------------------------------------------------------------------------- /docker/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:12.2.0-runtime-ubuntu20.04 2 | 3 | RUN apt-get update -y && apt-get install -y python3.9 python3.9-distutils curl 4 | RUN curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py 5 | RUN python3.9 get-pip.py 6 | RUN pip3 install fschat 7 | RUN pip3 install fschat[model_worker,webui] pydantic==1.10.13 -------------------------------------------------------------------------------- /docker/docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: "3.9" 2 | 3 | services: 4 | fastchat-controller: 5 | build: 6 | context: . 7 | dockerfile: Dockerfile 8 | image: fastchat:latest 9 | ports: 10 | - "21001:21001" 11 | entrypoint: ["python3.9", "-m", "fastchat.serve.controller", "--host", "0.0.0.0", "--port", "21001"] 12 | fastchat-model-worker: 13 | build: 14 | context: . 15 | dockerfile: Dockerfile 16 | volumes: 17 | - huggingface:/root/.cache/huggingface 18 | image: fastchat:latest 19 | deploy: 20 | resources: 21 | reservations: 22 | devices: 23 | - driver: nvidia 24 | count: 1 25 | capabilities: [gpu] 26 | entrypoint: ["python3.9", "-m", "fastchat.serve.model_worker", "--model-names", "${FASTCHAT_WORKER_MODEL_NAMES:-vicuna-7b-v1.5}", "--model-path", "${FASTCHAT_WORKER_MODEL_PATH:-lmsys/vicuna-7b-v1.5}", "--worker-address", "http://fastchat-model-worker:21002", "--controller-address", "http://fastchat-controller:21001", "--host", "0.0.0.0", "--port", "21002"] 27 | fastchat-api-server: 28 | build: 29 | context: . 30 | dockerfile: Dockerfile 31 | image: fastchat:latest 32 | ports: 33 | - "8000:8000" 34 | entrypoint: ["python3.9", "-m", "fastchat.serve.openai_api_server", "--controller-address", "http://fastchat-controller:21001", "--host", "0.0.0.0", "--port", "8000"] 35 | volumes: 36 | huggingface: 37 | -------------------------------------------------------------------------------- /examples/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TIGER-AI-Lab/GenAI-Arena/1792672eed6a4b146b194f9ba7316f25b95cb53b/examples/.DS_Store -------------------------------------------------------------------------------- /examples/banana.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TIGER-AI-Lab/GenAI-Arena/1792672eed6a4b146b194f9ba7316f25b95cb53b/examples/banana.jpg -------------------------------------------------------------------------------- /examples/cat.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TIGER-AI-Lab/GenAI-Arena/1792672eed6a4b146b194f9ba7316f25b95cb53b/examples/cat.jpeg -------------------------------------------------------------------------------- /examples/city.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TIGER-AI-Lab/GenAI-Arena/1792672eed6a4b146b194f9ba7316f25b95cb53b/examples/city.jpg -------------------------------------------------------------------------------- /examples/dog.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TIGER-AI-Lab/GenAI-Arena/1792672eed6a4b146b194f9ba7316f25b95cb53b/examples/dog.jpg -------------------------------------------------------------------------------- /examples/duck.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TIGER-AI-Lab/GenAI-Arena/1792672eed6a4b146b194f9ba7316f25b95cb53b/examples/duck.jpg -------------------------------------------------------------------------------- /examples/duck_hat.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TIGER-AI-Lab/GenAI-Arena/1792672eed6a4b146b194f9ba7316f25b95cb53b/examples/duck_hat.jpg -------------------------------------------------------------------------------- /examples/fire.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TIGER-AI-Lab/GenAI-Arena/1792672eed6a4b146b194f9ba7316f25b95cb53b/examples/fire.jpg -------------------------------------------------------------------------------- /examples/mouse.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TIGER-AI-Lab/GenAI-Arena/1792672eed6a4b146b194f9ba7316f25b95cb53b/examples/mouse.jpg -------------------------------------------------------------------------------- /examples/oranges.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TIGER-AI-Lab/GenAI-Arena/1792672eed6a4b146b194f9ba7316f25b95cb53b/examples/oranges.jpg -------------------------------------------------------------------------------- /examples/pig.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TIGER-AI-Lab/GenAI-Arena/1792672eed6a4b146b194f9ba7316f25b95cb53b/examples/pig.jpg -------------------------------------------------------------------------------- /examples/rabbit.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TIGER-AI-Lab/GenAI-Arena/1792672eed6a4b146b194f9ba7316f25b95cb53b/examples/rabbit.jpg -------------------------------------------------------------------------------- /examples/strawberries.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TIGER-AI-Lab/GenAI-Arena/1792672eed6a4b146b194f9ba7316f25b95cb53b/examples/strawberries.jpg -------------------------------------------------------------------------------- /fastchat/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.2.34" 2 | -------------------------------------------------------------------------------- /fastchat/constants.py: -------------------------------------------------------------------------------- 1 | """ 2 | Global constants. 3 | """ 4 | 5 | from enum import IntEnum 6 | import os 7 | 8 | REPO_PATH = os.path.dirname(os.path.dirname(__file__)) 9 | 10 | ##### For the gradio web server 11 | SERVER_ERROR_MSG = ( 12 | "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**" 13 | ) 14 | MODERATION_MSG = "$MODERATION$ YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES." 15 | CONVERSATION_LIMIT_MSG = "YOU HAVE REACHED THE CONVERSATION LENGTH LIMIT. PLEASE CLEAR HISTORY AND START A NEW CONVERSATION." 16 | INACTIVE_MSG = "THIS SESSION HAS BEEN INACTIVE FOR TOO LONG. PLEASE REFRESH THIS PAGE." 17 | SLOW_MODEL_MSG = "⚠️ Both models will show the responses all at once. Please stay patient as it may take over 30 seconds." 18 | # Maximum input length 19 | INPUT_CHAR_LEN_LIMIT = int(os.getenv("FASTCHAT_INPUT_CHAR_LEN_LIMIT", 12000)) 20 | # Maximum conversation turns 21 | CONVERSATION_TURN_LIMIT = 50 22 | # Session expiration time 23 | SESSION_EXPIRATION_TIME = 3600 24 | # The output dir of log files 25 | LOGDIR = os.getenv("LOGDIR", "./vote_log/") 26 | # CPU Instruction Set Architecture 27 | CPU_ISA = os.getenv("CPU_ISA") 28 | 29 | 30 | ##### For the controller and workers (could be overwritten through ENV variables.) 31 | CONTROLLER_HEART_BEAT_EXPIRATION = int( 32 | os.getenv("FASTCHAT_CONTROLLER_HEART_BEAT_EXPIRATION", 90) 33 | ) 34 | WORKER_HEART_BEAT_INTERVAL = int(os.getenv("FASTCHAT_WORKER_HEART_BEAT_INTERVAL", 45)) 35 | WORKER_API_TIMEOUT = int(os.getenv("FASTCHAT_WORKER_API_TIMEOUT", 100)) 36 | WORKER_API_EMBEDDING_BATCH_SIZE = int( 37 | os.getenv("FASTCHAT_WORKER_API_EMBEDDING_BATCH_SIZE", 4) 38 | ) 39 | 40 | 41 | class ErrorCode(IntEnum): 42 | """ 43 | https://platform.openai.com/docs/guides/error-codes/api-errors 44 | """ 45 | 46 | VALIDATION_TYPE_ERROR = 40001 47 | 48 | INVALID_AUTH_KEY = 40101 49 | INCORRECT_AUTH_KEY = 40102 50 | NO_PERMISSION = 40103 51 | 52 | INVALID_MODEL = 40301 53 | PARAM_OUT_OF_RANGE = 40302 54 | CONTEXT_OVERFLOW = 40303 55 | 56 | RATE_LIMIT = 42901 57 | QUOTA_EXCEEDED = 42902 58 | ENGINE_OVERLOADED = 42903 59 | 60 | INTERNAL_ERROR = 50001 61 | CUDA_OUT_OF_MEMORY = 50002 62 | GRADIO_REQUEST_ERROR = 50003 63 | GRADIO_STREAM_UNKNOWN_ERROR = 50004 64 | CONTROLLER_NO_WORKER = 50005 65 | CONTROLLER_WORKER_TIMEOUT = 50006 66 | -------------------------------------------------------------------------------- /fastchat/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TIGER-AI-Lab/GenAI-Arena/1792672eed6a4b146b194f9ba7316f25b95cb53b/fastchat/data/__init__.py -------------------------------------------------------------------------------- /fastchat/data/convert_alpaca.py: -------------------------------------------------------------------------------- 1 | """ 2 | Convert alpaca dataset into sharegpt format. 3 | 4 | Usage: python3 -m fastchat.data.convert_alpaca --in alpaca_data.json 5 | """ 6 | 7 | import argparse 8 | import json 9 | 10 | from transformers import AutoTokenizer, AutoModelForCausalLM 11 | import numpy as np 12 | 13 | 14 | if __name__ == "__main__": 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument("--in-file", type=str) 17 | parser.add_argument("--out-file", type=str) 18 | args = parser.parse_args() 19 | 20 | content = json.load(open(args.in_file, "r")) 21 | new_content = [] 22 | for i, c in enumerate(content): 23 | if len(c["input"].strip()) > 1: 24 | q, a = c["instruction"] + "\nInput:\n" + c["input"], c["output"] 25 | else: 26 | q, a = c["instruction"], c["output"] 27 | new_content.append( 28 | { 29 | "id": f"alpaca_{i}", 30 | "conversations": [ 31 | {"from": "human", "value": q}, 32 | {"from": "gpt", "value": a}, 33 | ], 34 | } 35 | ) 36 | 37 | print(f"#out: {len(new_content)}") 38 | json.dump(new_content, open(args.out_file, "w"), indent=2, ensure_ascii=False) 39 | -------------------------------------------------------------------------------- /fastchat/data/extract_gpt4_only.py: -------------------------------------------------------------------------------- 1 | """ 2 | Extract the conversations generated by GPT-4 only. 3 | 4 | Usage: python3 -m fastchat.data.extract_gpt4_only --in sharegpt.json 5 | """ 6 | import argparse 7 | import json 8 | 9 | 10 | if __name__ == "__main__": 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument("--in-file", type=str, required=True) 13 | parser.add_argument("--out-file", type=str) 14 | parser.add_argument("--begin", type=int) 15 | parser.add_argument("--end", type=int) 16 | args = parser.parse_args() 17 | 18 | content = json.load(open(args.in_file, "r")) 19 | content = content[args.begin : args.end] 20 | new_content = [] 21 | for c in content: 22 | model = c.get("model", None) 23 | if model == "gpt4" or model is None: 24 | new_content.append(c) 25 | 26 | if args.out_file: 27 | out_file = args.out_file 28 | else: 29 | out_file = args.in_file.replace(".json", "_gpt4.json") 30 | 31 | print(f"#in: {len(content)}, #out: {len(new_content)}") 32 | json.dump(new_content, open(out_file, "w"), indent=2, ensure_ascii=False) 33 | -------------------------------------------------------------------------------- /fastchat/data/extract_single_round.py: -------------------------------------------------------------------------------- 1 | """ 2 | Extract the first round of the conversations. 3 | 4 | Usage: python3 -m fastchat.data.extract_single_round --in sharegpt.json 5 | """ 6 | import argparse 7 | import json 8 | 9 | 10 | if __name__ == "__main__": 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument("--in-file", type=str, required=True) 13 | parser.add_argument("--out-file", type=str) 14 | parser.add_argument("--begin", type=int) 15 | parser.add_argument("--end", type=int) 16 | args = parser.parse_args() 17 | 18 | content = json.load(open(args.in_file, "r")) 19 | content = content[args.begin : args.end] 20 | for c in content: 21 | c["conversations"] = c["conversations"][:2] 22 | 23 | if args.out_file: 24 | out_file = args.out_file 25 | else: 26 | out_file = args.in_file.replace(".json", "_single.json") 27 | 28 | print(f"#in: {len(content)}, #out: {len(content)}") 29 | json.dump(content, open(out_file, "w"), indent=2, ensure_ascii=False) 30 | -------------------------------------------------------------------------------- /fastchat/data/filter_wrong_format.py: -------------------------------------------------------------------------------- 1 | """ 2 | Filter conversations with wrong formats. 3 | 4 | Usage: 5 | python3 -m fastchat.data.filter_wrong_format --in input.json --out output.json 6 | 7 | """ 8 | import argparse 9 | import json 10 | import re 11 | 12 | from tqdm import tqdm 13 | 14 | wrong_indices_pattern = re.compile("\n1\. [^2]*\n1\. ") 15 | 16 | 17 | def should_skip(conv): 18 | # Filter wrong list indices like https://sharegpt.com/c/1pREAGO 19 | for sentence in conv["conversations"]: 20 | val = sentence["value"] 21 | sub = re.search(wrong_indices_pattern, val) 22 | if sub is not None: 23 | return True 24 | 25 | return False 26 | 27 | 28 | if __name__ == "__main__": 29 | parser = argparse.ArgumentParser() 30 | parser.add_argument("--in-file", type=str, required=True) 31 | parser.add_argument("--out-file", type=str, required=True) 32 | args = parser.parse_args() 33 | 34 | content = json.load(open(args.in_file, "r")) 35 | 36 | new_content = [] 37 | for conv in tqdm(content): 38 | if should_skip(conv): 39 | print(f"{conv['id']} contains a wrong format.") 40 | else: 41 | new_content.append(conv) 42 | 43 | print(f"#in: {len(content)}, #out: {len(new_content)}") 44 | json.dump(new_content, open(args.out_file, "w"), indent=2, ensure_ascii=False) 45 | -------------------------------------------------------------------------------- /fastchat/data/get_stats.py: -------------------------------------------------------------------------------- 1 | """ 2 | Get stats of a dataset. 3 | 4 | Usage: python3 -m fastchat.data.get_stats --in sharegpt.json 5 | """ 6 | 7 | import argparse 8 | from concurrent.futures import ProcessPoolExecutor 9 | import json 10 | 11 | import numpy as np 12 | from tqdm import tqdm 13 | from transformers import AutoTokenizer, AutoModelForCausalLM 14 | 15 | K = 1e3 16 | M = 1e6 17 | 18 | 19 | def tokenize_one_sample(c): 20 | for i in range(len(c["conversations"])): 21 | v = c["conversations"][i]["value"] 22 | c["conversations"][i]["value"] = tokenizer.tokenize(v) 23 | return c 24 | 25 | 26 | def tokenize_dataset(content): 27 | processed = [] 28 | with ProcessPoolExecutor() as executor: 29 | for result in tqdm( 30 | executor.map(tokenize_one_sample, content), total=len(content) 31 | ): 32 | processed.append(result) 33 | 34 | return processed 35 | 36 | 37 | def compute_stats(content): 38 | sample_lens = [] 39 | sample_turns = [] 40 | prompt_lens = [] 41 | res_lens = [] 42 | 43 | for c in content: 44 | sample_len = 0 45 | sample_turns.append(len(c["conversations"]) // 2) 46 | for i in range(len(c["conversations"]) // 2): 47 | p = c["conversations"][i * 2]["value"] 48 | r = c["conversations"][i * 2 + 1]["value"] 49 | 50 | turn_len = len(p) + len(r) 51 | sample_len += turn_len 52 | prompt_lens.append(len(p)) 53 | res_lens.append(len(r)) 54 | sample_lens.append(sample_len) 55 | 56 | return sample_lens, sample_turns, prompt_lens, res_lens 57 | 58 | 59 | if __name__ == "__main__": 60 | parser = argparse.ArgumentParser() 61 | parser.add_argument("--in-file", type=str) 62 | parser.add_argument( 63 | "--model-name-or-path", type=str, default="meta-llama/Llama-2-7b-chat-hf" 64 | ) 65 | args = parser.parse_args() 66 | 67 | content = json.load(open(args.in_file, "r")) 68 | tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=False) 69 | content = tokenize_dataset(content) 70 | 71 | sample_lens, sample_turns, prompt_lens, res_lens = compute_stats(content) 72 | print(f"#sequence: {len(content)/K:.2f} K") 73 | print(f"#tokens: {np.sum(sample_lens)/M:.2f} M") 74 | print(f"avg. turns: {np.mean(sample_turns):.2f}") 75 | print(f"avg. prompt length: {np.mean(prompt_lens):.2f}") 76 | print(f"avg. response length: {np.mean(res_lens):.2f}") 77 | 78 | print("\n- Histogram -") 79 | bin_edges = [0, 1024, 2048, 4096, 8192, 16384, 32768] 80 | hist = np.histogram(sample_lens, bins=bin_edges)[0] 81 | for i in range(len(hist)): 82 | print(f"L{bin_edges[i]} - {bin_edges[i+1]}: {hist[i]}") 83 | -------------------------------------------------------------------------------- /fastchat/data/inspect_data.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 -m fastchat.data.inspect_data --in sharegpt_20230322_clean_lang_split.json 4 | """ 5 | import argparse 6 | import json 7 | import random 8 | 9 | 10 | if __name__ == "__main__": 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument("--in-file", type=str, required=True) 13 | parser.add_argument("--begin", type=int) 14 | parser.add_argument("--random-n", type=int) 15 | args = parser.parse_args() 16 | 17 | content = json.load(open(args.in_file, "r")) 18 | 19 | if args.random_n: 20 | indices = [random.randint(0, len(content) - 1) for _ in range(args.random_n)] 21 | elif args.begin: 22 | indices = range(args.begin, len(content)) 23 | else: 24 | indices = range(0, len(content)) 25 | 26 | for idx in indices: 27 | sample = content[idx] 28 | print("=" * 40) 29 | print(f"no: {idx}, id: {sample['id']}") 30 | for conv in sample["conversations"]: 31 | print(conv["from"] + ": ") 32 | print(conv["value"]) 33 | input() 34 | -------------------------------------------------------------------------------- /fastchat/data/merge.py: -------------------------------------------------------------------------------- 1 | """ 2 | Merge two conversation files into one 3 | 4 | Usage: python3 -m fastchat.data.merge --in file1.json file2.json --out merged.json 5 | """ 6 | 7 | import argparse 8 | import json 9 | 10 | 11 | if __name__ == "__main__": 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument("--in-file", type=str, required=True, nargs="+") 14 | parser.add_argument("--out-file", type=str, default="merged.json") 15 | args = parser.parse_args() 16 | 17 | new_content = [] 18 | for in_file in args.in_file: 19 | content = json.load(open(in_file, "r")) 20 | new_content.extend(content) 21 | 22 | print(f"#out: {len(new_content)}") 23 | json.dump(new_content, open(args.out_file, "w"), indent=2, ensure_ascii=False) 24 | -------------------------------------------------------------------------------- /fastchat/data/optional_clean.py: -------------------------------------------------------------------------------- 1 | """ 2 | Do optional cleaning (e.g., remove some languages). 3 | 4 | Usage: 5 | python3 -m fastchat.data.optional_clean --in input.json --out output.json --keep-lang en 6 | python3 -m fastchat.data.optional_clean --in input.json --out output.json --skip-lang en 7 | 8 | Requirement: 9 | pip3 install polyglot pyicu pycld2 10 | """ 11 | import argparse 12 | import json 13 | import re 14 | 15 | import polyglot 16 | from polyglot.detect import Detector 17 | import pycld2 18 | from tqdm import tqdm 19 | 20 | 21 | def skip(conv, args): 22 | # Remove certain languages 23 | if args.keep_lang != "all" or args.skip_lang is not None: 24 | text = "\n".join([x["value"] for x in conv["conversations"]]) 25 | try: 26 | lang_code = Detector(text).language.code 27 | except (pycld2.error, polyglot.detect.base.UnknownLanguage): 28 | lang_code = "unknown" 29 | 30 | if args.keep_lang != "all" and lang_code != args.keep_lang: 31 | return True 32 | 33 | if lang_code == args.skip_lang: 34 | return True 35 | 36 | # Remove repetitive numbers 37 | if args.reduce_rep: 38 | for sentence in conv["conversations"]: 39 | val = sentence["value"] 40 | sub = re.search(r"(\d)\1{8}", val) 41 | if sub is not None: 42 | return True 43 | 44 | return False 45 | 46 | 47 | if __name__ == "__main__": 48 | parser = argparse.ArgumentParser() 49 | parser.add_argument("--in-file", type=str, required=True) 50 | parser.add_argument("--out-file", type=str) 51 | parser.add_argument( 52 | "--keep-lang", 53 | type=str, 54 | default="all", 55 | choices=["all", "en"], 56 | help="Only keep certain langauges.", 57 | ) 58 | parser.add_argument("--skip-lang", type=str, help="Skip a specific language.") 59 | # NOTE: Be careful about reduce_rep which may remove some good data. 60 | # For example, addresses could have long consecutive 0's 61 | parser.add_argument("--reduce-rep", action="store_true") 62 | args = parser.parse_args() 63 | 64 | in_file = args.in_file 65 | out_file = args.out_file 66 | keep_lang = args.keep_lang 67 | skip_lang = args.skip_lang 68 | reduce_rep = args.reduce_rep 69 | assert keep_lang == "all" or skip_lang is None 70 | 71 | if out_file is None: 72 | out_file = "sharegpt_clean" 73 | if keep_lang != "all": 74 | out_file += "_" + keep_lang 75 | if skip_lang is not None: 76 | out_file += "_skip_" + skip_lang 77 | if reduce_rep: 78 | out_file += "_reduce_rep" 79 | out_file += ".json" 80 | 81 | content = json.load(open(in_file, "r")) 82 | num_conv = len(content) 83 | 84 | new_content = [] 85 | for conv in tqdm(content): 86 | if not skip(conv, args): 87 | new_content.append(conv) 88 | 89 | print(f"#in: {len(content)}, #out: {len(new_content)}") 90 | json.dump(new_content, open(out_file, "w"), indent=2, ensure_ascii=False) 91 | -------------------------------------------------------------------------------- /fastchat/data/optional_replace.py: -------------------------------------------------------------------------------- 1 | """ 2 | Do optional replace of bos/eos/pad/unk. 3 | 4 | Usage: 5 | python3 -m fastchat.data.optional_replace --in input.json --out output.json --model-name-or-path 6 | 7 | Requirement: 8 | pip3 install transformers tqdm 9 | """ 10 | import argparse 11 | import json 12 | import traceback 13 | 14 | import transformers 15 | from tqdm import tqdm 16 | 17 | 18 | def replace_special_tokens( 19 | tokenizer: transformers.PreTrainedTokenizer, text: str 20 | ) -> str: 21 | if not text: 22 | return text 23 | 24 | def _insert_vline(token: str) -> str: 25 | if len(token) < 2: 26 | return " " 27 | elif len(token) == 2: 28 | return f"{token[0]}|{token[1]}" 29 | else: 30 | return f"{token[:1]}|{token[1:-1]}|{token[-1:]}" 31 | 32 | if tokenizer.bos_token: 33 | text = text.replace(tokenizer.bos_token, _insert_vline(tokenizer.bos_token)) 34 | if tokenizer.eos_token: 35 | text = text.replace(tokenizer.eos_token, _insert_vline(tokenizer.eos_token)) 36 | if tokenizer.pad_token: 37 | text = text.replace(tokenizer.pad_token, _insert_vline(tokenizer.pad_token)) 38 | if tokenizer.unk_token: 39 | text = text.replace(tokenizer.unk_token, _insert_vline(tokenizer.unk_token)) 40 | return text 41 | 42 | 43 | def replace(conv, tokenizer): 44 | # Replace bos/eos/pad/unk tokens 45 | if tokenizer: 46 | try: 47 | for sentence in conv["conversations"]: 48 | sentence["value"] = replace_special_tokens(tokenizer, sentence["value"]) 49 | except Exception as e: 50 | traceback.print_exc() 51 | 52 | 53 | if __name__ == "__main__": 54 | parser = argparse.ArgumentParser() 55 | parser.add_argument("--in-file", type=str, required=True) 56 | parser.add_argument("--out-file", type=str) 57 | parser.add_argument( 58 | "--model-name-or-path", 59 | type=str, 60 | help="The directory or address where the model token is stored.", 61 | ) 62 | args = parser.parse_args() 63 | 64 | in_file = args.in_file 65 | out_file = args.out_file 66 | tokenizer = None 67 | if args.model_name_or_path: 68 | tokenizer = transformers.AutoTokenizer.from_pretrained( 69 | args.model_name_or_path, 70 | trust_remote_code=True, 71 | use_fast=False, 72 | ) 73 | 74 | if out_file is None: 75 | out_file = f"{in_file}_replace.json" 76 | 77 | content = json.load(open(in_file, "r")) 78 | 79 | for conv in tqdm(content): 80 | replace(conv, tokenizer) 81 | 82 | json.dump(content, open(out_file, "w"), indent=2, ensure_ascii=False) 83 | -------------------------------------------------------------------------------- /fastchat/data/prepare_all.py: -------------------------------------------------------------------------------- 1 | """Prepare all datasets.""" 2 | 3 | import argparse 4 | import os 5 | 6 | from fastchat.utils import run_cmd 7 | 8 | 9 | if __name__ == "__main__": 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument("--prefix", type=str, default="~/datasets/sharegpt_20230521") 12 | parser.add_argument( 13 | "--model-name-or-path", type=str, default="meta-llama/Llama-2-7b-chat-hf" 14 | ) 15 | parser.add_argument("--seq-len", type=int, default=4096) 16 | args = parser.parse_args() 17 | 18 | in_prefix = args.prefix 19 | model_path = args.model_name_or_path 20 | seq_len = args.seq_len 21 | prefix = ( 22 | f"{in_prefix}_{seq_len}".replace("4096", "4k") 23 | .replace("8192", "8k") 24 | .replace("16384", "16k") 25 | ) 26 | 27 | cmd_list = [ 28 | f"python3 -m fastchat.data.clean_sharegpt --in {in_prefix}_html.json --out {prefix}_clean.json", 29 | f"python3 -m fastchat.data.optional_clean --in {prefix}_clean.json --out {prefix}_clean_lang.json --skip-lang ko", 30 | f"python3 -m fastchat.data.split_long_conversation --in {prefix}_clean_lang.json --out {prefix}_clean_lang_split.json --model-name {model_path} --max-length {seq_len}", 31 | f"python3 -m fastchat.data.filter_wrong_format --in {prefix}_clean_lang_split.json --out {prefix}_clean_lang_split.json", 32 | f"python3 -m fastchat.data.split_train_test --in {prefix}_clean_lang_split.json --ratio 0.99", 33 | f"python3 -m fastchat.data.hardcoded_questions", 34 | f"python3 -m fastchat.data.merge --in {prefix}_clean_lang_split_train.json hardcoded.json --out {prefix}_clean_lang_split_identity.json", 35 | f"python3 -m fastchat.data.extract_gpt4_only --in {prefix}_clean_lang_split_identity.json", 36 | f"python3 -m fastchat.data.extract_single_round --in {prefix}_clean_lang_split_identity.json", 37 | ] 38 | 39 | for cmd in cmd_list: 40 | ret = run_cmd(cmd) 41 | if ret != 0: 42 | exit(ret) 43 | -------------------------------------------------------------------------------- /fastchat/data/pretty_json.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 pretty_json.py --in in.json --out out.json 4 | """ 5 | 6 | import argparse 7 | import json 8 | 9 | 10 | if __name__ == "__main__": 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument("--in-file", type=str, required=True) 13 | parser.add_argument("--out-file", type=str, required=True) 14 | args = parser.parse_args() 15 | 16 | with open(args.in_file, "r") as fin: 17 | data = json.load(fin) 18 | 19 | with open(args.out_file, "w") as fout: 20 | json.dump(data, fout, indent=2, ensure_ascii=False) 21 | -------------------------------------------------------------------------------- /fastchat/data/sample.py: -------------------------------------------------------------------------------- 1 | """ 2 | Sample some conversations from a file. 3 | 4 | Usage: python3 -m fastchat.data.sample --in sharegpt.json --out sampled.json 5 | """ 6 | import argparse 7 | import json 8 | 9 | import numpy as np 10 | 11 | 12 | if __name__ == "__main__": 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument("--in-file", type=str, required=True) 15 | parser.add_argument("--out-file", type=str, default="sampled.json") 16 | parser.add_argument("--begin", type=int, default=0) 17 | parser.add_argument("--end", type=int, default=100) 18 | parser.add_argument("--max-length", type=int, default=1024) 19 | parser.add_argument("--keep-order", action="store_true") 20 | args = parser.parse_args() 21 | 22 | content = json.load(open(args.in_file, "r")) 23 | if not args.keep_order: 24 | np.random.seed(42) 25 | np.random.shuffle(content) 26 | 27 | new_content = [] 28 | for i in range(args.begin, min(args.end, len(content))): 29 | sample = content[i] 30 | concat = "" 31 | for s in sample["conversations"]: 32 | concat += s["value"] 33 | 34 | if len(concat) > args.max_length: 35 | continue 36 | 37 | new_content.append(sample) 38 | 39 | print(f"#in: {len(content)}, #out: {len(new_content)}") 40 | json.dump(new_content, open(args.out_file, "w"), indent=2, ensure_ascii=False) 41 | -------------------------------------------------------------------------------- /fastchat/data/split_long_conversation.py: -------------------------------------------------------------------------------- 1 | """ 2 | Split long conversations based on certain max length. 3 | 4 | Usage: python3 -m fastchat.data.split_long_conversation \ 5 | --in sharegpt_clean.json \ 6 | --out sharegpt_split.json \ 7 | --model-name-or-path $ 8 | """ 9 | import argparse 10 | from concurrent.futures import ProcessPoolExecutor 11 | import json 12 | from typing import Dict, Sequence, Optional 13 | 14 | import transformers 15 | from tqdm import tqdm 16 | 17 | 18 | def make_sample(sample, start_idx, end_idx): 19 | assert (end_idx - start_idx) % 2 == 0 20 | return { 21 | "id": sample["id"] + "_" + str(start_idx), 22 | "model": sample.get("model", ""), 23 | "conversations": sample["conversations"][start_idx:end_idx], 24 | } 25 | 26 | 27 | tokenizer = max_length = None 28 | 29 | 30 | def split_one_sample(sample): 31 | tokenized_lens = [] 32 | conversations = sample["conversations"] 33 | conversations = conversations[: len(conversations) // 2 * 2] 34 | for c in conversations: 35 | length = len(tokenizer(c["value"]).input_ids) + 6 36 | tokenized_lens.append(length) 37 | 38 | start_idx = 0 39 | cur_len = 0 40 | 41 | if len(conversations) % 2 != 0 or len(conversations) < 2: 42 | return [] 43 | 44 | new_samples = [] 45 | for i in range(0, len(conversations), 2): 46 | tmp_len = tokenized_lens[i] + tokenized_lens[i + 1] 47 | if cur_len + tmp_len > max_length: 48 | new_samples.append(make_sample(sample, start_idx, i)) 49 | start_idx = i 50 | cur_len = 0 51 | elif i == len(conversations) - 2: 52 | new_samples.append(make_sample(sample, start_idx, i + 2)) 53 | 54 | cur_len += tmp_len 55 | 56 | return new_samples 57 | 58 | 59 | def worker(input_data): 60 | result = [] 61 | for sample in input_data: 62 | result.extend(split_one_sample(sample)) 63 | return result 64 | 65 | 66 | def split_all(content, begin, end, tokenizer_, max_length_): 67 | """ 68 | Keep the maximum round of conversations within the max token length constraint 69 | """ 70 | global tokenizer, max_length 71 | tokenizer = tokenizer_ 72 | max_length = max_length_ 73 | 74 | content = content[begin:end] 75 | new_content = [] 76 | 77 | # Split content into chunks 78 | chunks = [content[i : i + 1000] for i in range(0, len(content), 1000)] 79 | with ProcessPoolExecutor() as executor: 80 | for result in tqdm(executor.map(worker, chunks), total=len(chunks)): 81 | new_content.extend(result) 82 | 83 | return new_content 84 | 85 | 86 | def filter_invalid_roles(content): 87 | new_content = [] 88 | for i, c in enumerate(content): 89 | roles = ["human", "gpt"] 90 | if len(c["conversations"]) <= 0: 91 | continue 92 | 93 | valid = True 94 | for j, s in enumerate(c["conversations"]): 95 | if s["from"] != roles[j % 2]: 96 | valid = False 97 | break 98 | 99 | if valid: 100 | new_content.append(c) 101 | 102 | return new_content 103 | 104 | 105 | def main(args): 106 | content = json.load(open(args.in_file, "r")) 107 | tokenizer = transformers.AutoTokenizer.from_pretrained( 108 | args.model_name_or_path, 109 | model_max_length=args.max_length, 110 | padding_side="right", 111 | use_fast=False, 112 | ) 113 | new_content = split_all(content, args.begin, args.end, tokenizer, args.max_length) 114 | new_content = filter_invalid_roles(new_content) 115 | 116 | print(f"#in: {len(content)}, #out: {len(new_content)}") 117 | json.dump(new_content, open(args.out_file, "w"), indent=2, ensure_ascii=False) 118 | 119 | 120 | if __name__ == "__main__": 121 | parser = argparse.ArgumentParser() 122 | parser.add_argument("--in-file", type=str, required=True) 123 | parser.add_argument("--out-file", type=str, default="sharegpt_split.json") 124 | parser.add_argument("--begin", type=int) 125 | parser.add_argument("--end", type=int) 126 | parser.add_argument("--model-name-or-path", type=str, required=True) 127 | parser.add_argument("--max-length", type=int, default=2048) 128 | args = parser.parse_args() 129 | main(args) 130 | -------------------------------------------------------------------------------- /fastchat/data/split_train_test.py: -------------------------------------------------------------------------------- 1 | """ 2 | Split the dataset into training and test set. 3 | 4 | Usage: python3 -m fastchat.data.split_train_test --in sharegpt.json 5 | """ 6 | import argparse 7 | import json 8 | 9 | import numpy as np 10 | 11 | 12 | if __name__ == "__main__": 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument("--in-file", type=str, required=True) 15 | parser.add_argument("--begin", type=int, default=0) 16 | parser.add_argument("--end", type=int, default=100) 17 | parser.add_argument("--ratio", type=float, default=0.9) 18 | args = parser.parse_args() 19 | 20 | content = json.load(open(args.in_file, "r")) 21 | np.random.seed(0) 22 | 23 | perm = np.random.permutation(len(content)) 24 | content = [content[i] for i in perm] 25 | split = int(args.ratio * len(content)) 26 | 27 | train_set = content[:split] 28 | test_set = content[split:] 29 | 30 | print(f"#train: {len(train_set)}, #test: {len(test_set)}") 31 | train_name = args.in_file.replace(".json", "_train.json") 32 | test_name = args.in_file.replace(".json", "_test.json") 33 | json.dump(train_set, open(train_name, "w"), indent=2, ensure_ascii=False) 34 | json.dump(test_set, open(test_name, "w"), indent=2, ensure_ascii=False) 35 | -------------------------------------------------------------------------------- /fastchat/llm_judge/clean_judgment.py: -------------------------------------------------------------------------------- 1 | """ 2 | Clean model judgment files. 3 | """ 4 | import argparse 5 | import json 6 | 7 | selected_models = [ 8 | "alpaca-13b", 9 | "baize-v2-13b", 10 | "chatglm-6b", 11 | "claude-instant-v1", 12 | "claude-v1", 13 | "dolly-v2-12b", 14 | "falcon-40b-instruct", 15 | "fastchat-t5-3b", 16 | "gpt-3.5-turbo", 17 | "gpt-4", 18 | "gpt4all-13b-snoozy", 19 | "guanaco-33b", 20 | "guanaco-65b", 21 | "h2ogpt-oasst-open-llama-13b", 22 | "koala-13b", 23 | "llama-13b", 24 | "mpt-30b-chat", 25 | "mpt-30b-instruct", 26 | "mpt-7b-chat", 27 | "nous-hermes-13b", 28 | "oasst-sft-4-pythia-12b", 29 | "oasst-sft-7-llama-30b", 30 | "palm-2-chat-bison-001", 31 | "rwkv-4-raven-14b", 32 | "stablelm-tuned-alpha-7b", 33 | "tulu-30b", 34 | "vicuna-13b-v1.3", 35 | "vicuna-33b-v1.3", 36 | "vicuna-7b-v1.3", 37 | "wizardlm-13b", 38 | "wizardlm-30b", 39 | ] 40 | 41 | 42 | if __name__ == "__main__": 43 | parser = argparse.ArgumentParser() 44 | parser.add_argument("--infile", type=str) 45 | args = parser.parse_args() 46 | 47 | infile = args.infile 48 | outfile = infile.replace(".jsonl", "_clean.jsonl") 49 | 50 | raw_lines = open(infile).readlines() 51 | rets = [] 52 | models = set() 53 | visited = set() 54 | for line in raw_lines: 55 | obj = json.loads(line) 56 | 57 | if "model_1" in obj: # pair 58 | model = obj["model_1"] 59 | key = ( 60 | obj["model_1"], 61 | obj["model_2"], 62 | obj["question_id"], 63 | tuple(obj["judge"]), 64 | ) 65 | else: # single 66 | model = obj["model"] 67 | key = (obj["model"], obj["question_id"], tuple(obj["judge"])) 68 | 69 | if key in visited: 70 | continue 71 | visited.add(key) 72 | 73 | if model not in selected_models: 74 | continue 75 | models.add(model) 76 | rets.append(obj) 77 | 78 | models = sorted(list(models)) 79 | missing_models = [x for x in selected_models if x not in models] 80 | print(f"in models: {models}, number: {len(models)}") 81 | print(f"missing models: {missing_models}") 82 | print(f"#in: {len(raw_lines)}, #out: {len(rets)}") 83 | rets.sort( 84 | key=lambda x: ( 85 | x["model"] if "model" in x else x["model_1"], 86 | x["question_id"], 87 | x["turn"], 88 | ) 89 | ) 90 | 91 | with open(outfile, "w") as fout: 92 | for x in rets: 93 | fout.write(json.dumps(x) + "\n") 94 | -------------------------------------------------------------------------------- /fastchat/llm_judge/compute_agreement.py: -------------------------------------------------------------------------------- 1 | """ 2 | Compute agreement among judges. 3 | 4 | Usage: 5 | python compute_agreement.py --judges gpt4-pair human --votefiles human_judgments.json gpt4_pair_judgments.json 6 | python compute_agreement.py --judges human human --votefiles human_judgments.json 7 | """ 8 | import argparse 9 | import json 10 | import os 11 | 12 | import numpy as np 13 | 14 | 15 | def get_judge_name(judge): 16 | if isinstance(judge, list) and judge[0] == "gpt-4" and judge[1].startswith("pair"): 17 | return "gpt4-pair" 18 | if judge.startswith("expert"): 19 | return "human" 20 | if judge.startswith("author"): 21 | return "author" 22 | 23 | 24 | def revert(vote): 25 | if vote == "model_a": 26 | return "model_b" 27 | elif vote == "model_b": 28 | return "model_a" 29 | return vote 30 | 31 | 32 | def get_mt_bench_votes_data(raw_votes): 33 | data = [{}, {}] 34 | 35 | for judge_votes in raw_votes: 36 | for vote in judge_votes: 37 | turn = vote["turn"] - 1 38 | if vote["model_a"] < vote["model_b"]: 39 | key = (vote["question_id"], vote["model_a"], vote["model_b"]) 40 | winner = vote["winner"] 41 | else: 42 | key = (vote["question_id"], vote["model_b"], vote["model_a"]) 43 | winner = revert(vote["winner"]) 44 | judge = get_judge_name(vote["judge"]) 45 | if key not in data[turn]: 46 | data[turn][key] = {} 47 | if judge not in data[turn][key]: 48 | data[turn][key][judge] = [] 49 | data[turn][key][judge].append(winner) 50 | 51 | return data 52 | 53 | 54 | def convertvote(vote): 55 | if "tie" in vote: 56 | return "tie" 57 | return vote 58 | 59 | 60 | def equalvote(vote1, vote2): 61 | if "tie" in vote1 and "tie" in vote2: 62 | return True 63 | return vote1 == vote2 64 | 65 | 66 | # data: Dict[qid -> List[vote]] 67 | def get_mt_bench_agreement(data, judge1, judge2, ban): 68 | if judge1.startswith("gpt4") and judge2 == "human": 69 | stats = [0, 0] 70 | for votes in data.values(): 71 | if judge1 not in votes or judge2 not in votes: 72 | continue 73 | assert len(votes[judge1]) == 1 74 | if convertvote(votes[judge1][0]) in ban: 75 | continue 76 | for v in votes[judge2]: 77 | if convertvote(v) in ban: 78 | continue 79 | stats[1] += 1 80 | stats[0] += equalvote(votes[judge1][0], v) 81 | return stats[0], stats[1] 82 | elif judge1 == "human" and judge2 == "human": 83 | stats = [0, 0] 84 | for votes in data.values(): 85 | if "human" not in votes: 86 | continue 87 | for i in range(len(votes["human"]) - 1): 88 | for j in range(i + 1, len(votes["human"])): 89 | if ( 90 | convertvote(votes["human"][i]) in ban 91 | or convertvote(votes["human"][j]) in ban 92 | ): 93 | continue 94 | stats[1] += 1 95 | stats[0] += equalvote(votes["human"][i], votes["human"][j]) 96 | return stats[0], stats[1] 97 | else: 98 | raise Exception("Unsupported judges.") 99 | 100 | 101 | def run_mt_bench_agreement(judges, votefiles): 102 | # votes[i]: List of votes 103 | votes = [] 104 | for filename in votefiles: 105 | with open(filename, "r") as f: 106 | data = json.load(f) 107 | votes.append(data) 108 | 109 | data = get_mt_bench_votes_data(votes) 110 | 111 | agree, total = get_mt_bench_agreement(data[0], judges[0], judges[1], ban=[]) 112 | print( 113 | f"turn 1 with tie. #total: {total}, #agree: {agree}, ratio: {agree/total:.2f}" 114 | ) 115 | agree, total = get_mt_bench_agreement(data[0], judges[0], judges[1], ban=["tie"]) 116 | print( 117 | f"turn 1 without tie. #total: {total}, #agree: {agree}, ratio: {agree/total:.2f}" 118 | ) 119 | agree, total = get_mt_bench_agreement(data[1], judges[0], judges[1], ban=[]) 120 | print( 121 | f"turn 2 with tie. #total: {total}, #agree: {agree}, ratio: {agree/total:.2f}" 122 | ) 123 | agree, total = get_mt_bench_agreement(data[1], judges[0], judges[1], ban=["tie"]) 124 | print( 125 | f"turn 2 without tie. #total: {total}, #agree: {agree}, ratio: {agree/total:.2f}" 126 | ) 127 | 128 | 129 | if __name__ == "__main__": 130 | parser = argparse.ArgumentParser() 131 | parser.add_argument("--judges", nargs=2, type=str, default=["gpt4-pair", "human"]) 132 | parser.add_argument( 133 | "--votefiles", 134 | nargs="+", 135 | type=str, 136 | default=["gpt4_judgments.json", "human_judgments.json"], 137 | ) 138 | args = parser.parse_args() 139 | 140 | run_mt_bench_agreement(args.judges, args.votefiles) 141 | -------------------------------------------------------------------------------- /fastchat/llm_judge/data/mt_bench/misc/radar.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TIGER-AI-Lab/GenAI-Arena/1792672eed6a4b146b194f9ba7316f25b95cb53b/fastchat/llm_judge/data/mt_bench/misc/radar.png -------------------------------------------------------------------------------- /fastchat/llm_judge/download_mt_bench_pregenerated.py: -------------------------------------------------------------------------------- 1 | """ 2 | Download the pre-generated model answers and judgments for MT-bench. 3 | """ 4 | import os 5 | 6 | from fastchat.utils import run_cmd 7 | 8 | filenames = [ 9 | "data/mt_bench/model_answer/alpaca-13b.jsonl", 10 | "data/mt_bench/model_answer/baize-v2-13b.jsonl", 11 | "data/mt_bench/model_answer/chatglm-6b.jsonl", 12 | "data/mt_bench/model_answer/claude-instant-v1.jsonl", 13 | "data/mt_bench/model_answer/claude-v1.jsonl", 14 | "data/mt_bench/model_answer/dolly-v2-12b.jsonl", 15 | "data/mt_bench/model_answer/falcon-40b-instruct.jsonl", 16 | "data/mt_bench/model_answer/fastchat-t5-3b.jsonl", 17 | "data/mt_bench/model_answer/gpt-3.5-turbo.jsonl", 18 | "data/mt_bench/model_answer/gpt-4.jsonl", 19 | "data/mt_bench/model_answer/gpt4all-13b-snoozy.jsonl", 20 | "data/mt_bench/model_answer/guanaco-33b.jsonl", 21 | "data/mt_bench/model_answer/guanaco-65b.jsonl", 22 | "data/mt_bench/model_answer/h2ogpt-oasst-open-llama-13b.jsonl", 23 | "data/mt_bench/model_answer/koala-13b.jsonl", 24 | "data/mt_bench/model_answer/llama-13b.jsonl", 25 | "data/mt_bench/model_answer/mpt-30b-chat.jsonl", 26 | "data/mt_bench/model_answer/mpt-30b-instruct.jsonl", 27 | "data/mt_bench/model_answer/mpt-7b-chat.jsonl", 28 | "data/mt_bench/model_answer/nous-hermes-13b.jsonl", 29 | "data/mt_bench/model_answer/oasst-sft-4-pythia-12b.jsonl", 30 | "data/mt_bench/model_answer/oasst-sft-7-llama-30b.jsonl", 31 | "data/mt_bench/model_answer/palm-2-chat-bison-001.jsonl", 32 | "data/mt_bench/model_answer/rwkv-4-raven-14b.jsonl", 33 | "data/mt_bench/model_answer/stablelm-tuned-alpha-7b.jsonl", 34 | "data/mt_bench/model_answer/tulu-30b.jsonl", 35 | "data/mt_bench/model_answer/vicuna-13b-v1.3.jsonl", 36 | "data/mt_bench/model_answer/vicuna-33b-v1.3.jsonl", 37 | "data/mt_bench/model_answer/vicuna-7b-v1.3.jsonl", 38 | "data/mt_bench/model_answer/wizardlm-13b.jsonl", 39 | "data/mt_bench/model_answer/wizardlm-30b.jsonl", 40 | "data/mt_bench/model_judgment/gpt-4_single.jsonl", 41 | "data/mt_bench/model_judgment/gpt-4_pair.jsonl", 42 | ] 43 | 44 | 45 | if __name__ == "__main__": 46 | prefix = "https://huggingface.co/spaces/lmsys/mt-bench/resolve/main/" 47 | 48 | for name in filenames: 49 | os.makedirs(os.path.dirname(name), exist_ok=True) 50 | ret = run_cmd(f"wget -q --show-progress -O {name} {prefix + name}") 51 | assert ret == 0 52 | -------------------------------------------------------------------------------- /fastchat/llm_judge/gen_api_answer.py: -------------------------------------------------------------------------------- 1 | """Generate answers with GPT-4 2 | 3 | Usage: 4 | python3 gen_api_answer.py --model gpt-3.5-turbo 5 | """ 6 | import argparse 7 | import json 8 | import os 9 | import time 10 | import concurrent.futures 11 | 12 | import openai 13 | import shortuuid 14 | import tqdm 15 | 16 | from fastchat.llm_judge.common import ( 17 | load_questions, 18 | temperature_config, 19 | chat_compeletion_openai, 20 | chat_compeletion_anthropic, 21 | chat_compeletion_palm, 22 | ) 23 | from fastchat.llm_judge.gen_model_answer import reorg_answer_file 24 | from fastchat.model.model_adapter import get_conversation_template, ANTHROPIC_MODEL_LIST 25 | 26 | 27 | def get_answer( 28 | question: dict, model: str, num_choices: int, max_tokens: int, answer_file: str 29 | ): 30 | assert ( 31 | args.force_temperature is not None and "required_temperature" in question.keys() 32 | ) == False 33 | if args.force_temperature is not None: 34 | temperature = args.force_temperature 35 | elif "required_temperature" in question.keys(): 36 | temperature = question["required_temperature"] 37 | elif question["category"] in temperature_config: 38 | temperature = temperature_config[question["category"]] 39 | else: 40 | temperature = 0.7 41 | 42 | choices = [] 43 | chat_state = None # for palm-2 model 44 | for i in range(num_choices): 45 | conv = get_conversation_template(model) 46 | 47 | turns = [] 48 | for j in range(len(question["turns"])): 49 | conv.append_message(conv.roles[0], question["turns"][j]) 50 | conv.append_message(conv.roles[1], None) 51 | 52 | if model in ANTHROPIC_MODEL_LIST: 53 | output = chat_compeletion_anthropic( 54 | model, conv, temperature, max_tokens 55 | ) 56 | elif model == "palm-2-chat-bison-001": 57 | chat_state, output = chat_compeletion_palm( 58 | chat_state, model, conv, temperature, max_tokens 59 | ) 60 | else: 61 | output = chat_compeletion_openai(model, conv, temperature, max_tokens) 62 | 63 | conv.update_last_message(output) 64 | turns.append(output) 65 | 66 | choices.append({"index": i, "turns": turns}) 67 | 68 | # Dump answers 69 | ans = { 70 | "question_id": question["question_id"], 71 | "answer_id": shortuuid.uuid(), 72 | "model_id": model, 73 | "choices": choices, 74 | "tstamp": time.time(), 75 | } 76 | 77 | os.makedirs(os.path.dirname(answer_file), exist_ok=True) 78 | with open(answer_file, "a") as fout: 79 | fout.write(json.dumps(ans) + "\n") 80 | 81 | 82 | if __name__ == "__main__": 83 | parser = argparse.ArgumentParser() 84 | parser.add_argument( 85 | "--bench-name", 86 | type=str, 87 | default="mt_bench", 88 | help="The name of the benchmark question set.", 89 | ) 90 | parser.add_argument("--answer-file", type=str, help="The output answer file.") 91 | parser.add_argument("--model", type=str, default="gpt-3.5-turbo") 92 | parser.add_argument( 93 | "--num-choices", 94 | type=int, 95 | default=1, 96 | help="How many completion choices to generate.", 97 | ) 98 | parser.add_argument( 99 | "--force-temperature", type=float, help="Forcibly set a sampling temperature." 100 | ) 101 | parser.add_argument( 102 | "--max-tokens", 103 | type=int, 104 | default=1024, 105 | help="The maximum number of new generated tokens.", 106 | ) 107 | parser.add_argument( 108 | "--question-begin", 109 | type=int, 110 | help="A debug option. The begin index of questions.", 111 | ) 112 | parser.add_argument( 113 | "--question-end", type=int, help="A debug option. The end index of questions." 114 | ) 115 | parser.add_argument( 116 | "--parallel", type=int, default=1, help="The number of concurrent API calls." 117 | ) 118 | parser.add_argument("--openai-api-base", type=str, default=None) 119 | args = parser.parse_args() 120 | 121 | if args.openai_api_base is not None: 122 | openai.api_base = args.openai_api_base 123 | 124 | question_file = f"data/{args.bench_name}/question.jsonl" 125 | questions = load_questions(question_file, args.question_begin, args.question_end) 126 | 127 | if args.answer_file: 128 | answer_file = args.answer_file 129 | else: 130 | answer_file = f"data/{args.bench_name}/model_answer/{args.model}.jsonl" 131 | print(f"Output to {answer_file}") 132 | 133 | with concurrent.futures.ThreadPoolExecutor(max_workers=args.parallel) as executor: 134 | futures = [] 135 | for question in questions: 136 | future = executor.submit( 137 | get_answer, 138 | question, 139 | args.model, 140 | args.num_choices, 141 | args.max_tokens, 142 | answer_file, 143 | ) 144 | futures.append(future) 145 | 146 | for future in tqdm.tqdm( 147 | concurrent.futures.as_completed(futures), total=len(futures) 148 | ): 149 | future.result() 150 | 151 | reorg_answer_file(answer_file) 152 | -------------------------------------------------------------------------------- /fastchat/llm_judge/show_result.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 show_result.py --mode [single|pairwise-baseline|pairwise-all] 4 | """ 5 | import argparse 6 | import pandas as pd 7 | 8 | 9 | def display_result_single(args): 10 | if args.input_file is None: 11 | input_file = ( 12 | f"data/{args.bench_name}/model_judgment/{args.judge_model}_single.jsonl" 13 | ) 14 | else: 15 | input_file = args.input_file 16 | 17 | print(f"Input file: {input_file}") 18 | df_all = pd.read_json(input_file, lines=True) 19 | df = df_all[["model", "score", "turn"]] 20 | df = df[df["score"] != -1] 21 | 22 | if args.model_list is not None: 23 | df = df[df["model"].isin(args.model_list)] 24 | 25 | print("\n########## First turn ##########") 26 | df_1 = df[df["turn"] == 1].groupby(["model", "turn"]).mean() 27 | print(df_1.sort_values(by="score", ascending=False)) 28 | 29 | if args.bench_name == "mt_bench": 30 | print("\n########## Second turn ##########") 31 | df_2 = df[df["turn"] == 2].groupby(["model", "turn"]).mean() 32 | print(df_2.sort_values(by="score", ascending=False)) 33 | 34 | print("\n########## Average ##########") 35 | df_3 = df[["model", "score"]].groupby(["model"]).mean() 36 | print(df_3.sort_values(by="score", ascending=False)) 37 | 38 | 39 | def display_result_pairwise(args): 40 | if args.input_file is None: 41 | input_file = ( 42 | f"data/{args.bench_name}/model_judgment/{args.judge_model}_pair.jsonl" 43 | ) 44 | else: 45 | input_file = args.input_file 46 | 47 | print(f"Input file: {input_file}") 48 | df_all = pd.read_json(input_file, lines=True) 49 | df_all = df_all[(df_all["g1_winner"] != "error") & (df_all["g2_winner"] != "error")] 50 | 51 | model_list = ( 52 | df_all["model_1"].unique().tolist() + df_all["model_2"].unique().tolist() 53 | ) 54 | model_list = list(set(model_list)) 55 | 56 | list_res = [] 57 | # traverse df row by row 58 | for index, row in df_all.iterrows(): 59 | if args.model_list is not None and row["model_1"] not in args.model_list: 60 | continue 61 | if args.baseline_model is not None: 62 | if args.baseline_model not in [row["model_1"], row["model_2"]]: 63 | continue 64 | if row["g1_winner"] == "tie" or row["g1_winner"] != row["g2_winner"]: 65 | list_res.append({"model": row["model_1"], "win": 0, "loss": 0, "tie": 1}) 66 | list_res.append({"model": row["model_2"], "win": 0, "loss": 0, "tie": 1}) 67 | else: 68 | if row["g1_winner"] == "model_1": 69 | winner = row["model_1"] 70 | loser = row["model_2"] 71 | else: 72 | winner = row["model_2"] 73 | loser = row["model_1"] 74 | list_res.append({"model": winner, "win": 1, "loss": 0, "tie": 0}) 75 | list_res.append({"model": loser, "win": 0, "loss": 1, "tie": 0}) 76 | 77 | df = pd.DataFrame(list_res) 78 | df = df.groupby(["model"]).sum() 79 | 80 | # remove baseline model 81 | if args.baseline_model is not None: 82 | df = df[df.index != args.baseline_model] 83 | # add win rate 84 | df["win_rate"] = df["win"] / (df["win"] + df["loss"] + df["tie"]) 85 | df["loss_rate"] = df["loss"] / (df["win"] + df["loss"] + df["tie"]) 86 | # each tie counts as 0.5 win + 0.5 loss 87 | df["win_rate_adjusted"] = (df["win"] + 0.5 * df["tie"]) / ( 88 | df["win"] + df["loss"] + df["tie"] 89 | ) 90 | # print(df.sort_values(by="win_rate", ascending=False)) 91 | # print(df.sort_values(by="loss_rate", ascending=True)) 92 | print(df.sort_values(by="win_rate_adjusted", ascending=False)) 93 | 94 | 95 | if __name__ == "__main__": 96 | parser = argparse.ArgumentParser() 97 | parser.add_argument("--bench-name", type=str, default="mt_bench") 98 | parser.add_argument("--input-file", type=str) 99 | parser.add_argument("--judge-model", type=str, default="gpt-4") 100 | parser.add_argument("--baseline-model", type=str, default="gpt-3.5-turbo") 101 | parser.add_argument( 102 | "--model-list", 103 | type=str, 104 | nargs="+", 105 | default=None, 106 | help="A list of models to be evaluated", 107 | ) 108 | parser.add_argument( 109 | "--mode", 110 | type=str, 111 | default="single", 112 | choices=["pairwise-baseline", "pairwise-all", "single"], 113 | help=( 114 | "Evaluation mode. " 115 | "`pairwise-baseline` runs pairwise comparision against a baseline. " 116 | "`pairwise-all` runs pairwise comparision between all pairs. " 117 | "`single` runs single answer grading." 118 | ), 119 | ) 120 | args = parser.parse_args() 121 | 122 | if args.mode == "single": 123 | display_result_func = display_result_single 124 | else: 125 | if args.mode == "pairwise-all": 126 | args.baseline_model = None 127 | display_result_func = display_result_pairwise 128 | 129 | print(f"Mode: {args.mode}") 130 | display_result_func(args) 131 | -------------------------------------------------------------------------------- /fastchat/model/__init__.py: -------------------------------------------------------------------------------- 1 | from fastchat.model.model_adapter import ( 2 | load_model, 3 | get_conversation_template, 4 | add_model_args, 5 | ) 6 | -------------------------------------------------------------------------------- /fastchat/model/apply_lora.py: -------------------------------------------------------------------------------- 1 | """ 2 | Apply the LoRA weights on top of a base model. 3 | 4 | Usage: 5 | python3 -m fastchat.model.apply_lora --base ~/model_weights/llama-7b --target ~/model_weights/baize-7b --lora project-baize/baize-lora-7B 6 | 7 | Dependency: 8 | pip3 install git+https://github.com/huggingface/peft.git@2822398fbe896f25d4dac5e468624dc5fd65a51b 9 | """ 10 | import argparse 11 | 12 | import torch 13 | from peft import PeftModel 14 | from transformers import AutoTokenizer, AutoModelForCausalLM 15 | 16 | 17 | def apply_lora(base_model_path, target_model_path, lora_path): 18 | print(f"Loading the base model from {base_model_path}") 19 | base = AutoModelForCausalLM.from_pretrained( 20 | base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True 21 | ) 22 | base_tokenizer = AutoTokenizer.from_pretrained(base_model_path, use_fast=False) 23 | 24 | print(f"Loading the LoRA adapter from {lora_path}") 25 | 26 | lora_model = PeftModel.from_pretrained( 27 | base, 28 | lora_path, 29 | # torch_dtype=torch.float16 30 | ) 31 | 32 | print("Applying the LoRA") 33 | model = lora_model.merge_and_unload() 34 | 35 | print(f"Saving the target model to {target_model_path}") 36 | model.save_pretrained(target_model_path) 37 | base_tokenizer.save_pretrained(target_model_path) 38 | 39 | 40 | if __name__ == "__main__": 41 | parser = argparse.ArgumentParser() 42 | parser.add_argument("--base-model-path", type=str, required=True) 43 | parser.add_argument("--target-model-path", type=str, required=True) 44 | parser.add_argument("--lora-path", type=str, required=True) 45 | 46 | args = parser.parse_args() 47 | 48 | apply_lora(args.base_model_path, args.target_model_path, args.lora_path) 49 | -------------------------------------------------------------------------------- /fastchat/model/convert_fp16.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 -m fastchat.model.convert_fp16 --in in-folder --out out-folder 4 | """ 5 | import argparse 6 | 7 | from transformers import AutoTokenizer, AutoModelForCausalLM 8 | import torch 9 | 10 | 11 | def convert_fp16(in_checkpoint, out_checkpoint): 12 | tokenizer = AutoTokenizer.from_pretrained(in_checkpoint, use_fast=False) 13 | model = AutoModelForCausalLM.from_pretrained( 14 | in_checkpoint, torch_dtype=torch.float16, low_cpu_mem_usage=True 15 | ) 16 | model.save_pretrained(out_checkpoint) 17 | tokenizer.save_pretrained(out_checkpoint) 18 | 19 | 20 | if __name__ == "__main__": 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument("--in-checkpoint", type=str, help="Path to the model") 23 | parser.add_argument("--out-checkpoint", type=str, help="Path to the output model") 24 | args = parser.parse_args() 25 | 26 | convert_fp16(args.in_checkpoint, args.out_checkpoint) 27 | -------------------------------------------------------------------------------- /fastchat/model/llama_condense_monkey_patch.py: -------------------------------------------------------------------------------- 1 | # Code adapted from https://huggingface.co/kaiokendev/superhot-13b-8k-no-rlhf-test/blob/main/llama_rope_scaled_monkey_patch.py 2 | 3 | from functools import partial 4 | 5 | import torch 6 | import transformers 7 | import transformers.models.llama.modeling_llama 8 | 9 | 10 | class CondenseRotaryEmbedding(torch.nn.Module): 11 | def __init__( 12 | self, dim, ratio, max_position_embeddings=2048, base=10000, device=None 13 | ): 14 | super().__init__() 15 | inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) 16 | self.register_buffer("inv_freq", inv_freq) 17 | 18 | # Build here to make `torch.jit.trace` work. 19 | self.ratio = ratio 20 | max_position_embeddings *= ratio 21 | self.max_seq_len_cached = max_position_embeddings 22 | # print(f"Monkey Patching condense ratio {ratio}") 23 | t = ( 24 | torch.arange( 25 | self.max_seq_len_cached, 26 | device=self.inv_freq.device, 27 | dtype=self.inv_freq.dtype, 28 | ) 29 | / ratio 30 | ) 31 | freqs = torch.einsum("i,j->ij", t, self.inv_freq) 32 | # Different from paper, but it uses a different permutation in order to obtain the same calculation 33 | emb = torch.cat((freqs, freqs), dim=-1) 34 | dtype = torch.get_default_dtype() 35 | self.register_buffer( 36 | "cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False 37 | ) 38 | self.register_buffer( 39 | "sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False 40 | ) 41 | 42 | def forward(self, x, seq_len=None): 43 | # x: [bs, num_attention_heads, seq_len, head_size] 44 | # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case. 45 | if seq_len > self.max_seq_len_cached: 46 | self.max_seq_len_cached = seq_len 47 | t = ( 48 | torch.arange( 49 | self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype 50 | ) 51 | / self.ratio 52 | ) 53 | freqs = torch.einsum("i,j->ij", t, self.inv_freq) 54 | # Different from paper, but it uses a different permutation in order to obtain the same calculation 55 | emb = torch.cat((freqs, freqs), dim=-1).to(x.device) 56 | self.register_buffer( 57 | "cos_cached", emb.cos()[None, None, :, :].to(x.dtype), persistent=False 58 | ) 59 | self.register_buffer( 60 | "sin_cached", emb.sin()[None, None, :, :].to(x.dtype), persistent=False 61 | ) 62 | return ( 63 | self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), 64 | self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), 65 | ) 66 | 67 | 68 | def replace_llama_with_condense(ratio): 69 | transformers.models.llama.modeling_llama.LlamaRotaryEmbedding = partial( 70 | CondenseRotaryEmbedding, ratio=ratio 71 | ) 72 | -------------------------------------------------------------------------------- /fastchat/model/make_delta.py: -------------------------------------------------------------------------------- 1 | """ 2 | Make the delta weights by subtracting base weights. 3 | 4 | Usage: 5 | python3 -m fastchat.model.make_delta --base ~/model_weights/llama-13b --target ~/model_weights/vicuna-13b --delta ~/model_weights/vicuna-13b-delta --hub-repo-id lmsys/vicuna-13b-delta-v1.1 6 | """ 7 | import argparse 8 | 9 | import torch 10 | from tqdm import tqdm 11 | from transformers import AutoTokenizer, AutoModelForCausalLM 12 | 13 | 14 | def make_delta(base_model_path, target_model_path, delta_path): 15 | print(f"Loading the base model from {base_model_path}") 16 | base = AutoModelForCausalLM.from_pretrained( 17 | base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True 18 | ) 19 | 20 | print(f"Loading the target model from {target_model_path}") 21 | target = AutoModelForCausalLM.from_pretrained( 22 | target_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True 23 | ) 24 | target_tokenizer = AutoTokenizer.from_pretrained(target_model_path, use_fast=False) 25 | 26 | print("Calculating the delta") 27 | for name, param in tqdm(target.state_dict().items(), desc="Calculating delta"): 28 | assert name in base.state_dict() 29 | param.data -= base.state_dict()[name] 30 | 31 | print(f"Saving the delta to {delta_path}") 32 | if args.hub_repo_id: 33 | kwargs = {"push_to_hub": True, "repo_id": args.hub_repo_id} 34 | else: 35 | kwargs = {} 36 | target.save_pretrained(delta_path, **kwargs) 37 | target_tokenizer.save_pretrained(delta_path, **kwargs) 38 | 39 | 40 | if __name__ == "__main__": 41 | parser = argparse.ArgumentParser() 42 | parser.add_argument("--base-model-path", type=str, required=True) 43 | parser.add_argument("--target-model-path", type=str, required=True) 44 | parser.add_argument("--delta-path", type=str, required=True) 45 | parser.add_argument("--hub-repo-id", type=str) 46 | args = parser.parse_args() 47 | 48 | make_delta(args.base_model_path, args.target_model_path, args.delta_path) 49 | -------------------------------------------------------------------------------- /fastchat/model/model_chatglm.py: -------------------------------------------------------------------------------- 1 | """ 2 | Inference code for ChatGLM. 3 | Adapted from https://huggingface.co/THUDM/chatglm-6b/blob/main/modeling_chatglm.py. 4 | """ 5 | import re 6 | 7 | import torch 8 | from transformers.generation.logits_process import LogitsProcessor 9 | 10 | 11 | class InvalidScoreLogitsProcessor(LogitsProcessor): 12 | def __call__( 13 | self, input_ids: torch.LongTensor, scores: torch.FloatTensor 14 | ) -> torch.FloatTensor: 15 | if torch.isnan(scores).any() or torch.isinf(scores).any(): 16 | scores.zero_() 17 | scores[..., 5] = 5e4 18 | return scores 19 | 20 | 21 | invalid_score_processor = InvalidScoreLogitsProcessor() 22 | 23 | 24 | def process_response(response): 25 | response = response.strip() 26 | response = response.replace("[[训练时间]]", "2023年") 27 | punkts = [ 28 | [",", ","], 29 | ["!", "!"], 30 | [":", ":"], 31 | [";", ";"], 32 | ["\?", "?"], 33 | ] 34 | for item in punkts: 35 | response = re.sub(r"([\u4e00-\u9fff])%s" % item[0], r"\1%s" % item[1], response) 36 | response = re.sub(r"%s([\u4e00-\u9fff])" % item[0], r"%s\1" % item[1], response) 37 | return response 38 | 39 | 40 | @torch.inference_mode() 41 | def generate_stream_chatglm( 42 | model, 43 | tokenizer, 44 | params, 45 | device, 46 | context_len=2048, 47 | stream_interval=2, 48 | judge_sent_end=False, 49 | ): 50 | prompt = params["prompt"] 51 | temperature = float(params.get("temperature", 1.0)) 52 | repetition_penalty = float(params.get("repetition_penalty", 1.0)) 53 | top_p = float(params.get("top_p", 1.0)) 54 | max_new_tokens = int(params.get("max_new_tokens", 256)) 55 | echo = params.get("echo", True) 56 | 57 | inputs = tokenizer([prompt], return_tensors="pt").to(model.device) 58 | input_echo_len = len(inputs["input_ids"][0]) 59 | 60 | gen_kwargs = { 61 | "max_length": max_new_tokens + input_echo_len, 62 | "do_sample": True if temperature > 1e-5 else False, 63 | "top_p": top_p, 64 | "repetition_penalty": repetition_penalty, 65 | "logits_processor": [invalid_score_processor], 66 | } 67 | if temperature > 1e-5: 68 | gen_kwargs["temperature"] = temperature 69 | 70 | total_len = 0 71 | for total_ids in model.stream_generate(**inputs, **gen_kwargs): 72 | total_ids = total_ids.tolist()[0] 73 | total_len = len(total_ids) 74 | if echo: 75 | output_ids = total_ids 76 | else: 77 | output_ids = total_ids[input_echo_len:] 78 | response = tokenizer.decode(output_ids) 79 | response = process_response(response) 80 | 81 | yield { 82 | "text": response, 83 | "usage": { 84 | "prompt_tokens": input_echo_len, 85 | "completion_tokens": total_len - input_echo_len, 86 | "total_tokens": total_len, 87 | }, 88 | "finish_reason": None, 89 | } 90 | 91 | # TODO: ChatGLM stop when it reach max length 92 | # Only last stream result contains finish_reason, we set finish_reason as stop 93 | ret = { 94 | "text": response, 95 | "usage": { 96 | "prompt_tokens": input_echo_len, 97 | "completion_tokens": total_len - input_echo_len, 98 | "total_tokens": total_len, 99 | }, 100 | "finish_reason": "stop", 101 | } 102 | yield ret 103 | -------------------------------------------------------------------------------- /fastchat/model/model_codet5p.py: -------------------------------------------------------------------------------- 1 | import gc 2 | from threading import Thread 3 | import torch 4 | import transformers 5 | from transformers import ( 6 | GenerationConfig, 7 | StoppingCriteria, 8 | StoppingCriteriaList, 9 | TextIteratorStreamer, 10 | ) 11 | 12 | 13 | @torch.inference_mode() 14 | def generate_stream_codet5p( 15 | model, 16 | tokenizer, 17 | params, 18 | device, 19 | context_len=2048, 20 | stream_interval=2, 21 | judge_sent_end=False, 22 | ): 23 | prompt = params["prompt"] 24 | temperature = float(params.get("temperature", 1.0)) 25 | repetition_penalty = float(params.get("repetition_penalty", 1.0)) 26 | top_p = float(params.get("top_p", 1.0)) 27 | top_k = int(params.get("top_k", 50)) # -1 means disable 28 | max_new_tokens = int(params.get("max_new_tokens", 1024)) 29 | stop_token_ids = params.get("stop_token_ids", None) or [] 30 | stop_token_ids.append(tokenizer.eos_token_id) 31 | 32 | decode_config = dict(skip_special_tokens=True, clean_up_tokenization_spaces=True) 33 | streamer = TextIteratorStreamer(tokenizer, **decode_config) 34 | encoding = tokenizer(prompt, return_tensors="pt").to(device) 35 | input_ids = encoding.input_ids 36 | encoding["decoder_input_ids"] = encoding["input_ids"].clone() 37 | input_echo_len = len(input_ids) 38 | 39 | generation_config = GenerationConfig( 40 | max_new_tokens=max_new_tokens, 41 | do_sample=temperature >= 1e-5, 42 | temperature=temperature, 43 | repetition_penalty=repetition_penalty, 44 | no_repeat_ngram_size=10, 45 | top_p=top_p, 46 | top_k=top_k, 47 | eos_token_id=stop_token_ids, 48 | ) 49 | 50 | class CodeBlockStopper(StoppingCriteria): 51 | def __call__( 52 | self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs 53 | ) -> bool: 54 | # Code-completion is open-end generation. 55 | # We check \n\n to stop at end of a code block. 56 | if list(input_ids[0][-2:]) == [628, 198]: 57 | return True 58 | return False 59 | 60 | gen_kwargs = dict( 61 | **encoding, 62 | streamer=streamer, 63 | generation_config=generation_config, 64 | stopping_criteria=StoppingCriteriaList([CodeBlockStopper()]), 65 | ) 66 | thread = Thread(target=model.generate, kwargs=gen_kwargs) 67 | thread.start() 68 | i = 0 69 | output = "" 70 | for new_text in streamer: 71 | i += 1 72 | output += new_text 73 | if i % stream_interval == 0 or i == max_new_tokens - 1: 74 | yield { 75 | "text": output, 76 | "usage": { 77 | "prompt_tokens": input_echo_len, 78 | "completion_tokens": i, 79 | "total_tokens": input_echo_len + i, 80 | }, 81 | "finish_reason": None, 82 | } 83 | if i >= max_new_tokens: 84 | break 85 | 86 | if i >= max_new_tokens: 87 | finish_reason = "length" 88 | else: 89 | finish_reason = "stop" 90 | 91 | yield { 92 | "text": output, 93 | "usage": { 94 | "prompt_tokens": input_echo_len, 95 | "completion_tokens": i, 96 | "total_tokens": input_echo_len + i, 97 | }, 98 | "finish_reason": finish_reason, 99 | } 100 | thread.join() 101 | 102 | # clean 103 | gc.collect() 104 | torch.cuda.empty_cache() 105 | if device == "xpu": 106 | torch.xpu.empty_cache() 107 | if device == "npu": 108 | torch.npu.empty_cache() 109 | -------------------------------------------------------------------------------- /fastchat/model/model_exllama.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import sys 3 | from typing import Dict 4 | 5 | import torch 6 | 7 | 8 | def generate_stream_exllama( 9 | model, 10 | tokenizer, 11 | params: Dict, 12 | device: str, 13 | context_len: int, 14 | stream_interval: int = 2, 15 | judge_sent_end: bool = False, 16 | ): 17 | try: 18 | from exllamav2.generator import ExLlamaV2StreamingGenerator, ExLlamaV2Sampler 19 | except ImportError as e: 20 | print(f"Error: Failed to load Exllamav2. {e}") 21 | sys.exit(-1) 22 | 23 | prompt = params["prompt"] 24 | 25 | generator = ExLlamaV2StreamingGenerator(model.model, model.cache, tokenizer) 26 | settings = ExLlamaV2Sampler.Settings() 27 | 28 | settings.temperature = float(params.get("temperature", 0.85)) 29 | settings.top_k = int(params.get("top_k", 50)) 30 | settings.top_p = float(params.get("top_p", 0.8)) 31 | settings.token_repetition_penalty = float(params.get("repetition_penalty", 1.15)) 32 | settings.disallow_tokens(generator.tokenizer, [generator.tokenizer.eos_token_id]) 33 | 34 | max_new_tokens = int(params.get("max_new_tokens", 256)) 35 | 36 | generator.set_stop_conditions(params.get("stop_token_ids", None) or []) 37 | echo = bool(params.get("echo", True)) 38 | 39 | input_ids = generator.tokenizer.encode(prompt) 40 | prompt_tokens = input_ids.shape[-1] 41 | generator.begin_stream(input_ids, settings) 42 | 43 | generated_tokens = 0 44 | if echo: 45 | output = prompt 46 | else: 47 | output = "" 48 | while True: 49 | chunk, eos, _ = generator.stream() 50 | output += chunk 51 | generated_tokens += 1 52 | if generated_tokens == max_new_tokens: 53 | finish_reason = "length" 54 | break 55 | elif eos: 56 | finish_reason = "length" 57 | break 58 | yield { 59 | "text": output, 60 | "usage": { 61 | "prompt_tokens": prompt_tokens, 62 | "completion_tokens": generated_tokens, 63 | "total_tokens": prompt_tokens + generated_tokens, 64 | }, 65 | "finish_reason": None, 66 | } 67 | 68 | yield { 69 | "text": output, 70 | "usage": { 71 | "prompt_tokens": prompt_tokens, 72 | "completion_tokens": generated_tokens, 73 | "total_tokens": prompt_tokens + generated_tokens, 74 | }, 75 | "finish_reason": finish_reason, 76 | } 77 | gc.collect() 78 | -------------------------------------------------------------------------------- /fastchat/model/model_falcon.py: -------------------------------------------------------------------------------- 1 | import gc 2 | from threading import Thread 3 | from typing import Iterable 4 | 5 | import torch 6 | import transformers 7 | from transformers import TextIteratorStreamer, GenerationConfig 8 | 9 | from fastchat.utils import is_partial_stop 10 | 11 | 12 | @torch.inference_mode() 13 | def generate_stream_falcon( 14 | model, 15 | tokenizer, 16 | params, 17 | device, 18 | context_len=2048, 19 | stream_interval=2, 20 | judge_sent_end=False, 21 | ): 22 | prompt = params["prompt"] 23 | len_prompt = len(prompt) 24 | temperature = float(params.get("temperature", 1.0)) 25 | repetition_penalty = float(params.get("repetition_penalty", 1.0)) 26 | top_p = float(params.get("top_p", 1.0)) 27 | top_k = int(params.get("top_k", 50)) # -1 means disable 28 | max_new_tokens = int(params.get("max_new_tokens", 256)) 29 | stop_str = params.get("stop", None) 30 | echo = bool(params.get("echo", True)) 31 | stop_token_ids = params.get("stop_token_ids", None) or [] 32 | stop_token_ids.append(tokenizer.eos_token_id) 33 | 34 | inputs = tokenizer(prompt, return_tensors="pt").to(model.device) 35 | input_ids = inputs["input_ids"] 36 | attention_mask = inputs["attention_mask"] 37 | 38 | max_src_len = context_len - max_new_tokens - 8 39 | 40 | input_ids = input_ids[-max_src_len:] # truncate from the left 41 | attention_mask = attention_mask[-max_src_len:] # truncate from the left 42 | input_echo_len = len(input_ids) 43 | 44 | decode_config = dict(skip_special_tokens=True, clean_up_tokenization_spaces=True) 45 | streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, **decode_config) 46 | 47 | generation_config = GenerationConfig( 48 | max_new_tokens=max_new_tokens, 49 | do_sample=temperature >= 1e-5, 50 | temperature=temperature, 51 | repetition_penalty=repetition_penalty, 52 | no_repeat_ngram_size=10, 53 | top_p=top_p, 54 | top_k=top_k, 55 | eos_token_id=stop_token_ids, 56 | ) 57 | 58 | generation_kwargs = dict( 59 | inputs=input_ids, 60 | attention_mask=attention_mask, 61 | streamer=streamer, 62 | generation_config=generation_config, 63 | ) 64 | 65 | thread = Thread(target=model.generate, kwargs=generation_kwargs) 66 | thread.start() 67 | 68 | if echo: 69 | # means keep the prompt 70 | output = prompt 71 | else: 72 | output = "" 73 | 74 | for i, new_text in enumerate(streamer): 75 | output += new_text 76 | if i % stream_interval == 0: 77 | if echo: 78 | rfind_start = len_prompt 79 | else: 80 | rfind_start = 0 81 | 82 | partially_stopped = False 83 | if stop_str: 84 | if isinstance(stop_str, str): 85 | pos = output.rfind(stop_str, rfind_start) 86 | if pos != -1: 87 | output = output[:pos] 88 | else: 89 | partially_stopped = is_partial_stop(output, stop_str) 90 | elif isinstance(stop_str, Iterable): 91 | for each_stop in stop_str: 92 | pos = output.rfind(each_stop, rfind_start) 93 | if pos != -1: 94 | output = output[:pos] 95 | break 96 | else: 97 | partially_stopped = is_partial_stop(output, each_stop) 98 | if partially_stopped: 99 | break 100 | else: 101 | raise ValueError("Invalid stop field type.") 102 | 103 | # prevent yielding partial stop sequence 104 | if not partially_stopped: 105 | yield { 106 | "text": output, 107 | "usage": { 108 | "prompt_tokens": input_echo_len, 109 | "completion_tokens": i, 110 | "total_tokens": input_echo_len + i, 111 | }, 112 | "finish_reason": None, 113 | } 114 | output = output.strip() 115 | 116 | # finish stream event, which contains finish reason 117 | if i == max_new_tokens - 1: 118 | finish_reason = "length" 119 | elif partially_stopped: 120 | finish_reason = None 121 | else: 122 | finish_reason = "stop" 123 | 124 | yield { 125 | "text": output, 126 | "usage": { 127 | "prompt_tokens": input_echo_len, 128 | "completion_tokens": i, 129 | "total_tokens": input_echo_len + i, 130 | }, 131 | "finish_reason": finish_reason, 132 | } 133 | 134 | # clean 135 | gc.collect() 136 | torch.cuda.empty_cache() 137 | if device == "xpu": 138 | torch.xpu.empty_cache() 139 | if device == "npu": 140 | torch.npu.empty_cache() 141 | -------------------------------------------------------------------------------- /fastchat/model/model_imagenhub.py: -------------------------------------------------------------------------------- 1 | import gc 2 | from threading import Thread 3 | import torch 4 | from diffusers import DDIMScheduler 5 | import transformers 6 | from transformers import ( 7 | GenerationConfig, 8 | StoppingCriteria, 9 | StoppingCriteriaList, 10 | TextIteratorStreamer, 11 | ) 12 | from fastchat.utils import build_logger 13 | 14 | logger = build_logger("diffusion_infer", 'diffusion_infer.log') 15 | 16 | @torch.inference_mode() 17 | def generate_stream_imagen( 18 | model, 19 | tokenizer, 20 | params, 21 | device, 22 | context_len=256, 23 | stream_interval=2, 24 | ): 25 | prompt = params["prompt"] 26 | encoding = tokenizer(prompt, return_tensors="pt").to(device) 27 | input_ids = encoding.input_ids 28 | # encoding["decoder_input_ids"] = encoding["input_ids"].clone() 29 | input_echo_len = len(input_ids) 30 | # 31 | # generation_config = GenerationConfig( 32 | # max_new_tokens=max_new_tokens, 33 | # do_sample=temperature >= 1e-5, 34 | # temperature=temperature, 35 | # repetition_penalty=repetition_penalty, 36 | # no_repeat_ngram_size=10, 37 | # top_p=top_p, 38 | # top_k=top_k, 39 | # eos_token_id=stop_token_ids, 40 | # ) 41 | # 42 | # class CodeBlockStopper(StoppingCriteria): 43 | # def __call__( 44 | # self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs 45 | # ) -> bool: 46 | # # Code-completion is open-end generation. 47 | # # We check \n\n to stop at end of a code block. 48 | # if list(input_ids[0][-2:]) == [628, 198]: 49 | # return True 50 | # return False 51 | 52 | # gen_kwargs = dict( 53 | # **encoding, 54 | # streamer=streamer, 55 | # generation_config=generation_config, 56 | # stopping_criteria=StoppingCriteriaList([CodeBlockStopper()]), 57 | # ) 58 | # generation_kwargs = {"prompt": prompt} 59 | # 60 | # model.pipe.scheduler = DDIMScheduler.from_config(model.pipe.scheduler.config) 61 | # thread = Thread(target=model.infer_one_image, kwargs=generation_kwargs) 62 | # thread.start() 63 | # i = 0 64 | # output = "" 65 | # for new_text in streamer: 66 | # i += 1 67 | # output += new_text 68 | # if i % stream_interval == 0 or i == max_new_tokens - 1: 69 | # yield { 70 | # "text": output, 71 | # "usage": { 72 | # "prompt_tokens": input_echo_len, 73 | # "completion_tokens": i, 74 | # "total_tokens": input_echo_len + i, 75 | # }, 76 | # "finish_reason": None, 77 | # } 78 | # if i >= max_new_tokens: 79 | # break 80 | # 81 | # if i >= max_new_tokens: 82 | # finish_reason = "length" 83 | # else: 84 | # finish_reason = "stop" 85 | logger.info(f"prompt: {prompt}") 86 | logger.info(f"model.scheduler: {model.pipe.scheduler}") 87 | logger.info(f"model.type: {type(model)}") 88 | # logger.info(f"prompt: {prompt}") 89 | output = model.infer_one_image(prompt=prompt, seed=42) 90 | 91 | yield { 92 | "text": output, 93 | "usage": { 94 | "prompt_tokens": input_echo_len, 95 | "completion_tokens": 0, 96 | "total_tokens": input_echo_len, 97 | }, 98 | "finish_reason": "stop", 99 | } 100 | # thread.join() 101 | 102 | # clean 103 | gc.collect() 104 | torch.cuda.empty_cache() 105 | if device == "xpu": 106 | torch.xpu.empty_cache() 107 | if device == "npu": 108 | torch.npu.empty_cache() 109 | 110 | 111 | 112 | -------------------------------------------------------------------------------- /fastchat/model/model_imagenhub_ie.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import numpy as np 3 | from threading import Thread 4 | import torch 5 | from diffusers import DDIMScheduler 6 | import transformers 7 | from transformers import ( 8 | GenerationConfig, 9 | StoppingCriteria, 10 | StoppingCriteriaList, 11 | TextIteratorStreamer, 12 | ) 13 | from fastchat.utils import build_logger 14 | import PIL 15 | 16 | logger = build_logger("diffusion_infer", 'diffusion_infer.log') 17 | 18 | def generate_stream_imagen_ie( 19 | model, 20 | tokenizer, 21 | params, 22 | device, 23 | context_len=256, 24 | stream_interval=2, 25 | ): 26 | prompt_source = params["prompt_source"] 27 | prompt_target = params["prompt_target"] 28 | prompt_instruct = params["prompt_instruct"] 29 | grey = np.array(params["image_source"]) 30 | image_source = PIL.Image.fromarray(grey.astype(np.uint8)).resize((256, 256)) 31 | # image_source = PIL.Image.fromarray(np.array(params["image_source"])) 32 | encoding = tokenizer(prompt_source, return_tensors="pt").to(device) 33 | input_ids = encoding.input_ids 34 | # encoding["decoder_input_ids"] = encoding["input_ids"].clone() 35 | input_echo_len = len(input_ids) 36 | 37 | logger.info(f"prompt source: {prompt_source}") 38 | logger.info(f"prompt target: {prompt_target}") 39 | logger.info(f"image source shape: {image_source.size}") 40 | logger.info(f"model.scheduler: {model.pipe.scheduler}") 41 | logger.info(f"model.type: {type(model)}") 42 | # logger.info(f"prompt: {prompt}") 43 | # logger.info(f"prompt: {prompt}") 44 | output = model.infer_one_image(src_image=image_source, src_prompt=prompt_source, target_prompt=prompt_target, 45 | instruct_prompt=prompt_instruct, seed=42) 46 | 47 | yield { 48 | "text": output, 49 | "usage": { 50 | "prompt_tokens": input_echo_len, 51 | "completion_tokens": 0, 52 | "total_tokens": input_echo_len, 53 | }, 54 | "finish_reason": "stop", 55 | } 56 | # thread.join() 57 | 58 | # clean 59 | gc.collect() 60 | torch.cuda.empty_cache() 61 | if device == "xpu": 62 | torch.xpu.empty_cache() 63 | if device == "npu": 64 | torch.npu.empty_cache() 65 | 66 | 67 | 68 | -------------------------------------------------------------------------------- /fastchat/model/model_lavie.py: -------------------------------------------------------------------------------- 1 | import gc 2 | from threading import Thread 3 | import torch 4 | from diffusers import DDIMScheduler 5 | 6 | from fastchat.utils import build_logger 7 | 8 | logger = build_logger("diffusion_infer", 'diffusion_infer.log') 9 | 10 | @torch.inference_mode() 11 | def generate_stream_lavie( 12 | model, 13 | tokenizer, 14 | params, 15 | device, 16 | ): 17 | prompt = params["prompt"] 18 | encoding = tokenizer(prompt, return_tensors="pt").to(device) 19 | input_ids = encoding.input_ids 20 | input_echo_len = len(input_ids) 21 | 22 | logger.info(f"prompt: {prompt}") 23 | logger.info(f"model.scheduler: {model.pipe.scheduler}") 24 | logger.info(f"model.type: {type(model)}") 25 | # logger.info(f"prompt: {prompt}") 26 | output = model(prompt=prompt, 27 | video_length=16, 28 | height=360, 29 | width=512, 30 | num_inference_steps=50, 31 | guidance_scale=7.5).video 32 | 33 | yield { 34 | "text": output, 35 | "usage": { 36 | "prompt_tokens": input_echo_len, 37 | "completion_tokens": 0, 38 | "total_tokens": input_echo_len, 39 | }, 40 | "finish_reason": "stop", 41 | } 42 | # thread.join() 43 | 44 | # clean 45 | gc.collect() 46 | torch.cuda.empty_cache() 47 | if device == "xpu": 48 | torch.xpu.empty_cache() 49 | if device == "npu": 50 | torch.npu.empty_cache() 51 | -------------------------------------------------------------------------------- /fastchat/model/model_playground.py: -------------------------------------------------------------------------------- 1 | import requests 2 | import io, base64, json 3 | from PIL import Image 4 | def generate_stream_playground( 5 | model, 6 | tokenizer, 7 | params, 8 | device, 9 | context_len=256, 10 | stream_interval=2, 11 | ): 12 | prompt = params["prompt"] 13 | if model == "Playground v2": 14 | model_name = "Playground_v2" 15 | elif model == "Playground v2.5": 16 | model_name = "Playground_v2.5" 17 | headers = { 18 | 'Content-Type': 'application/json', 19 | 'Authorization': 'Bearer pg_0061b0a63475918714c4be28ec9a4a861a5012b57b12b77adaee97677cb35a87', 20 | } 21 | 22 | data = json.dumps({"prompt": prompt, "filter_model": model_name, "scheduler": "DPMPP_2M_K", "guidance_scale": 3}) 23 | 24 | response = requests.post('https://playground.com/api/models/external/v1', headers=headers, data=data) 25 | json_obj = response.json() 26 | image = json_obj['images'][0] 27 | img = Image.open(io.BytesIO(base64.decodebytes(bytes(image, "utf-8")))) 28 | yield { 29 | "text": img, 30 | "usage": { 31 | "prompt_tokens": 0, 32 | "completion_tokens": 0, 33 | "total_tokens": 0, 34 | }, 35 | "finish_reason": "stop", 36 | } -------------------------------------------------------------------------------- /fastchat/model/model_stable_diffusion.py: -------------------------------------------------------------------------------- 1 | import gc 2 | from threading import Thread 3 | import torch 4 | from diffusers import DDIMScheduler 5 | import transformers 6 | from transformers import ( 7 | GenerationConfig, 8 | StoppingCriteria, 9 | StoppingCriteriaList, 10 | TextIteratorStreamer, 11 | ) 12 | from fastchat.utils import build_logger 13 | 14 | logger = build_logger("diffusion_infer", 'diffusion_infer.log') 15 | 16 | @torch.inference_mode() 17 | def generate_stream_sde( 18 | model, 19 | tokenizer, 20 | params, 21 | device, 22 | context_len=256, 23 | stream_interval=2, 24 | ): 25 | prompt = params["prompt"] 26 | # temperature = float(params.get("temperature", 1.0)) 27 | # repetition_penalty = float(params.get("repetition_penalty", 1.0)) 28 | # top_p = float(params.get("top_p", 1.0)) 29 | # top_k = int(params.get("top_k", 50)) # -1 means disable 30 | # max_new_tokens = int(params.get("max_new_tokens", 1024)) 31 | # stop_token_ids = params.get("stop_token_ids", None) or [] 32 | # stop_token_ids.append(tokenizer.eos_token_id) 33 | # 34 | # decode_config = dict(skip_special_tokens=True, clean_up_tokenization_spaces=True) 35 | # streamer = TextIteratorStreamer(tokenizer, **decode_config) 36 | encoding = tokenizer(prompt, return_tensors="pt").to(device) 37 | input_ids = encoding.input_ids 38 | # encoding["decoder_input_ids"] = encoding["input_ids"].clone() 39 | input_echo_len = len(input_ids) 40 | # 41 | # generation_config = GenerationConfig( 42 | # max_new_tokens=max_new_tokens, 43 | # do_sample=temperature >= 1e-5, 44 | # temperature=temperature, 45 | # repetition_penalty=repetition_penalty, 46 | # no_repeat_ngram_size=10, 47 | # top_p=top_p, 48 | # top_k=top_k, 49 | # eos_token_id=stop_token_ids, 50 | # ) 51 | # 52 | # class CodeBlockStopper(StoppingCriteria): 53 | # def __call__( 54 | # self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs 55 | # ) -> bool: 56 | # # Code-completion is open-end generation. 57 | # # We check \n\n to stop at end of a code block. 58 | # if list(input_ids[0][-2:]) == [628, 198]: 59 | # return True 60 | # return False 61 | 62 | # gen_kwargs = dict( 63 | # **encoding, 64 | # streamer=streamer, 65 | # generation_config=generation_config, 66 | # stopping_criteria=StoppingCriteriaList([CodeBlockStopper()]), 67 | # ) 68 | generation_kwargs = {"prompt": prompt} 69 | logger.info(f"str(type(model)): {str(type(model))}") 70 | if "StableCascade" not in str(type(model)): 71 | model.scheduler = DDIMScheduler.from_config(model.scheduler.config) 72 | logger.info(f"model.scheduler: {model.scheduler}") 73 | thread = Thread(target=model, kwargs=generation_kwargs) 74 | thread.start() 75 | # i = 0 76 | # output = "" 77 | # for new_text in streamer: 78 | # i += 1 79 | # output += new_text 80 | # if i % stream_interval == 0 or i == max_new_tokens - 1: 81 | # yield { 82 | # "text": output, 83 | # "usage": { 84 | # "prompt_tokens": input_echo_len, 85 | # "completion_tokens": i, 86 | # "total_tokens": input_echo_len + i, 87 | # }, 88 | # "finish_reason": None, 89 | # } 90 | # if i >= max_new_tokens: 91 | # break 92 | # 93 | # if i >= max_new_tokens: 94 | # finish_reason = "length" 95 | # else: 96 | # finish_reason = "stop" 97 | logger.info(f"prompt: {prompt}") 98 | output = model(prompt=prompt).images[0] 99 | 100 | yield { 101 | "text": output, 102 | "usage": { 103 | "prompt_tokens": input_echo_len, 104 | "completion_tokens": 0, 105 | "total_tokens": input_echo_len, 106 | }, 107 | "finish_reason": "stop", 108 | } 109 | thread.join() 110 | 111 | # clean 112 | gc.collect() 113 | torch.cuda.empty_cache() 114 | if device == "xpu": 115 | torch.xpu.empty_cache() 116 | if device == "npu": 117 | torch.npu.empty_cache() 118 | -------------------------------------------------------------------------------- /fastchat/model/model_xfastertransformer.py: -------------------------------------------------------------------------------- 1 | import gc 2 | from threading import Thread 3 | 4 | import torch 5 | from transformers import TextIteratorStreamer 6 | 7 | 8 | @torch.inference_mode() 9 | def generate_stream_xft( 10 | model, 11 | tokenizer, 12 | params, 13 | device, 14 | context_len=8192, 15 | stream_interval=2, 16 | judge_sent_end=False, 17 | ): 18 | prompt = params["prompt"] 19 | repetition_penalty = float(params.get("repetition_penalty", 1.0)) 20 | 21 | # unused now, and placehold for future. 22 | # temperature = float(params.get("temperature", 1.0)) 23 | # top_p = float(params.get("top_p", 1.0)) 24 | 25 | max_new_tokens = int(params.get("max_new_tokens", 4096)) 26 | echo = params.get("echo", True) 27 | 28 | inputs = tokenizer( 29 | prompt, return_tensors="pt", padding=model.config.padding 30 | ).input_ids 31 | input_echo_len = len(inputs[0]) 32 | max_len = max_new_tokens + input_echo_len 33 | 34 | decode_config = dict(skip_special_tokens=True, clean_up_tokenization_spaces=True) 35 | streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, **decode_config) 36 | generation_kwargs = { 37 | "input_ids": inputs, 38 | "streamer": streamer, 39 | "max_length": max_len, 40 | "num_beams": model.config.beam_width, 41 | "length_penalty": repetition_penalty, 42 | "num_return_sequences": model.config.num_return_sequences, 43 | "early_stopping": model.config.early_stopping, 44 | "eos_token_id": model.config.eos_token_id, 45 | "pad_token_id": model.config.pad_token_id, 46 | } 47 | 48 | thread = Thread(target=model.model.generate, kwargs=generation_kwargs) 49 | thread.start() 50 | if echo: 51 | # means keep the prompt 52 | output = prompt 53 | else: 54 | output = "" 55 | i = 0 56 | for i, new_text in enumerate(streamer): 57 | output += new_text 58 | yield { 59 | "text": output, 60 | "usage": { 61 | "prompt_tokens": input_echo_len, 62 | "completion_tokens": i, 63 | "total_tokens": input_echo_len + i, 64 | }, 65 | "finish_reason": None, 66 | } 67 | output = output.strip() 68 | if i == max_new_tokens - 1: 69 | finish_reason = "length" 70 | else: 71 | finish_reason = "stop" 72 | yield { 73 | "text": output, 74 | "usage": { 75 | "prompt_tokens": input_echo_len, 76 | "completion_tokens": i, 77 | "total_tokens": input_echo_len + i, 78 | }, 79 | "finish_reason": finish_reason, 80 | } 81 | gc.collect() 82 | -------------------------------------------------------------------------------- /fastchat/model/monkey_patch_non_inplace.py: -------------------------------------------------------------------------------- 1 | """ 2 | Monkey patch the llama implementation in the huggingface/transformers library. 3 | Avoid bugs in mps backend by not using in-place operations. 4 | """ 5 | import math 6 | from typing import List, Optional, Tuple 7 | 8 | import torch 9 | from torch import nn 10 | import transformers 11 | 12 | 13 | def rotate_half(x): 14 | """Rotates half the hidden dims of the input.""" 15 | x1 = x[..., : x.shape[-1] // 2].clone() 16 | x2 = x[..., x.shape[-1] // 2 :].clone() 17 | return torch.cat((-x2, x1), dim=-1) 18 | 19 | 20 | def apply_rotary_pos_emb(q, k, cos, sin, position_ids): 21 | gather_indices = position_ids[:, None, :, None] # [bs, 1, seq_len, 1] 22 | gather_indices = gather_indices.repeat(1, cos.shape[1], 1, cos.shape[3]) 23 | cos = torch.gather(cos.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices) 24 | sin = torch.gather(sin.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices) 25 | q_embed = (q * cos) + (rotate_half(q) * sin) 26 | k_embed = (k * cos) + (rotate_half(k) * sin) 27 | return q_embed, k_embed 28 | 29 | 30 | def forward( 31 | self, 32 | hidden_states: torch.Tensor, 33 | attention_mask: Optional[torch.Tensor] = None, 34 | position_ids: Optional[torch.LongTensor] = None, 35 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 36 | output_attentions: bool = False, 37 | use_cache: bool = False, 38 | padding_mask: Optional[torch.LongTensor] = None, 39 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 40 | bsz, q_len, _ = hidden_states.size() 41 | 42 | query_states = ( 43 | self.q_proj(hidden_states) 44 | .view(bsz, q_len, self.num_heads, self.head_dim) 45 | .transpose(1, 2) 46 | ) 47 | key_states = ( 48 | self.k_proj(hidden_states) 49 | .view(bsz, q_len, self.num_heads, self.head_dim) 50 | .transpose(1, 2) 51 | ) 52 | value_states = ( 53 | self.v_proj(hidden_states) 54 | .view(bsz, q_len, self.num_heads, self.head_dim) 55 | .transpose(1, 2) 56 | ) 57 | 58 | kv_seq_len = key_states.shape[-2] 59 | if past_key_value is not None: 60 | kv_seq_len += past_key_value[0].shape[-2] 61 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) 62 | query_states, key_states = apply_rotary_pos_emb( 63 | query_states, key_states, cos, sin, position_ids 64 | ) 65 | # [bsz, nh, t, hd] 66 | 67 | if past_key_value is not None: 68 | # reuse k, v, self_attention 69 | key_states = torch.cat([past_key_value[0], key_states], dim=2) 70 | value_states = torch.cat([past_key_value[1], value_states], dim=2) 71 | 72 | past_key_value = (key_states, value_states) if use_cache else None 73 | 74 | attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt( 75 | self.head_dim 76 | ) 77 | 78 | if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): 79 | raise ValueError( 80 | f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is" 81 | f" {attn_weights.size()}" 82 | ) 83 | 84 | if attention_mask is not None: 85 | if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): 86 | raise ValueError( 87 | f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" 88 | ) 89 | attn_weights = attn_weights + attention_mask 90 | attn_weights = torch.max( 91 | attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min) 92 | ) 93 | 94 | # upcast attention to fp32 95 | attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to( 96 | query_states.dtype 97 | ) 98 | attn_output = torch.matmul(attn_weights, value_states) 99 | 100 | if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): 101 | raise ValueError( 102 | f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" 103 | f" {attn_output.size()}" 104 | ) 105 | 106 | attn_output = attn_output.transpose(1, 2) 107 | attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) 108 | 109 | attn_output = self.o_proj(attn_output) 110 | 111 | if not output_attentions: 112 | attn_weights = None 113 | 114 | return attn_output, attn_weights, past_key_value 115 | 116 | 117 | def replace_llama_attn_with_non_inplace_operations(): 118 | """Avoid bugs in mps backend by not using in-place operations.""" 119 | transformers.models.llama.modeling_llama.LlamaAttention.forward = forward 120 | -------------------------------------------------------------------------------- /fastchat/model/rwkv_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | from types import SimpleNamespace 3 | import warnings 4 | 5 | import torch 6 | 7 | os.environ["RWKV_JIT_ON"] = "1" 8 | os.environ["RWKV_CUDA_ON"] = "1" 9 | 10 | from rwkv.model import RWKV 11 | from rwkv.utils import PIPELINE, PIPELINE_ARGS 12 | 13 | 14 | class RwkvModel: 15 | def __init__(self, model_path): 16 | warnings.warn( 17 | "Experimental support. Please use ChatRWKV if you want to chat with RWKV" 18 | ) 19 | self.config = SimpleNamespace(is_encoder_decoder=False) 20 | self.model = RWKV(model=model_path, strategy="cuda fp16") 21 | # two GPUs 22 | # self.model = RWKV(model=model_path, strategy="cuda:0 fp16 *20 -> cuda:1 fp16") 23 | 24 | self.tokenizer = None 25 | self.model_path = model_path 26 | 27 | def to(self, target): 28 | assert target == "cuda" 29 | 30 | def __call__(self, input_ids, use_cache, past_key_values=None): 31 | assert use_cache == True 32 | input_ids = input_ids[0].detach().cpu().numpy() 33 | # print(input_ids) 34 | logits, state = self.model.forward(input_ids, past_key_values) 35 | # print(logits) 36 | logits = logits.unsqueeze(0).unsqueeze(0) 37 | out = SimpleNamespace(logits=logits, past_key_values=state) 38 | return out 39 | 40 | def generate( 41 | self, input_ids, do_sample, temperature, max_new_tokens, repetition_penalty=1.0 42 | ): 43 | # This function is used by fastchat.llm_judge. 44 | # Because RWKV does not support huggingface generation API, 45 | # we reuse fastchat.serve.inference.generate_stream as a workaround. 46 | from transformers import AutoTokenizer 47 | 48 | from fastchat.serve.inference import generate_stream 49 | from fastchat.conversation import get_conv_template 50 | 51 | if self.tokenizer is None: 52 | self.tokenizer = AutoTokenizer.from_pretrained( 53 | "EleutherAI/pythia-160m", use_fast=True 54 | ) 55 | prompt = self.tokenizer.decode(input_ids[0].tolist()) 56 | conv = get_conv_template("rwkv") 57 | 58 | gen_params = { 59 | "model": self.model_path, 60 | "prompt": prompt, 61 | "temperature": temperature, 62 | "repetition_penalty": repetition_penalty, 63 | "max_new_tokens": max_new_tokens, 64 | "stop": conv.stop_str, 65 | "stop_token_ids": conv.stop_token_ids, 66 | "echo": False, 67 | } 68 | res_iter = generate_stream(self, self.tokenizer, gen_params, "cuda") 69 | 70 | for res in res_iter: 71 | pass 72 | 73 | output = res["text"] 74 | output_ids = self.tokenizer.encode(output) 75 | 76 | return [input_ids[0].tolist() + output_ids] 77 | -------------------------------------------------------------------------------- /fastchat/model/upload_hub.py: -------------------------------------------------------------------------------- 1 | """ 2 | Upload weights to huggingface. 3 | 4 | Usage: 5 | python3 -m fastchat.model.upload_hub --model-path ~/model_weights/vicuna-13b --hub-repo-id lmsys/vicuna-13b-v1.3 6 | """ 7 | import argparse 8 | import tempfile 9 | 10 | import torch 11 | from transformers import AutoTokenizer, AutoModelForCausalLM 12 | 13 | 14 | def upload_hub(model_path, hub_repo_id, component, private): 15 | if component == "all": 16 | components = ["model", "tokenizer"] 17 | else: 18 | components = [component] 19 | 20 | kwargs = {"push_to_hub": True, "repo_id": hub_repo_id, "private": args.private} 21 | 22 | if "model" in components: 23 | model = AutoModelForCausalLM.from_pretrained( 24 | model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True 25 | ) 26 | with tempfile.TemporaryDirectory() as tmp_path: 27 | model.save_pretrained(tmp_path, **kwargs) 28 | 29 | if "tokenizer" in components: 30 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) 31 | with tempfile.TemporaryDirectory() as tmp_path: 32 | tokenizer.save_pretrained(tmp_path, **kwargs) 33 | 34 | 35 | if __name__ == "__main__": 36 | parser = argparse.ArgumentParser() 37 | parser.add_argument("--model-path", type=str, required=True) 38 | parser.add_argument("--hub-repo-id", type=str, required=True) 39 | parser.add_argument( 40 | "--component", type=str, choices=["all", "model", "tokenizer"], default="all" 41 | ) 42 | parser.add_argument("--private", action="store_true") 43 | args = parser.parse_args() 44 | 45 | upload_hub(args.model_path, args.hub_repo_id, args.component, args.private) 46 | -------------------------------------------------------------------------------- /fastchat/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TIGER-AI-Lab/GenAI-Arena/1792672eed6a4b146b194f9ba7316f25b95cb53b/fastchat/modules/__init__.py -------------------------------------------------------------------------------- /fastchat/modules/awq.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from pathlib import Path 3 | import sys 4 | 5 | import torch 6 | from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM, modeling_utils 7 | 8 | 9 | @dataclass 10 | class AWQConfig: 11 | ckpt: str = field( 12 | default=None, 13 | metadata={ 14 | "help": "Load quantized model. The path to the local AWQ checkpoint." 15 | }, 16 | ) 17 | wbits: int = field(default=16, metadata={"help": "#bits to use for quantization"}) 18 | groupsize: int = field( 19 | default=-1, 20 | metadata={"help": "Groupsize to use for quantization; default uses full row."}, 21 | ) 22 | 23 | 24 | def load_awq_quantized(model_name, awq_config: AWQConfig, device): 25 | print("Loading AWQ quantized model...") 26 | 27 | try: 28 | from tinychat.utils import load_quant 29 | from tinychat.modules import make_quant_norm, make_quant_attn, make_fused_mlp 30 | except ImportError as e: 31 | print(f"Error: Failed to import tinychat. {e}") 32 | print("Please double check if you have successfully installed AWQ") 33 | print("See https://github.com/lm-sys/FastChat/blob/main/docs/awq.md") 34 | sys.exit(-1) 35 | 36 | config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) 37 | tokenizer = AutoTokenizer.from_pretrained( 38 | model_name, use_fast=False, trust_remote_code=True 39 | ) 40 | 41 | def skip(*args, **kwargs): 42 | pass 43 | 44 | torch.nn.init.kaiming_uniform_ = skip 45 | torch.nn.init.kaiming_normal_ = skip 46 | torch.nn.init.uniform_ = skip 47 | torch.nn.init.normal_ = skip 48 | modeling_utils._init_weights = False 49 | 50 | torch.set_default_dtype(torch.half) 51 | model = AutoModelForCausalLM.from_config(config, trust_remote_code=True) 52 | 53 | if any(name in find_awq_ckpt(awq_config) for name in ["llama", "vicuna"]): 54 | model = load_quant.load_awq_llama_fast( 55 | model, 56 | find_awq_ckpt(awq_config), 57 | awq_config.wbits, 58 | awq_config.groupsize, 59 | device, 60 | ) 61 | make_quant_attn(model, device) 62 | make_quant_norm(model) 63 | make_fused_mlp(model) 64 | else: 65 | model = load_quant.load_awq_model( 66 | model, 67 | find_awq_ckpt(awq_config), 68 | awq_config.wbits, 69 | awq_config.groupsize, 70 | device, 71 | ) 72 | return model, tokenizer 73 | 74 | 75 | def find_awq_ckpt(awq_config: AWQConfig): 76 | if Path(awq_config.ckpt).is_file(): 77 | return awq_config.ckpt 78 | 79 | for ext in ["*.pt", "*.safetensors"]: 80 | matched_result = sorted(Path(awq_config.ckpt).glob(ext)) 81 | if len(matched_result) > 0: 82 | return str(matched_result[-1]) 83 | 84 | print("Error: AWQ checkpoint not found") 85 | sys.exit(1) 86 | -------------------------------------------------------------------------------- /fastchat/modules/exllama.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | import sys 3 | 4 | 5 | @dataclass 6 | class ExllamaConfig: 7 | max_seq_len: int 8 | gpu_split: str = None 9 | cache_8bit: bool = False 10 | 11 | 12 | class ExllamaModel: 13 | def __init__(self, exllama_model, exllama_cache): 14 | self.model = exllama_model 15 | self.cache = exllama_cache 16 | self.config = self.model.config 17 | 18 | 19 | def load_exllama_model(model_path, exllama_config: ExllamaConfig): 20 | try: 21 | from exllamav2 import ( 22 | ExLlamaV2Config, 23 | ExLlamaV2Tokenizer, 24 | ExLlamaV2, 25 | ExLlamaV2Cache, 26 | ExLlamaV2Cache_8bit, 27 | ) 28 | except ImportError as e: 29 | print(f"Error: Failed to load Exllamav2. {e}") 30 | sys.exit(-1) 31 | 32 | exllamav2_config = ExLlamaV2Config() 33 | exllamav2_config.model_dir = model_path 34 | exllamav2_config.prepare() 35 | exllamav2_config.max_seq_len = exllama_config.max_seq_len 36 | exllamav2_config.cache_8bit = exllama_config.cache_8bit 37 | 38 | exllama_model = ExLlamaV2(exllamav2_config) 39 | tokenizer = ExLlamaV2Tokenizer(exllamav2_config) 40 | 41 | split = None 42 | if exllama_config.gpu_split: 43 | split = [float(alloc) for alloc in exllama_config.gpu_split.split(",")] 44 | exllama_model.load(split) 45 | 46 | cache_class = ExLlamaV2Cache_8bit if exllamav2_config.cache_8bit else ExLlamaV2Cache 47 | exllama_cache = cache_class(exllama_model) 48 | model = ExllamaModel(exllama_model=exllama_model, exllama_cache=exllama_cache) 49 | 50 | return model, tokenizer 51 | -------------------------------------------------------------------------------- /fastchat/modules/gptq.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | import os 3 | from os.path import isdir, isfile 4 | from pathlib import Path 5 | import sys 6 | 7 | from transformers import AutoTokenizer 8 | 9 | 10 | @dataclass 11 | class GptqConfig: 12 | ckpt: str = field( 13 | default=None, 14 | metadata={ 15 | "help": "Load quantized model. The path to the local GPTQ checkpoint." 16 | }, 17 | ) 18 | wbits: int = field(default=16, metadata={"help": "#bits to use for quantization"}) 19 | groupsize: int = field( 20 | default=-1, 21 | metadata={"help": "Groupsize to use for quantization; default uses full row."}, 22 | ) 23 | act_order: bool = field( 24 | default=True, 25 | metadata={"help": "Whether to apply the activation order GPTQ heuristic"}, 26 | ) 27 | 28 | 29 | def load_gptq_quantized(model_name, gptq_config: GptqConfig): 30 | print("Loading GPTQ quantized model...") 31 | 32 | try: 33 | script_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 34 | module_path = os.path.join(script_path, "../repositories/GPTQ-for-LLaMa") 35 | 36 | sys.path.insert(0, module_path) 37 | from llama import load_quant 38 | except ImportError as e: 39 | print(f"Error: Failed to load GPTQ-for-LLaMa. {e}") 40 | print("See https://github.com/lm-sys/FastChat/blob/main/docs/gptq.md") 41 | sys.exit(-1) 42 | 43 | tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False) 44 | # only `fastest-inference-4bit` branch cares about `act_order` 45 | if gptq_config.act_order: 46 | model = load_quant( 47 | model_name, 48 | find_gptq_ckpt(gptq_config), 49 | gptq_config.wbits, 50 | gptq_config.groupsize, 51 | act_order=gptq_config.act_order, 52 | ) 53 | else: 54 | # other branches 55 | model = load_quant( 56 | model_name, 57 | find_gptq_ckpt(gptq_config), 58 | gptq_config.wbits, 59 | gptq_config.groupsize, 60 | ) 61 | 62 | return model, tokenizer 63 | 64 | 65 | def find_gptq_ckpt(gptq_config: GptqConfig): 66 | if Path(gptq_config.ckpt).is_file(): 67 | return gptq_config.ckpt 68 | 69 | for ext in ["*.pt", "*.safetensors"]: 70 | matched_result = sorted(Path(gptq_config.ckpt).glob(ext)) 71 | if len(matched_result) > 0: 72 | return str(matched_result[-1]) 73 | 74 | print("Error: gptq checkpoint not found") 75 | sys.exit(1) 76 | -------------------------------------------------------------------------------- /fastchat/modules/xfastertransformer.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | import sys 3 | 4 | 5 | @dataclass 6 | class XftConfig: 7 | max_seq_len: int = 4096 8 | beam_width: int = 1 9 | eos_token_id: int = -1 10 | pad_token_id: int = -1 11 | num_return_sequences: int = 1 12 | is_encoder_decoder: bool = False 13 | padding: bool = True 14 | early_stopping: bool = False 15 | data_type: str = "bf16_fp16" 16 | 17 | 18 | class XftModel: 19 | def __init__(self, xft_model, xft_config): 20 | self.model = xft_model 21 | self.config = xft_config 22 | 23 | 24 | def load_xft_model(model_path, xft_config: XftConfig): 25 | try: 26 | import xfastertransformer 27 | from transformers import AutoTokenizer 28 | except ImportError as e: 29 | print(f"Error: Failed to load xFasterTransformer. {e}") 30 | sys.exit(-1) 31 | 32 | if xft_config.data_type is None or xft_config.data_type == "": 33 | data_type = "bf16_fp16" 34 | else: 35 | data_type = xft_config.data_type 36 | tokenizer = AutoTokenizer.from_pretrained( 37 | model_path, use_fast=False, padding_side="left", trust_remote_code=True 38 | ) 39 | xft_model = xfastertransformer.AutoModel.from_pretrained( 40 | model_path, dtype=data_type 41 | ) 42 | model = XftModel(xft_model=xft_model, xft_config=xft_config) 43 | if model.model.rank > 0: 44 | while True: 45 | model.model.generate() 46 | return model, tokenizer 47 | -------------------------------------------------------------------------------- /fastchat/serve/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TIGER-AI-Lab/GenAI-Arena/1792672eed6a4b146b194f9ba7316f25b95cb53b/fastchat/serve/__init__.py -------------------------------------------------------------------------------- /fastchat/serve/api_provider.py: -------------------------------------------------------------------------------- 1 | """Call API providers.""" 2 | 3 | import os 4 | import random 5 | import time 6 | 7 | from fastchat.utils import build_logger 8 | from fastchat.constants import WORKER_API_TIMEOUT 9 | 10 | 11 | logger = build_logger("gradio_web_server", "gradio_web_server.log") 12 | 13 | 14 | def openai_api_stream_iter( 15 | model_name, 16 | messages, 17 | temperature, 18 | top_p, 19 | max_new_tokens, 20 | api_base=None, 21 | api_key=None, 22 | ): 23 | import openai 24 | 25 | is_azure = False 26 | if "azure" in model_name: 27 | is_azure = True 28 | openai.api_type = "azure" 29 | openai.api_version = "2023-07-01-preview" 30 | else: 31 | openai.api_type = "open_ai" 32 | openai.api_version = None 33 | 34 | openai.api_base = api_base or "https://api.openai.com/v1" 35 | openai.api_key = api_key or os.environ["OPENAI_API_KEY"] 36 | if model_name == "gpt-4-turbo": 37 | model_name = "gpt-4-1106-preview" 38 | 39 | # Make requests 40 | gen_params = { 41 | "model": model_name, 42 | "prompt": messages, 43 | "temperature": temperature, 44 | "top_p": top_p, 45 | "max_new_tokens": max_new_tokens, 46 | } 47 | logger.info(f"==== request ====\n{gen_params}") 48 | 49 | if is_azure: 50 | res = openai.ChatCompletion.create( 51 | engine=model_name, 52 | messages=messages, 53 | temperature=temperature, 54 | max_tokens=max_new_tokens, 55 | stream=True, 56 | ) 57 | else: 58 | res = openai.ChatCompletion.create( 59 | model=model_name, 60 | messages=messages, 61 | temperature=temperature, 62 | max_tokens=max_new_tokens, 63 | stream=True, 64 | ) 65 | text = "" 66 | for chunk in res: 67 | if len(chunk["choices"]) > 0: 68 | text += chunk["choices"][0]["delta"].get("content", "") 69 | data = { 70 | "text": text, 71 | "error_code": 0, 72 | } 73 | yield data 74 | 75 | 76 | def anthropic_api_stream_iter(model_name, prompt, temperature, top_p, max_new_tokens): 77 | import anthropic 78 | 79 | c = anthropic.Anthropic(api_key=os.environ["ANTHROPIC_API_KEY"]) 80 | 81 | # Make requests 82 | gen_params = { 83 | "model": model_name, 84 | "prompt": prompt, 85 | "temperature": temperature, 86 | "top_p": top_p, 87 | "max_new_tokens": max_new_tokens, 88 | } 89 | logger.info(f"==== request ====\n{gen_params}") 90 | 91 | res = c.completions.create( 92 | prompt=prompt, 93 | stop_sequences=[anthropic.HUMAN_PROMPT], 94 | max_tokens_to_sample=max_new_tokens, 95 | temperature=temperature, 96 | top_p=top_p, 97 | model=model_name, 98 | stream=True, 99 | ) 100 | text = "" 101 | for chunk in res: 102 | text += chunk.completion 103 | data = { 104 | "text": text, 105 | "error_code": 0, 106 | } 107 | yield data 108 | 109 | 110 | def init_palm_chat(model_name): 111 | import vertexai # pip3 install google-cloud-aiplatform 112 | from vertexai.preview.language_models import ChatModel 113 | 114 | project_id = os.environ["GCP_PROJECT_ID"] 115 | location = "us-central1" 116 | vertexai.init(project=project_id, location=location) 117 | 118 | chat_model = ChatModel.from_pretrained(model_name) 119 | chat = chat_model.start_chat(examples=[]) 120 | return chat 121 | 122 | 123 | def palm_api_stream_iter(chat, message, temperature, top_p, max_new_tokens): 124 | parameters = { 125 | "temperature": temperature, 126 | "top_p": top_p, 127 | "max_output_tokens": max_new_tokens, 128 | } 129 | gen_params = { 130 | "model": "palm-2", 131 | "prompt": message, 132 | } 133 | gen_params.update(parameters) 134 | logger.info(f"==== request ====\n{gen_params}") 135 | 136 | response = chat.send_message(message, **parameters) 137 | content = response.text 138 | 139 | pos = 0 140 | while pos < len(content): 141 | # This is a fancy way to simulate token generation latency combined 142 | # with a Poisson process. 143 | pos += random.randint(10, 20) 144 | time.sleep(random.expovariate(50)) 145 | data = { 146 | "text": content[:pos], 147 | "error_code": 0, 148 | } 149 | yield data 150 | -------------------------------------------------------------------------------- /fastchat/serve/gateway/README.md: -------------------------------------------------------------------------------- 1 | # fastchat Nginx Gateway 2 | 3 | ## Purpose of the Gateway 4 | 5 | The Nginx gateway serves the following purposes: 6 | 7 | 1. Protects Gradio servers by acting as a firewall. 8 | 2. Facilitates dynamic mounting and unmounting of Gradio servers. 9 | 3. Provides load balancing for Gradio servers. 10 | 4. Offers additional security features, such as total connection limit. 11 | 5. Reduces attack surface by requiring only a single public port to be exposed for serving. 12 | 13 | ## Deployment and Updating of the Gateway 14 | 15 | ### Installing Nginx 16 | 17 | On Debian-based distributions (e.g., Ubuntu): 18 | 19 | ```bash 20 | sudo apt update 21 | sudo apt install nginx 22 | ``` 23 | On Red Hat-based distributions (e.g., CentOS, Fedora): 24 | 25 | ```bash 26 | sudo yum install epel-release 27 | sudo yum install nginx 28 | ``` 29 | 30 | ### Deployment 31 | 32 | Copy `nginx.conf` to `/etc/nginx/nginx.conf` (need sudo permission). 33 | 34 | Replace the port number 7860 in `server localhost:7860` with the port where you deploy the Gradio web server. 35 | 36 | Modify `upstream websocket` to configure Gradio servers behind the gateway. 37 | 38 | Lastly, update Nginx. 39 | 40 | 41 | ### HTTPS Deployment with a Public Domain URL 42 | 43 | Make sure you obtain the HTTPS certificate and the private key used to generate the certificate. 44 | 45 | Fill the addresses to your certificate and private key in the `[PATH_TO_SSL_CERT]` and `[PATH_TO_PRIVATE_KEY]` fields. 46 | 47 | If you have your own domain url to serve the chatbot, replace the chat.lmsys.org url with your own domain url. 48 | 49 | ### Updating 50 | 51 | Every time when `/etc/nginx/nginx.conf` is modified, you need to update the Nginx service: 52 | 53 | ```bash 54 | sudo nginx -t # check `/etc/nginx/nginx.conf` 55 | sudo systemctl reload nginx # restart Nginx service to load the new config 56 | sudo systemctl status nginx # check the status of the Nginx service. It should be active (running). 57 | ``` 58 | -------------------------------------------------------------------------------- /fastchat/serve/gateway/nginx.conf: -------------------------------------------------------------------------------- 1 | user www-data; 2 | worker_processes auto; 3 | pid /run/nginx.pid; 4 | include /etc/nginx/modules-enabled/*.conf; 5 | 6 | events { 7 | worker_connections 1024; # maximum number of connections that a worker process can handle concurrently 8 | # multi_accept on; # enabling multi_accept can help improve performance under high load, but may increase the number of simultaneous connections that a worker process can handle 9 | 10 | } 11 | 12 | http { 13 | ## 14 | # Basic Settings 15 | ## 16 | 17 | sendfile on; # enable sendfile for performance optimization 18 | tcp_nopush on; # enable TCP no-pushing 19 | tcp_nodelay on; # enable TCP no-delay 20 | keepalive_timeout 65; # sets the timeout for keep-alive connections 21 | types_hash_max_size 2048; # maximum size of the types hash table 22 | # server_tokens off; # disable server token (i.e., server signature) in response headers to improve security 23 | 24 | # server_names_hash_bucket_size 64; 25 | # server_name_in_redirect off; 26 | 27 | include /etc/nginx/mime.types; # include MIME types file 28 | default_type application/octet-stream; # default MIME type for unknown file types 29 | 30 | ## 31 | # SSL Settings 32 | ## 33 | 34 | ssl_protocols TLSv1.2; # specify SSL/TLS protocols to use 35 | ssl_prefer_server_ciphers on; # prefer server ciphers over client ciphers 36 | 37 | ## 38 | # Logging Settings 39 | ## 40 | 41 | access_log /var/log/nginx/access.log; # path to access log file 42 | error_log /var/log/nginx/error.log; # path to error log file 43 | 44 | ## 45 | # Gzip Settings 46 | ## 47 | gzip on; # enable Gzip compression 48 | 49 | ## 50 | # Virtual Host Configs 51 | ## 52 | 53 | include /etc/nginx/conf.d/*.conf; # include all configuration files in conf.d directory 54 | include /etc/nginx/sites-enabled/*; # include all enabled sites configuration files 55 | 56 | # WebSocket Proxy: https://www.nginx.com/blog/websocket-nginx/ 57 | map $http_upgrade $connection_upgrade { 58 | default upgrade; 59 | '' close; 60 | } 61 | 62 | upstream websocket { 63 | ip_hash; # load balancing by IP to guarantee session persistence 64 | server localhost:7860; # The port should be the gradio web server port 65 | # server localhost:7861; # extra gradio server if more than one 66 | } 67 | 68 | limit_conn_status 429; 69 | limit_conn_zone $binary_remote_addr zone=perip:10m; # limit number of connections per IP 70 | limit_conn_zone $server_name zone=perserver:10m; # limit number of connections per server 71 | 72 | server { 73 | listen 443 ssl; # the listening port of our server 74 | ssl_certificate [PATH_TO_SSL_CERT]; 75 | ssl_certificate_key [PATH_TO_PRIVATE_KEY]; 76 | server_name chat.lmsys.org; # replace the url with your own domain url 77 | limit_conn perserver 1024; # connections per server 78 | location / { 79 | proxy_pass http://websocket; # proxy all requests to the defined upstream server 80 | limit_conn perip 5; # connections per IP 81 | proxy_set_header Host $host; # set the Host header for the upstream server 82 | proxy_set_header X-Real-IP $remote_addr; # set the client IP address as the real IP for the upstream server 83 | proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; # set the client IP addresses in the X-Forwarded-For header 84 | proxy_http_version 1.1; # use HTTP version 1.1 for upstream communication 85 | proxy_set_header Upgrade $http_upgrade; 86 | proxy_set_header Connection "Upgrade"; # set the Connection header to Upgrade to enable WebSocket communication 87 | } 88 | } 89 | 90 | # the following block routes all HTTP traffic to HTTPS via nginx 91 | server { 92 | listen 80; 93 | server_name chat.lmsys.org; 94 | return 301 https://chat.lmsys.org$request_uri; 95 | } 96 | 97 | } 98 | -------------------------------------------------------------------------------- /fastchat/serve/huggingface_api.py: -------------------------------------------------------------------------------- 1 | """ 2 | Use FastChat with Hugging Face generation APIs. 3 | 4 | Usage: 5 | python3 -m fastchat.serve.huggingface_api --model lmsys/vicuna-7b-v1.5 6 | python3 -m fastchat.serve.huggingface_api --model lmsys/fastchat-t5-3b-v1.0 7 | """ 8 | import argparse 9 | 10 | import torch 11 | 12 | from fastchat.model import load_model, get_conversation_template, add_model_args 13 | 14 | 15 | @torch.inference_mode() 16 | def main(args): 17 | # Load model 18 | model, tokenizer = load_model( 19 | args.model_path, 20 | device=args.device, 21 | num_gpus=args.num_gpus, 22 | max_gpu_memory=args.max_gpu_memory, 23 | load_8bit=args.load_8bit, 24 | cpu_offloading=args.cpu_offloading, 25 | revision=args.revision, 26 | debug=args.debug, 27 | ) 28 | 29 | # Build the prompt with a conversation template 30 | msg = args.message 31 | conv = get_conversation_template(args.model_path) 32 | conv.append_message(conv.roles[0], msg) 33 | conv.append_message(conv.roles[1], None) 34 | prompt = conv.get_prompt() 35 | 36 | # Run inference 37 | inputs = tokenizer([prompt], return_tensors="pt").to(args.device) 38 | output_ids = model.generate( 39 | **inputs, 40 | do_sample=True if args.temperature > 1e-5 else False, 41 | temperature=args.temperature, 42 | repetition_penalty=args.repetition_penalty, 43 | max_new_tokens=args.max_new_tokens, 44 | ) 45 | 46 | if model.config.is_encoder_decoder: 47 | output_ids = output_ids[0] 48 | else: 49 | output_ids = output_ids[0][len(inputs["input_ids"][0]) :] 50 | outputs = tokenizer.decode( 51 | output_ids, skip_special_tokens=True, spaces_between_special_tokens=False 52 | ) 53 | 54 | # Print results 55 | print(f"{conv.roles[0]}: {msg}") 56 | print(f"{conv.roles[1]}: {outputs}") 57 | 58 | 59 | if __name__ == "__main__": 60 | parser = argparse.ArgumentParser() 61 | add_model_args(parser) 62 | parser.add_argument("--temperature", type=float, default=0.7) 63 | parser.add_argument("--repetition_penalty", type=float, default=1.0) 64 | parser.add_argument("--max-new-tokens", type=int, default=512) 65 | parser.add_argument("--debug", action="store_true") 66 | parser.add_argument("--message", type=str, default="Hello! Who are you?") 67 | args = parser.parse_args() 68 | 69 | # Reset default repetition penalty for T5 models. 70 | if "t5" in args.model_path and args.repetition_penalty == 1.0: 71 | args.repetition_penalty = 1.2 72 | 73 | main(args) 74 | -------------------------------------------------------------------------------- /fastchat/serve/monitor/dataset_release_scripts/arena_33k/count_unique_users.py: -------------------------------------------------------------------------------- 1 | """Count the unique users in a battle log file.""" 2 | 3 | import argparse 4 | import json 5 | 6 | 7 | if __name__ == "__main__": 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument("--input", type=str) 10 | args = parser.parse_args() 11 | 12 | lines = json.load(open(args.input)) 13 | ct_anony_votes = 0 14 | all_users = set() 15 | all_models = set() 16 | for l in lines: 17 | if not l["anony"]: 18 | continue 19 | all_users.add(l["judge"]) 20 | all_models.add(l["model_a"]) 21 | all_models.add(l["model_b"]) 22 | ct_anony_votes += 1 23 | 24 | print(f"#anony_vote: {ct_anony_votes}, #user: {len(all_users)}") 25 | print(f"#model: {len(all_models)}") 26 | -------------------------------------------------------------------------------- /fastchat/serve/monitor/dataset_release_scripts/arena_33k/filter_bad_conv.py: -------------------------------------------------------------------------------- 1 | """ 2 | Filter conversations for release. 3 | 4 | Usage: python3 filter_bad_conv.py --in clean_battle_conv_20230630_tagged_v1_pii.json 5 | """ 6 | import argparse 7 | from collections import defaultdict 8 | from enum import Enum, auto 9 | import json 10 | import os 11 | import random 12 | 13 | from tqdm import tqdm 14 | 15 | BLOCKED_WORDS_FILENAME = "blocked_words.json" 16 | blocked_words = [] 17 | frequency = defaultdict(lambda: 0) 18 | 19 | 20 | class TypeCode(Enum): 21 | CORRECT = auto() 22 | ANONYMIZED = auto() 23 | REDACTED = auto() 24 | BAD_FORMAT = auto() 25 | BLOCKED_WORD = auto() 26 | BLOCKED_MODEL = auto() 27 | TOO_SHORT = auto() 28 | TOO_FREQUENT = auto() 29 | 30 | 31 | def detect_type(conv): 32 | for key in ["conversation_a", "conversation_b"]: 33 | messages = [row["content"] for row in conv[key]] 34 | for msg in messages: 35 | if not isinstance(msg, str): 36 | return TypeCode.BAD_FORMAT 37 | 38 | user_prompts = [ 39 | row["content"].lower().strip() for row in conv[key] if row["role"] == "user" 40 | ] 41 | if len(messages) <= 2 and all(len(x) < 16 for x in user_prompts): 42 | return TypeCode.TOO_SHORT 43 | 44 | if all(x in frequent_prompts for x in user_prompts): 45 | return TypeCode.TOO_FREQUENT 46 | 47 | for msg in messages: 48 | msg = msg.lower() 49 | if "" in msg: 50 | return TypeCode.ANONYMIZED 51 | if "" in msg: 52 | return TypeCode.REDACTED 53 | 54 | for w in blocked_words: 55 | if w in msg: 56 | return TypeCode.BLOCKED_WORD 57 | 58 | for key in ["model_a", "model_b"]: 59 | if conv[key] in ["vicuna-33b", "mpt-30b-chat"]: 60 | return TypeCode.BLOCKED_MODEL 61 | 62 | return TypeCode.CORRECT 63 | 64 | 65 | if __name__ == "__main__": 66 | parser = argparse.ArgumentParser() 67 | parser.add_argument("--in-file", type=str, required=True) 68 | parser.add_argument("--sample", type=int) 69 | args = parser.parse_args() 70 | 71 | # Read conversations 72 | convs = json.load(open(args.in_file)) 73 | print(f"#conv: {len(convs)}") 74 | 75 | # Read blocked words 76 | if os.path.exists(BLOCKED_WORDS_FILENAME): 77 | blocked_words = json.load(open(BLOCKED_WORDS_FILENAME)) 78 | 79 | # Count frequency 80 | for conv in convs: 81 | for key in ["conversation_a", "conversation_b"]: 82 | messages = [row["content"] for row in conv[key] if row["role"] == "user"] 83 | for msg in messages: 84 | if not isinstance(msg, str): 85 | continue 86 | msg = msg.lower().strip() 87 | frequency[msg] += 1 88 | 89 | keys = list(frequency.keys()) 90 | keys.sort(key=lambda x: -frequency[x]) 91 | frequent_prompts = keys[:10] 92 | frequent_prompts = set(frequent_prompts) 93 | frequent_prompts.add("") 94 | 95 | # Start filter 96 | ct_bad_format = 0 97 | ct_anonymized = 0 98 | ct_redacted = 0 99 | ct_error = 0 100 | ct_lang_filter = 0 101 | ct_flagged = 0 102 | ct_blocked_word = 0 103 | ct_blocked_model = 0 104 | ct_too_short = 0 105 | ct_too_frequent = 0 106 | 107 | new_convs = [] 108 | for conv in tqdm(convs): 109 | type_code = detect_type(conv) 110 | 111 | if type_code == TypeCode.BAD_FORMAT: 112 | ct_bad_format += 1 113 | continue 114 | 115 | if type_code == TypeCode.ANONYMIZED: 116 | ct_anonymized += 1 117 | continue 118 | elif type_code == TypeCode.REDACTED: 119 | ct_redacted += 1 120 | continue 121 | elif type_code == TypeCode.BLOCKED_WORD: 122 | ct_blocked_word += 1 123 | continue 124 | elif type_code == TypeCode.BLOCKED_MODEL: 125 | ct_blocked_model += 1 126 | continue 127 | elif type_code == TypeCode.TOO_SHORT: 128 | ct_too_short += 1 129 | continue 130 | elif type_code == TypeCode.TOO_FREQUENT: 131 | ct_too_frequent += 1 132 | continue 133 | 134 | if conv["openai_moderation"]["flagged"]: 135 | ct_flagged += 1 136 | continue 137 | 138 | if type_code in [TypeCode.CORRECT]: 139 | new_convs.append(conv) 140 | 141 | if args.sample: 142 | # random.seed(0) 143 | # random.shuffle(new_convs) 144 | new_convs = new_convs[: args.sample] 145 | 146 | print(f"ct_anonymized: {ct_anonymized}, ct_redacted: {ct_redacted}") 147 | print(f"ct_bad_format: {ct_bad_format}, ct_flagged: {ct_flagged}") 148 | print(f"ct_blocked_word: {ct_blocked_word}, ct_blocked_model: {ct_blocked_model}") 149 | print(f"ct_too_short: {ct_too_short}, ct_too_frequent: {ct_anonymized}") 150 | print(f"new_conv: {len(new_convs)}") 151 | 152 | out_file = args.in_file.replace(".json", ".out.json") 153 | print(f"Output to {out_file}") 154 | with open(out_file, "w") as fout: 155 | json.dump(new_convs, fout, indent=2, ensure_ascii=False) 156 | -------------------------------------------------------------------------------- /fastchat/serve/monitor/dataset_release_scripts/arena_33k/merge_field.py: -------------------------------------------------------------------------------- 1 | """Count the unique users in a battle log file.""" 2 | 3 | import argparse 4 | import json 5 | 6 | 7 | if __name__ == "__main__": 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument("--input", type=str) 10 | parser.add_argument("--tag-file", type=str) 11 | args = parser.parse_args() 12 | 13 | # build index 14 | objs = json.load(open(args.tag_file)) 15 | new_field_dict = {} 16 | for obj in objs: 17 | new_field_dict[obj["question_id"]] = obj["toxic_chat"] 18 | 19 | objs = json.load(open(args.input)) 20 | for obj in objs: 21 | obj["toxic_chat_tag"] = new_field_dict[obj["question_id"]] 22 | 23 | output = args.input.replace(".json", "_added.json") 24 | with open(output, "w") as fout: 25 | json.dump(objs, fout, indent=2, ensure_ascii=False) 26 | -------------------------------------------------------------------------------- /fastchat/serve/monitor/dataset_release_scripts/arena_33k/sample.py: -------------------------------------------------------------------------------- 1 | """ 2 | Count the unique users in a battle log file. 3 | 4 | Usage: 5 | python3 -input in.json --number 1000 6 | """ 7 | 8 | import argparse 9 | import json 10 | import random 11 | 12 | K = 1000 13 | 14 | if __name__ == "__main__": 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument("--input", type=str) 17 | parser.add_argument("--number", type=int, nargs="+") 18 | args = parser.parse_args() 19 | 20 | convs = json.load(open(args.input)) 21 | random.seed(0) 22 | random.shuffle(convs) 23 | 24 | for number in args.number: 25 | new_convs = convs[:number] 26 | 27 | output = args.input.replace(".json", f"_{number//K}k.json") 28 | with open(output, "w") as fout: 29 | json.dump(new_convs, fout, indent=2, ensure_ascii=False) 30 | 31 | print(f"#in: {len(convs)}, #out: {len(new_convs)}") 32 | print(f"Write to file: {output}") 33 | -------------------------------------------------------------------------------- /fastchat/serve/monitor/dataset_release_scripts/arena_33k/upload_hf_dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | Upload to huggingface. 3 | """ 4 | import json 5 | from datasets import Dataset, DatasetDict, load_dataset 6 | 7 | objs = json.load(open("clean_battle_conv_20230630_tagged_v3_pii_33k_added.json")) 8 | data = Dataset.from_list(objs) 9 | data.push_to_hub("lmsys/chatbot_arena_conversations", private=True) 10 | -------------------------------------------------------------------------------- /fastchat/serve/monitor/dataset_release_scripts/lmsys_chat_1m/approve_all.py: -------------------------------------------------------------------------------- 1 | import requests 2 | 3 | headers = {"authorization": "Bearer hf_XXX"} 4 | 5 | url = "https://huggingface.co/api/datasets/lmsys/lmsys-chat-1m/user-access-request/pending" 6 | a = requests.get(url, headers=headers) 7 | 8 | for u in a.json(): 9 | user = u["user"]["user"] 10 | url = "https://huggingface.co/api/datasets/lmsys/lmsys-chat-1m/user-access-request/grant" 11 | ret = requests.post(url, headers=headers, json={"user": user}) 12 | print(user, ret.status_code) 13 | assert ret.status_code == 200 14 | -------------------------------------------------------------------------------- /fastchat/serve/monitor/dataset_release_scripts/lmsys_chat_1m/compute_stats.py: -------------------------------------------------------------------------------- 1 | """ 2 | From colab: 3 | https://colab.research.google.com/drive/1oMdw_Lqgmd6DletSOLHsyD-Rc96cRShs?usp=sharing 4 | """ 5 | import argparse 6 | import datetime 7 | import json 8 | import os 9 | from pytz import timezone 10 | import time 11 | 12 | import kaleido 13 | import numpy as np 14 | import pandas as pd 15 | import plotly.express as px 16 | import plotly.graph_objects as go 17 | from tqdm import tqdm 18 | 19 | import plotly.io as pio 20 | 21 | pio.kaleido.scope.mathjax = None 22 | 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument("--in-file", type=str, required=True) 25 | parser.add_argument("--scale", type=int, required=True) 26 | args = parser.parse_args() 27 | 28 | filename = args.in_file 29 | scale = args.scale 30 | convs = json.load(open(filename)) 31 | df = pd.DataFrame(convs) 32 | df 33 | 34 | print(f"#ips: {df['user_id'].nunique() * scale}") 35 | print(f"#models: {df['model'].nunique()}") 36 | print(f"#language: {df['language'].nunique()}") 37 | print(f"#turns: {df['turn'].mean()}") 38 | 39 | model_counts = df["model"].value_counts() * scale 40 | # print("model counts", model_counts) 41 | fig = px.bar(x=model_counts.index, y=model_counts) 42 | fig.update_layout( 43 | xaxis_title=None, 44 | yaxis_title="Count", 45 | height=200, 46 | width=950, 47 | margin=dict(l=0, r=0, t=0, b=0), 48 | ) 49 | fig.show() 50 | fig.write_image("model_count.pdf") 51 | 52 | 53 | model_counts = df["language"].value_counts().head(25) * scale 54 | fig = px.bar(x=model_counts.index, y=model_counts) 55 | fig.update_layout( 56 | xaxis_title=None, 57 | yaxis_title="Count", 58 | height=200, 59 | width=950, 60 | margin=dict(l=0, r=0, t=0, b=0), 61 | ) 62 | fig.show() 63 | fig.write_image("language_count.pdf") 64 | 65 | chat_dates = [ 66 | datetime.datetime.fromtimestamp(x, tz=timezone("US/Pacific")).strftime("%Y-%m-%d") 67 | for x in df["tstamp"] 68 | ] 69 | 70 | 71 | def to_remove(x): 72 | for d in ["08-09", "08-08", "08-07", "08-06", "08-05", "08-04"]: 73 | if d in x: 74 | return True 75 | return False 76 | 77 | 78 | chat_dates = [x for x in chat_dates if not to_remove(x)] 79 | 80 | chat_dates_counts = pd.value_counts(chat_dates) * scale 81 | print(f"mean #chat per day: {np.mean(chat_dates_counts):.2f}") 82 | 83 | fig = px.bar(x=chat_dates_counts.index, y=chat_dates_counts) 84 | fig.update_layout( 85 | xaxis_title="Dates", 86 | yaxis_title="Count", 87 | height=200, 88 | width=950, 89 | margin=dict(l=0, r=0, t=0, b=0), 90 | ) 91 | fig.show() 92 | fig.write_image("daily_conversation_count.pdf") 93 | 94 | import transformers 95 | 96 | tokenizer = transformers.AutoTokenizer.from_pretrained( 97 | "lmsys/vicuna-7b-v1.5", use_fast=False 98 | ) 99 | 100 | prompts = [] 101 | responses = [] 102 | for conv in df["conversation"]: 103 | for row in conv: 104 | if row["role"] == "user": 105 | prompts.append(row["content"]) 106 | else: 107 | responses.append(row["content"]) 108 | 109 | print(f"#prompts: {len(prompts)}") 110 | print(f"#responses: {len(responses)}") 111 | 112 | 113 | prompt_lens = [len(tokenizer(x).input_ids) for x in tqdm(prompts)] 114 | print() 115 | print(f"mean prompt len: {np.mean(prompt_lens):.2f}") 116 | 117 | response_lens = [len(tokenizer(x).input_ids) if x else 0 for x in tqdm(responses)] 118 | print() 119 | print(f"mean response len: {np.mean(response_lens):.2f}") 120 | -------------------------------------------------------------------------------- /fastchat/serve/monitor/dataset_release_scripts/lmsys_chat_1m/filter_bad_conv.py: -------------------------------------------------------------------------------- 1 | """ 2 | Filter conversations for release. 3 | 4 | Dependency: 5 | pip install opencc-python-reimplementedpip install opencc-python-reimplemented 6 | 7 | Usage: 8 | python3 filter_bad_conv_lmsys_chat_1m.py --in clean_battle_conv_20230630_tagged_v1_pii.json 9 | """ 10 | import argparse 11 | from concurrent.futures import ProcessPoolExecutor 12 | from collections import defaultdict 13 | from enum import Enum, auto 14 | import json 15 | import os 16 | import random 17 | 18 | from tqdm import tqdm 19 | import opencc 20 | 21 | BLOCKED_WORDS_FILENAME = "blocked_words.json" 22 | blocked_words = [] 23 | frequency = defaultdict(lambda: 0) 24 | 25 | cc_converter = opencc.OpenCC("t2s") 26 | 27 | 28 | class TypeCode(Enum): 29 | CORRECT = auto() 30 | ANONYMIZED = auto() 31 | REDACTED = auto() 32 | BAD_FORMAT = auto() 33 | BLOCKED_WORD = auto() 34 | BLOCKED_MODEL = auto() 35 | TOO_SHORT = auto() 36 | TOO_FREQUENT = auto() 37 | 38 | 39 | def detect_type(conv): 40 | for key in ["conversation_a", "conversation_b", "conversation"]: 41 | if key not in conv: 42 | continue 43 | 44 | messages = [row["content"] for row in conv[key]] 45 | for msg in messages: 46 | if not isinstance(msg, str): 47 | return TypeCode.BAD_FORMAT 48 | 49 | if len(messages) == 0: 50 | return TypeCode.BAD_FORMAT 51 | 52 | user_prompts = [ 53 | row["content"].lower().strip() for row in conv[key] if row["role"] == "user" 54 | ] 55 | 56 | for msg in messages: 57 | msg = cc_converter.convert(msg.lower()) 58 | if "" in msg: 59 | return TypeCode.ANONYMIZED 60 | if "" in msg: 61 | return TypeCode.REDACTED 62 | 63 | for w in blocked_words: 64 | if w in msg: 65 | return TypeCode.BLOCKED_WORD 66 | 67 | return TypeCode.CORRECT 68 | 69 | 70 | if __name__ == "__main__": 71 | parser = argparse.ArgumentParser() 72 | parser.add_argument("--in-file", type=str, required=True) 73 | parser.add_argument("--sample", type=int) 74 | args = parser.parse_args() 75 | 76 | # Read conversations 77 | convs = json.load(open(args.in_file)) 78 | print(f"#conv: {len(convs)}") 79 | 80 | # Read blocked words 81 | if os.path.exists(BLOCKED_WORDS_FILENAME): 82 | blocked_words = json.load(open(BLOCKED_WORDS_FILENAME)) 83 | blocked_words = [cc_converter.convert(w) for w in blocked_words] 84 | 85 | # Start filter 86 | ct_bad_format = 0 87 | ct_anonymized = 0 88 | ct_redacted = 0 89 | ct_error = 0 90 | ct_lang_filter = 0 91 | ct_flagged = 0 92 | ct_blocked_word = 0 93 | ct_blocked_model = 0 94 | ct_too_short = 0 95 | ct_too_frequent = 0 96 | 97 | type_codes = [] 98 | with ProcessPoolExecutor() as executor: 99 | for result in tqdm(executor.map(detect_type, convs), total=len(convs)): 100 | type_codes.append(result) 101 | 102 | new_convs = [] 103 | for conv, type_code in zip(convs, type_codes): 104 | if type_code == TypeCode.BAD_FORMAT: 105 | ct_bad_format += 1 106 | continue 107 | 108 | if type_code == TypeCode.ANONYMIZED: 109 | ct_anonymized += 1 110 | continue 111 | elif type_code == TypeCode.REDACTED: 112 | ct_redacted += 1 113 | continue 114 | elif type_code == TypeCode.BLOCKED_WORD: 115 | ct_blocked_word += 1 116 | continue 117 | elif type_code == TypeCode.BLOCKED_MODEL: 118 | ct_blocked_model += 1 119 | continue 120 | elif type_code == TypeCode.TOO_SHORT: 121 | ct_too_short += 1 122 | continue 123 | elif type_code == TypeCode.TOO_FREQUENT: 124 | ct_too_frequent += 1 125 | continue 126 | 127 | if "openai_moderation" in conv and conv["openai_moderation"]["flagged"]: 128 | ct_flagged += 1 129 | continue 130 | 131 | if type_code in [TypeCode.CORRECT]: 132 | new_convs.append(conv) 133 | 134 | if args.sample: 135 | random.seed(42) 136 | random.shuffle(new_convs) 137 | new_convs = new_convs[: args.sample] 138 | 139 | print(f"ct_anonymized: {ct_anonymized}, ct_redacted: {ct_redacted}") 140 | print(f"ct_bad_format: {ct_bad_format}, ct_flagged: {ct_flagged}") 141 | print(f"ct_blocked_word: {ct_blocked_word}, ct_blocked_model: {ct_blocked_model}") 142 | print(f"ct_too_short: {ct_too_short}, ct_too_frequent: {ct_too_frequent}") 143 | print(f"new_conv: {len(new_convs)}") 144 | 145 | out_file = args.in_file.replace(".json", ".s1.json") 146 | print(f"Output to {out_file}") 147 | with open(out_file, "w") as fout: 148 | json.dump(new_convs, fout, indent=2, ensure_ascii=False) 149 | -------------------------------------------------------------------------------- /fastchat/serve/monitor/dataset_release_scripts/lmsys_chat_1m/final_post_processing.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | 4 | from tqdm import tqdm 5 | import numpy as np 6 | 7 | 8 | if __name__ == "__main__": 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument("--in-file", type=str, required=True) 11 | args = parser.parse_args() 12 | 13 | # Read conversations 14 | convs = json.load(open(args.in_file)) 15 | print(f"#conv: {len(convs)}") 16 | 17 | # Delete some fileds 18 | for c in convs: 19 | del c["tstamp"] 20 | del c["user_id"] 21 | 22 | # Write 23 | print(f"#out conv: {len(convs)}") 24 | out_file = args.in_file.replace(".json", ".s2.json") 25 | print(f"Output to {out_file}") 26 | with open(out_file, "w") as fout: 27 | json.dump(convs, fout, indent=2, ensure_ascii=False) 28 | -------------------------------------------------------------------------------- /fastchat/serve/monitor/dataset_release_scripts/lmsys_chat_1m/instructions.md: -------------------------------------------------------------------------------- 1 | ``` 2 | export BASE=clean_conv_20230809_100k_pii 3 | export SCALE=10 4 | 5 | # filter words 6 | python3 filter_bad_conv.py --in $BASE.json 7 | 8 | # Clean up some fileds (e.g., timestamps) 9 | python3 final_post_processing.py --in $BASE.s1.json 10 | 11 | # upload to hf 12 | python3 upload_hf_dataset.py --in $BASE.s1.s2.json 13 | 14 | # Make another version with openai moderation tag 15 | python3 merge_oai_tag.py --in $BASE.s1.s2.json 16 | 17 | # Make visualizations 18 | python3 compute_stats.py --in $BASE.s1.json --scale $SCALE 19 | 20 | # Copy figures 21 | scp "atlas:/data/lmzheng/FastChat/fastchat/serve/monitor/dataset_release_scripts/lmsys_chat_1m/*.pdf" . 22 | ``` 23 | 24 | -------------------------------------------------------------------------------- /fastchat/serve/monitor/dataset_release_scripts/lmsys_chat_1m/merge_oai_tag.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import time 4 | 5 | from tqdm import tqdm 6 | 7 | 8 | if __name__ == "__main__": 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument("--in-file", type=str, required=True) 11 | parser.add_argument("--sample", type=int) 12 | args = parser.parse_args() 13 | 14 | tag_file = "clean_conv_20230809_1.5M_oai_filter_v2.json" 15 | # tag_file = "clean_conv_20230809_1.5M_oai_filter_v2_100k.json" 16 | in_file = args.in_file 17 | tic = time.time() 18 | 19 | # Load tags 20 | print("Load tags...") 21 | tag_data = json.load(open(tag_file)) 22 | tag_dict = {} 23 | for c in tqdm(tag_data): 24 | tag_dict[c["conversation_id"]] = [x["oai_filter"] for x in c["conversation"]] 25 | print(f"elapsed: {time.time() - tic:.2f} s") 26 | 27 | # Append to input_file 28 | print("Load inputs...") 29 | input_data = json.load(open(in_file)) 30 | for c in tqdm(input_data): 31 | cid = c["conversation_id"] 32 | if cid in tag_dict: 33 | c["openai_moderation"] = tag_dict[cid] 34 | else: 35 | print(f"missing tag for conv {cid}") 36 | exit() 37 | print(f"elapsed: {time.time() - tic:.2f} s") 38 | 39 | # Write output 40 | print("Write outputs...") 41 | out_file = in_file.replace(".json", ".with_tag.json") 42 | print(f"Output to {out_file}") 43 | with open(out_file, "w") as fout: 44 | json.dump(input_data, fout, indent=2, ensure_ascii=False) 45 | print(f"elapsed: {time.time() - tic:.2f} s") 46 | -------------------------------------------------------------------------------- /fastchat/serve/monitor/dataset_release_scripts/lmsys_chat_1m/process_all.sh: -------------------------------------------------------------------------------- 1 | export BASE=clean_conv_20230809_1.5M_pii 2 | #export BASE=clean_conv_20230809_100k_pii 3 | export SCALE=1 4 | 5 | # Filter words 6 | python3 filter_bad_conv.py --in $BASE.json --sample 1000000 7 | 8 | # Clean up some fileds (e.g., timestamps) 9 | python3 final_post_processing.py --in $BASE.s1.json 10 | 11 | # Upload to hf 12 | python3 upload_hf_dataset.py --in $BASE.s1.s2.json 13 | 14 | # Make another version with openai moderation tag 15 | python3 merge_oai_tag.py --in $BASE.s1.s2.json 16 | 17 | # Make visualizations 18 | python3 compute_stats.py --in $BASE.s1.json --scale $SCALE 19 | -------------------------------------------------------------------------------- /fastchat/serve/monitor/dataset_release_scripts/lmsys_chat_1m/sample.py: -------------------------------------------------------------------------------- 1 | """ 2 | Count the unique users in a battle log file. 3 | 4 | Usage: 5 | python3 -input in.json --number 1000 6 | """ 7 | 8 | import argparse 9 | import json 10 | import random 11 | 12 | K = 1000 13 | 14 | if __name__ == "__main__": 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument("--input", type=str) 17 | parser.add_argument("--number", type=int, nargs="+") 18 | args = parser.parse_args() 19 | 20 | convs = json.load(open(args.input)) 21 | random.seed(42) 22 | random.shuffle(convs) 23 | 24 | for number in args.number: 25 | new_convs = convs[:number] 26 | 27 | output = args.input.replace(".json", f"_{number//K}k.json") 28 | with open(output, "w") as fout: 29 | json.dump(new_convs, fout, indent=2, ensure_ascii=False) 30 | 31 | print(f"#in: {len(convs)}, #out: {len(new_convs)}") 32 | print(f"Write to file: {output}") 33 | -------------------------------------------------------------------------------- /fastchat/serve/monitor/dataset_release_scripts/lmsys_chat_1m/upload_hf_dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | Upload to huggingface. 3 | """ 4 | import argparse 5 | import json 6 | from datasets import Dataset, DatasetDict, load_dataset 7 | 8 | 9 | if __name__ == "__main__": 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument("--in-file", type=str, required=True) 12 | args = parser.parse_args() 13 | 14 | objs = json.load(open(args.in_file)) 15 | print(f"#convs: {len(objs)}") 16 | data = Dataset.from_list(objs) 17 | data.push_to_hub("lmsys/lmsys-chat-1m", private=True) 18 | -------------------------------------------------------------------------------- /fastchat/serve/monitor/inspect_conv.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import code 3 | import datetime 4 | import json 5 | import os 6 | from pytz import timezone 7 | import time 8 | 9 | import pandas as pd 10 | from tqdm import tqdm 11 | 12 | 13 | def get_log_files(max_num_files=None): 14 | dates = [] 15 | for month in [4, 5]: 16 | for day in range(1, 32): 17 | dates.append(f"2023-{month:02d}-{day:02d}") 18 | 19 | num_servers = 14 20 | filenames = [] 21 | for d in dates: 22 | for i in range(num_servers): 23 | name = os.path.expanduser(f"~/fastchat_logs/server{i}/{d}-conv.json") 24 | if os.path.exists(name): 25 | filenames.append(name) 26 | max_num_files = max_num_files or len(filenames) 27 | filenames = filenames[-max_num_files:] 28 | return filenames 29 | 30 | 31 | def pretty_print_conversation(messages): 32 | for role, msg in messages: 33 | print(f"[[{role}]]: {msg}") 34 | 35 | 36 | def inspect_convs(log_files): 37 | data = [] 38 | for filename in tqdm(log_files, desc="read files"): 39 | for retry in range(5): 40 | try: 41 | lines = open(filename).readlines() 42 | break 43 | except FileNotFoundError: 44 | time.sleep(2) 45 | 46 | for l in lines: 47 | row = json.loads(l) 48 | 49 | if "states" not in row: 50 | continue 51 | if row["type"] not in ["leftvote", "rightvote", "bothbad_vote"]: 52 | continue 53 | 54 | model_names = row["states"][0]["model_name"], row["states"][1]["model_name"] 55 | if row["type"] == "leftvote": 56 | winner, loser = model_names[0], model_names[1] 57 | winner_conv, loser_conv = row["states"][0], row["states"][1] 58 | elif row["type"] == "rightvote": 59 | loser, winner = model_names[0], model_names[1] 60 | loser_conv, winner_conv = row["states"][0], row["states"][1] 61 | 62 | if loser == "bard" and winner == "vicuna-13b": 63 | print("=" * 20) 64 | print(f"Winner: {winner}") 65 | pretty_print_conversation(winner_conv["messages"]) 66 | print(f"Loser: {loser}") 67 | pretty_print_conversation(loser_conv["messages"]) 68 | print("=" * 20) 69 | input() 70 | 71 | # if row["type"] == "bothbad_vote" and "gpt-4" in model_names: 72 | # print("=" * 20) 73 | # print(f"Model A: {model_names[0]}") 74 | # pretty_print_conversation(row["states"][0]["messages"]) 75 | # print(f"Model B: {model_names[1]}") 76 | # pretty_print_conversation(row["states"][1]["messages"]) 77 | # print("=" * 20) 78 | # input() 79 | 80 | 81 | if __name__ == "__main__": 82 | parser = argparse.ArgumentParser() 83 | parser.add_argument("--max-num-files", type=int) 84 | args = parser.parse_args() 85 | 86 | log_files = get_log_files(args.max_num_files) 87 | inspect_convs(log_files) 88 | -------------------------------------------------------------------------------- /fastchat/serve/monitor/intersect_conv_file.py: -------------------------------------------------------------------------------- 1 | """ 2 | Take the intersection of two conversation files. 3 | 4 | Usage: python3 -m fastchat.data.merge --input input.json --conv-id conv_id_file.json --out intersect.json 5 | """ 6 | 7 | import argparse 8 | import json 9 | 10 | 11 | if __name__ == "__main__": 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument("--input", type=str, required=True) 14 | parser.add_argument("--conv-id", type=str, required=True) 15 | parser.add_argument("--out-file", type=str, default="intersect.json") 16 | args = parser.parse_args() 17 | 18 | conv_id_objs = json.load(open(args.conv_id, "r")) 19 | conv_ids = set(x["conversation_id"] for x in conv_id_objs) 20 | 21 | objs = json.load(open(args.input, "r")) 22 | after_objs = [x for x in objs if x["conversation_id"] in conv_ids] 23 | 24 | print(f"#in: {len(objs)}, #out: {len(after_objs)}") 25 | json.dump(after_objs, open(args.out_file, "w"), indent=2, ensure_ascii=False) 26 | -------------------------------------------------------------------------------- /fastchat/serve/monitor/leaderboard_csv_to_html.py: -------------------------------------------------------------------------------- 1 | """ 2 | Convert a leaderboard csv file to html table used in the blog. 3 | 4 | Usage: 5 | python3 leaderboard_csv_to_html.py --in leaderboard_table_20230619.csv 6 | """ 7 | import argparse 8 | 9 | import numpy as np 10 | 11 | from fastchat.serve.monitor.monitor import load_leaderboard_table_csv 12 | 13 | 14 | def model_hyperlink(model_name, link): 15 | return f' {model_name} ' 16 | 17 | 18 | if __name__ == "__main__": 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument("--input", type=str, required=True) 21 | args = parser.parse_args() 22 | 23 | data = load_leaderboard_table_csv(args.input, add_hyperlink=False) 24 | headers = [ 25 | "Model", 26 | "MT-bench (score)", 27 | "Arena Elo rating", 28 | "MMLU", 29 | "License", 30 | ] 31 | values = [] 32 | for item in data: 33 | row = [] 34 | for key in headers: 35 | value = item[key] 36 | row.append(value) 37 | row[0] = model_hyperlink(item["Model"], item["Link"]) 38 | values.append(row) 39 | values.sort(key=lambda x: -x[1] if not np.isnan(x[1]) else 1e9) 40 | 41 | for value in values: 42 | row = "" 43 | for x in value: 44 | try: 45 | if np.isnan(x): 46 | x = "-" 47 | except TypeError: 48 | pass 49 | row += f" {x} " 50 | row += "" 51 | print(row) 52 | -------------------------------------------------------------------------------- /fastchat/serve/monitor/summarize_cluster.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 summarize_cluster.py --in results_c20_kmeans_cluster.pkl --model gpt-4 --num-prompts 100 4 | python3 summarize_cluster.py --in results_c20_kmeans_cluster.pkl --model azure-gpt-4-32k --num-prompts 200 5 | """ 6 | import argparse 7 | import pickle 8 | 9 | from fastchat.llm_judge.common import ( 10 | chat_compeletion_openai, 11 | chat_compeletion_openai_azure, 12 | chat_compeletion_anthropic, 13 | ) 14 | from fastchat.conversation import get_conv_template 15 | 16 | 17 | def truncate_string(s, l): 18 | half = int(l // 2) 19 | return s[:half] + s[-half:] if len(s) > l else s 20 | 21 | 22 | if __name__ == "__main__": 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument("--input-file", type=str, required=True) 25 | parser.add_argument("--model", type=str, default="gpt-3.5-turbo") 26 | parser.add_argument("--num-prompts", type=int, default=100) 27 | args = parser.parse_args() 28 | 29 | model = args.model 30 | 31 | cluster_infos = pickle.load(open(args.input_file, "rb")) 32 | num_total_prompts = sum([x[0] for x in cluster_infos]) 33 | 34 | topics = [] 35 | percentages = [] 36 | for i, info in enumerate(cluster_infos): 37 | num_samples, topk_prompts, random_prompts = info 38 | percentage = num_samples / num_total_prompts 39 | print( 40 | f"cluster {i}, #prompts {num_samples}, percentage: {percentage * 100:.2f}%" 41 | ) 42 | instruct = "Given a list of user messages, use less than 8 words to summarize a central topic for all messages in English. Your output should only include a single line. Try to be specific." 43 | split = int(args.num_prompts * 0.8) 44 | prompt = "\n".join( 45 | [truncate_string(x, l=200) for x in topk_prompts[:split]] 46 | + [ 47 | truncate_string(x, l=200) 48 | for x in random_prompts[: args.num_prompts - split] 49 | ] 50 | ) 51 | prompt = "BEGIN OF THE MESSAGE LIST\n" + prompt + "\nEND OF THE MESSAGE LIST." 52 | 53 | if "azure-" in model: 54 | template_name = "chatgpt" 55 | completion_func = chat_compeletion_openai_azure 56 | elif "gpt" in model: 57 | template_name = "chatgpt" 58 | completion_func = chat_compeletion_openai 59 | elif "claude" in model: 60 | template_name = "claude" 61 | completion_func = chat_compeletion_anthropic 62 | 63 | conv = get_conv_template(template_name) 64 | conv.set_system_message(instruct) 65 | conv.append_message(conv.roles[0], prompt) 66 | conv.append_message(conv.roles[1], None) 67 | 68 | topic = completion_func(model, conv, temperature=0, max_tokens=256) 69 | print(topic) 70 | 71 | topics.append(topic) 72 | percentages.append(round(percentage, 6)) 73 | 74 | print() 75 | print(f"topics: {topics}") 76 | print(f"percentages: {percentages}") 77 | -------------------------------------------------------------------------------- /fastchat/serve/monitor/tag_openai_moderation.py: -------------------------------------------------------------------------------- 1 | """ 2 | Add OpenAI moderation API results to all conversations. 3 | """ 4 | import argparse 5 | from concurrent.futures import ThreadPoolExecutor 6 | import json 7 | import os 8 | import time 9 | 10 | import openai 11 | import requests 12 | from tqdm import tqdm 13 | 14 | 15 | API_MAX_RETRY = 16 16 | API_RETRY_SLEEP = 10 17 | API_ERROR_OUTPUT = "$ERROR$" 18 | 19 | 20 | def tag_moderation(text): 21 | result = API_ERROR_OUTPUT 22 | for _ in range(API_MAX_RETRY): 23 | try: 24 | result = openai.Moderation.create(input=text)["results"][0] 25 | break 26 | except openai.error.OpenAIError as e: 27 | print(type(e), e) 28 | time.sleep(API_RETRY_SLEEP) 29 | 30 | return result 31 | 32 | 33 | def tag_openai_moderation(x): 34 | conv = x["conversation_a"] 35 | user_prompts = "\n".join([x["content"] for x in conv if x["role"] == "user"]) 36 | result = tag_moderation(user_prompts) 37 | x["openai_moderation"] = result 38 | 39 | 40 | if __name__ == "__main__": 41 | parser = argparse.ArgumentParser() 42 | parser.add_argument("--input", type=str, required=True) 43 | parser.add_argument( 44 | "--parallel", type=int, default=1, help="The number of concurrent API calls." 45 | ) 46 | parser.add_argument("--first-n", type=int) 47 | args = parser.parse_args() 48 | 49 | battles = json.load(open(args.input)) 50 | 51 | if args.first_n: 52 | battles = battles[: args.first_n] 53 | 54 | with ThreadPoolExecutor(args.parallel) as executor: 55 | for line in tqdm( 56 | executor.map(tag_openai_moderation, battles), total=len(battles) 57 | ): 58 | pass 59 | 60 | output = args.input.replace(".json", "_tagged.json") 61 | with open(output, "w") as fout: 62 | json.dump(battles, fout, indent=2, ensure_ascii=False) 63 | print(f"Write cleaned data to {output}") 64 | -------------------------------------------------------------------------------- /fastchat/serve/register_worker.py: -------------------------------------------------------------------------------- 1 | """ 2 | Manually register workers. 3 | 4 | Usage: 5 | python3 -m fastchat.serve.register_worker --controller http://localhost:21001 --worker-name http://localhost:21002 6 | """ 7 | 8 | import argparse 9 | 10 | import requests 11 | 12 | if __name__ == "__main__": 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument("--controller-address", type=str) 15 | parser.add_argument("--worker-name", type=str) 16 | parser.add_argument("--check-heart-beat", action="store_true") 17 | args = parser.parse_args() 18 | 19 | url = args.controller_address + "/register_worker" 20 | data = { 21 | "worker_name": args.worker_name, 22 | "check_heart_beat": args.check_heart_beat, 23 | "worker_status": None, 24 | } 25 | r = requests.post(url, json=data) 26 | assert r.status_code == 200 27 | -------------------------------------------------------------------------------- /fastchat/serve/shutdown_serve.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python shutdown_serve.py --down all 4 | options: "all","controller","model_worker","openai_api_server", `all` means to stop all related servers 5 | """ 6 | 7 | import argparse 8 | import os 9 | import subprocess 10 | 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument( 13 | "--down", choices=["all", "controller", "model_worker", "openai_api_server"] 14 | ) 15 | args = parser.parse_args() 16 | base_shell = "ps -eo user,pid,cmd|grep fastchat.serve{}|grep -v grep|awk '{{print $2}}'|xargs kill -9" 17 | if args.down == "all": 18 | shell_script = base_shell.format("") 19 | else: 20 | serve = f".{args.down}" 21 | shell_script = base_shell.format(serve) 22 | print(f"execute shell cmd: {shell_script}") 23 | subprocess.run(shell_script, shell=True, check=True) 24 | print(f"{args.down} has been shutdown!") 25 | -------------------------------------------------------------------------------- /fastchat/serve/test_message.py: -------------------------------------------------------------------------------- 1 | """Send a test message.""" 2 | import argparse 3 | import json 4 | 5 | import requests 6 | 7 | from fastchat.model.model_adapter import get_conversation_template 8 | 9 | 10 | def main(): 11 | model_name = args.model_name 12 | 13 | if args.worker_address: 14 | worker_addr = args.worker_address 15 | else: 16 | controller_addr = args.controller_address 17 | ret = requests.post(controller_addr + "/refresh_all_workers") 18 | ret = requests.post(controller_addr + "/list_models") 19 | models = ret.json()["models"] 20 | models.sort() 21 | print(f"Models: {models}") 22 | 23 | ret = requests.post( 24 | controller_addr + "/get_worker_address", json={"model": model_name} 25 | ) 26 | worker_addr = ret.json()["address"] 27 | print(f"worker_addr: {worker_addr}") 28 | 29 | if worker_addr == "": 30 | print(f"No available workers for {model_name}") 31 | return 32 | 33 | conv = get_conversation_template(model_name) 34 | conv.append_message(conv.roles[0], args.message) 35 | conv.append_message(conv.roles[1], None) 36 | prompt = conv.get_prompt() 37 | 38 | headers = {"User-Agent": "FastChat Client"} 39 | gen_params = { 40 | "model": model_name, 41 | "prompt": prompt, 42 | "temperature": args.temperature, 43 | "max_new_tokens": args.max_new_tokens, 44 | "stop": conv.stop_str, 45 | "stop_token_ids": conv.stop_token_ids, 46 | "echo": False, 47 | } 48 | response = requests.post( 49 | worker_addr + "/worker_generate_stream", 50 | headers=headers, 51 | json=gen_params, 52 | stream=True, 53 | ) 54 | 55 | print(f"{conv.roles[0]}: {args.message}") 56 | print(f"{conv.roles[1]}: ", end="") 57 | prev = 0 58 | for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): 59 | if chunk: 60 | data = json.loads(chunk.decode()) 61 | output = data["text"].strip() 62 | print(output[prev:], end="", flush=True) 63 | prev = len(output) 64 | print("") 65 | 66 | 67 | if __name__ == "__main__": 68 | parser = argparse.ArgumentParser() 69 | parser.add_argument( 70 | "--controller-address", type=str, default="http://localhost:21001" 71 | ) 72 | parser.add_argument("--worker-address", type=str) 73 | parser.add_argument("--model-name", type=str, required=True) 74 | parser.add_argument("--temperature", type=float, default=0.0) 75 | parser.add_argument("--max-new-tokens", type=int, default=32) 76 | parser.add_argument( 77 | "--message", type=str, default="Tell me a story with more than 1000 words." 78 | ) 79 | args = parser.parse_args() 80 | 81 | main() 82 | -------------------------------------------------------------------------------- /fastchat/serve/test_throughput.py: -------------------------------------------------------------------------------- 1 | """Benchmarking script to test the throughput of serving workers.""" 2 | import argparse 3 | import json 4 | 5 | import requests 6 | import threading 7 | import time 8 | 9 | from fastchat.conversation import get_conv_template 10 | 11 | 12 | def main(): 13 | if args.worker_address: 14 | worker_addr = args.worker_address 15 | else: 16 | controller_addr = args.controller_address 17 | ret = requests.post(controller_addr + "/refresh_all_workers") 18 | ret = requests.post(controller_addr + "/list_models") 19 | models = ret.json()["models"] 20 | models.sort() 21 | print(f"Models: {models}") 22 | 23 | ret = requests.post( 24 | controller_addr + "/get_worker_address", json={"model": args.model_name} 25 | ) 26 | worker_addr = ret.json()["address"] 27 | print(f"worker_addr: {worker_addr}") 28 | 29 | if worker_addr == "": 30 | return 31 | 32 | conv = get_conv_template("vicuna_v1.1") 33 | conv.append_message(conv.roles[0], "Tell me a story with more than 1000 words") 34 | prompt_template = conv.get_prompt() 35 | prompts = [prompt_template for _ in range(args.n_thread)] 36 | 37 | headers = {"User-Agent": "fastchat Client"} 38 | ploads = [ 39 | { 40 | "model": args.model_name, 41 | "prompt": prompts[i], 42 | "max_new_tokens": args.max_new_tokens, 43 | "temperature": 0.0, 44 | # "stop": conv.sep, 45 | } 46 | for i in range(len(prompts)) 47 | ] 48 | 49 | def send_request(results, i): 50 | if args.test_dispatch: 51 | ret = requests.post( 52 | controller_addr + "/get_worker_address", json={"model": args.model_name} 53 | ) 54 | thread_worker_addr = ret.json()["address"] 55 | else: 56 | thread_worker_addr = worker_addr 57 | print(f"thread {i} goes to {thread_worker_addr}") 58 | response = requests.post( 59 | thread_worker_addr + "/worker_generate_stream", 60 | headers=headers, 61 | json=ploads[i], 62 | stream=False, 63 | ) 64 | k = list( 65 | response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0") 66 | ) 67 | # print(k) 68 | response_new_words = json.loads(k[-2].decode("utf-8"))["text"] 69 | error_code = json.loads(k[-2].decode("utf-8"))["error_code"] 70 | # print(f"=== Thread {i} ===, words: {1}, error code: {error_code}") 71 | results[i] = len(response_new_words.split(" ")) - len(prompts[i].split(" ")) 72 | 73 | # use N threads to prompt the backend 74 | tik = time.time() 75 | threads = [] 76 | results = [None] * args.n_thread 77 | for i in range(args.n_thread): 78 | t = threading.Thread(target=send_request, args=(results, i)) 79 | t.start() 80 | # time.sleep(0.5) 81 | threads.append(t) 82 | 83 | for t in threads: 84 | t.join() 85 | 86 | print(f"Time (POST): {time.time() - tik} s") 87 | # n_words = 0 88 | # for i, response in enumerate(results): 89 | # # print(prompt[i].replace(conv.sep, "\n"), end="") 90 | # # make sure the streaming finishes at EOS or stopping criteria 91 | # k = list(response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0")) 92 | # response_new_words = json.loads(k[-2].decode("utf-8"))["text"] 93 | # # print(response_new_words) 94 | # n_words += len(response_new_words.split(" ")) - len(prompts[i].split(" ")) 95 | n_words = sum(results) 96 | time_seconds = time.time() - tik 97 | print( 98 | f"Time (Completion): {time_seconds}, n threads: {args.n_thread}, " 99 | f"throughput: {n_words / time_seconds} words/s." 100 | ) 101 | 102 | 103 | if __name__ == "__main__": 104 | parser = argparse.ArgumentParser() 105 | parser.add_argument( 106 | "--controller-address", type=str, default="http://localhost:21001" 107 | ) 108 | parser.add_argument("--worker-address", type=str) 109 | parser.add_argument("--model-name", type=str, default="vicuna") 110 | parser.add_argument("--max-new-tokens", type=int, default=2048) 111 | parser.add_argument("--n-thread", type=int, default=8) 112 | parser.add_argument("--test-dispatch", action="store_true") 113 | args = parser.parse_args() 114 | 115 | main() 116 | -------------------------------------------------------------------------------- /fastchat/train/llama_flash_attn_monkey_patch.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple 2 | import warnings 3 | 4 | import torch 5 | from torch import nn 6 | import transformers 7 | from transformers.models.llama.modeling_llama import apply_rotary_pos_emb 8 | 9 | from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func 10 | from flash_attn.bert_padding import unpad_input, pad_input 11 | 12 | 13 | def forward( 14 | self, 15 | hidden_states: torch.Tensor, 16 | attention_mask: Optional[torch.Tensor] = None, 17 | position_ids: Optional[torch.Tensor] = None, 18 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 19 | output_attentions: bool = False, 20 | use_cache: bool = False, 21 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 22 | if output_attentions: 23 | warnings.warn( 24 | "Output attentions is not supported for patched `LlamaAttention`, returning `None` instead." 25 | ) 26 | 27 | bsz, q_len, _ = hidden_states.size() 28 | 29 | query_states = ( 30 | self.q_proj(hidden_states) 31 | .view(bsz, q_len, self.num_heads, self.head_dim) 32 | .transpose(1, 2) 33 | ) 34 | key_states = ( 35 | self.k_proj(hidden_states) 36 | .view(bsz, q_len, self.num_heads, self.head_dim) 37 | .transpose(1, 2) 38 | ) 39 | value_states = ( 40 | self.v_proj(hidden_states) 41 | .view(bsz, q_len, self.num_heads, self.head_dim) 42 | .transpose(1, 2) 43 | ) # shape: (b, num_heads, s, head_dim) 44 | 45 | kv_seq_len = key_states.shape[-2] 46 | if past_key_value is not None: 47 | kv_seq_len += past_key_value[0].shape[-2] 48 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) 49 | query_states, key_states = apply_rotary_pos_emb( 50 | query_states, key_states, cos, sin, position_ids 51 | ) 52 | 53 | if past_key_value is not None: 54 | # reuse k, v 55 | key_states = torch.cat([past_key_value[0], key_states], dim=2) 56 | value_states = torch.cat([past_key_value[1], value_states], dim=2) 57 | 58 | past_key_value = (key_states, value_states) if use_cache else None 59 | 60 | # Transform the data into the format required by flash attention 61 | qkv = torch.stack([query_states, key_states, value_states], dim=2) 62 | qkv = qkv.transpose(1, 3) # shape: [b, s, 3, num_heads, head_dim] 63 | key_padding_mask = attention_mask 64 | 65 | if key_padding_mask is None: 66 | qkv = qkv.reshape(-1, 3, self.num_heads, self.head_dim) 67 | cu_q_lens = torch.arange( 68 | 0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device 69 | ) 70 | max_s = q_len 71 | output = flash_attn_varlen_qkvpacked_func( 72 | qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True 73 | ) 74 | output = output.view(bsz, q_len, -1) 75 | else: 76 | qkv = qkv.reshape(bsz, q_len, -1) 77 | qkv, indices, cu_q_lens, max_s = unpad_input(qkv, key_padding_mask) 78 | qkv = qkv.view(-1, 3, self.num_heads, self.head_dim) 79 | output_unpad = flash_attn_varlen_qkvpacked_func( 80 | qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True 81 | ) 82 | output_unpad = output_unpad.reshape(-1, self.num_heads * self.head_dim) 83 | output = pad_input(output_unpad, indices, bsz, q_len) 84 | 85 | return self.o_proj(output), None, past_key_value 86 | 87 | 88 | # Disable the transformation of the attention mask in LlamaModel as the flash attention 89 | # requires the attention mask to be the same as the key_padding_mask 90 | def _prepare_decoder_attention_mask( 91 | self, attention_mask, input_shape, inputs_embeds, past_key_values_length 92 | ): 93 | # [bsz, seq_len] 94 | return attention_mask 95 | 96 | 97 | def replace_llama_attn_with_flash_attn(): 98 | cuda_major, cuda_minor = torch.cuda.get_device_capability() 99 | if cuda_major < 8: 100 | warnings.warn( 101 | "Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward." 102 | "ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593" 103 | ) 104 | transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( 105 | _prepare_decoder_attention_mask 106 | ) 107 | transformers.models.llama.modeling_llama.LlamaAttention.forward = forward 108 | -------------------------------------------------------------------------------- /fastchat/train/llama_xformers_attn_monkey_patch.py: -------------------------------------------------------------------------------- 1 | """ 2 | Directly copied the code from https://raw.githubusercontent.com/oobabooga/text-generation-webui/main/modules/llama_attn_hijack.py and made some adjustments 3 | """ 4 | 5 | import logging 6 | import math 7 | from typing import Optional, Tuple 8 | 9 | import torch 10 | import transformers.models.llama.modeling_llama 11 | from torch import nn 12 | 13 | try: 14 | import xformers.ops 15 | except ImportError: 16 | logging.error("xformers not found! Please install it before trying to use it.") 17 | 18 | 19 | def replace_llama_attn_with_xformers_attn(): 20 | transformers.models.llama.modeling_llama.LlamaAttention.forward = xformers_forward 21 | 22 | 23 | def xformers_forward( 24 | self, 25 | hidden_states: torch.Tensor, 26 | attention_mask: Optional[torch.Tensor] = None, 27 | position_ids: Optional[torch.LongTensor] = None, 28 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 29 | output_attentions: bool = False, 30 | use_cache: bool = False, 31 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 32 | # pylint: disable=duplicate-code 33 | bsz, q_len, _ = hidden_states.size() 34 | 35 | query_states = ( 36 | self.q_proj(hidden_states) 37 | .view(bsz, q_len, self.num_heads, self.head_dim) 38 | .transpose(1, 2) 39 | ) 40 | key_states = ( 41 | self.k_proj(hidden_states) 42 | .view(bsz, q_len, self.num_heads, self.head_dim) 43 | .transpose(1, 2) 44 | ) 45 | value_states = ( 46 | self.v_proj(hidden_states) 47 | .view(bsz, q_len, self.num_heads, self.head_dim) 48 | .transpose(1, 2) 49 | ) 50 | 51 | kv_seq_len = key_states.shape[-2] 52 | if past_key_value is not None: 53 | kv_seq_len += past_key_value[0].shape[-2] 54 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) 55 | ( 56 | query_states, 57 | key_states, 58 | ) = transformers.models.llama.modeling_llama.apply_rotary_pos_emb( 59 | query_states, key_states, cos, sin, position_ids 60 | ) 61 | # [bsz, nh, t, hd] 62 | 63 | if past_key_value is not None: 64 | # reuse k, v, self_attention 65 | key_states = torch.cat([past_key_value[0], key_states], dim=2) 66 | value_states = torch.cat([past_key_value[1], value_states], dim=2) 67 | 68 | past_key_value = (key_states, value_states) if use_cache else None 69 | 70 | # We only apply xformers optimizations if we don't need to output the whole attention matrix 71 | if not output_attentions: 72 | query_states = query_states.transpose(1, 2) 73 | key_states = key_states.transpose(1, 2) 74 | value_states = value_states.transpose(1, 2) 75 | 76 | # This is a nasty hack. We know attention_mask in transformers is either LowerTriangular or all Zeros. 77 | # We therefore check if one element in the upper triangular portion is zero. If it is, then the mask is all zeros. 78 | if attention_mask is None or attention_mask[0, 0, 0, 1] == 0: 79 | # input and output should be of form (bsz, q_len, num_heads, head_dim) 80 | attn_output = xformers.ops.memory_efficient_attention( 81 | query_states, key_states, value_states, attn_bias=None 82 | ) 83 | else: 84 | # input and output should be of form (bsz, q_len, num_heads, head_dim) 85 | attn_output = xformers.ops.memory_efficient_attention( 86 | query_states, 87 | key_states, 88 | value_states, 89 | attn_bias=xformers.ops.LowerTriangularMask(), 90 | ) 91 | attn_weights = None 92 | else: 93 | attn_weights = torch.matmul( 94 | query_states, key_states.transpose(2, 3) 95 | ) / math.sqrt(self.head_dim) 96 | 97 | if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): 98 | raise ValueError( 99 | f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is" 100 | f" {attn_weights.size()}" 101 | ) 102 | 103 | if attention_mask is not None: 104 | if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): 105 | raise ValueError( 106 | f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" 107 | ) 108 | attn_weights = attn_weights + attention_mask 109 | attn_weights = torch.max( 110 | attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min) 111 | ) 112 | 113 | # upcast attention to fp32 114 | attn_weights = nn.functional.softmax( 115 | attn_weights, dim=-1, dtype=torch.float32 116 | ).to(query_states.dtype) 117 | attn_output = torch.matmul(attn_weights, value_states) 118 | 119 | if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): 120 | raise ValueError( 121 | f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" 122 | f" {attn_output.size()}" 123 | ) 124 | 125 | attn_output = attn_output.transpose(1, 2) 126 | 127 | attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) 128 | attn_output = self.o_proj(attn_output) 129 | return attn_output, attn_weights, past_key_value 130 | -------------------------------------------------------------------------------- /fastchat/train/train_mem.py: -------------------------------------------------------------------------------- 1 | # Make it more memory efficient by monkey patching the LLaMA model with FlashAttn. 2 | 3 | # Need to call this before importing transformers. 4 | from fastchat.train.llama2_flash_attn_monkey_patch import ( 5 | replace_llama_attn_with_flash_attn, 6 | ) 7 | 8 | replace_llama_attn_with_flash_attn() 9 | 10 | from fastchat.train.train import train 11 | 12 | if __name__ == "__main__": 13 | train() 14 | -------------------------------------------------------------------------------- /fastchat/train/train_xformers.py: -------------------------------------------------------------------------------- 1 | # Make it more memory efficient by monkey patching the LLaMA model with xformers attention. 2 | 3 | # Need to call this before importing transformers. 4 | from fastchat.train.llama_xformers_attn_monkey_patch import ( 5 | replace_llama_attn_with_xformers_attn, 6 | ) 7 | 8 | replace_llama_attn_with_xformers_attn() 9 | 10 | from fastchat.train.train import train 11 | 12 | if __name__ == "__main__": 13 | train() 14 | -------------------------------------------------------------------------------- /format.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Adapted from https://github.com/skypilot-org/skypilot/blob/master/format.sh 4 | 5 | # Cause the script to exit if a single command fails 6 | set -eo pipefail 7 | 8 | # this stops git rev-parse from failing if we run this from the .git directory 9 | builtin cd "$(dirname "${BASH_SOURCE:-$0}")" 10 | ROOT="$(git rev-parse --show-toplevel)" 11 | builtin cd "$ROOT" || exit 1 12 | 13 | BLACK_VERSION=$(black --version | head -n 1 | awk '{print $2}') 14 | PYLINT_VERSION=$(pylint --version | head -n 1 | awk '{print $2}') 15 | 16 | # # params: tool name, tool version, required version 17 | tool_version_check() { 18 | if [[ $2 != $3 ]]; then 19 | echo "Wrong $1 version installed: $3 is required, not $2." 20 | exit 1 21 | fi 22 | } 23 | 24 | tool_version_check "black" $BLACK_VERSION "23.3.0" 25 | tool_version_check "pylint" $PYLINT_VERSION "2.8.2" 26 | 27 | # Format files that differ from main branch. Ignores dirs that are not slated 28 | # for autoformat yet. 29 | format_changed() { 30 | # The `if` guard ensures that the list of filenames is not empty, which 31 | # could cause yapf to receive 0 positional arguments, making it hang 32 | # waiting for STDIN. 33 | # 34 | # `diff-filter=ACM` and $MERGEBASE is to ensure we only format files that 35 | # exist on both branches. 36 | MERGEBASE="$(git merge-base origin/main HEAD)" 37 | 38 | if ! git diff --diff-filter=ACM --quiet --exit-code "$MERGEBASE" -- '*.py' '*.pyi' &>/dev/null; then 39 | git diff --name-only --diff-filter=ACM "$MERGEBASE" -- '*.py' '*.pyi' | xargs -P 5 black 40 | fi 41 | } 42 | 43 | ## This flag formats individual files. --files *must* be the first command line 44 | ## arg to use this option. 45 | if [[ "$1" == '--files' ]]; then 46 | black "${@:2}" 47 | # If `--all` is passed, then any further arguments are ignored and the 48 | # entire python directory is formatted. 49 | elif [[ "$1" == '--all' ]]; then 50 | # Format all files 51 | black fastchat 52 | else 53 | # Format only the files that changed in last commit. 54 | format_changed 55 | fi 56 | echo 'FastChat Black: Done' 57 | 58 | # Run Pylint 59 | echo 'FastChat Pylint:' 60 | pylint fastchat 61 | # TODO(suquark): disable 'pylint_quotes' for now due to too many inconsistent quotes 62 | # pylint --load-plugins pylint_quotes fastchat 63 | 64 | if ! git diff --quiet &>/dev/null; then 65 | echo 'Reformatted files. Please review and stage the changes.' 66 | echo 'Changes not staged for commit:' 67 | echo 68 | git --no-pager diff --name-only 69 | 70 | exit 1 71 | fi 72 | -------------------------------------------------------------------------------- /imagenhub_requirements.txt: -------------------------------------------------------------------------------- 1 | xformers 2 | flask 3 | flask_restful 4 | flask_cors 5 | faiss-cpu 6 | fire 7 | h5py 8 | numpy>=1.24.0 9 | pandas<2.0.0 10 | peft 11 | torch 12 | torchvision 13 | torchaudio 14 | jupyterlab>=4.0.2 15 | notebook>=6.5.4 16 | albumentations>=1.1.0 17 | opencv-python>=4.2.0 18 | pudb~=2019.2 19 | imageio>=2.14.1 20 | imageio-ffmpeg>=0.4.7 21 | pytorch-lightning>=1.5.9 22 | omegaconf~=2.1.1 23 | gradio 24 | pillow~=9.5.0 25 | einops>=0.4.1 26 | torch-fidelity>=0.3.0 27 | setuptools>=59.5.0 28 | transformers~=4.28.0 29 | torchmetrics>=0.6.0 30 | lpips 31 | dreamsim 32 | image-reward 33 | kornia>=0.6 34 | diffusers==0.24.0 35 | accelerate>=0.20.3 36 | safetensors 37 | datasets 38 | tqdm>=4.64.1 39 | matplotlib>=3.7.1 40 | taming-transformers-rom1504~=0.0.6 41 | madgrad>=1.1 42 | -e git+https://github.com/openai/CLIP.git@main#egg=clip 43 | dominate>=2.8.0 44 | -e git+https://github.com/CompVis/latent-diffusion.git#egg=latent-diffusion #ldm 45 | openai 46 | nltk~=3.8.1 47 | krippendorff 48 | statsmodels -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61.0"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "fschat" 7 | version = "0.2.34" 8 | description = "An open platform for training, serving, and evaluating large language model based chatbots." 9 | readme = "README.md" 10 | requires-python = ">=3.8" 11 | classifiers = [ 12 | "Programming Language :: Python :: 3", 13 | "License :: OSI Approved :: Apache Software License", 14 | ] 15 | dependencies = [ 16 | "aiohttp", "fastapi", "httpx", "markdown2[all]", "nh3", "numpy", 17 | "prompt_toolkit>=3.0.0", "pydantic<2,>=1", "requests", "rich>=10.0.0", 18 | "shortuuid", "tiktoken", "uvicorn", "diffusers", "chardet", "omegaconf", 19 | "openai", "IPython", "ftfy", "plotly", "gradio<4.0,>3.0" 20 | ] 21 | 22 | [project.optional-dependencies] 23 | model_worker = ["accelerate>=0.21", "peft", "sentencepiece", "torch", "transformers>=4.31.0", "protobuf"] 24 | webui = ["gradio"] 25 | train = ["einops", "flash-attn>=2.0", "wandb"] 26 | llm_judge = ["openai<1", "anthropic>=0.3", "ray"] 27 | dev = ["black==23.3.0", "pylint==2.8.2"] 28 | 29 | [project.urls] 30 | "Homepage" = "https://github.com/lm-sys/fastchat" 31 | "Bug Tracker" = "https://github.com/lm-sys/fastchat/issues" 32 | 33 | [tool.setuptools.packages.find] 34 | exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"] 35 | 36 | [tool.wheel] 37 | exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"] 38 | --------------------------------------------------------------------------------