35 | #
36 | # ```
37 | # This function convert it into the correct markdown format
38 | return re.sub(code_lang_pattern, code_lang_format, val)
39 |
40 |
41 | def html_to_markdown(val: str) -> str:
42 | # Remove all . This is required to make intent work in code blocks.
43 | val = re.sub(div_pattern, "", val)
44 | # Remove all . This is required to make underscores work in code blocks.
45 | val = re.sub(span_pattern, "", val)
46 | # Markdown to html
47 | val = markdownify.markdownify(val).strip()
48 | # Reformat code
49 | val = reformat_code(val)
50 |
51 | # Remove noisy "[number] / [number]" at the beginning
52 | noise = re.search(regenerate_pattern, val)
53 | if noise and noise.start() == 0:
54 | val = val[noise.end() :]
55 | # Remove noisy "Copy[number] chars / [number] words"
56 | val = re.sub(copy_chars_pattern, "", val)
57 | # Remove empty code block ```\nCopy code\n```
58 | val = re.sub(copy_code_pattern, "", val)
59 |
60 | # Strip
61 | val = val.replace("\n\n\n", "\n").strip()
62 |
63 | return val
64 |
65 |
66 | def contain_blocked_words(val: str) -> bool:
67 | blocked_words = ["openai", "chatgpt"]
68 | for w in blocked_words:
69 | if w in val.lower():
70 | return True
71 | return False
72 |
73 |
74 | def clean_html_one_sample(sample):
75 | roles = ["human", "gpt"]
76 |
77 | if len(sample["conversations"]) <= 1:
78 | return (sample, 1)
79 |
80 | # Adjust the offset for cases like https://sharegpt.com/c/VyaZlh4
81 | if sample["conversations"][0]["from"] != "human":
82 | sample["conversations"] = sample["conversations"][1:]
83 | if len(sample["conversations"]) <= 1:
84 | return (sample, 1)
85 |
86 | if sample["conversations"][-1]["from"] == "human":
87 | sample["conversations"] = sample["conversations"][:-1]
88 | if len(sample["conversations"]) <= 1:
89 | return (sample, 1)
90 |
91 | char_count = 0
92 | new_conversations = []
93 | for i, c in enumerate(sample["conversations"]):
94 | if c["from"] != roles[i % 2]:
95 | return (sample, 2)
96 |
97 | if contain_blocked_words(c["value"]):
98 | return (sample, 3)
99 |
100 | try:
101 | new_val = html_to_markdown(c["value"])
102 | except (bs4.builder.ParserRejectedMarkup, AssertionError):
103 | return (sample, 4)
104 |
105 | # Filter empty answers like https://sharegpt.com/c/mrllZ6u
106 | if not new_val or not new_val[0].isprintable():
107 | break
108 |
109 | char_count += len(new_val)
110 | new_conversations.append(
111 | {
112 | "from": c["from"],
113 | "value": new_val,
114 | }
115 | )
116 |
117 | new_conversations = new_conversations[: len(new_conversations) // 2 * 2]
118 | sample["conversations"] = new_conversations
119 |
120 | if char_count < 16 or len(sample["conversations"]) <= 0:
121 | return (sample, 1)
122 |
123 | return (sample, 0)
124 |
125 |
126 | def clean_html_all(content, begin, end):
127 | """
128 | Clean the source html files.
129 | """
130 | cnt_skip = 0
131 | cnt_blocked_words = 0
132 | cnt_wrong_format = 0
133 | cnt_parser_error = 0
134 | cnt_too_short = 0
135 | cnt_id_duplication = 0
136 | cnt_value_duplication = 0
137 | cnt_plugin = 0
138 | cnt_tag = 0
139 |
140 | content = content[begin:end]
141 | processed = []
142 | with ProcessPoolExecutor() as executor:
143 | for result in tqdm(
144 | executor.map(clean_html_one_sample, content), total=len(content)
145 | ):
146 | processed.append(result)
147 |
148 | visited = {}
149 | new_content = []
150 | for sample, error_code in processed:
151 | cid = sample["id"]
152 | skipped = True
153 |
154 | if error_code != 0:
155 | if error_code == 1:
156 | print(f"id {cid} is too short")
157 | cnt_too_short += 1
158 | elif error_code == 2:
159 | print(f"id {cid} has a wrong format")
160 | cnt_wrong_format += 1
161 | elif error_code == 3:
162 | print(f"id {cid} contains blocked words")
163 | cnt_blocked_words += 1
164 | elif error_code == 4:
165 | print(f"id {cid} contains parser errors")
166 | cnt_parser_error += 1
167 | else:
168 | raise ValueError(f"Invalid error_code: {error_code}")
169 | elif cid in visited:
170 | print(f"id {cid} is an id duplication of {visited[cid]}")
171 | cnt_id_duplication += 1
172 | elif sample.get("plugins", None) is not None:
173 | print(f"id {cid} contains plugin")
174 | cnt_plugin += 1
175 | else:
176 | key = (
177 | sample["conversations"][0]["value"],
178 | sample["conversations"][1]["value"],
179 | )
180 | if key in visited:
181 | print(f"id {cid} is a value duplication of {visited[key]}")
182 | cnt_value_duplication += 1
183 | else:
184 | visited[cid] = visited[key] = cid
185 | skipped = False
186 |
187 | if not skipped:
188 | new_content.append(sample)
189 | else:
190 | cnt_skip += 1
191 |
192 | print(
193 | f"total: {len(content)}, skip: {cnt_skip}, new: {len(new_content)}, "
194 | f"cnt_blocked_words: {cnt_blocked_words}, cnt_parser_error: {cnt_parser_error}, "
195 | f"cnt_wrong_format: {cnt_wrong_format}, "
196 | f"cnt_too_short: {cnt_too_short}, cnt_id_duplication: {cnt_id_duplication}, "
197 | f"cnt_value_duplication: {cnt_value_duplication}, cnt_plugin: {cnt_plugin}"
198 | )
199 |
200 | return new_content
201 |
202 |
203 | def main(args):
204 | content = json.load(open(args["in_file"], "r"))
205 | content = clean_html_all(content, args["begin"], args["end"])
206 | json.dump(content, open(args["out_file"], "w"), indent=2, ensure_ascii=False)
207 |
208 |
209 | if __name__ == "__main__":
210 | parser = argparse.ArgumentParser()
211 | parser.add_argument("--in-file", type=str, required=True)
212 | parser.add_argument("--out-file", type=str, default="sharegpt_clean.json")
213 | parser.add_argument("--begin", type=int)
214 | parser.add_argument("--end", type=int)
215 | parser.add_argument("--debug", action="store_true")
216 | args = parser.parse_args()
217 | main(vars(args))
218 |
--------------------------------------------------------------------------------
/chat/server/monitor/basic_stats.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import code
3 | import datetime
4 | import json
5 | import os
6 | from pytz import timezone
7 | import time
8 |
9 | import pandas as pd # pandas>=2.0.3
10 | import plotly.express as px
11 | import plotly.graph_objects as go
12 | from tqdm import tqdm
13 |
14 |
15 | NUM_SERVERS = 14
16 |
17 |
18 | def get_log_files(max_num_files=None):
19 | dates = []
20 | for month in range(4, 9):
21 | for day in range(1, 33):
22 | dates.append(f"2023-{month:02d}-{day:02d}")
23 |
24 | filenames = []
25 | for d in dates:
26 | for i in range(NUM_SERVERS):
27 | name = os.path.expanduser(f"~/fastchat_logs/server{i}/{d}-conv.json")
28 | if os.path.exists(name):
29 | filenames.append(name)
30 | max_num_files = max_num_files or len(filenames)
31 | filenames = filenames[-max_num_files:]
32 | return filenames
33 |
34 |
35 | def load_log_files(log_files):
36 | data = []
37 | for filename in tqdm(log_files, desc="read files"):
38 | for retry in range(5):
39 | try:
40 | lines = open(filename).readlines()
41 | break
42 | except FileNotFoundError:
43 | time.sleep(2)
44 |
45 | for l in lines:
46 | row = json.loads(l)
47 |
48 | data.append(
49 | dict(
50 | type=row["type"],
51 | tstamp=row["tstamp"],
52 | model=row.get("model", ""),
53 | models=row.get("models", ["", ""]),
54 | )
55 | )
56 |
57 | return data
58 |
59 |
60 | def get_anony_vote_df(df):
61 | anony_vote_df = df[
62 | df["type"].isin(["leftvote", "rightvote", "tievote", "bothbad_vote"])
63 | ]
64 | anony_vote_df = anony_vote_df[anony_vote_df["models"].apply(lambda x: x[0] == "")]
65 | return anony_vote_df
66 |
67 |
68 | def merge_counts(series, on, names):
69 | ret = pd.merge(series[0], series[1], on=on)
70 | for i in range(2, len(series)):
71 | ret = pd.merge(ret, series[i], on=on)
72 | ret = ret.reset_index()
73 | old_names = list(ret.columns)[-len(series) :]
74 | rename = {old_name: new_name for old_name, new_name in zip(old_names, names)}
75 | ret = ret.rename(columns=rename)
76 | return ret
77 |
78 |
79 | def report_basic_stats(log_files):
80 | df_all = load_log_files(log_files)
81 | df_all = pd.DataFrame(df_all)
82 | now_t = df_all["tstamp"].max()
83 | df_1_hour = df_all[df_all["tstamp"] > (now_t - 3600)]
84 | df_1_day = df_all[df_all["tstamp"] > (now_t - 3600 * 24)]
85 | anony_vote_df_all = get_anony_vote_df(df_all)
86 |
87 | # Chat trends
88 | chat_dates = [
89 | datetime.datetime.fromtimestamp(x, tz=timezone("US/Pacific")).strftime(
90 | "%Y-%m-%d"
91 | )
92 | for x in df_all[df_all["type"] == "chat"]["tstamp"]
93 | ]
94 | chat_dates_counts = pd.value_counts(chat_dates)
95 | vote_dates = [
96 | datetime.datetime.fromtimestamp(x, tz=timezone("US/Pacific")).strftime(
97 | "%Y-%m-%d"
98 | )
99 | for x in anony_vote_df_all["tstamp"]
100 | ]
101 | vote_dates_counts = pd.value_counts(vote_dates)
102 | chat_dates_bar = go.Figure(
103 | data=[
104 | go.Bar(
105 | name="Anony. Vote",
106 | x=vote_dates_counts.index,
107 | y=vote_dates_counts,
108 | text=[f"{val:.0f}" for val in vote_dates_counts],
109 | textposition="auto",
110 | ),
111 | go.Bar(
112 | name="Chat",
113 | x=chat_dates_counts.index,
114 | y=chat_dates_counts,
115 | text=[f"{val:.0f}" for val in chat_dates_counts],
116 | textposition="auto",
117 | ),
118 | ]
119 | )
120 | chat_dates_bar.update_layout(
121 | barmode="stack",
122 | xaxis_title="Dates",
123 | yaxis_title="Count",
124 | height=300,
125 | width=1200,
126 | )
127 |
128 | # Model call counts
129 | model_hist_all = df_all[df_all["type"] == "chat"]["model"].value_counts()
130 | model_hist_1_day = df_1_day[df_1_day["type"] == "chat"]["model"].value_counts()
131 | model_hist_1_hour = df_1_hour[df_1_hour["type"] == "chat"]["model"].value_counts()
132 | model_hist = merge_counts(
133 | [model_hist_all, model_hist_1_day, model_hist_1_hour],
134 | on="model",
135 | names=["All", "Last Day", "Last Hour"],
136 | )
137 | model_hist_md = model_hist.to_markdown(index=False, tablefmt="github")
138 |
139 | # Action counts
140 | action_hist_all = df_all["type"].value_counts()
141 | action_hist_1_day = df_1_day["type"].value_counts()
142 | action_hist_1_hour = df_1_hour["type"].value_counts()
143 | action_hist = merge_counts(
144 | [action_hist_all, action_hist_1_day, action_hist_1_hour],
145 | on="type",
146 | names=["All", "Last Day", "Last Hour"],
147 | )
148 | action_hist_md = action_hist.to_markdown(index=False, tablefmt="github")
149 |
150 | # Anony vote counts
151 | anony_vote_hist_all = anony_vote_df_all["type"].value_counts()
152 | anony_vote_df_1_day = get_anony_vote_df(df_1_day)
153 | anony_vote_hist_1_day = anony_vote_df_1_day["type"].value_counts()
154 | # anony_vote_df_1_hour = get_anony_vote_df(df_1_hour)
155 | # anony_vote_hist_1_hour = anony_vote_df_1_hour["type"].value_counts()
156 | anony_vote_hist = merge_counts(
157 | [anony_vote_hist_all, anony_vote_hist_1_day],
158 | on="type",
159 | names=["All", "Last Day"],
160 | )
161 | anony_vote_hist_md = anony_vote_hist.to_markdown(index=False, tablefmt="github")
162 |
163 | # Last 24 hours
164 | chat_1_day = df_1_day[df_1_day["type"] == "chat"]
165 | num_chats_last_24_hours = []
166 | base = df_1_day["tstamp"].min()
167 | for i in range(24, 0, -1):
168 | left = base + (i - 1) * 3600
169 | right = base + i * 3600
170 | num = ((chat_1_day["tstamp"] >= left) & (chat_1_day["tstamp"] < right)).sum()
171 | num_chats_last_24_hours.append(num)
172 | times = [
173 | datetime.datetime.fromtimestamp(
174 | base + i * 3600, tz=timezone("US/Pacific")
175 | ).strftime("%Y-%m-%d %H:%M:%S %Z")
176 | for i in range(24, 0, -1)
177 | ]
178 | last_24_hours_df = pd.DataFrame({"time": times, "value": num_chats_last_24_hours})
179 | last_24_hours_md = last_24_hours_df.to_markdown(index=False, tablefmt="github")
180 |
181 | # Last update datetime
182 | last_updated_tstamp = now_t
183 | last_updated_datetime = datetime.datetime.fromtimestamp(
184 | last_updated_tstamp, tz=timezone("US/Pacific")
185 | ).strftime("%Y-%m-%d %H:%M:%S %Z")
186 |
187 | # code.interact(local=locals())
188 |
189 | return {
190 | "chat_dates_bar": chat_dates_bar,
191 | "model_hist_md": model_hist_md,
192 | "action_hist_md": action_hist_md,
193 | "anony_vote_hist_md": anony_vote_hist_md,
194 | "num_chats_last_24_hours": last_24_hours_md,
195 | "last_updated_datetime": last_updated_datetime,
196 | }
197 |
198 |
199 | if __name__ == "__main__":
200 | parser = argparse.ArgumentParser()
201 | parser.add_argument("--max-num-files", type=int)
202 | args = parser.parse_args()
203 |
204 | log_files = get_log_files(args.max_num_files)
205 | basic_stats = report_basic_stats(log_files)
206 |
207 | print(basic_stats["action_hist_md"] + "\n")
208 | print(basic_stats["model_hist_md"] + "\n")
209 | print(basic_stats["anony_vote_hist_md"] + "\n")
210 | print(basic_stats["num_chats_last_24_hours"] + "\n")
211 |
--------------------------------------------------------------------------------
/chat/server/vllm_worker.py:
--------------------------------------------------------------------------------
1 | """
2 | A model worker that executes the model based on vLLM.
3 |
4 | See documentations at docs/vllm_integration.md
5 | """
6 |
7 | import argparse
8 | import asyncio
9 | import json
10 | from typing import List
11 |
12 | from fastapi import FastAPI, Request, BackgroundTasks
13 | from fastapi.responses import StreamingResponse, JSONResponse
14 | import torch
15 | import uvicorn
16 | from vllm import AsyncLLMEngine
17 | from vllm.engine.arg_utils import AsyncEngineArgs
18 | from vllm.sampling_params import SamplingParams
19 | from vllm.utils import random_uuid
20 |
21 | from chat.server.model_worker import (
22 | BaseModelWorker,
23 | logger,
24 | worker_id,
25 | )
26 | from chat.utils import get_context_length
27 |
28 |
29 | app = FastAPI()
30 |
31 |
32 | class VLLMWorker(BaseModelWorker):
33 | def __init__(
34 | self,
35 | controller_addr: str,
36 | worker_addr: str,
37 | worker_id: str,
38 | model_path: str,
39 | model_names: List[str],
40 | limit_worker_concurrency: int,
41 | no_register: bool,
42 | llm_engine: AsyncLLMEngine,
43 | conv_template: str,
44 | ):
45 | super().__init__(
46 | controller_addr,
47 | worker_addr,
48 | worker_id,
49 | model_path,
50 | model_names,
51 | limit_worker_concurrency,
52 | conv_template,
53 | )
54 |
55 | logger.info(
56 | f"Loading the model {self.model_names} on worker {worker_id}, worker type: vLLM worker..."
57 | )
58 | self.tokenizer = llm_engine.engine.tokenizer
59 | self.context_len = get_context_length(llm_engine.engine.model_config.hf_config)
60 |
61 | if not no_register:
62 | self.init_heart_beat()
63 |
64 | async def generate_stream(self, params):
65 | self.call_ct += 1
66 |
67 | context = params.pop("prompt")
68 | request_id = params.pop("request_id")
69 | temperature = float(params.get("temperature", 1.0))
70 | top_p = float(params.get("top_p", 1.0))
71 | max_new_tokens = params.get("max_new_tokens", 256)
72 | stop_str = params.get("stop", None)
73 | stop_token_ids = params.get("stop_token_ids", None) or []
74 | if self.tokenizer.eos_token_id is not None:
75 | stop_token_ids.append(self.tokenizer.eos_token_id)
76 | echo = params.get("echo", True)
77 |
78 | # Handle stop_str
79 | stop = set()
80 | if isinstance(stop_str, str) and stop_str != "":
81 | stop.add(stop_str)
82 | elif isinstance(stop_str, list) and stop_str != []:
83 | stop.update(stop_str)
84 |
85 | for tid in stop_token_ids:
86 | if tid is not None:
87 | stop.add(self.tokenizer.decode(tid))
88 |
89 | # make sampling params in vllm
90 | top_p = max(top_p, 1e-5)
91 | if temperature <= 1e-5:
92 | top_p = 1.0
93 | sampling_params = SamplingParams(
94 | n=1,
95 | temperature=temperature,
96 | top_p=top_p,
97 | use_beam_search=False,
98 | stop=list(stop),
99 | max_tokens=max_new_tokens,
100 | )
101 | results_generator = engine.generate(context, sampling_params, request_id)
102 |
103 | async for request_output in results_generator:
104 | prompt = request_output.prompt
105 | if echo:
106 | text_outputs = [
107 | prompt + output.text for output in request_output.outputs
108 | ]
109 | else:
110 | text_outputs = [output.text for output in request_output.outputs]
111 | text_outputs = " ".join(text_outputs)
112 | # Note: usage is not supported yet
113 | ret = {"text": text_outputs, "error_code": 0, "usage": {}}
114 | yield (json.dumps(ret) + "\0").encode()
115 |
116 | async def generate(self, params):
117 | async for x in self.generate_stream(params):
118 | pass
119 | return json.loads(x[:-1].decode())
120 |
121 |
122 | def release_worker_semaphore():
123 | worker.semaphore.release()
124 |
125 |
126 | def acquire_worker_semaphore():
127 | if worker.semaphore is None:
128 | worker.semaphore = asyncio.Semaphore(worker.limit_worker_concurrency)
129 | return worker.semaphore.acquire()
130 |
131 |
132 | def create_background_tasks(request_id):
133 | async def abort_request() -> None:
134 | await engine.abort(request_id)
135 |
136 | background_tasks = BackgroundTasks()
137 | background_tasks.add_task(release_worker_semaphore)
138 | background_tasks.add_task(abort_request)
139 | return background_tasks
140 |
141 |
142 | @app.post("/worker_generate_stream")
143 | async def api_generate_stream(request: Request):
144 | params = await request.json()
145 | await acquire_worker_semaphore()
146 | request_id = random_uuid()
147 | params["request_id"] = request_id
148 | generator = worker.generate_stream(params)
149 | background_tasks = create_background_tasks(request_id)
150 | return StreamingResponse(generator, background=background_tasks)
151 |
152 |
153 | @app.post("/worker_generate")
154 | async def api_generate(request: Request):
155 | params = await request.json()
156 | await acquire_worker_semaphore()
157 | request_id = random_uuid()
158 | params["request_id"] = request_id
159 | output = await worker.generate(params)
160 | release_worker_semaphore()
161 | await engine.abort(request_id)
162 | return JSONResponse(output)
163 |
164 |
165 | @app.post("/worker_get_status")
166 | async def api_get_status(request: Request):
167 | return worker.get_status()
168 |
169 |
170 | @app.post("/count_token")
171 | async def api_count_token(request: Request):
172 | params = await request.json()
173 | return worker.count_token(params)
174 |
175 |
176 | @app.post("/worker_get_conv_template")
177 | async def api_get_conv(request: Request):
178 | return worker.get_conv_template()
179 |
180 |
181 | @app.post("/model_details")
182 | async def api_model_details(request: Request):
183 | return {"context_length": worker.context_len}
184 |
185 |
186 | if __name__ == "__main__":
187 | parser = argparse.ArgumentParser()
188 | parser.add_argument("--host", type=str, default="localhost")
189 | parser.add_argument("--port", type=int, default=21002)
190 | parser.add_argument("--worker-address", type=str, default="http://localhost:21002")
191 | parser.add_argument(
192 | "--controller-address", type=str, default="http://localhost:21001"
193 | )
194 | parser.add_argument("--model-path", type=str, default="lmsys/vicuna-7b-v1.3")
195 | parser.add_argument(
196 | "--model-names",
197 | type=lambda s: s.split(","),
198 | help="Optional display comma separated names",
199 | )
200 | parser.add_argument("--limit-worker-concurrency", type=int, default=1024)
201 | parser.add_argument("--no-register", action="store_true")
202 | parser.add_argument("--num-gpus", type=int, default=1)
203 | parser.add_argument(
204 | "--conv-template", type=str, default=None, help="Conversation prompt template."
205 | )
206 |
207 | parser = AsyncEngineArgs.add_cli_args(parser)
208 | args = parser.parse_args()
209 | if args.model_path:
210 | args.model = args.model_path
211 | if args.num_gpus > 1:
212 | args.tensor_parallel_size = args.num_gpus
213 |
214 | engine_args = AsyncEngineArgs.from_cli_args(args)
215 | engine = AsyncLLMEngine.from_engine_args(engine_args)
216 | worker = VLLMWorker(
217 | args.controller_address,
218 | args.worker_address,
219 | worker_id,
220 | args.model_path,
221 | args.model_names,
222 | args.limit_worker_concurrency,
223 | args.no_register,
224 | engine,
225 | args.conv_template,
226 | )
227 | uvicorn.run(app, host=args.host, port=args.port, log_level="info")
228 |
--------------------------------------------------------------------------------
/chat/server/monitor/clean_battle_data.py:
--------------------------------------------------------------------------------
1 | """
2 | Clean chatbot arena battle log.
3 |
4 | Usage:
5 | python3 clean_battle_data.py --mode conv_release
6 | """
7 | import argparse
8 | import datetime
9 | import json
10 | import os
11 | from pytz import timezone
12 | import time
13 |
14 | from tqdm import tqdm
15 |
16 | from chat.server.monitor.basic_stats import get_log_files, NUM_SERVERS
17 | from chat.utils import detect_language
18 |
19 |
20 | VOTES = ["tievote", "leftvote", "rightvote", "bothbad_vote"]
21 | IDENTITY_WORDS = [
22 | "vicuna",
23 | "lmsys",
24 | "koala",
25 | "uc berkeley",
26 | "open assistant",
27 | "laion",
28 | "chatglm",
29 | "chatgpt",
30 | "openai",
31 | "anthropic",
32 | "claude",
33 | "bard",
34 | "palm",
35 | "lamda",
36 | "google",
37 | "NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.",
38 | ]
39 |
40 | for i in range(len(IDENTITY_WORDS)):
41 | IDENTITY_WORDS[i] = IDENTITY_WORDS[i].lower()
42 |
43 |
44 | def get_log_files(max_num_files=None):
45 | dates = []
46 | for month in [4, 5, 6, 7]:
47 | for day in range(1, 32):
48 | dates.append(f"2023-{month:02d}-{day:02d}")
49 |
50 | for month in [8]:
51 | for day in range(1, 32):
52 | dates.append(f"2023-{month:02d}-{day:02d}")
53 |
54 | filenames = []
55 | for d in dates:
56 | for i in range(NUM_SERVERS):
57 | name = os.path.expanduser(f"~/fastchat_logs/server{i}/{d}-conv.json")
58 | if os.path.exists(name):
59 | filenames.append(name)
60 | max_num_files = max_num_files or len(filenames)
61 | filenames = filenames[-max_num_files:]
62 | return filenames
63 |
64 |
65 | def remove_html(raw):
66 | if raw.startswith(""):
67 | return raw[raw.find(": ") + 2 : -len("
\n")]
68 | return raw
69 |
70 |
71 | def to_openai_format(messages):
72 | roles = ["user", "assistant"]
73 | ret = []
74 | for i, x in enumerate(messages):
75 | ret.append({"role": roles[i % 2], "content": x[1]})
76 | return ret
77 |
78 |
79 | def replace_model_name(old_name):
80 | return (
81 | old_name.replace("bard", "palm-2")
82 | .replace("claude-v1", "claude-1")
83 | .replace("claude-instant-v1", "claude-instant-1")
84 | .replace("oasst-sft-1-pythia-12b", "oasst-pythia-12b")
85 | )
86 |
87 |
88 | def clean_battle_data(log_files):
89 | data = []
90 | for filename in tqdm(log_files, desc="read files"):
91 | for retry in range(5):
92 | try:
93 | lines = open(filename).readlines()
94 | break
95 | except FileNotFoundError:
96 | time.sleep(2)
97 |
98 | for l in lines:
99 | row = json.loads(l)
100 | if row["type"] in VOTES:
101 | data.append(row)
102 |
103 | convert_type = {
104 | "leftvote": "model_a",
105 | "rightvote": "model_b",
106 | "tievote": "tie",
107 | "bothbad_vote": "tie (bothbad)",
108 | }
109 |
110 | all_models = set()
111 | all_ips = dict()
112 | ct_anony = 0
113 | ct_invalid = 0
114 | ct_leaked_identity = 0
115 | battles = []
116 | for row in data:
117 | if row["models"][0] is None or row["models"][1] is None:
118 | continue
119 |
120 | # Resolve model names
121 | models_public = [remove_html(row["models"][0]), remove_html(row["models"][1])]
122 | if "model_name" in row["states"][0]:
123 | models_hidden = [
124 | row["states"][0]["model_name"],
125 | row["states"][1]["model_name"],
126 | ]
127 | if models_hidden[0] is None:
128 | models_hidden = models_public
129 | else:
130 | models_hidden = models_public
131 |
132 | if (models_public[0] == "" and models_public[1] != "") or (
133 | models_public[1] == "" and models_public[0] != ""
134 | ):
135 | ct_invalid += 1
136 | continue
137 |
138 | if models_public[0] == "" or models_public[0] == "Model A":
139 | anony = True
140 | models = models_hidden
141 | ct_anony += 1
142 | else:
143 | anony = False
144 | models = models_public
145 | if not models_public == models_hidden:
146 | ct_invalid += 1
147 | continue
148 |
149 | # Detect langauge
150 | state = row["states"][0]
151 | if state["offset"] >= len(state["messages"]):
152 | ct_invalid += 1
153 | continue
154 | lang_code = detect_language(state["messages"][state["offset"]][1])
155 |
156 | # Drop conversations if the model names are leaked
157 | leaked_identity = False
158 | messages = ""
159 | for i in range(2):
160 | state = row["states"][i]
161 | for role, msg in state["messages"][state["offset"] :]:
162 | if msg:
163 | messages += msg.lower()
164 | for word in IDENTITY_WORDS:
165 | if word in messages:
166 | leaked_identity = True
167 | break
168 |
169 | if leaked_identity:
170 | ct_leaked_identity += 1
171 | continue
172 |
173 | # Replace bard with palm
174 | models = [replace_model_name(m) for m in models]
175 |
176 | question_id = row["states"][0]["conv_id"]
177 | conversation_a = to_openai_format(
178 | row["states"][0]["messages"][row["states"][0]["offset"] :]
179 | )
180 | conversation_b = to_openai_format(
181 | row["states"][1]["messages"][row["states"][1]["offset"] :]
182 | )
183 |
184 | ip = row["ip"]
185 | if ip not in all_ips:
186 | all_ips[ip] = len(all_ips)
187 | user_id = all_ips[ip]
188 |
189 | # Save the result
190 | battles.append(
191 | dict(
192 | question_id=question_id,
193 | model_a=models[0],
194 | model_b=models[1],
195 | winner=convert_type[row["type"]],
196 | judge=f"arena_user_{user_id}",
197 | conversation_a=conversation_a,
198 | conversation_b=conversation_b,
199 | turn=len(conversation_a) // 2,
200 | anony=anony,
201 | language=lang_code,
202 | tstamp=row["tstamp"],
203 | )
204 | )
205 |
206 | all_models.update(models_hidden)
207 | battles.sort(key=lambda x: x["tstamp"])
208 | last_updated_tstamp = battles[-1]["tstamp"]
209 |
210 | last_updated_datetime = datetime.datetime.fromtimestamp(
211 | last_updated_tstamp, tz=timezone("US/Pacific")
212 | ).strftime("%Y-%m-%d %H:%M:%S %Z")
213 |
214 | print(
215 | f"#votes: {len(data)}, #invalid votes: {ct_invalid}, "
216 | f"#leaked_identity: {ct_leaked_identity}"
217 | )
218 | print(f"#battles: {len(battles)}, #anony: {ct_anony}")
219 | print(f"#models: {len(all_models)}, {all_models}")
220 | print(f"last-updated: {last_updated_datetime}")
221 |
222 | return battles
223 |
224 |
225 | if __name__ == "__main__":
226 | parser = argparse.ArgumentParser()
227 | parser.add_argument("--max-num-files", type=int)
228 | parser.add_argument(
229 | "--mode", type=str, choices=["simple", "conv_release"], default="simple"
230 | )
231 | args = parser.parse_args()
232 |
233 | log_files = get_log_files(args.max_num_files)
234 | battles = clean_battle_data(log_files)
235 | last_updated_tstamp = battles[-1]["tstamp"]
236 | cutoff_date = datetime.datetime.fromtimestamp(
237 | last_updated_tstamp, tz=timezone("US/Pacific")
238 | ).strftime("%Y%m%d")
239 |
240 | if args.mode == "simple":
241 | for x in battles:
242 | for key in [
243 | "conversation_a",
244 | "conversation_b",
245 | "question_id",
246 | ]:
247 | del x[key]
248 | print("Samples:")
249 | for i in range(4):
250 | print(battles[i])
251 | output = f"clean_battle_{cutoff_date}.json"
252 | elif args.mode == "conv_release":
253 | new_battles = []
254 | for x in battles:
255 | if not x["anony"]:
256 | continue
257 | for key in []:
258 | del x[key]
259 | new_battles.append(x)
260 | battles = new_battles
261 | output = f"clean_battle_conv_{cutoff_date}.json"
262 |
263 | with open(output, "w") as fout:
264 | json.dump(battles, fout, indent=2, ensure_ascii=False)
265 | print(f"Write cleaned data to {output}")
266 |
--------------------------------------------------------------------------------
/chat/server/monitor/topic_clustering.py:
--------------------------------------------------------------------------------
1 | """
2 |
3 | Usage:
4 | python3 topic_clustering.py --in arena.json --english-only --min-length 32
5 | python3 topic_clustering.py --in clean_conv_20230809_100k.json --english-only --min-length 32 --max-length 1024
6 | """
7 | import argparse
8 | import json
9 | import pickle
10 | import string
11 | import time
12 |
13 | import numpy as np
14 | from sentence_transformers import SentenceTransformer
15 | from sentence_transformers.util import cos_sim
16 | from sklearn.cluster import KMeans, AgglomerativeClustering
17 | import torch
18 | from tqdm import tqdm
19 |
20 | from chat.utils import detect_language
21 |
22 |
23 | def remove_punctuation(input_string):
24 | # Make a translator object to remove all punctuation
25 | translator = str.maketrans("", "", string.punctuation)
26 |
27 | # Use the translator object to remove the punctuation
28 | no_punct = input_string.translate(translator)
29 | return no_punct
30 |
31 |
32 | def read_texts(input_file, min_length, max_length, english_only):
33 | visited = set()
34 | texts = []
35 |
36 | lines = json.load(open(input_file, "r"))
37 |
38 | for l in tqdm(lines):
39 | if "text" in l:
40 | line_texts = [l["text"]]
41 | elif "conversation_a" in l:
42 | line_texts = [
43 | x["content"] for x in l["conversation_a"] if x["role"] == "user"
44 | ]
45 | elif "conversation" in l:
46 | line_texts = [
47 | x["content"] for x in l["conversation"] if x["role"] == "user"
48 | ]
49 |
50 | for text in line_texts:
51 | text = text.strip()
52 |
53 | # Filter language
54 | if english_only:
55 | lang = detect_language(text)
56 | if lang != "English":
57 | continue
58 |
59 | # Filter short or long prompts
60 | if min_length:
61 | if len(text) < min_length:
62 | continue
63 |
64 | if max_length:
65 | if len(text) > max_length:
66 | continue
67 |
68 | # De-duplication
69 | words = sorted([x.lower() for x in remove_punctuation(text).split(" ")])
70 | words = "".join(words)
71 | if words in visited:
72 | continue
73 |
74 | visited.add(words)
75 | texts.append(text)
76 | return np.array(texts)
77 |
78 |
79 | def get_embeddings(texts, model_name, batch_size):
80 | model = SentenceTransformer(model_name)
81 | embeddings = model.encode(
82 | texts,
83 | batch_size=batch_size,
84 | show_progress_bar=True,
85 | device="cuda",
86 | convert_to_tensor=True,
87 | )
88 | embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
89 | return embeddings.cpu()
90 |
91 |
92 | def run_k_means(embeddings, num_clusters):
93 | np.random.seed(0)
94 | clustering_model = KMeans(n_clusters=num_clusters, n_init="auto")
95 | clustering_model.fit(embeddings.numpy())
96 | centers = torch.from_numpy(clustering_model.cluster_centers_)
97 | labels = torch.from_numpy(clustering_model.labels_)
98 |
99 | # Sort labels
100 | classes, counts = np.unique(labels, return_counts=True)
101 | indices = np.argsort(counts)[::-1]
102 | classes = [classes[i] for i in indices]
103 | new_labels = torch.empty_like(labels)
104 | new_centers = torch.empty_like(centers)
105 | for i, c in enumerate(classes):
106 | new_labels[labels == c] = i
107 | new_centers[i] = centers[c]
108 | return new_centers, new_labels
109 |
110 |
111 | def run_agg_cluster(embeddings, num_clusters):
112 | np.random.seed(0)
113 | clustering_model = AgglomerativeClustering(n_clusters=num_clusters)
114 | clustering_model.fit(embeddings)
115 | labels = torch.from_numpy(clustering_model.labels_)
116 |
117 | # Sort labels
118 | classes, counts = np.unique(labels, return_counts=True)
119 | indices = np.argsort(counts)[::-1]
120 | classes = [classes[i] for i in indices]
121 | new_labels = torch.empty_like(labels)
122 | for i, c in enumerate(classes):
123 | new_labels[labels == c] = i
124 |
125 | # Compute centers
126 | centers = []
127 | for i in range(clustering_model.n_clusters_):
128 | centers.append(embeddings[new_labels == i].mean(axis=0, keepdim=True))
129 | centers = torch.cat(centers)
130 | return centers, new_labels
131 |
132 |
133 | def get_topk_indices(centers, labels, embeddings, topk):
134 | indices = []
135 | arange = torch.arange(len(labels))
136 | counts = torch.unique(labels, return_counts=True)[1]
137 | topk = min(topk, counts.min().item())
138 | for i in range(len(centers)):
139 | tmp_indices = labels == i
140 | tmp_arange = arange[tmp_indices]
141 | tmp_embeddings = embeddings[tmp_indices]
142 |
143 | scores = cos_sim(centers[i].unsqueeze(0), tmp_embeddings)[0]
144 | sorted_indices = torch.flip(torch.argsort(scores), dims=[0])
145 | indices.append(tmp_arange[sorted_indices[:topk]].unsqueeze(0))
146 | return torch.cat(indices)
147 |
148 |
149 | def print_topk(texts, labels, topk_indices, show_cut_off):
150 | ret = ""
151 | for k in range(len(topk_indices)):
152 | num_samples = torch.sum(labels == k).item()
153 |
154 | ret += "=" * 20 + f" cluster {k}, #samples: {num_samples} " + "=" * 20 + "\n"
155 | for idx in topk_indices[k]:
156 | ret += "PROMPT: " + texts[idx][:show_cut_off] + "\n"
157 | ret += "=" * 40 + "\n\n"
158 |
159 | return ret
160 |
161 |
162 | def get_cluster_info(texts, labels, topk_indices):
163 | cluster_info = []
164 | for k in range(len(topk_indices)):
165 | num_samples = torch.sum(labels == k).item()
166 | prompts = []
167 | for idx in topk_indices[k]:
168 | prompts.append(texts[idx])
169 | cluster_info.append((num_samples, prompts))
170 |
171 | return cluster_info
172 |
173 |
174 | if __name__ == "__main__":
175 | parser = argparse.ArgumentParser()
176 | parser.add_argument("--input-file", type=str, required=True)
177 | parser.add_argument("--model", type=str, default="all-mpnet-base-v2")
178 | # default="all-MiniLM-L12-v2")
179 | # default="multi-qa-distilbert-cos-v1")
180 | parser.add_argument("--batch-size", type=int, default=256)
181 | parser.add_argument("--min-length", type=int)
182 | parser.add_argument("--max-length", type=int)
183 | parser.add_argument("--english-only", action="store_true")
184 | parser.add_argument("--num-clusters", type=int, default=20)
185 | parser.add_argument(
186 | "--cluster-alg", type=str, choices=["kmeans", "aggcls"], default="kmeans"
187 | )
188 | parser.add_argument("--show-top-k", type=int, default=200)
189 | parser.add_argument("--show-cut-off", type=int, default=512)
190 | args = parser.parse_args()
191 |
192 | num_clusters = args.num_clusters
193 | show_top_k = args.show_top_k
194 | show_cut_off = args.show_cut_off
195 |
196 | texts = read_texts(
197 | args.input_file, args.min_length, args.max_length, args.english_only
198 | )
199 | print(f"#text: {len(texts)}")
200 |
201 | embeddings = get_embeddings(texts, args.model, args.batch_size)
202 | if args.cluster_alg == "kmeans":
203 | centers, labels = run_k_means(embeddings, num_clusters)
204 | elif args.cluster_alg == "aggcls":
205 | centers, labels = run_agg_cluster(embeddings, num_clusters)
206 | else:
207 | raise ValueError(f"Invalid clustering algorithm: {args.cluster_alg}")
208 |
209 | topk_indices = get_topk_indices(centers, labels, embeddings, args.show_top_k)
210 | topk_str = print_topk(texts, labels, topk_indices, args.show_cut_off)
211 | num_clusters = len(centers)
212 |
213 | cluster_info = get_cluster_info(texts, labels, topk_indices)
214 |
215 | # Dump results
216 | filename_prefix = f"results_c{num_clusters}_{args.cluster_alg}"
217 | print(topk_str)
218 | with open(filename_prefix + "_topk.txt", "w") as fout:
219 | fout.write(topk_str)
220 |
221 | with open(filename_prefix + "_all.txt", "w") as fout:
222 | for i in range(len(centers)):
223 | tmp_indices = labels == i
224 | tmp_embeddings = embeddings[tmp_indices]
225 | tmp_texts = texts[tmp_indices]
226 |
227 | scores = cos_sim(centers[i].unsqueeze(0), tmp_embeddings)[0]
228 | sorted_indices = torch.flip(torch.argsort(scores), dims=[0])
229 |
230 | for text, score in zip(tmp_texts[sorted_indices], scores[sorted_indices]):
231 | obj = {"cluster": i, "text": text, "sim": score.item()}
232 | fout.write(json.dumps(obj, ensure_ascii=False) + "\n")
233 |
234 | with open(filename_prefix + "_cluster.pkl", "wb") as fout:
235 | pickle.dump(cluster_info, fout)
236 |
--------------------------------------------------------------------------------
/chat/server/multi_model_worker.py:
--------------------------------------------------------------------------------
1 | """
2 | A multi-model worker that contains multiple sub-works one for each model. This
3 | supports running a list of models on the same machine so that they can
4 | (potentially) share the same background weights.
5 |
6 | Each model can have one or more model names.
7 |
8 | This multi-model worker assumes the models shares some underlying weights and
9 | thus reports the combined queue lengths for health checks.
10 |
11 | We recommend using this with multiple Peft models (with `peft` in the name)
12 | where all Peft models are trained on the exact same base model.
13 | """
14 | import argparse
15 | import asyncio
16 | import dataclasses
17 | import logging
18 | import json
19 | import os
20 | import time
21 | from typing import List, Union
22 | import threading
23 | import uuid
24 |
25 | from fastapi import FastAPI, Request, BackgroundTasks
26 | from fastapi.responses import StreamingResponse, JSONResponse
27 | import requests
28 |
29 | try:
30 | from transformers import (
31 | AutoTokenizer,
32 | AutoModelForCausalLM,
33 | LlamaTokenizer,
34 | AutoModel,
35 | )
36 | except ImportError:
37 | from transformers import (
38 | AutoTokenizer,
39 | AutoModelForCausalLM,
40 | LLaMATokenizer,
41 | AutoModel,
42 | )
43 | import torch
44 | import torch.nn.functional as F
45 | import uvicorn
46 |
47 | from chat.constants import WORKER_HEART_BEAT_INTERVAL, ErrorCode, SERVER_ERROR_MSG
48 | from chat.model.model_adapter import (
49 | load_model,
50 | add_model_args,
51 | get_conversation_template,
52 | )
53 | from chat.model.model_chatglm import generate_stream_chatglm
54 | from chat.model.model_falcon import generate_stream_falcon
55 | from chat.model.model_codet5p import generate_stream_codet5p
56 | from chat.modules.gptq import GptqConfig
57 | from chat.server.inference import generate_stream
58 | from chat.server.model_worker import ModelWorker, worker_id, logger
59 | from chat.utils import build_logger, pretty_print_semaphore, get_context_length
60 |
61 |
62 | # We store both the underlying workers and a mapping from their model names to
63 | # the worker instance. This makes it easy to fetch the appropriate worker for
64 | # each API call.
65 | workers = []
66 | worker_map = {}
67 | app = FastAPI()
68 |
69 |
70 | def release_worker_semaphore():
71 | workers[0].semaphore.release()
72 |
73 |
74 | def acquire_worker_semaphore():
75 | if workers[0].semaphore is None:
76 | # Share the same semaphore for all workers because
77 | # all workers share the same GPU.
78 | semaphore = asyncio.Semaphore(workers[0].limit_worker_concurrency)
79 | for w in workers:
80 | w.semaphore = semaphore
81 | return workers[0].semaphore.acquire()
82 |
83 |
84 | def create_background_tasks():
85 | background_tasks = BackgroundTasks()
86 | background_tasks.add_task(release_worker_semaphore)
87 | return background_tasks
88 |
89 |
90 | # Note: for all the calls below, we make a hard assumption that the caller
91 | # includes the model name in the payload, otherwise we can't figure out which
92 | # underlying sub-worker to call.
93 |
94 |
95 | @app.post("/worker_generate_stream")
96 | async def api_generate_stream(request: Request):
97 | params = await request.json()
98 | await acquire_worker_semaphore()
99 | worker = worker_map[params["model"]]
100 | generator = worker.generate_stream_gate(params)
101 | background_tasks = create_background_tasks()
102 | return StreamingResponse(generator, background=background_tasks)
103 |
104 |
105 | @app.post("/worker_generate")
106 | async def api_generate(request: Request):
107 | params = await request.json()
108 | await acquire_worker_semaphore()
109 | worker = worker_map[params["model"]]
110 | output = worker.generate_gate(params)
111 | release_worker_semaphore()
112 | return JSONResponse(output)
113 |
114 |
115 | @app.post("/worker_get_embeddings")
116 | async def api_get_embeddings(request: Request):
117 | params = await request.json()
118 | await acquire_worker_semaphore()
119 | worker = worker_map[params["model"]]
120 | embedding = worker.get_embeddings(params)
121 | background_tasks = create_background_tasks()
122 | return JSONResponse(content=embedding, background=background_tasks)
123 |
124 |
125 | @app.post("/worker_get_status")
126 | async def api_get_status(request: Request):
127 | return {
128 | "model_names": [m for w in workers for m in w.model_names],
129 | "speed": 1,
130 | "queue_length": sum([w.get_queue_length() for w in workers]),
131 | }
132 |
133 |
134 | @app.post("/count_token")
135 | async def api_count_token(request: Request):
136 | params = await request.json()
137 | worker = worker_map[params["model"]]
138 | return worker.count_token(params)
139 |
140 |
141 | @app.post("/worker_get_conv_template")
142 | async def api_get_conv(request: Request):
143 | params = await request.json()
144 | worker = worker_map[params["model"]]
145 | return worker.get_conv_template()
146 |
147 |
148 | @app.post("/model_details")
149 | async def api_model_details(request: Request):
150 | params = await request.json()
151 | worker = worker_map[params["model"]]
152 | return {"context_length": worker.context_len}
153 |
154 |
155 | def create_multi_model_worker():
156 | # Note: Ensure we resolve arg conflicts. We let `add_model_args` add MOST
157 | # of the model args but we'll override one to have an append action that
158 | # supports multiple values.
159 | parser = argparse.ArgumentParser(conflict_handler="resolve")
160 | parser.add_argument("--host", type=str, default="localhost")
161 | parser.add_argument("--port", type=int, default=21002)
162 | parser.add_argument("--worker-address", type=str, default="http://localhost:21002")
163 | parser.add_argument(
164 | "--controller-address", type=str, default="http://localhost:21001"
165 | )
166 | add_model_args(parser)
167 | # Override the model path to be repeated and align it with model names.
168 | parser.add_argument(
169 | "--model-path",
170 | type=str,
171 | default=[],
172 | action="append",
173 | help="One or more paths to model weights to load. This can be a local folder or a Hugging Face repo ID.",
174 | )
175 | parser.add_argument(
176 | "--model-names",
177 | type=lambda s: s.split(","),
178 | action="append",
179 | help="One or more model names. Values must be aligned with `--model-path` values.",
180 | )
181 | parser.add_argument("--limit-worker-concurrency", type=int, default=5)
182 | parser.add_argument("--stream-interval", type=int, default=2)
183 | parser.add_argument("--no-register", action="store_true")
184 | args = parser.parse_args()
185 | logger.info(f"args: {args}")
186 |
187 | if args.gpus:
188 | if len(args.gpus.split(",")) < args.num_gpus:
189 | raise ValueError(
190 | f"Larger --num-gpus ({args.num_gpus}) than --gpus {args.gpus}!"
191 | )
192 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
193 |
194 | gptq_config = GptqConfig(
195 | ckpt=args.gptq_ckpt or args.model_path,
196 | wbits=args.gptq_wbits,
197 | groupsize=args.gptq_groupsize,
198 | act_order=args.gptq_act_order,
199 | )
200 |
201 | if args.model_names is None:
202 | args.model_names = [[x.split("/")[-1]] for x in args.model_path]
203 |
204 | # Launch all workers
205 | workers = []
206 | for model_path, model_names in zip(args.model_path, args.model_names):
207 | w = ModelWorker(
208 | args.controller_address,
209 | args.worker_address,
210 | worker_id,
211 | model_path,
212 | model_names,
213 | args.limit_worker_concurrency,
214 | args.no_register,
215 | device=args.device,
216 | num_gpus=args.num_gpus,
217 | max_gpu_memory=args.max_gpu_memory,
218 | load_8bit=args.load_8bit,
219 | cpu_offloading=args.cpu_offloading,
220 | gptq_config=gptq_config,
221 | stream_interval=args.stream_interval,
222 | )
223 | workers.append(w)
224 | for model_name in model_names:
225 | worker_map[model_name] = w
226 |
227 | # Register all models
228 | url = args.controller_address + "/register_worker"
229 | data = {
230 | "worker_name": workers[0].worker_addr,
231 | "check_heart_beat": not args.no_register,
232 | "worker_status": {
233 | "model_names": [m for w in workers for m in w.model_names],
234 | "speed": 1,
235 | "queue_length": sum([w.get_queue_length() for w in workers]),
236 | },
237 | }
238 | r = requests.post(url, json=data)
239 | assert r.status_code == 200
240 |
241 | return args, workers
242 |
243 |
244 | if __name__ == "__main__":
245 | args, workers = create_multi_model_worker()
246 | uvicorn.run(app, host=args.host, port=args.port, log_level="info")
247 |
--------------------------------------------------------------------------------
/chat/server/launch_all_serve.py:
--------------------------------------------------------------------------------
1 | """
2 | Usage: python launch_all_serve_by_shell.py --model-path-address "THUDM/chatglm2-6b@localhost@2021" "huggyllama/llama-7b@localhost@2022"
3 |
4 | Workers are listed in format of `model-path`@`host`@`port`
5 |
6 | The key mechanism behind this scripts is:
7 | 1, execute shell cmd to launch the controller/worker/openai-api-server;
8 | 2, check the log of controller/worker/openai-api-server to ensure that the server is launched properly.
9 | Note that a few of non-critical `chat.server` cmd options are not supported currently.
10 | """
11 | import sys
12 | import os
13 |
14 | sys.path.append(os.path.dirname(os.path.dirname(__file__)))
15 |
16 | import subprocess
17 | import re
18 | import argparse
19 |
20 | LOGDIR = "./logs/"
21 |
22 | if not os.path.exists(LOGDIR):
23 | os.makedirs(LOGDIR)
24 |
25 | parser = argparse.ArgumentParser()
26 | # ------multi worker-----------------
27 | parser.add_argument(
28 | "--model-path-address",
29 | default="THUDM/chatglm2-6b@localhost@20002",
30 | nargs="+",
31 | type=str,
32 | help="model path, host, and port, formatted as model-path@host@port",
33 | )
34 | # ---------------controller-------------------------
35 |
36 | parser.add_argument("--controller-host", type=str, default="localhost")
37 | parser.add_argument("--controller-port", type=int, default=21001)
38 | parser.add_argument(
39 | "--dispatch-method",
40 | type=str,
41 | choices=["lottery", "shortest_queue"],
42 | default="shortest_queue",
43 | )
44 | controller_args = ["controller-host", "controller-port", "dispatch-method"]
45 |
46 | # ----------------------worker------------------------------------------
47 |
48 | parser.add_argument("--worker-host", type=str, default="localhost")
49 | parser.add_argument("--worker-port", type=int, default=21002)
50 | # parser.add_argument("--worker-address", type=str, default="http://localhost:21002")
51 | # parser.add_argument(
52 | # "--controller-address", type=str, default="http://localhost:21001"
53 | # )
54 | parser.add_argument(
55 | "--model-path",
56 | type=str,
57 | default="lmsys/vicuna-7b-v1.3",
58 | help="The path to the weights. This can be a local folder or a Hugging Face repo ID.",
59 | )
60 | parser.add_argument(
61 | "--revision",
62 | type=str,
63 | default="main",
64 | help="Hugging Face Hub model revision identifier",
65 | )
66 | parser.add_argument(
67 | "--device",
68 | type=str,
69 | choices=["cpu", "cuda", "mps", "xpu"],
70 | default="cuda",
71 | help="The device type",
72 | )
73 | parser.add_argument(
74 | "--gpus",
75 | type=str,
76 | default="0",
77 | help="A single GPU like 1 or multiple GPUs like 0,2",
78 | )
79 | parser.add_argument("--num-gpus", type=int, default=1)
80 | parser.add_argument(
81 | "--max-gpu-memory",
82 | type=str,
83 | help="The maximum memory per gpu. Use a string like '13Gib'",
84 | )
85 | parser.add_argument("--load-8bit", action="store_true", help="Use 8-bit quantization")
86 | parser.add_argument(
87 | "--cpu-offloading",
88 | action="store_true",
89 | help="Only when using 8-bit quantization: Offload excess weights to the CPU that don't fit on the GPU",
90 | )
91 | parser.add_argument(
92 | "--gptq-ckpt",
93 | type=str,
94 | default=None,
95 | help="Load quantized model. The path to the local GPTQ checkpoint.",
96 | )
97 | parser.add_argument(
98 | "--gptq-wbits",
99 | type=int,
100 | default=16,
101 | choices=[2, 3, 4, 8, 16],
102 | help="#bits to use for quantization",
103 | )
104 | parser.add_argument(
105 | "--gptq-groupsize",
106 | type=int,
107 | default=-1,
108 | help="Groupsize to use for quantization; default uses full row.",
109 | )
110 | parser.add_argument(
111 | "--gptq-act-order",
112 | action="store_true",
113 | help="Whether to apply the activation order GPTQ heuristic",
114 | )
115 | parser.add_argument(
116 | "--model-names",
117 | type=lambda s: s.split(","),
118 | help="Optional display comma separated names",
119 | )
120 | parser.add_argument(
121 | "--limit-worker-concurrency",
122 | type=int,
123 | default=5,
124 | help="Limit the model concurrency to prevent OOM.",
125 | )
126 | parser.add_argument("--stream-interval", type=int, default=2)
127 | parser.add_argument("--no-register", action="store_true")
128 |
129 | worker_args = [
130 | "worker-host",
131 | "worker-port",
132 | "model-path",
133 | "revision",
134 | "device",
135 | "gpus",
136 | "num-gpus",
137 | "max-gpu-memory",
138 | "load-8bit",
139 | "cpu-offloading",
140 | "gptq-ckpt",
141 | "gptq-wbits",
142 | "gptq-groupsize",
143 | "gptq-act-order",
144 | "model-names",
145 | "limit-worker-concurrency",
146 | "stream-interval",
147 | "no-register",
148 | "controller-address",
149 | ]
150 | # -----------------openai server---------------------------
151 |
152 | parser.add_argument("--server-host", type=str, default="localhost", help="host name")
153 | parser.add_argument("--server-port", type=int, default=8001, help="port number")
154 | parser.add_argument(
155 | "--allow-credentials", action="store_true", help="allow credentials"
156 | )
157 | # parser.add_argument(
158 | # "--allowed-origins", type=json.loads, default=["*"], help="allowed origins"
159 | # )
160 | # parser.add_argument(
161 | # "--allowed-methods", type=json.loads, default=["*"], help="allowed methods"
162 | # )
163 | # parser.add_argument(
164 | # "--allowed-headers", type=json.loads, default=["*"], help="allowed headers"
165 | # )
166 | parser.add_argument(
167 | "--api-keys",
168 | type=lambda s: s.split(","),
169 | help="Optional list of comma separated API keys",
170 | )
171 | server_args = [
172 | "server-host",
173 | "server-port",
174 | "allow-credentials",
175 | "api-keys",
176 | "controller-address",
177 | ]
178 |
179 | args = parser.parse_args()
180 |
181 | args = argparse.Namespace(
182 | **vars(args),
183 | **{"controller-address": f"http://{args.controller_host}:{args.controller_port}"},
184 | )
185 |
186 | if args.gpus:
187 | if len(args.gpus.split(",")) < args.num_gpus:
188 | raise ValueError(
189 | f"Larger --num-gpus ({args.num_gpus}) than --gpus {args.gpus}!"
190 | )
191 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
192 |
193 | # 0,controller, model_worker, openai_api_server
194 | # 1, cmd options
195 | # 2,LOGDIR
196 | # 3, log file name
197 | base_launch_sh = "nohup python3 -m chat.server.{0} {1} >{2}/{3}.log 2>&1 &"
198 |
199 | # 0 LOGDIR
200 | #! 1 log file name
201 | # 2 controller, worker, openai_api_server
202 | base_check_sh = """while [ `grep -c "Uvicorn running on" {0}/{1}.log` -eq '0' ];do
203 | sleep 1s;
204 | echo "wait {2} running"
205 | done
206 | echo '{2} running' """
207 |
208 |
209 | def string_args(args, args_list):
210 | args_str = ""
211 | for key, value in args._get_kwargs():
212 | key = key.replace("_", "-")
213 | if key not in args_list:
214 | continue
215 |
216 | key = key.split("-")[-1] if re.search("port|host", key) else key
217 | if not value:
218 | pass
219 | # 1==True -> True
220 | elif isinstance(value, bool) and value == True:
221 | args_str += f" --{key} "
222 | elif (
223 | isinstance(value, list)
224 | or isinstance(value, tuple)
225 | or isinstance(value, set)
226 | ):
227 | value = " ".join(value)
228 | args_str += f" --{key} {value} "
229 | else:
230 | args_str += f" --{key} {value} "
231 |
232 | return args_str
233 |
234 |
235 | def launch_worker(item):
236 | log_name = (
237 | item.split("/")[-1]
238 | .split("\\")[-1]
239 | .replace("-", "_")
240 | .replace("@", "_")
241 | .replace(".", "_")
242 | )
243 |
244 | args.model_path, args.worker_host, args.worker_port = item.split("@")
245 | print("*" * 80)
246 | worker_str_args = string_args(args, worker_args)
247 | print(worker_str_args)
248 | worker_sh = base_launch_sh.format(
249 | "model_worker", worker_str_args, LOGDIR, f"worker_{log_name}"
250 | )
251 | worker_check_sh = base_check_sh.format(LOGDIR, f"worker_{log_name}", "model_worker")
252 | subprocess.run(worker_sh, shell=True, check=True)
253 | subprocess.run(worker_check_sh, shell=True, check=True)
254 |
255 |
256 | def launch_all():
257 | controller_str_args = string_args(args, controller_args)
258 | controller_sh = base_launch_sh.format(
259 | "controller", controller_str_args, LOGDIR, "controller"
260 | )
261 | controller_check_sh = base_check_sh.format(LOGDIR, "controller", "controller")
262 | subprocess.run(controller_sh, shell=True, check=True)
263 | subprocess.run(controller_check_sh, shell=True, check=True)
264 |
265 | if isinstance(args.model_path_address, str):
266 | launch_worker(args.model_path_address)
267 | else:
268 | for idx, item in enumerate(args.model_path_address):
269 | print(f"loading {idx}th model:{item}")
270 | launch_worker(item)
271 |
272 | server_str_args = string_args(args, server_args)
273 | server_sh = base_launch_sh.format(
274 | "openai_api_server", server_str_args, LOGDIR, "openai_api_server"
275 | )
276 | server_check_sh = base_check_sh.format(
277 | LOGDIR, "openai_api_server", "openai_api_server"
278 | )
279 | subprocess.run(server_sh, shell=True, check=True)
280 | subprocess.run(server_check_sh, shell=True, check=True)
281 |
282 |
283 | if __name__ == "__main__":
284 | launch_all()
285 |
--------------------------------------------------------------------------------
/chat/server/monitor/elo_analysis.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from collections import defaultdict
3 | import datetime
4 | import json
5 | import math
6 | import pickle
7 | from pytz import timezone
8 |
9 | import numpy as np
10 | import pandas as pd
11 | import plotly.express as px
12 | from tqdm import tqdm
13 |
14 | from chat.model.model_registry import get_model_info
15 | from chat.server.monitor.basic_stats import get_log_files
16 | from chat.server.monitor.clean_battle_data import clean_battle_data
17 |
18 |
19 | pd.options.display.float_format = "{:.2f}".format
20 |
21 |
22 | def compute_elo(battles, K=4, SCALE=400, BASE=10, INIT_RATING=1000):
23 | rating = defaultdict(lambda: INIT_RATING)
24 |
25 | for rd, model_a, model_b, winner in battles[
26 | ["model_a", "model_b", "winner"]
27 | ].itertuples():
28 | ra = rating[model_a]
29 | rb = rating[model_b]
30 | ea = 1 / (1 + BASE ** ((rb - ra) / SCALE))
31 | eb = 1 / (1 + BASE ** ((ra - rb) / SCALE))
32 | if winner == "model_a":
33 | sa = 1
34 | elif winner == "model_b":
35 | sa = 0
36 | elif winner == "tie" or winner == "tie (bothbad)":
37 | sa = 0.5
38 | else:
39 | raise Exception(f"unexpected vote {winner}")
40 | rating[model_a] += K * (sa - ea)
41 | rating[model_b] += K * (1 - sa - eb)
42 |
43 | return dict(rating)
44 |
45 |
46 | def get_bootstrap_result(battles, func_compute_elo, num_round=1000):
47 | rows = []
48 | for i in tqdm(range(num_round), desc="bootstrap"):
49 | tmp_battles = battles.sample(frac=1.0, replace=True)
50 | rows.append(func_compute_elo(tmp_battles))
51 | df = pd.DataFrame(rows)
52 | return df[df.median().sort_values(ascending=False).index]
53 |
54 |
55 | def get_median_elo_from_bootstrap(bootstrap_df):
56 | median = dict(bootstrap_df.quantile(0.5))
57 | median = {k: int(v + 0.5) for k, v in median.items()}
58 | return median
59 |
60 |
61 | def compute_pairwise_win_fraction(battles, model_order):
62 | # Times each model wins as Model A
63 | a_win_ptbl = pd.pivot_table(
64 | battles[battles["winner"] == "model_a"],
65 | index="model_a",
66 | columns="model_b",
67 | aggfunc="size",
68 | fill_value=0,
69 | )
70 |
71 | # Table counting times each model wins as Model B
72 | b_win_ptbl = pd.pivot_table(
73 | battles[battles["winner"] == "model_b"],
74 | index="model_a",
75 | columns="model_b",
76 | aggfunc="size",
77 | fill_value=0,
78 | )
79 |
80 | # Table counting number of A-B pairs
81 | num_battles_ptbl = pd.pivot_table(
82 | battles, index="model_a", columns="model_b", aggfunc="size", fill_value=0
83 | )
84 |
85 | # Computing the proportion of wins for each model as A and as B
86 | # against all other models
87 | row_beats_col_freq = (a_win_ptbl + b_win_ptbl.T) / (
88 | num_battles_ptbl + num_battles_ptbl.T
89 | )
90 |
91 | if model_order is None:
92 | prop_wins = row_beats_col_freq.mean(axis=1).sort_values(ascending=False)
93 | model_order = list(prop_wins.keys())
94 |
95 | # Arrange ordering according to proprition of wins
96 | row_beats_col = row_beats_col_freq.loc[model_order, model_order]
97 | return row_beats_col
98 |
99 |
100 | def visualize_leaderboard_table(rating):
101 | models = list(rating.keys())
102 | models.sort(key=lambda k: -rating[k])
103 |
104 | emoji_dict = {
105 | 1: "🥇",
106 | 2: "🥈",
107 | 3: "🥉",
108 | }
109 |
110 | md = ""
111 | md += "| Rank | Model | Elo Rating | Description |\n"
112 | md += "| --- | --- | --- | --- |\n"
113 | for i, model in enumerate(models):
114 | rank = i + 1
115 | minfo = get_model_info(model)
116 | emoji = emoji_dict.get(rank, "")
117 | md += f"| {rank} | {emoji} [{model}]({minfo.link}) | {rating[model]:.0f} | {minfo.description} |\n"
118 |
119 | return md
120 |
121 |
122 | def visualize_pairwise_win_fraction(battles, model_order):
123 | row_beats_col = compute_pairwise_win_fraction(battles, model_order)
124 | fig = px.imshow(
125 | row_beats_col,
126 | color_continuous_scale="RdBu",
127 | text_auto=".2f",
128 | height=700,
129 | width=700,
130 | )
131 | fig.update_layout(
132 | xaxis_title="Model B",
133 | yaxis_title="Model A",
134 | xaxis_side="top",
135 | title_y=0.07,
136 | title_x=0.5,
137 | )
138 | fig.update_traces(
139 | hovertemplate="Model A: %{y}
Model B: %{x}
Fraction of A Wins: %{z}"
140 | )
141 |
142 | return fig
143 |
144 |
145 | def visualize_battle_count(battles, model_order):
146 | ptbl = pd.pivot_table(
147 | battles, index="model_a", columns="model_b", aggfunc="size", fill_value=0
148 | )
149 | battle_counts = ptbl + ptbl.T
150 | fig = px.imshow(
151 | battle_counts.loc[model_order, model_order],
152 | text_auto=True,
153 | height=700,
154 | width=700,
155 | )
156 | fig.update_layout(
157 | xaxis_title="Model B",
158 | yaxis_title="Model A",
159 | xaxis_side="top",
160 | title_y=0.07,
161 | title_x=0.5,
162 | )
163 | fig.update_traces(
164 | hovertemplate="Model A: %{y}
Model B: %{x}
Count: %{z}"
165 | )
166 | return fig
167 |
168 |
169 | def visualize_average_win_rate(battles):
170 | row_beats_col_freq = compute_pairwise_win_fraction(battles, None)
171 | fig = px.bar(
172 | row_beats_col_freq.mean(axis=1).sort_values(ascending=False),
173 | text_auto=".2f",
174 | height=500,
175 | width=700,
176 | )
177 | fig.update_layout(
178 | yaxis_title="Average Win Rate", xaxis_title="Model", showlegend=False
179 | )
180 | return fig
181 |
182 |
183 | def visualize_bootstrap_elo_rating(df):
184 | bars = (
185 | pd.DataFrame(
186 | dict(
187 | lower=df.quantile(0.025),
188 | rating=df.quantile(0.5),
189 | upper=df.quantile(0.975),
190 | )
191 | )
192 | .reset_index(names="model")
193 | .sort_values("rating", ascending=False)
194 | )
195 | bars["error_y"] = bars["upper"] - bars["rating"]
196 | bars["error_y_minus"] = bars["rating"] - bars["lower"]
197 | bars["rating_rounded"] = np.round(bars["rating"], 2)
198 | fig = px.scatter(
199 | bars,
200 | x="model",
201 | y="rating",
202 | error_y="error_y",
203 | error_y_minus="error_y_minus",
204 | text="rating_rounded",
205 | height=500,
206 | width=700,
207 | )
208 | fig.update_layout(xaxis_title="Model", yaxis_title="Rating")
209 | return fig
210 |
211 |
212 | def report_elo_analysis_results(battles_json):
213 | battles = pd.DataFrame(battles_json)
214 | battles = battles.sort_values(ascending=True, by=["tstamp"])
215 | # Only use anonymous votes
216 | battles = battles[battles["anony"]].reset_index(drop=True)
217 | battles_no_ties = battles[~battles["winner"].str.contains("tie")]
218 |
219 | # Online update
220 | elo_rating_online = compute_elo(battles)
221 |
222 | # Bootstrap
223 | bootstrap_df = get_bootstrap_result(battles, compute_elo)
224 | elo_rating_median = get_median_elo_from_bootstrap(bootstrap_df)
225 | model_order = list(elo_rating_median.keys())
226 | model_order.sort(key=lambda k: -elo_rating_median[k])
227 |
228 | # Plots
229 | leaderboard_table = visualize_leaderboard_table(elo_rating_median)
230 | win_fraction_heatmap = visualize_pairwise_win_fraction(battles_no_ties, model_order)
231 | battle_count_heatmap = visualize_battle_count(battles_no_ties, model_order)
232 | average_win_rate_bar = visualize_average_win_rate(battles_no_ties)
233 | bootstrap_elo_rating = visualize_bootstrap_elo_rating(bootstrap_df)
234 |
235 | last_updated_tstamp = battles["tstamp"].max()
236 | last_updated_datetime = datetime.datetime.fromtimestamp(
237 | last_updated_tstamp, tz=timezone("US/Pacific")
238 | ).strftime("%Y-%m-%d %H:%M:%S %Z")
239 |
240 | return {
241 | "elo_rating_online": elo_rating_online,
242 | "elo_rating_median": elo_rating_median,
243 | "leaderboard_table": leaderboard_table,
244 | "win_fraction_heatmap": win_fraction_heatmap,
245 | "battle_count_heatmap": battle_count_heatmap,
246 | "average_win_rate_bar": average_win_rate_bar,
247 | "bootstrap_elo_rating": bootstrap_elo_rating,
248 | "last_updated_datetime": last_updated_datetime,
249 | "last_updated_tstamp": last_updated_tstamp,
250 | }
251 |
252 |
253 | def pretty_print_elo_rating(rating):
254 | model_order = list(rating.keys())
255 | model_order.sort(key=lambda k: -rating[k])
256 | for i, model in enumerate(model_order):
257 | print(f"{i+1:2d}, {model:25s}, {rating[model]:.0f}")
258 |
259 |
260 | if __name__ == "__main__":
261 | parser = argparse.ArgumentParser()
262 | parser.add_argument("--clean-battle-file", type=str)
263 | parser.add_argument("--max-num-files", type=int)
264 | args = parser.parse_args()
265 |
266 | np.random.seed(42)
267 |
268 | if args.clean_battle_file:
269 | # Read data from a cleaned battle files
270 | battles = pd.read_json(args.clean_battle_file)
271 | else:
272 | # Read data from all log files
273 | log_files = get_log_files(args.max_num_files)
274 | battles = clean_battle_data(log_files)
275 |
276 | results = report_elo_analysis_results(battles)
277 |
278 | print("# Online")
279 | pretty_print_elo_rating(results["elo_rating_online"])
280 | print("# Median")
281 | pretty_print_elo_rating(results["elo_rating_median"])
282 | print(f"last update : {results['last_updated_datetime']}")
283 |
284 | last_updated_tstamp = results["last_updated_tstamp"]
285 | cutoff_date = datetime.datetime.fromtimestamp(
286 | last_updated_tstamp, tz=timezone("US/Pacific")
287 | ).strftime("%Y%m%d")
288 |
289 | with open(f"elo_results_{cutoff_date}.pkl", "wb") as fout:
290 | pickle.dump(results, fout)
291 |
--------------------------------------------------------------------------------