├── infini_websearch
├── __init__.py
├── service
│ ├── __init__.py
│ └── search_service.py
├── actions
│ ├── __init__.py
│ ├── base_action.py
│ ├── action_utils.py
│ └── websearch.py
├── utils
│ ├── __init__.py
│ └── misc.py
├── model
│ ├── __init__.py
│ ├── postprocessing.py
│ └── inference.py
└── configs
│ ├── server.py
│ ├── css_style.py
│ ├── __init__.py
│ └── prompt.py
├── assets
└── websearch_demo.gif
├── requirements.txt
├── setup.py
├── .gitignore
├── .pre-commit-config.yaml
├── README.md
├── README_en.md
├── LICENSE
└── gradio_app.py
/infini_websearch/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/infini_websearch/service/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/assets/websearch_demo.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/infinigence/InfiniWebSearch/HEAD/assets/websearch_demo.gif
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | fastapi==0.115.6
2 | gradio==5.9.0
3 | gradio_toggle==2.0.2
4 | openai==1.57.4
5 | Requests==2.32.3
6 | selenium==4.27.1
7 | setuptools==68.2.2
8 | transformers==4.46.2
9 | uvicorn==0.32.1
10 | vllm==0.6.3.post1
11 |
--------------------------------------------------------------------------------
/infini_websearch/actions/__init__.py:
--------------------------------------------------------------------------------
1 | from infini_websearch.actions.action_utils import parse_function_call_from_model_ouput
2 | from infini_websearch.actions.websearch import GoogleSearch
3 |
4 | __all__ = ["parse_function_call_from_model_ouput", "GoogleSearch"]
5 |
--------------------------------------------------------------------------------
/infini_websearch/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from infini_websearch.utils.misc import (
2 | extract_citations,
3 | format_search_results,
4 | functions2str,
5 | get_datetime_now,
6 | )
7 |
8 | __all__ = [
9 | "extract_citations",
10 | "format_search_results",
11 | "functions2str",
12 | "get_datetime_now",
13 | ]
14 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import find_packages, setup
2 |
3 | setup(
4 | name='infini_websearch',
5 | version='0.1.0',
6 | description=(
7 | "A demo built on Megrez-3B-Instruct,"
8 | "integrating a web search tool to enhance the model's question-and-answer capabilities."
9 | ),
10 | packages=find_packages(),
11 | )
12 |
--------------------------------------------------------------------------------
/infini_websearch/actions/base_action.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from typing import Any, Dict, Optional
3 |
4 |
5 | class BaseAction(ABC):
6 | @property
7 | @abstractmethod
8 | def function_defination(self) -> Optional[Dict]:
9 | pass
10 |
11 | @abstractmethod
12 | def run(self, arguments: Dict) -> Any:
13 | pass
14 |
--------------------------------------------------------------------------------
/infini_websearch/model/__init__.py:
--------------------------------------------------------------------------------
1 | from infini_websearch.model.inference import get_vllm_model_output_function
2 | from infini_websearch.model.postprocessing import (
3 | include_special_tokens,
4 | split_text_by_special_token,
5 | )
6 |
7 | __all__ = [
8 | "get_vllm_model_output_function",
9 | "include_special_tokens",
10 | "split_text_by_special_token",
11 | ]
12 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Python
2 | __pycache__
3 | *.pyc
4 | *.egg-info
5 | dist
6 | .venv
7 |
8 | # Log
9 | *.log
10 | *.log.*
11 | *.json
12 | !playground/deepspeed_config_s2.json
13 | !playground/deepspeed_config_s3.json
14 |
15 | # Editor
16 | .idea
17 | *.swp
18 |
19 | # Other
20 | .DS_Store
21 | wandb
22 | output
23 | checkpoints_flant5_3b
24 |
25 | # Data
26 | *.pkl
27 | *.csv
28 | tests/state_of_the_union.txt
29 |
30 | # Build
31 | build
32 |
--------------------------------------------------------------------------------
/infini_websearch/model/postprocessing.py:
--------------------------------------------------------------------------------
1 | from typing import List, Tuple
2 |
3 |
4 | def include_special_tokens(text: str, tokens: List[str]) -> bool:
5 | include_all_tokens = True
6 | for token in tokens:
7 | if token not in text:
8 | include_all_tokens = False
9 | break
10 | return include_all_tokens
11 |
12 |
13 | def split_text_by_special_token(text: str, token: str) -> Tuple[str, str]:
14 | assert token in text, f"{token} not found in {text}"
15 | parts = text.split(token)
16 | return parts[0], parts[1]
17 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | - repo: https://github.com/pre-commit/pre-commit-hooks
3 | rev: v4.6.0
4 | hooks:
5 | - id: trailing-whitespace
6 | - id: end-of-file-fixer
7 | - id: check-added-large-files
8 | - id: check-yaml
9 | - id: check-json
10 |
11 | - repo: https://github.com/pycqa/isort
12 | rev: 5.13.2
13 | hooks:
14 | - id: isort
15 | args:
16 | - "--profile=black"
17 |
18 | - repo: https://github.com/psf/black
19 | rev: 23.7.0
20 | hooks:
21 | - id: black
22 | language_version: python3.10
23 | args:
24 | - "--skip-string-normalization"
25 |
26 |
27 | - repo: https://github.com/PyCQA/flake8
28 | rev: 7.1.1
29 | hooks:
30 | - id: flake8
31 | args:
32 | - --max-line-length=120
33 | additional_dependencies:
34 | - flake8-bugbear
35 |
36 | default_language_version:
37 | python: python3.10
38 |
--------------------------------------------------------------------------------
/infini_websearch/configs/server.py:
--------------------------------------------------------------------------------
1 | # gradio
2 | SESSION_WINDOW_SIZE = 4
3 |
4 | # websearch service
5 | SEARCH_SERVER_URL = "http://localhost:8021/search"
6 | NUM_SEARCH_WEBPAGES = 5
7 | WEBPAGE_LOAD_TIMETOUT = 10.0
8 | PROXIES = {
9 | "http": None,
10 | "https": None,
11 | }
12 |
13 |
14 | # model
15 | MODEL_NAME = "megrez"
16 | MODEL_SERVER_URL = "http://localhost:8011/v1/"
17 | STOP_TOKENS = ["<|turn_end|>"]
18 | FUNCTION_START_TOKEN, FUNCTION_END_TOKEN = "<|function_start|>", "<|function_end|>"
19 | MAX_ACTION_TURNS = 1
20 |
21 | # 4k
22 | # WEBPAGE_SUMMARY_MAX_INPUT_TOKENS = 2048
23 | # WEBPAGE_SUMMARY_MAX_OUTPUT_TOKENS = 512
24 | # SESSION_MAX_INPUT_TOKENS = 3072
25 | # CHAT_TEMPERATURE = 0.4
26 | # CHAT_MAX_OUTPUT_TOKENS = 2048
27 | # AGENT_TEMPERATURE = 0.01
28 | # AGENT_MAX_OUTPUT_TOKENS = 512
29 |
30 | # 32k
31 | WEBPAGE_SUMMARY_MAX_INPUT_TOKENS = 2048
32 | WEBPAGE_SUMMARY_MAX_OUTPUT_TOKENS = 512
33 | SESSION_MAX_INPUT_TOKENS = 32768 - 4096
34 | CHAT_TEMPERATURE = 0.4
35 | CHAT_MAX_OUTPUT_TOKENS = 2048
36 | AGENT_TEMPERATURE = 0.01
37 | AGENT_MAX_OUTPUT_TOKENS = 2048
38 |
--------------------------------------------------------------------------------
/infini_websearch/utils/misc.py:
--------------------------------------------------------------------------------
1 | import json
2 | import re
3 | from datetime import datetime
4 |
5 |
6 | def functions2str(functions: list) -> str:
7 | return "\n\n".join(
8 | [json.dumps(function, ensure_ascii=False, indent=4) for function in functions]
9 | )
10 |
11 |
12 | def get_datetime_now():
13 | current_date = datetime.now()
14 | formatted_time = current_date.strftime("%Y-%m-%d %H:%M:%S")
15 | weekday_id = current_date.weekday()
16 | weekday_names = [
17 | "Monday",
18 | "Tuesday",
19 | "Wednesday",
20 | "Thursday",
21 | "Friday",
22 | "Saturday",
23 | "Sunday",
24 | ]
25 | return formatted_time, weekday_names[weekday_id]
26 |
27 |
28 | def format_search_results(url_infos: dict):
29 | return "\n".join(
30 | f"[{url_info['title']}]({url_info['link']})" for url_info in url_infos
31 | )
32 |
33 |
34 | def extract_citations(text):
35 | citations1 = re.findall(r"\[citation:(\d+)\]", text)
36 | citations2 = re.findall(r"\[ citation:(\d+)\]", text)
37 | return citations1 or citations2
38 |
--------------------------------------------------------------------------------
/infini_websearch/configs/css_style.py:
--------------------------------------------------------------------------------
1 | CSS_STYLE = """
2 | .canvas {
3 | # width: 100% !important;
4 | # max-width: 100% !important;
5 | width: 100vh;
6 | }
7 |
8 | .fullheight {
9 | height: 80vh;
10 | }
11 |
12 | .chatbot {
13 | flex-grow: 1;
14 | overflow: auto;
15 | position: relative;
16 | z-index: 100;
17 | }
18 |
19 | .bottom-bar {
20 | position: fixed;
21 | bottom: 0;
22 | left: 50%;
23 | transform: translateX(-50%);
24 | display: flex;
25 | width: 80vh;
26 | z-index: 1000;
27 | }
28 |
29 | .unicode-circle {
30 | font-family: 'Arial Unicode MS', Arial, sans-serif;
31 | font-size: 14px;
32 | border-radius: 50%;
33 | border: 1px solid black;
34 | width: 20px;
35 | height: 20px;
36 | line-height: 20px;
37 | text-align: center;
38 | display: inline-block;
39 | background-color: white;
40 | }
41 |
42 | .circle-link {
43 | font-family: 'Arial Unicode MS', Arial, sans-serif;
44 | text-decoration: none;
45 | color: black;
46 | border-radius: 50%;
47 | border: 1px solid black;
48 | width: 20px;
49 | height: 20px;
50 | line-height: 20px;
51 | text-align: center;
52 | display: inline-block;
53 | background-color: white;
54 | cursor: pointer;
55 | }
56 |
57 | .circle-link:hover {
58 | background-color: #e0e0e0;
59 | }
60 |
61 | .circle-link:hover::after {
62 | content: attr(title);
63 | position: absolute;
64 | white-space: nowrap;
65 | }
66 | """
67 |
--------------------------------------------------------------------------------
/infini_websearch/actions/action_utils.py:
--------------------------------------------------------------------------------
1 | import json
2 | import re
3 | from typing import Dict, List, Optional, Tuple, Union
4 |
5 |
6 | def parse_function_call_from_model_ouput(
7 | output: str,
8 | registered_function_names: Optional[List[str]],
9 | speical_tokens_map: Optional[Dict],
10 | ) -> Tuple[Optional[str], Union[Dict, Optional[str]]]:
11 | if speical_tokens_map is None:
12 | speical_tokens_map = dict(
13 | function_start_token="<|function_start|>",
14 | function_end_token="<|function_end|>",
15 | )
16 |
17 | function_name, function_arguments = None, None
18 | function_call_texts = re.findall(
19 | f'{re.escape(speical_tokens_map["function_start_token"])}(.*?){re.escape(speical_tokens_map["function_end_token"])}', # noqa: E501
20 | output,
21 | re.DOTALL,
22 | )
23 | if len(function_call_texts) > 0:
24 | # support only one action per turn, choose the first one
25 | function_call_text = function_call_texts[0].strip()
26 | try:
27 | function_call_dict = json.loads(function_call_text)
28 | function_name = function_call_dict["name"]
29 | function_arguments = function_call_dict["arguments"]
30 | except Exception as e:
31 | print(e)
32 | function_name = None
33 | print("function call json输入格式错误")
34 | print(function_call_text)
35 |
36 | if function_name is not None and function_name not in registered_function_names:
37 | function_arguments = f"{function_name}不在可以使用的工具列表中"
38 | function_name = None
39 | return function_name, function_arguments
40 |
--------------------------------------------------------------------------------
/infini_websearch/configs/__init__.py:
--------------------------------------------------------------------------------
1 | from infini_websearch.configs.css_style import CSS_STYLE
2 | from infini_websearch.configs.prompt import (
3 | FUNCTION_CALLING_PROMPT_TEMPLATE,
4 | OBSERVATION_PROMPT_TEMPLATE,
5 | ROLE_PROMPT,
6 | SUMMARY_PROMPT_TEMPLATE,
7 | TIME_PROMPT_TEMPLATE,
8 | )
9 | from infini_websearch.configs.server import (
10 | AGENT_MAX_OUTPUT_TOKENS,
11 | AGENT_TEMPERATURE,
12 | CHAT_MAX_OUTPUT_TOKENS,
13 | CHAT_TEMPERATURE,
14 | FUNCTION_END_TOKEN,
15 | FUNCTION_START_TOKEN,
16 | MAX_ACTION_TURNS,
17 | MODEL_NAME,
18 | MODEL_SERVER_URL,
19 | NUM_SEARCH_WEBPAGES,
20 | PROXIES,
21 | SEARCH_SERVER_URL,
22 | SESSION_MAX_INPUT_TOKENS,
23 | SESSION_WINDOW_SIZE,
24 | STOP_TOKENS,
25 | WEBPAGE_LOAD_TIMETOUT,
26 | WEBPAGE_SUMMARY_MAX_INPUT_TOKENS,
27 | WEBPAGE_SUMMARY_MAX_OUTPUT_TOKENS,
28 | )
29 |
30 | __all__ = [
31 | "CSS_STYLE",
32 | "FUNCTION_CALLING_PROMPT_TEMPLATE",
33 | "OBSERVATION_PROMPT_TEMPLATE",
34 | "ROLE_PROMPT",
35 | "SUMMARY_PROMPT_TEMPLATE",
36 | "TIME_PROMPT_TEMPLATE",
37 | "AGENT_MAX_OUTPUT_TOKENS",
38 | "AGENT_TEMPERATURE",
39 | "CHAT_MAX_OUTPUT_TOKENS",
40 | "CHAT_TEMPERATURE",
41 | "FUNCTION_END_TOKEN",
42 | "FUNCTION_START_TOKEN",
43 | "MAX_ACTION_TURNS",
44 | "MODEL_NAME",
45 | "MODEL_SERVER_URL",
46 | "NUM_SEARCH_WEBPAGES",
47 | "PROXIES",
48 | "SEARCH_SERVER_URL",
49 | "STOP_TOKENS",
50 | "WEBPAGE_LOAD_TIMETOUT",
51 | "WEBPAGE_SUMMARY_MAX_INPUT_TOKENS",
52 | "SESSION_MAX_INPUT_TOKENS",
53 | "SESSION_WINDOW_SIZE",
54 | "WEBPAGE_SUMMARY_MAX_OUTPUT_TOKENS",
55 | ]
56 |
--------------------------------------------------------------------------------
/infini_websearch/configs/prompt.py:
--------------------------------------------------------------------------------
1 | ROLE_PROMPT = "你是Megrez-3B-Instruct, 将针对用户的问题给出详细的、积极的回答."
2 |
3 | TIME_PROMPT_TEMPLATE = "The current time is {current_time}, {weekday}."
4 |
5 | FUNCTION_CALLING_PROMPT_TEMPLATE = (
6 | "You have access to the following functions. Use them if required -\n{functions}"
7 | )
8 |
9 | SUMMARY_PROMPT_TEMPLATE = (
10 | '从信息中总结能够回答问题的相关内容,要求简明扼要不能完全照搬原文。直接返回总结不要说其他话,如果没有相关内容则返回"无相关内容", 返回内容为中文。\n\n'
11 | "<问题>{question}问题>\n"
12 | "<信息>{context}信息>"
13 | )
14 |
15 | # this prompt was inspired by
16 | # https://github.com/leptonai/search_with_lepton/blob/main/search_with_lepton.py
17 | OBSERVATION_PROMPT_TEMPLATE = (
18 | "You will be given a set of related contexts to the question, "
19 | "each starting with a reference number like [[citation:x]], where x is a number. "
20 | "Please use the context and cite the context at the end of each sentence if applicable."
21 | "\n\n"
22 | "Please cite the contexts with the reference numbers, in the format [citation:x]. "
23 | "If a sentence comes from multiple contexts, please list all applicable citations, like [citation:3][citation:5]. "
24 | "If the context does not provide relevant information to answer the question, "
25 | "inform the user that there is no relevant information in the search results and that the question cannot be answered." # noqa: E501
26 | "\n\n"
27 | "Other than code and specific names and citations, your answer must be written in Chinese."
28 | "\n\n"
29 | "Ensure that your response is concise and clearly formatted. "
30 | "Group related content together and use Markdown points or lists where appropriate."
31 | "\n\n"
32 | "Remember, summarize and don't blindly repeat the contexts verbatim. And here is the user question:\n"
33 | "{question}\n"
34 | "Here is the keywords of the question:\n"
35 | "{keywords}"
36 | "\n\n"
37 | "Here are the set of contexts:"
38 | "\n\n"
39 | "{context}"
40 | )
41 |
--------------------------------------------------------------------------------
/infini_websearch/model/inference.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 | from typing import Callable, Dict, Generator, List, Union
3 |
4 | import openai
5 |
6 |
7 | def get_vllm_model_output_function(
8 | url: str,
9 | model_name: str,
10 | chat_mode: bool,
11 | model_config: Dict,
12 | stream: bool,
13 | buffer_size: int = 20,
14 | timeout: int = 60,
15 | ) -> Callable:
16 | """
17 | Get model generate function: streaming/non-streaming
18 | """
19 | data = {
20 | "model": model_name,
21 | "stream": stream,
22 | **model_config,
23 | }
24 |
25 | openai.api_key = "EMPTY"
26 | openai.base_url = url
27 | openai.proxy = ""
28 | chat_func = (
29 | openai.chat.completions.create
30 | if chat_mode is True
31 | else openai.completions.create
32 | )
33 |
34 | if stream is True:
35 | return partial(
36 | get_model_streaming_output,
37 | llm_function=chat_func,
38 | model_config=data,
39 | chat_mode=chat_mode,
40 | buffer_size=buffer_size,
41 | timeout=timeout,
42 | )
43 | else:
44 | return partial(
45 | get_model_output,
46 | llm_function=chat_func,
47 | model_config=data,
48 | chat_mode=chat_mode,
49 | timeout=timeout,
50 | )
51 |
52 |
53 | def get_model_streaming_output(
54 | messages: List[Union[Dict, str]],
55 | model_config: Dict,
56 | llm_function: Callable,
57 | chat_mode: bool,
58 | buffer_size: int,
59 | timeout: int,
60 | ) -> Generator[str, None, None]:
61 | if chat_mode is True:
62 | model_config["messages"] = messages
63 | else:
64 | model_config["prompt"] = messages
65 | buffer = ""
66 | for chunk in llm_function(**model_config, timeout=timeout):
67 | if chunk.choices[0].delta.content:
68 | buffer += chunk.choices[0].delta.content
69 | if len(buffer) >= buffer_size:
70 | # '[citation:x]' has been truncated?
71 | if buffer.rfind("]") < buffer.rfind("["):
72 | yield buffer[: buffer.rfind("[")]
73 | buffer = buffer[buffer.rfind("[") :] # noqa: E203
74 | else:
75 | yield buffer
76 | buffer = ""
77 | if buffer:
78 | yield buffer
79 |
80 |
81 | def get_model_output(
82 | messages: List[Union[Dict, str]],
83 | model_config: Dict,
84 | llm_function: Callable,
85 | chat_mode: bool,
86 | timeout: int,
87 | ) -> str:
88 | if chat_mode is True:
89 | model_config["messages"] = messages
90 | else:
91 | model_config["prompt"] = messages
92 | response = llm_function(**model_config, timeout=timeout)
93 | return response
94 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # InfiniWebSearch
2 |
3 | 基于[Megrez-3B-Instruct](https://huggingface.co/Infinigence/Megrez-3B-Instruct)搭建的demo, 接入网络搜索工具增强模型的问答能力.
4 |
5 | Read this in [English](README_en.md)
6 |
7 |
8 |

9 |
10 |
11 | ## 特性
12 |
13 | 1. **意图识别**: LLM自动决定搜索工具调用时机
14 | 2. **上下文理解**: 根据多轮对话生成合理搜索关键词
15 | 3. **模型回答包含引用**: 回答内容可查明出处
16 | 4. **即插即用**: 通过system prompt控制WebSearch功能开启与否
17 |
18 | ## 项目简介
19 |
20 | 本工程包含如下几部分:
21 |
22 | - gradio服务 ([gradio_app.py](gradio_app.py)): 定义了整个app的workflow
23 | - 网络搜索服务 ([search_service.py](infini_websearch/service/search_service.py)): 搜索网页, 加载网页
24 | - 模型服务: 聊天问答, 工具调用, 总结网页内容
25 |
26 | ## 快速上手
27 |
28 | ### 安装
29 |
30 | ```shell
31 | git clone https://github.com/infinigence/InfiniWebSearch
32 | cd InfiniWebSearch
33 | conda create -n infini_websearch python=3.10 -y
34 | conda activate infini_websearch
35 | pip install -r requirements.txt
36 | pip install -e .
37 | ```
38 |
39 | ### 运行demo
40 |
41 | #### 1. 启动网络搜索服务
42 |
43 | 以Ubuntu环境为例, 从[Google Chrome Labs](https://googlechromelabs.github.io/chrome-for-testing/)下载**chrome**和**chromedriver**并解压到本地.
44 |
45 | ```
46 | cd infini_websearch/service
47 | wget https://storage.googleapis.com/chrome-for-testing-public/128.0.6613.86/linux64/chrome-linux64.zip
48 | wget https://storage.googleapis.com/chrome-for-testing-public/128.0.6613.86/linux64/chromedriver-linux64.zip
49 | unzip chrome-linux64.zip
50 | unzip chromedriver-linux64.zip
51 | ```
52 |
53 | 安装依赖
54 |
55 | ```shell
56 | sudo apt-get update
57 | sudo apt-get install -y \
58 | libatk-bridge2.0-0 \
59 | libatk1.0-0 \
60 | libgconf-2-4 \
61 | libnss3 \
62 | libxss1 \
63 | libappindicator1 \
64 | libindicator7 \
65 | libasound2 \
66 | libxcomposite1 \
67 | libxcursor1 \
68 | libxdamage1 \
69 | libxi6 \
70 | libxtst6 \
71 | libglib2.0-0 \
72 | libpango1.0-0 \
73 | libcups2 \
74 | libxrandr2 \
75 | libxrandr-dev \
76 | libxkbcommon0 \
77 | libgbm1
78 | ```
79 |
80 | 在[Serper](https://serper.dev/)注册账户获得`SERPER_API_KEY`, 并添加到环境变量中.
81 | 启动网络搜索服务并指定端口号.
82 | 在[server.py](infini_websearch/configs/server.py)设置`SEARCH_SERVER_URL`为 http://localhost:8021/search .
83 |
84 | ```shell
85 | export SERPER_API_KEY=$YOUR_API_KEY
86 | cd infini_websearch/service
87 | python search_service.py --port 8021 --chrome ./chrome-linux64/chrome --chromedriver ./chromedriver-linux64/chromedriver
88 | ```
89 |
90 | #### 2. 启动模型服务
91 |
92 | 使用vllm.entrypoints.openai.api_server启动服务并制定端口号, `--served-model-name`设置为megrez, `--max-seq-len`设置为32768.
93 | 在[server.py](infini_websearch/configs/server.py)设置`MODEL_SERVER_URL`, 默认为 http://localhost:8011/v1/ . 设置`MODEL_NAME`为"megrez".
94 |
95 | ```shell
96 | python -m vllm.entrypoints.openai.api_server --served-model-name megrez --model $MODEL_PATH --port 8011 --max-seq-len 32768 --trust_remote_code --gpu-memory-utilization 0.8
97 | ```
98 |
99 | #### 3. 启动gradio服务
100 |
101 | 运行[gradio_app.py](infini_websearch/gradio_app.py), 指定模型路径和端口号.
102 | ```shell
103 | export no_proxy="localhost,127.0.0.1"
104 | python gradio_app.py -m $MODEL_PATH --port 7860
105 | ```
106 |
107 | 成功启动之后, 访问 http://localhost:7860/ 即可使用
108 |
109 | ## 说明
110 |
111 | 1. 我们提供了`WEBPAGE_SUMMARY_MAX_INPUT_TOKENS`, `WEBPAGE_SUMMARY_MAX_OUTPUT_TOKENS`, `SESSION_MAX_INPUT_TOKENS`, `CHAT_MAX_OUTPUT_TOKENS`, `AGENT_MAX_OUTPUT_TOKENS`来控制模型的输入和输出长度. 使用`SESSION_WINDOW_SIZE`来保留最近的几轮对话历史.你可以在[server.py](infini_websearch/configs/server.py)中按需修改.
112 | 2. 已经开始首轮对话后,点击websearch toggle切换状态会在后端清空对话历史, 但前端显示依然保留对话历史.
113 | 3. 如果搜索服务出现异常(例如: 网页加载超时或服务器异常), 工具调用的observation会返回预定义好的信息(例如: "搜索页面加载超时, 请重试").你可以在[websearch.py](infini_websearch/actions/websearch.py)和[search_service.py](infini_websearch/service/search_service.py)中自定义边界条件的后处理逻辑.
114 | 4. 使用[Serper](https://serper.dev/)时([search_service.py](infini_websearch/service/search_service.py)), 我们设置"hl"参数为"zh-CN"来尽可能得到中文搜索结果. 如果搜索结果英文网页太多, 可能导致模型用英文回答.
115 | 5. 如果网页搜索成功了, 但是模型表示搜索结果中没有包含回答问题的相关信息, 可以检查控制台打印的各网页的摘要信息. 如果摘要信息显示"无相关内容", 代表原网页没有与问题相关的信息, 或者模型提取与问题相关信息失败.
116 |
117 | ## 协议
118 |
119 | - 本开源仓库的代码则遵循 [Apache 2.0](LICENSE) 协议.
120 |
--------------------------------------------------------------------------------
/README_en.md:
--------------------------------------------------------------------------------
1 | # InfiniWebSearch
2 |
3 | A demo built on [Megrez-3B-Instruct](https://huggingface.co/Infinigence/Megrez-3B-Instruct), integrating a web search tool to enhance the model's question-and-answer capabilities.
4 |
5 | [中文阅读](README.md).
6 |
7 |
8 |

9 |
10 |
11 | ## Features
12 |
13 | 1. **Intent Recognition**: LLM automatically determines when to invoke search tool calls.
14 | 2. **Context Understanding**: Generates reasonable search keywords based on multi-turn dialogue.
15 | 3. **Model Responses Include Citation Links**: The content of the responses can be traced back to their sources.
16 | 4. **Plug-and-Play**: Control the activation or deactivation of WebSearch functionality through system prompts.
17 |
18 | ## Project Introduction
19 |
20 | This project consists of the following parts:
21 |
22 | - Gradio service ([gradio_app.py](gradio_app.py)): Defines the entire app's workflow.
23 | - Web search service ([search_service.py](infini_websearch/service/search_service.py)): Searches for web pages and loads web pages.
24 | - Model service: Chat, function calling and summarize web page.
25 |
26 | ## Quick start
27 |
28 | ### Install
29 |
30 | ```shell
31 | git clone https://github.com/infinigence/InfiniWebSearch
32 | cd InfiniWebSearch
33 | conda create -n infini_websearch python=3.10 -y
34 | conda activate infini_websearch
35 | pip install -r requirements.txt
36 | pip install -e .
37 | ```
38 |
39 | ### Running Demo
40 |
41 | #### 1. Starting Web Search Service
42 |
43 | As an example for the Ubuntu environment, download **chrome** and **chromedriver** from [Google Chrome Labs](https://googlechromelabs.github.io/chrome-for-testing/) and unzip them to the local directory.
44 |
45 | ```
46 | cd infini_websearch/service
47 | wget https://storage.googleapis.com/chrome-for-testing-public/128.0.6613.86/linux64/chrome-linux64.zip
48 | wget https://storage.googleapis.com/chrome-for-testing-public/128.0.6613.86/linux64/chromedriver-linux64.zip
49 | unzip chrome-linux64.zip
50 | unzip chromedriver-linux64.zip
51 | ```
52 |
53 | Install dependencies
54 |
55 | ```shell
56 | sudo apt-get update
57 | sudo apt-get install -y \
58 | libatk-bridge2.0-0 \
59 | libatk1.0-0 \
60 | libgconf-2-4 \
61 | libnss3 \
62 | libxss1 \
63 | libappindicator1 \
64 | libindicator7 \
65 | libasound2 \
66 | libxcomposite1 \
67 | libxcursor1 \
68 | libxdamage1 \
69 | libxi6 \
70 | libxtst6 \
71 | libglib2.0-0 \
72 | libpango1.0-0 \
73 | libcups2 \
74 | libxrandr2 \
75 | libxrandr-dev \
76 | libxkbcommon0 \
77 | libgbm1
78 | ```
79 |
80 | Register an account on [Serper](https://serper.dev/) to obtain `SERPER_API_KEY` and add it to the environment variable.
81 | Start the web search service and specify the port number.
82 | Set the `SEARCH_SERVER_URL` to http://localhost:8021/search in the file [server.py](infini_websearch/configs/server.py).
83 |
84 | ```shell
85 | export SERPER_API_KEY=$YOUR_API_KEY
86 | cd infini_websearch/service
87 | python search_service.py --port 8021 --chrome ./chrome-linux64/chrome --chromedriver ./chromedriver-linux64/chromedriver
88 | ```
89 |
90 | #### 2. Starting Model Service
91 |
92 | Use vllm.entrypoints.openai.api_server to start the service and specify the port number. Set the `--served-model-name` to "megrez" and `--max-seq-len` to 32768.
93 | Set the `MODEL_SERVER_URL` in the file [server.py](infini_websearch/configs/server.py) with a default value of http://localhost:8011/v1/. Also, set the `MODEL_NAME` to "megrez".
94 |
95 | ```shell
96 | python -m vllm.entrypoints.openai.api_server --served-model-name megrez --model $MODEL_PATH --port 8011 --max-seq-len 32768 --trust_remote_code --gpu-memory-utilization 0.8
97 | ```
98 |
99 | #### 3. Starting Gradio Service
100 |
101 | Run [gradio_app.py](infini_websearch/gradio_app.py), specify the model path and port number.
102 |
103 | ```shell
104 | export no_proxy="localhost,127.0.0.1"
105 | python gradio_app.py -m $MODEL_PATH --port 7860
106 | ```
107 |
108 | After successful startup, you can use it by visiting http://localhost:7860/.
109 |
110 | ## Notes
111 |
112 | 1. We provide `WEBPAGE_SUMMARY_MAX_INPUT_TOKENS`, `WEBPAGE_SUMMARY_MAX_OUTPUT_TOKENS`, `SESSION_MAX_INPUT_TOKENS`, `CHAT_MAX_OUTPUT_TOKENS`, `AGENT_MAX_OUTPUT_TOKENS` to control the input and output lengths of the model. Use `SESSION_WINDOW_SIZE` to retain the most recent dialogue history. You can modify these settings as needed in [server.py](infini_websearch/configs/server.py).
113 | 2. After starting the first round of dialogue, toggling the websearch state will clear the dialogue history on the backend, but the frontend will still display the dialogue history.
114 | 3. If there is an exception with the search service (e.g. webpage loading timeout or server error), the observation from the tool call will return predefined messages (e.g. "The search page loading timed out, please try again"). You can customize the post-processing logic for boundary conditions in [websearch.py](infini_websearch/actions/websearch.py) and [search_service.py](infini_websearch/service/search_service.py).
115 | 4. When using [Serper](https://serper.dev/) ([search_service.py](infini_websearch/service/search_service.py)), we set the "hl" parameter to "zh-CN" to obtain Chinese search results as much as possible. If there are too many English webpages in the search results, it may lead to the model responding in English.
116 | 5. If the web search is successful but the model indicates that the search results do not contain relevant information to answer the question, you can check the summary information of each webpage printed in the console. If the summary shows "No relevant content", it means either the original webpage does not contain information related to the question, or the model failed to extract relevant information from the webpage.
117 |
118 | ## License
119 |
120 | The code in this open-source repository follows the [Apache 2.0](LICENSE) license.
121 |
--------------------------------------------------------------------------------
/infini_websearch/service/search_service.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import os
4 | import time
5 | from concurrent.futures import ThreadPoolExecutor, as_completed
6 | from typing import Dict, Generator, Optional, Tuple, Union
7 |
8 | import requests
9 | from fastapi import FastAPI, HTTPException, Request
10 | from fastapi.responses import StreamingResponse
11 | from selenium import webdriver
12 | from selenium.common.exceptions import TimeoutException
13 | from selenium.webdriver.chrome.options import Options
14 | from selenium.webdriver.chrome.service import Service
15 |
16 | parser = argparse.ArgumentParser()
17 | parser.add_argument("--chrome", type=str)
18 | parser.add_argument("--chromedriver", type=str)
19 | parser.add_argument("--port", type=int)
20 |
21 | args = parser.parse_args()
22 |
23 | app = FastAPI()
24 |
25 | SERPER_API_KEY = os.environ.get("SERPER_API_KEY")
26 |
27 |
28 | def get_webpage_content(url: str, chrome_path: str, chromedriver_path: str) -> str:
29 | """
30 | Load the content of web pages by chromedriver.
31 | """
32 | options = Options()
33 | options.add_argument("--no-sandbox")
34 | options.add_argument("--disable-dev-shm-usage")
35 | options.add_argument("--disable-extensions")
36 | options.add_argument("--disable-gpu")
37 | options.add_argument("--headless")
38 | options.add_argument("--disable-infobars")
39 | options.add_argument("--disable-browser-side-navigation")
40 | options.add_argument("--disable-features=VizDisplayCompositor")
41 | options.add_argument("--no-first-run")
42 | options.add_argument("--no-default-browser-check")
43 | options.add_argument("--disable-popup-blocking")
44 | options.add_argument("--disable-application-cache")
45 | options.add_argument("--dns-prefetch-disable")
46 | options.add_argument("--no-proxy-server")
47 | options.add_argument("--blink-settings=imagesEnabled=false")
48 | options.add_argument("--enable-http2")
49 | options.add_argument("--disable-quic")
50 | options.binary_location = chrome_path
51 | options.page_load_strategy = "eager"
52 | prefs = {
53 | "profile.managed_default_content_settings.images": 2,
54 | "profile.default_content_setting_values.notifications": 2,
55 | "download_restrictions": 3,
56 | }
57 | options.add_experimental_option("prefs", prefs)
58 | service = Service(executable_path=chromedriver_path)
59 | driver = webdriver.Chrome(options=options, service=service)
60 | try:
61 | timeout = 10
62 | start = time.time()
63 | driver.set_page_load_timeout(timeout)
64 | try:
65 | driver.get(url)
66 | except TimeoutException:
67 | print(f"页面加载超时({timeout}秒)")
68 | return "搜索页面加载超时, 请重试"
69 | end = time.time()
70 | print(f"读取网页内容耗时: {end - start}s")
71 | content = driver.execute_script("return document.body.innerText;")
72 | return content
73 | except Exception as e:
74 | print(e)
75 | return ""
76 | finally:
77 | driver.quit()
78 |
79 |
80 | def serper_search(
81 | search_term: str, search_type: Optional[str] = "search", timeout: int = 5, **kwargs
82 | ) -> Tuple[int, Union[Dict, str]]:
83 | """
84 | Get google search results by serper api (https://serper.dev/).
85 | """
86 | headers = {
87 | "X-API-KEY": SERPER_API_KEY,
88 | "Content-Type": "application/json",
89 | }
90 | params = {
91 | "q": search_term,
92 | "gl": "cn", # country
93 | "sort": "date",
94 | "hl": "zh-CN",
95 | **{key: value for key, value in kwargs.items() if value is not None},
96 | }
97 | try:
98 | response = requests.post(
99 | f"https://google.serper.dev/{search_type}",
100 | headers=headers,
101 | params=params,
102 | proxies=None,
103 | timeout=timeout,
104 | )
105 | except Exception as e:
106 | return -1, str(e)
107 | return response.status_code, response.json()
108 |
109 |
110 | def streaming_fetch_webpage_content(
111 | results: dict, num_search_pages: int, chrome_path: str, chromedriver_path: str
112 | ) -> Generator[Tuple[str, str], None, None]:
113 | url_infos = results["organic"][:num_search_pages]
114 |
115 | with ThreadPoolExecutor(max_workers=len(url_infos)) as executor:
116 | future_to_url = {
117 | executor.submit(
118 | get_webpage_content, url_info["link"], chrome_path, chromedriver_path
119 | ): url_info
120 | for url_info in url_infos
121 | }
122 | for future in as_completed(future_to_url):
123 | url_info = future_to_url[future]
124 | try:
125 | result = future.result()
126 | yield url_info, result
127 | except Exception as exc:
128 | print(f'{url_info["link"]} generated an exception: {exc}')
129 | yield url_info, ""
130 |
131 |
132 | @app.post("/search")
133 | async def search(request: Request):
134 | data = await request.json()
135 | print(data)
136 |
137 | start = time.time()
138 | status_code, response = serper_search(data["query"], timeout=10)
139 | end = time.time()
140 | print(f"搜索网页耗时: {end - start}s")
141 |
142 | if status_code != 200:
143 | raise HTTPException(status_code=500, detail="搜索网页超时, 请重试")
144 |
145 | def html_docs_text_generator():
146 | start = time.time()
147 | for url_info, content in streaming_fetch_webpage_content(
148 | response,
149 | num_search_pages=data["num_search_pages"],
150 | chrome_path=args.chrome,
151 | chromedriver_path=args.chromedriver,
152 | ):
153 | yield json.dumps(
154 | {
155 | "search_status_code": status_code,
156 | "search_response": response,
157 | "url_info": url_info,
158 | "html_content": content,
159 | },
160 | ensure_ascii=False,
161 | ) + "\n"
162 | end = time.time()
163 | print(f"解析网页耗时: {end - start}s")
164 |
165 | return StreamingResponse(html_docs_text_generator(), media_type="application/json")
166 |
167 |
168 | if __name__ == "__main__":
169 | import uvicorn
170 |
171 | uvicorn.run("search_service:app", host="0.0.0.0", port=args.port, reload=True)
172 |
--------------------------------------------------------------------------------
/infini_websearch/actions/websearch.py:
--------------------------------------------------------------------------------
1 | import json
2 | from typing import Callable, Dict, Generator, List, Optional
3 |
4 | import requests
5 | from transformers import AutoTokenizer
6 |
7 | from infini_websearch.actions.base_action import BaseAction
8 |
9 |
10 | class GoogleSearch(BaseAction):
11 | def __init__(
12 | self,
13 | server_url: str,
14 | summary_prompt_template: str,
15 | observation_prompt_template: str,
16 | num_search_webpages: int = 5,
17 | webpage_summary_max_input_tokens: int = 2048,
18 | webpage_load_timetout: float = 10.0,
19 | proxies: Optional[Dict] = None,
20 | ) -> None:
21 | self.server_url = server_url
22 | self.summary_prompt_template = summary_prompt_template
23 | self.observation_prompt_template = observation_prompt_template
24 | self.num_search_webpages = num_search_webpages
25 | self.webpage_summary_max_input_tokens = webpage_summary_max_input_tokens
26 | self.webpage_load_timetout = webpage_load_timetout
27 | if proxies is None:
28 | proxies = {"http": None, "https": None}
29 | self.proxies = proxies
30 |
31 | @property
32 | def function_defination(self) -> Optional[Dict]:
33 | return {
34 | "name": "googleWebSearch",
35 | "description": (
36 | "A Google Search Engine. "
37 | "Useful when you need to search information you don't know such as weather, "
38 | "exchange rate, current events."
39 | "Never ever use this tool when user want to translate"
40 | ),
41 | "parameters": {
42 | "type": "object",
43 | "properties": {
44 | "query": {
45 | "type": "string",
46 | "description": (
47 | "Content that users want to search for, such as 'weather', 'current events', etc."
48 | "If special characters such as '\n' appear in the search, "
49 | "these special characters must be ignored.\n"
50 | "Chinese characters are preferred."
51 | ),
52 | }
53 | },
54 | "required": ["query"],
55 | },
56 | }
57 |
58 | def run(
59 | self,
60 | user_question: str,
61 | arguments: Dict,
62 | llm_completion_funcion: Callable,
63 | tokenizer: AutoTokenizer,
64 | return_webpage_details: bool,
65 | ) -> Generator[Dict, None, None]:
66 | if "query" not in arguments:
67 | return {"observation": "调用工具失败, 缺乏必要输入参数, 请重试"}
68 |
69 | # get webpage content
70 | webpage_detail_list = []
71 | try:
72 | for webpage_detail in self.streaming_fetch_search_results(
73 | self.server_url,
74 | {
75 | "query": arguments["query"],
76 | "num_search_pages": self.num_search_webpages,
77 | },
78 | self.proxies,
79 | ):
80 | webpage_detail_list.append(webpage_detail)
81 | if return_webpage_details:
82 | yield webpage_detail
83 | except Exception as e:
84 | print(e)
85 | yield {"observation": '输出"websearch server发生错误, 请重试"'}
86 | return
87 |
88 | webpage_texts = [
89 | webpage_detail["html_content"] for webpage_detail in webpage_detail_list
90 | ]
91 |
92 | no_webpages_loaded = True
93 | for webpage_text in webpage_texts:
94 | if webpage_text != "搜索页面加载超时, 请重试":
95 | no_webpages_loaded = False
96 | break
97 |
98 | # all web pages are timing out when loading
99 | if no_webpages_loaded:
100 | summaries = webpage_texts
101 | yield {"observation": "搜索页面加载超时, 请重试"}
102 | else:
103 | summary_prompts = self.make_summary_tasks(
104 | query=arguments["query"],
105 | webpage_texts=webpage_texts,
106 | summary_prompt_template=self.summary_prompt_template,
107 | tokenizer=tokenizer,
108 | webpage_summary_max_input_tokens=self.webpage_summary_max_input_tokens,
109 | )
110 | print("#######summary prompts[start]######")
111 | print(summary_prompts)
112 | print("#######summary prompts[end]######")
113 | response_message = llm_completion_funcion(messages=summary_prompts)
114 | summaries = [choice.text for choice in response_message.choices]
115 | context = "\n".join(
116 | [
117 | f"[[citation:{str(i+1)}]]\n{summary}"
118 | for i, summary in enumerate(summaries)
119 | ]
120 | )
121 | yield {
122 | "observation": self.observation_prompt_template.format(
123 | context=context, question=user_question, keywords=arguments["query"]
124 | )
125 | }
126 | return
127 |
128 | @staticmethod
129 | def make_summary_tasks(
130 | query: str,
131 | webpage_texts: List[str],
132 | summary_prompt_template: str,
133 | tokenizer: AutoTokenizer,
134 | webpage_summary_max_input_tokens: int = 2048,
135 | ) -> List[str]:
136 | messages_all = []
137 | for webpage_text in webpage_texts:
138 | if len(webpage_text) > 0:
139 | webpage_tokens = tokenizer.encode(webpage_text)
140 | webpage_text = tokenizer.decode(
141 | webpage_tokens[:webpage_summary_max_input_tokens]
142 | )
143 | messages = [
144 | {"role": "system", "content": "You are a helpful assistant."},
145 | {
146 | "role": "user",
147 | "content": summary_prompt_template.format(
148 | question=query, context=webpage_text
149 | ),
150 | },
151 | ]
152 | messages_all.append(
153 | tokenizer.apply_chat_template(
154 | messages, tokenize=False, add_generation_prompt=True
155 | )
156 | )
157 | return messages_all
158 |
159 | @staticmethod
160 | def streaming_fetch_search_results(
161 | url: str, content: Dict, proxies: Dict
162 | ) -> Generator[Dict, None, str]:
163 | try:
164 | with requests.post(
165 | url, json=content, stream=True, proxies=proxies
166 | ) as response:
167 | response.raise_for_status()
168 | for line in response.iter_lines():
169 | if line:
170 | yield json.loads(line)
171 | except requests.exceptions.HTTPError as error:
172 | print(f"HTTP error occurred: {error}")
173 | return "网页加载超时"
174 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/gradio_app.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | from typing import Dict, Generator, List, Optional, Tuple
4 |
5 | import gradio as gr
6 | from gradio_toggle import Toggle
7 | from transformers import AutoTokenizer
8 |
9 | from infini_websearch.actions import GoogleSearch, parse_function_call_from_model_ouput
10 | from infini_websearch.configs import (
11 | AGENT_MAX_OUTPUT_TOKENS,
12 | AGENT_TEMPERATURE,
13 | CHAT_MAX_OUTPUT_TOKENS,
14 | CHAT_TEMPERATURE,
15 | CSS_STYLE,
16 | FUNCTION_CALLING_PROMPT_TEMPLATE,
17 | FUNCTION_END_TOKEN,
18 | FUNCTION_START_TOKEN,
19 | MAX_ACTION_TURNS,
20 | MODEL_NAME,
21 | MODEL_SERVER_URL,
22 | NUM_SEARCH_WEBPAGES,
23 | OBSERVATION_PROMPT_TEMPLATE,
24 | PROXIES,
25 | ROLE_PROMPT,
26 | SEARCH_SERVER_URL,
27 | SESSION_MAX_INPUT_TOKENS,
28 | SESSION_WINDOW_SIZE,
29 | STOP_TOKENS,
30 | SUMMARY_PROMPT_TEMPLATE,
31 | TIME_PROMPT_TEMPLATE,
32 | WEBPAGE_LOAD_TIMETOUT,
33 | WEBPAGE_SUMMARY_MAX_INPUT_TOKENS,
34 | WEBPAGE_SUMMARY_MAX_OUTPUT_TOKENS,
35 | )
36 | from infini_websearch.model import (
37 | get_vllm_model_output_function,
38 | include_special_tokens,
39 | split_text_by_special_token,
40 | )
41 | from infini_websearch.utils import (
42 | extract_citations,
43 | format_search_results,
44 | functions2str,
45 | get_datetime_now,
46 | )
47 |
48 | parser = argparse.ArgumentParser()
49 | parser.add_argument("--model-path", "-m", type=str)
50 | parser.add_argument("--port", type=int, default=7860)
51 |
52 | args = parser.parse_args()
53 |
54 | MODEL_PATH = args.model_path
55 | SERVER_PORT = args.port
56 |
57 | # tokenizer
58 | os.environ["TOKENIZERS_PARALLELISM"] = "false"
59 | TOKENIZER = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True)
60 |
61 |
62 | # function name -> action
63 | ACTIONS_MAP = {
64 | "googleWebSearch": GoogleSearch(
65 | server_url=SEARCH_SERVER_URL,
66 | num_search_webpages=NUM_SEARCH_WEBPAGES,
67 | summary_prompt_template=SUMMARY_PROMPT_TEMPLATE,
68 | observation_prompt_template=OBSERVATION_PROMPT_TEMPLATE,
69 | webpage_summary_max_input_tokens=WEBPAGE_SUMMARY_MAX_INPUT_TOKENS,
70 | webpage_load_timetout=WEBPAGE_LOAD_TIMETOUT,
71 | proxies=PROXIES,
72 | ),
73 | }
74 | # tool -> function name
75 | TOOLS_TO_ACTION_NAMES = {
76 | "websearch": "googleWebSearch",
77 | }
78 |
79 |
80 | def get_system_prompt(functions: Optional[List] = None) -> str:
81 | """
82 | Get system prompt for current conversation.
83 | """
84 | if functions is None:
85 | functions = []
86 | current_time, weekday = get_datetime_now()
87 | time_info = TIME_PROMPT_TEMPLATE.format(current_time=current_time, weekday=weekday)
88 | system_prompt = ROLE_PROMPT + "\n" + time_info
89 | if len(functions) > 0:
90 | system_prompt += "\n" + FUNCTION_CALLING_PROMPT_TEMPLATE.format(
91 | functions=functions2str(functions)
92 | )
93 | return system_prompt
94 |
95 |
96 | def user(
97 | user_message: str, history: List[Dict], session_state: gr.State
98 | ) -> Tuple[str, List[Dict], gr.State]:
99 | """
100 | Add user input message to history.
101 | """
102 | session_state["messages"] += [{"role": "user", "content": user_message}]
103 | return "", history + [{"role": "user", "content": user_message}], session_state
104 |
105 |
106 | def bot(
107 | history: List[Dict],
108 | websearch: bool,
109 | session_state: gr.State,
110 | ) -> Generator[List[Dict], None, None]:
111 | """
112 | Main workflow.
113 | """
114 | # get registered tools
115 | registered_tools = []
116 | if websearch is True:
117 | registered_tools.append("websearch")
118 | temperature = AGENT_TEMPERATURE
119 | max_gen_length = AGENT_MAX_OUTPUT_TOKENS
120 | else:
121 | temperature = CHAT_TEMPERATURE
122 | max_gen_length = CHAT_MAX_OUTPUT_TOKENS
123 | registered_function_names = [
124 | TOOLS_TO_ACTION_NAMES[tool] for tool in registered_tools
125 | ]
126 | registered_functions = [
127 | ACTIONS_MAP[function_name] for function_name in registered_function_names
128 | ]
129 | functions = [
130 | function.function_defination
131 | for function in registered_functions
132 | if function.function_defination is not None
133 | ]
134 |
135 | # get system prompt
136 | system_prompt = get_system_prompt(functions=functions)
137 | # get model streaming output function
138 | llm_streaming_output_func = get_vllm_model_output_function(
139 | url=MODEL_SERVER_URL,
140 | model_name=MODEL_NAME,
141 | chat_mode=True,
142 | stream=True,
143 | model_config={
144 | "temperature": temperature,
145 | "max_tokens": max_gen_length,
146 | "stop": STOP_TOKENS,
147 | },
148 | )
149 |
150 | input_dict = {
151 | "temperature": temperature,
152 | "max_gen_length": max_gen_length,
153 | "use_websearch": "websearch" in registered_tools,
154 | "MODEL_NAME": MODEL_NAME,
155 | }
156 |
157 | for _ in range(MAX_ACTION_TURNS * 2):
158 | # ASSISTANT answers two times per action turn
159 | # answer1: <|function_start|>xxx<|function_end|>
160 | # answer2: observation -> final answer
161 |
162 | """
163 | retain SESSION_WINDOW_SIZE turns (max_sequence_length is short -> 4096)
164 | """
165 | messages_truncated = truncate_messages(
166 | messages=session_state["messages"],
167 | tokenizer=TOKENIZER,
168 | session_window_size=SESSION_WINDOW_SIZE,
169 | max_input_tokens=SESSION_MAX_INPUT_TOKENS,
170 | system_prompt=system_prompt,
171 | )
172 |
173 | messages_input = [
174 | {"role": "system", "content": system_prompt}
175 | ] + messages_truncated
176 | input_dict.update(dict(messages=messages_input))
177 | # print input prompt
178 | print(
179 | TOKENIZER.apply_chat_template(
180 | messages_input, tokenize=False, add_generation_prompt=True
181 | )
182 | )
183 |
184 | response_raw = ""
185 | response_gradio = ""
186 | """
187 | streaming output status:
188 | 1. [chat]: generating chat message
189 | 2. [function start]: start generating function calling information ([chat] -> [function])
190 | 3. [function]: generating function calling information
191 | 4. [function end]: end generating function calling information ([function] -> [chat])
192 | """
193 | # in [function] status?
194 | function_status = False
195 | chunk_buffer = ""
196 | for chunk in llm_streaming_output_func(messages=input_dict["messages"]):
197 | chunk_buffer += chunk
198 | # '<|function_start|>' and '<|function_end|>' appear to be truncated ?
199 | if chunk_buffer.rfind("|>") < chunk_buffer.rfind("<|"):
200 | continue
201 | # [function start] status: ([chat] -> [function])
202 | if function_status is False and include_special_tokens(
203 | chunk_buffer, [FUNCTION_START_TOKEN]
204 | ):
205 | function_status = True
206 | chat_part, tool_part = split_text_by_special_token(
207 | chunk_buffer, FUNCTION_START_TOKEN
208 | )
209 | tool_part = FUNCTION_START_TOKEN + tool_part
210 | # add chat message to history
211 | if len(response_gradio + chat_part) > 0:
212 | history.append(
213 | {"role": "assistant", "content": response_gradio + chat_part}
214 | )
215 | response_gradio = tool_part
216 | response_raw += chat_part + tool_part
217 | chunk_buffer = ""
218 | yield history + [
219 | {
220 | "role": "assistant",
221 | "content": response_gradio,
222 | "metadata": {"title": "tool parameters"},
223 | }
224 | ]
225 | # [function end] status: ([function] -> [chat])
226 | elif function_status is True and include_special_tokens(
227 | chunk_buffer, [FUNCTION_END_TOKEN]
228 | ):
229 | chat_part, _ = split_text_by_special_token(
230 | chunk_buffer, FUNCTION_END_TOKEN
231 | )
232 | response_gradio += chat_part + FUNCTION_END_TOKEN
233 | response_raw += chat_part + FUNCTION_END_TOKEN
234 | chunk_buffer = ""
235 | history.append(
236 | {
237 | "role": "assistant",
238 | "content": response_gradio,
239 | "metadata": {"title": "tool parameters"},
240 | }
241 | )
242 | yield history
243 | break
244 | # [function] status
245 | elif function_status is True:
246 | response_gradio += chunk_buffer
247 | response_raw += chunk_buffer
248 | chunk_buffer = ""
249 | yield history + [
250 | {
251 | "role": "assistant",
252 | "content": response_gradio,
253 | "metadata": {"title": "tool parameters"},
254 | }
255 | ]
256 | # [chat] status
257 | elif function_status is False:
258 | citations = extract_citations(chunk_buffer)
259 | if len(citations) > 0 and len(session_state["url_infos"]) > 0:
260 | chunk_new = chunk_buffer
261 | for citation in citations:
262 | url_ind = int(citation) - 1
263 | # hardcoding for out-of-bounds
264 | if url_ind < 0:
265 | url_ind = 0
266 | elif url_ind >= len(session_state["url_infos"]):
267 | url_ind = len(session_state["url_infos"]) - 1
268 | # Add a space before the tag to prevent rendering
269 | # errors when multiple tags are adjacent to
270 | # each other.
271 | chunk_new = chunk_new.replace(
272 | f"[citation:{citation}]",
273 | f' {citation}', # noqa: E501
274 | )
275 | response_gradio += chunk_new
276 | else:
277 | response_gradio += chunk_buffer
278 | response_raw += chunk_buffer
279 |
280 | chunk_buffer = ""
281 | yield history + [{"role": "assistant", "content": response_gradio}]
282 |
283 | if session_state["stop_generation"] is True:
284 | session_state["stop_generation"] = False
285 | break
286 |
287 | # if streaming ends with [chat] status, add response to history
288 | if not include_special_tokens(response_gradio, FUNCTION_END_TOKEN):
289 | history.append({"role": "assistant", "content": response_gradio})
290 |
291 | session_state["messages"].append({"role": "assistant", "content": response_raw})
292 |
293 | # no tool registered, end this turn
294 | if len(registered_tools) == 0:
295 | break
296 |
297 | function_name, function_arguments = parse_function_call_from_model_ouput(
298 | response_raw,
299 | registered_function_names,
300 | speical_tokens_map=dict(
301 | function_start_token=FUNCTION_START_TOKEN,
302 | function_end_token=FUNCTION_END_TOKEN,
303 | ),
304 | )
305 |
306 | # no tool use this turn, end this turn
307 | if function_arguments is None:
308 | break
309 |
310 | # something is wrong, use function_arguments as observation (error
311 | # message)
312 | if function_name is None and isinstance(function_arguments, str):
313 | history.append({"role": "observation", "content": function_arguments})
314 | session_state["messages"].append(
315 | {"role": "observation", "content": function_arguments}
316 | )
317 | continue
318 |
319 | url_infos, html_contents = [], []
320 | latest_tool_response = None
321 | action = ACTIONS_MAP[function_name]
322 | observation = None
323 | if function_name == "googleWebSearch":
324 | observation_genrator = action.run(
325 | user_question=session_state["messages"][-2]["content"],
326 | arguments=function_arguments,
327 | llm_completion_funcion=get_vllm_model_output_function(
328 | url=MODEL_SERVER_URL,
329 | model_name=MODEL_NAME,
330 | chat_mode=False,
331 | stream=False,
332 | model_config={
333 | "temperature": temperature,
334 | "max_tokens": WEBPAGE_SUMMARY_MAX_OUTPUT_TOKENS,
335 | "stop": STOP_TOKENS,
336 | },
337 | ),
338 | tokenizer=TOKENIZER,
339 | return_webpage_details=True,
340 | )
341 | for item in gr.Progress().tqdm(observation_genrator, desc="summarizing..."):
342 | if isinstance(item, dict):
343 | if "observation" in item:
344 | observation = item["observation"]
345 | break
346 | url_infos.append(item["url_info"])
347 | html_contents.append(item["html_content"])
348 | yield history + [
349 | {
350 | "role": "assistant",
351 | "content": format_search_results(url_infos),
352 | "metadata": {"title": "tool results"},
353 | }
354 | ]
355 | # error message
356 | elif isinstance(item, str):
357 | latest_tool_response = item
358 | yield history + [
359 | {
360 | "role": "assistant",
361 | "content": latest_tool_response,
362 | "metadata": {"title": "tool results"},
363 | }
364 | ]
365 | else:
366 | raise NotImplementedError
367 |
368 | if len(url_infos) > 0:
369 | history.append(
370 | {
371 | "role": "assistant",
372 | "content": format_search_results(url_infos),
373 | "metadata": {"title": "tool results"},
374 | }
375 | )
376 | else:
377 | history.append(
378 | {
379 | "role": "assistant",
380 | "content": latest_tool_response,
381 | "metadata": {"title": "tool results"},
382 | }
383 | )
384 |
385 | # update url_infos
386 | session_state["url_infos"] = url_infos
387 | else:
388 | observation = action.run(function_arguments)
389 |
390 | assert observation is not None
391 | history.append({"role": "observation", "content": observation})
392 | session_state["messages"].append(
393 | {"role": "observation", "content": observation}
394 | )
395 |
396 |
397 | def truncate_messages(
398 | messages: List[Dict],
399 | tokenizer: AutoTokenizer,
400 | session_window_size: int,
401 | max_input_tokens: int,
402 | system_prompt: str,
403 | ) -> List[Dict]:
404 | """
405 | truncate messages for model input by session_window_size and max_input_tokens
406 | """
407 | # get parts for each turn
408 | turn_start_inds = []
409 | for ind, message in enumerate(messages):
410 | if message["role"] == "user":
411 | turn_start_inds.append(ind)
412 | # only latest turns are used as input
413 | turn_start_inds_used = turn_start_inds[-session_window_size:]
414 | messages_parts = []
415 | for i in range(len(turn_start_inds_used)):
416 | turn_start_ind = turn_start_inds_used[i]
417 | turn_end_ind = (
418 | len(messages)
419 | if i + 1 >= len(turn_start_inds_used)
420 | else turn_start_inds_used[i + 1]
421 | )
422 | messages_parts.append(messages[turn_start_ind:turn_end_ind])
423 | # truncate by max_input_tokens
424 | messages_truncated = []
425 | for i, messages_part in enumerate(reversed(messages_parts)):
426 | if (
427 | i == 0
428 | or len(
429 | tokenizer.apply_chat_template(
430 | [{"role": "system", "content": system_prompt}]
431 | + messages_truncated
432 | + messages_part,
433 | tokenize=True,
434 | )
435 | )
436 | < max_input_tokens
437 | ):
438 | messages_truncated = messages_part + messages_truncated
439 | else:
440 | break
441 | return messages_truncated
442 |
443 |
444 | def stop_response(session_state: gr.State) -> gr.State:
445 | session_state["stop_generation"] = True
446 | return session_state
447 |
448 |
449 | def clear(history: List[Dict], session_state: gr.State) -> Tuple[List[Dict], gr.State]:
450 | session_state["messages"] = []
451 | session_state["url_infos"] = []
452 | session_state["stop_generation"] = False
453 | return [], session_state
454 |
455 |
456 | def toggle_change(session_state: gr.State) -> gr.State:
457 | session_state["messages"] = []
458 | return session_state
459 |
460 |
461 | with gr.Blocks(
462 | css=CSS_STYLE, fill_height=True, elem_classes="canvas", theme=gr.themes.Monochrome()
463 | ) as demo:
464 | # chatbot interface
465 | with gr.Row(equal_height=False, variant="compact"):
466 | with gr.Column(scale=1.0, elem_classes="fullheight"):
467 | chatbot = gr.Chatbot(
468 | type="messages",
469 | elem_classes="chatbot",
470 | label="infini-websearch",
471 | autoscroll=True,
472 | )
473 |
474 | # conversation state vars
475 | session_state = gr.State(
476 | dict(
477 | messages=[],
478 | url_infos=[],
479 | stop_generation=False,
480 | )
481 | )
482 | toggle_is_interactive = gr.State(value=True)
483 |
484 | # bottom bar
485 | with gr.Group(elem_classes="bottom-bar") as bottom_bar:
486 | msg = gr.Textbox(label="question")
487 | with gr.Row():
488 | clear_btn = gr.Button("Clear")
489 | stop_btn = gr.Button("Stop")
490 |
491 | # toggle
492 | with gr.Group() as toggle_group:
493 | websearch = Toggle(
494 | label="websearch",
495 | value=True,
496 | interactive=True,
497 | )
498 |
499 | websearch.change(toggle_change, [session_state], [session_state])
500 | msg.submit(
501 | user, [msg, chatbot, session_state], outputs=[msg, chatbot, session_state]
502 | ).then(
503 | bot,
504 | [chatbot, websearch, session_state],
505 | outputs=[chatbot],
506 | concurrency_limit=2,
507 | )
508 | clear_btn.click(clear, [chatbot, session_state], outputs=[chatbot, session_state])
509 | stop_btn.click(stop_response, [session_state], outputs=[session_state], queue=False)
510 |
511 |
512 | if __name__ == "__main__":
513 | demo.launch(share=False, server_port=SERVER_PORT)
514 |
--------------------------------------------------------------------------------