├── .gitignore ├── Auto-RAG ├── __init__.py ├── autorag.py └── gui.py ├── LICENSE ├── assets ├── autorag.gif └── results_.png ├── readme.md └── scripts ├── deploy.sh ├── prepare_retriever.sh ├── run_evaluation.sh └── run_gui.sh /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .idea 11 | .Python 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | applications/DeepSpeed-Chat/data 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | pip-wheel-metadata/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | db.sqlite3-journal 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | .python-version 88 | 89 | # pipenv 90 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 91 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 92 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 93 | # install all needed dependencies. 94 | #Pipfile.lock 95 | 96 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 97 | __pypackages__/ 98 | 99 | # Celery stuff 100 | celerybeat-schedule 101 | celerybeat.pid 102 | 103 | # SageMath parsed files 104 | *.sage.py 105 | 106 | # Environments 107 | .env 108 | .venv 109 | env/ 110 | venv/ 111 | ENV/ 112 | env.bak/ 113 | venv.bak/ 114 | 115 | # Spyder project settings 116 | .spyderproject 117 | .spyproject 118 | 119 | # Rope project settings 120 | .ropeproject 121 | 122 | # mkdocs documentation 123 | /site 124 | 125 | # mypy 126 | .mypy_cache/ 127 | .dmypy.json 128 | dmypy.json 129 | 130 | # Pyre type checker 131 | .pyre/ 132 | 133 | 134 | # vscode 135 | .vscode 136 | 137 | 138 | # third party models 139 | yala/model/third_party_models 140 | 141 | # aim 142 | .aim 143 | 144 | # test files 145 | _test*.py 146 | _test*.ipynb 147 | 148 | # experimental configs 149 | experimental_configs/ 150 | 151 | # hydra logs 152 | outputs/ 153 | 154 | # pytest configs 155 | tests/configs/ 156 | 157 | # cibuildwheel 158 | wheelhouse/ -------------------------------------------------------------------------------- /Auto-RAG/__init__.py: -------------------------------------------------------------------------------- 1 | from .autorag import AutoRAGAssistant, AutoRAGConfig 2 | 3 | __all__ = ["AutoRAGAssistant", "AutoRAGConfig"] -------------------------------------------------------------------------------- /Auto-RAG/autorag.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | from dataclasses import dataclass 3 | from typing import Any, Optional, Generator 4 | 5 | from omegaconf import MISSING 6 | 7 | from flexrag.assistant import AssistantBase, ASSISTANTS 8 | from flexrag.common_dataclass import RetrievedContext 9 | from flexrag.models import GenerationConfig, OpenAIGenerator, OpenAIGeneratorConfig 10 | from flexrag.prompt import ChatPrompt, ChatTurn 11 | from flexrag.retriever import DenseRetriever, DenseRetrieverConfig 12 | 13 | 14 | Knowledge_Prompt = """ 15 | Your task is to generate one corresponding wikipedia document based on the given query to help the LLM answer questions. 16 | 17 | Demostration: 18 | 19 | Origin Question: How many episodes in a season of vampire diaries? 20 | 21 | Query: The Vampire Diaries episode count 22 | 23 | Document: The Vampire Diaries has a total of 171 episodes over 8 seasons. The show's first season had 22 episodes, the second season had 22 episodes, the third season had 22 episodes, the fourth season had 23 episodes, the fifth season had 22 episodes, the sixth season had 22 episodes, the seventh season had 22 episodes, and the eighth season had 16 episodes. 24 | 25 | ### 26 | 27 | Origin Question: Who developed an explanation for the photoelectric effect? 28 | 29 | Query: Photoelectric Effect Explanation 30 | 31 | Document: To make sense of the fact that light can eject electrons even if its intensity is low, Albert Einstein proposed that a beam of light is not a wave propagating through space, but rather a collection of discrete wave packets (photons), each with energy hν. This shed light on Max Planck's previous discovery of the Planck relation (E = hν) linking energy (E) and frequency (ν) as arising from quantization of energy. The factor h is known as the Planck constant. In 1887, Heinrich Hertz discovered that electrodes illuminated with ultraviolet light create electric sparks more easily. In 1900, while studying black-body radiation, the German physicist Max Planck suggested that the energy carried by electromagnetic waves could only be released 32 | 33 | ### 34 | 35 | Origin Question: District of maharashtra that are part of red corridor? 36 | 37 | Query: Red Corridor in Maharashtra districts 38 | 39 | Document: The Red Corridor in Maharashtra includes the following districts: Chandrapur, Gondia, and Gadchiroli. 40 | 41 | ### 42 | 43 | Origin Question: Who played jason in mighty morphin power rangers? 44 | 45 | Query: Mighty Morphin Power Rangers Jason 46 | 47 | Document: from Dairanger were featured in the second season while only the Kakuranger mecha was featured in the third season, though the Kakuranger costumes were later used for the mini-series Mighty Morphin Alien Rangers. The series was produced by MMPR Productions and distributed by Saban Entertainment, while the show's merchandise was produced and distributed by Bandai Entertainment. The series was well known for its campy tone. In 2010, a re-version of Mighty Morphin Power Rangers, with a revised new look of the original 1993 logo, comic book-referenced graphics, and extra alternative visual effects, was broadcast on ABC Kids, and Bandai produced brand new toys to coincide with the series. Only the first 32 of season one's 60 episodes were remade. 48 | 49 | ### 50 | 51 | Origin Question: {} 52 | 53 | Query: {} 54 | 55 | Document: """ 56 | 57 | 58 | @dataclass 59 | class AutoRAGConfig(OpenAIGeneratorConfig, DenseRetrieverConfig): 60 | data_path: str = MISSING 61 | max_iter: int = 10 62 | elicit_max_iter: int = 5 63 | max_passages: int = 2 64 | verbose: bool = False 65 | 66 | 67 | @ASSISTANTS("autorag", config_class=AutoRAGConfig) 68 | class AutoRAGAssistant(AssistantBase): 69 | def __init__(self, cfg: AutoRAGConfig): 70 | self.prompt = ChatPrompt( 71 | system=( 72 | "Answer the question by retrieving external knowledge. " 73 | "Extract useful information from each retrieved document. " 74 | "If the information is insufficient or irrelevant, " 75 | "refine your query and search again until you are able to answer the question." 76 | ) 77 | ) 78 | # load model & retriever 79 | self.main_model = OpenAIGenerator(cfg) 80 | self.retriever = DenseRetriever(cfg) 81 | 82 | # set parameters 83 | self.verbose = cfg.verbose 84 | self.elicit_max_iter = cfg.elicit_max_iter 85 | self.max_iter = cfg.max_iter 86 | self.max_passages = cfg.max_passages 87 | return 88 | 89 | def interactive_answer( 90 | self, history: list[dict[str, Any]], show_details: bool 91 | ) -> Generator[tuple[list[dict], None], None, None]: 92 | # store the search history for interacting with the user 93 | question = history[-1]["content"].strip() 94 | if not show_details: 95 | history.append( 96 | { 97 | "role": "assistant", 98 | "content": "Retrieving and reasoning...", 99 | "metadata": {"title": "🤖 Auto-RAG"}, 100 | } 101 | ) 102 | yield history, [] 103 | 104 | queries = [question] 105 | retrieved_ids = [] 106 | prompt = deepcopy(self.prompt) 107 | prompt.update(ChatTurn(role="user", content="Question: " + question.strip())) 108 | current_iter = 0 109 | first_model_output = None 110 | max_iter = self.max_iter 111 | # start retrieval iteration 112 | while max_iter > 0: 113 | if self.verbose: 114 | print("input", prompt) 115 | 116 | # generate thought & action 117 | first_model_output = self.main_model.chat( 118 | prompts=[prompt], 119 | generation_config=GenerationConfig(do_sample=False, max_new_tokens=200), 120 | )[0][0].strip() 121 | prompt.update(ChatTurn(role="assistant", content=first_model_output)) 122 | history.append( 123 | { 124 | "role": "assistant", 125 | "content": first_model_output, 126 | "metadata": {"title": "🤖 Auto-RAG"}, 127 | } 128 | ) 129 | if show_details: 130 | yield history, [] 131 | 132 | # extract action 133 | if "Query:".lower() in first_model_output.lower(): 134 | queries = [first_model_output.split("Query:")[-1].strip()] 135 | current_iter += 1 136 | elif "final answer" in first_model_output.lower(): 137 | prompt.update(ChatTurn(role="assistant", content=first_model_output)) 138 | break 139 | else: 140 | print("Exception: Follow Failed") 141 | print(prompt) 142 | print(first_model_output) 143 | 144 | # retrieve documents 145 | document = None 146 | queries[0] = queries[0].replace("[Dense]", "").strip() 147 | documents = [] 148 | retrieval_results = self.retriever.search(queries[0])[0] 149 | 150 | # process retrieved documents 151 | for result in retrieval_results: 152 | if result.context_id not in retrieved_ids: 153 | retrieved_ids.append(result.context_id) 154 | documents.append(result.data["text"].split("\n")[-1]) 155 | if len(documents) >= self.max_passages: 156 | break 157 | document = " ".join(documents) 158 | prompt.update( 159 | ChatTurn( 160 | role="user", 161 | content=f"Retrieved Document_{current_iter}: {document.strip()}", 162 | ) 163 | ) 164 | history.append( 165 | { 166 | "role": "assistant", 167 | "content": f"Retrieved Document_{current_iter}: {document.strip()}", 168 | "metadata": {"title": "🔍︎ **Dense Retriever**"}, 169 | } 170 | ) 171 | if show_details: 172 | yield history, [] 173 | 174 | max_iter -= 1 175 | 176 | first_model_output = "" 177 | if max_iter == 0: 178 | first_model_output = self.main_model.chat( 179 | prompts=[prompt], 180 | generation_config=GenerationConfig(temperature=0.0, max_new_tokens=150), 181 | )[0][0].strip() 182 | prompt.update(ChatTurn(role="assistant", content=first_model_output)) 183 | history.append( 184 | { 185 | "role": "assistant", 186 | "content": first_model_output, 187 | "metadata": {"title": "🤖 Auto-RAG"}, 188 | } 189 | ) 190 | if show_details: 191 | yield history, [] 192 | 193 | max_iter = self.elicit_max_iter 194 | 195 | # try to generate pesudo document for answer the question 196 | while "Refined Query:" in first_model_output and max_iter > 0: 197 | current_iter += 1 198 | query = first_model_output.split("Refined Query:")[-1].strip() 199 | 200 | document_prompt = Knowledge_Prompt.format(question, query) 201 | 202 | document = self.main_model.generate( 203 | prefixes=[document_prompt], 204 | generation_config=GenerationConfig( 205 | temperature=0.0, 206 | max_new_tokens=200, 207 | stop_str=["<|eot_id|>", "\n"], 208 | ), 209 | )[0][0].strip() 210 | 211 | # generate thought & action based on the pseudo document 212 | prompt.update( 213 | ChatTurn( 214 | role="user", 215 | content=f"Retrieved Document_{current_iter}: {document.strip()}", 216 | ) 217 | ) 218 | history.append( 219 | { 220 | "role": "user", 221 | "content": document.strip(), 222 | "metadata": {"title": "Parametric Knowledge"}, 223 | } 224 | ) 225 | if show_details: 226 | yield history, [] 227 | first_model_output = self.main_model.chat( 228 | prompts=[prompt], 229 | generation_config=GenerationConfig( 230 | do_sample=False, 231 | max_new_tokens=150, 232 | ), 233 | )[0][0].strip() 234 | prompt.update(ChatTurn(role="assistant", content=first_model_output)) 235 | history.append( 236 | { 237 | "role": "assistant", 238 | "content": first_model_output, 239 | "metadata": {"title": "🤖 Auto-RAG"}, 240 | } 241 | ) 242 | if show_details: 243 | yield history, [] 244 | max_iter -= 1 245 | 246 | # Generate the final answer 247 | if not show_details: 248 | backup_history = [] 249 | for id in range(len(history)): 250 | print(history[id]) 251 | new_item = {} 252 | if type(history[id]) == dict: 253 | new_item["role"] = history[id]["role"] 254 | new_item["content"] = history[id]["content"] 255 | if "metadata" in history[id]: 256 | new_item["metadata"] = history[id]["metadata"] 257 | else: 258 | new_item["role"] = history[id]["role"] 259 | new_item["content"] = history[id]["content"] 260 | if history[id]["metadata"]: 261 | new_item["metadata"] = history[id]["metadata"] 262 | backup_history.append(new_item) 263 | history = [history[0], history[-1]] 264 | history[-1]["content"] = ( 265 | history[-1]["content"].split("Final Answer:")[-1].strip() 266 | ) 267 | else: 268 | backup_history = [] 269 | backup_history.append(history[0]) 270 | backup_history.append( 271 | { 272 | "role": history[-1]["role"], 273 | "content": history[-1]["content"] 274 | .split("Final Answer:")[-1] 275 | .strip(), 276 | "metadata": history[-1]["metadata"], 277 | } 278 | ) 279 | yield history, backup_history 280 | 281 | def answer( 282 | self, question: str 283 | ) -> tuple[str, Optional[list[RetrievedContext]], Optional[dict]]: 284 | queries = [question] 285 | retrieved_ids = [] 286 | prompt = deepcopy(self.prompt) 287 | prompt.update(ChatTurn(role="user", content="Question: " + question.strip())) 288 | current_iter = 0 289 | first_model_output = None 290 | max_iter = self.max_iter 291 | response = "" 292 | # start retrieval iteration 293 | while max_iter > 0: 294 | if self.verbose: 295 | print("input", prompt) 296 | 297 | # generate thought & action 298 | first_model_output = self.main_model.chat( 299 | prompts=[prompt], 300 | generation_config=GenerationConfig(do_sample=False, max_new_tokens=200), 301 | )[0][0].strip() 302 | prompt.update(ChatTurn(role="assistant", content=first_model_output)) 303 | 304 | # extract action 305 | if "Query:".lower() in first_model_output.lower(): 306 | queries = [first_model_output.split("Query:")[-1].strip()] 307 | current_iter += 1 308 | elif "final answer" in first_model_output.lower(): 309 | prompt.update(ChatTurn(role="assistant", content=first_model_output)) 310 | response = first_model_output.split("Final Answer: ")[-1].strip() 311 | break 312 | else: 313 | print("Exception: Follow Failed") 314 | print(prompt) 315 | print(first_model_output) 316 | 317 | # retrieve documents 318 | document = None 319 | queries[0] = queries[0].replace("[Dense]", "").strip() 320 | documents = [] 321 | retrieval_results = self.retriever.search(queries[0])[0] 322 | 323 | # process retrieved documents 324 | for result in retrieval_results: 325 | if result.context_id not in retrieved_ids: 326 | retrieved_ids.append(result.context_id) 327 | documents.append(result.data["text"].split("\n")[-1]) 328 | if len(documents) >= self.max_passages: 329 | break 330 | document = " ".join(documents) 331 | prompt.update( 332 | ChatTurn( 333 | role="user", 334 | content=f"Retrieved Document_{current_iter}: {document.strip()}", 335 | ) 336 | ) 337 | max_iter -= 1 338 | 339 | first_model_output = "" 340 | if max_iter == 0: 341 | first_model_output = self.main_model.chat( 342 | prompts=[prompt], 343 | generation_config=GenerationConfig(temperature=0.0, max_new_tokens=150), 344 | )[0][0].strip() 345 | prompt.update(ChatTurn(role="assistant", content=first_model_output)) 346 | 347 | # try to generate pesudo document for answer the question 348 | max_iter = self.elicit_max_iter 349 | while "Refined Query:" in first_model_output and max_iter > 0: 350 | current_iter += 1 351 | query = first_model_output.split("Refined Query:")[-1].strip() 352 | 353 | document_prompt = Knowledge_Prompt.format(question, query) 354 | 355 | document = self.main_model.generate( 356 | prefixes=[document_prompt], 357 | generation_config=GenerationConfig( 358 | temperature=0.0, 359 | max_new_tokens=200, 360 | stop_str=["<|eot_id|>", "\n"], 361 | ), 362 | )[0][0].strip() 363 | 364 | # generate thought & action based on the pseudo document 365 | prompt.update( 366 | ChatTurn( 367 | role="user", 368 | content=f"Retrieved Document_{current_iter}: {document.strip()}", 369 | ) 370 | ) 371 | first_model_output = self.main_model.chat( 372 | prompts=[prompt], 373 | generation_config=GenerationConfig( 374 | do_sample=False, 375 | max_new_tokens=150, 376 | ), 377 | )[0][0].strip() 378 | prompt.update(ChatTurn(role="assistant", content=first_model_output)) 379 | max_iter -= 1 380 | return ( 381 | response, 382 | [ 383 | RetrievedContext( 384 | retriever="autorag", query=question, data={"text": document} 385 | ) 386 | ], 387 | {"prompt": prompt}, 388 | ) 389 | -------------------------------------------------------------------------------- /Auto-RAG/gui.py: -------------------------------------------------------------------------------- 1 | import gradio as gr 2 | import hydra 3 | from hydra.core.config_store import ConfigStore 4 | 5 | from autorag import AutoRAGAssistant, AutoRAGConfig 6 | 7 | 8 | def update_details_button(show_details): 9 | new_label = "Hide Details" if show_details else "Show Details" 10 | return gr.update(value=new_label) 11 | 12 | 13 | def update_show_details(show_details): 14 | 15 | show_details = not show_details 16 | return show_details 17 | 18 | 19 | def update_history(history, backup_history): 20 | print(history) 21 | print(backup_history) 22 | tmp = history 23 | history = backup_history 24 | backup_history = tmp 25 | return history, backup_history 26 | 27 | 28 | def user(user_message, history: list): 29 | history.append({"role": "user", "content": user_message}) 30 | yield "", history 31 | 32 | 33 | html_output = """ 34 |
35 | Auto-RAG: Autonomous Retrieval-Augmented Generation for Large Language Models 36 |
37 |
38 | Authors: Tian Yu, Shaolei Zhang, and Yang Feng 39 |
40 | """ 41 | 42 | 43 | cs = ConfigStore.instance() 44 | cs.store(name="default", node=AutoRAGConfig) 45 | 46 | 47 | @hydra.main(version_base="1.3", config_path=None, config_name="default") 48 | def main(config: AutoRAGConfig): 49 | # load assistant 50 | assistant = AutoRAGAssistant(config) 51 | 52 | # run assistant 53 | with gr.Blocks() as demo: 54 | gr.HTML(html_output) 55 | show_details = gr.State(True) 56 | backup_history = gr.State([]) 57 | 58 | chatbot = gr.Chatbot( 59 | type="messages", 60 | label="Auto-RAG", 61 | height=500, 62 | placeholder="Ask me anything!", 63 | show_copy_button=True, 64 | bubble_full_width=False, 65 | layout="bubble", 66 | ) 67 | msg = gr.Textbox() 68 | with gr.Row(): 69 | toggle_button = gr.Button(f"Hide Details") 70 | clear_button = gr.Button("Clear") 71 | toggle_button.click(update_show_details, show_details, show_details).then( 72 | update_details_button, show_details, toggle_button 73 | ).then(update_history, [chatbot, backup_history], [chatbot, backup_history]) 74 | msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then( 75 | assistant.interactive_answer, 76 | [chatbot, show_details], 77 | [chatbot, backup_history], 78 | ) 79 | clear_button.click(lambda x: [], chatbot, chatbot).then( 80 | lambda x: [], backup_history, backup_history 81 | ) 82 | demo.launch(server_name="0.0.0.0") 83 | 84 | 85 | if __name__ == "__main__": 86 | main() 87 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /assets/autorag.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ictnlp/Auto-RAG/b487d235b39f9db5dcd0f23cc7c03c22882dab7e/assets/autorag.gif -------------------------------------------------------------------------------- /assets/results_.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ictnlp/Auto-RAG/b487d235b39f9db5dcd0f23cc7c03c22882dab7e/assets/results_.png -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # Auto-RAG: Autonomous Retrieval-Augmented Generation for Large Language models 2 | 3 | > [Tian Yu](https://tianyu0313.github.io/), [Shaolei Zhang](https://zhangshaolei1998.github.io/), [Yang Feng](https://people.ucas.edu.cn/~yangfeng?language=en)* 4 | 5 | [![arXiv](https://img.shields.io/badge/arXiv-2411.19443-b31b1b.svg?logo=arXiv)](https://arxiv.org/abs/2411.19443) 6 | [![code](https://img.shields.io/badge/Github-Code-keygen.svg?logo=github)](https://github.com/ictnlp/Auto-RAG) 7 | [![model](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging_Face-Model-blue.svg)](https://huggingface.co/ICTNLP/Auto-RAG) 8 | 9 | Source code for paper "[Auto-RAG: Autonomous Retrieval-Augmented Generation for Large Language models](https://arxiv.org/abs/2411.19443)". 10 | 11 | If you find this project useful, feel free to ⭐️ it and give it a [citation](#citation)! 12 | 13 | 14 | ## Overview 15 | 16 | **Auto-RAG** is an autonomous iterative retrieval model centered on the LLM's powerful decision-making capabilities. Auto-RAG models the interaction between the LLM and the retriever through multi-turn dialogue, employs iterative reasoning to determine when and what to retrieve, ceasing the iteration when sufficient external knowledge is available, and subsequently providing the answer to the user. 17 | 18 | - **GUI interaction**: We provide a deployable user interaction interface. After inputting a question, Auto-RAG autonomously engages in interaction with the retriever without any human intervention. Users have the option to decide whether to display the details of the interaction between Auto-RAG and the retriever. 19 | 20 |
21 | img 22 |
23 | 24 | 25 | - To interact with Auto-RAG in your browser, follow the guide for [GUI interaction](#gui-interaction). 26 | 27 | 28 | ## Models Download 29 | 30 | We provide trained Auto-RAG models using the synthetic data. Please refer to https://huggingface.co/ICTNLP/Auto-RAG-Llama-3-8B-Instruct. 31 | 32 | ## Installation 33 | - Environment requirements: Python 3.12, [FlexRAG](https://github.com/ictnlp/flexrag). 34 | 35 | ```bash 36 | conda env create autorag python=3.12 37 | 38 | pip install flexrag==0.2.0 39 | ``` 40 | 41 | - Clone Auto-RAG's repo. 42 | 43 | ```bash 44 | git clone https://github.com/ictnlp/Auto-RAG.git 45 | cd Auto-RAG 46 | ``` 47 | 48 | - Download corpus and prepare the retriever 49 | 50 | We use the wiki corpus provided by [DPR](https://github.com/facebookresearch/DPR) project. You can prepare the dense retriever by runing the following command: 51 | 52 | ```bash 53 | bash scripts/prepare_retriever.sh 54 | ``` 55 | 56 | 57 | ## Model deployment 58 | 59 | We use vLLM to deploy the model for inference. You can update the parameters in deploy.sh to adjust the GPU and model path configuration, then execute: 60 | 61 | ```bash 62 | bash scripts/deploy.sh 63 | ``` 64 | 65 | 66 | ## GUI Interaction 67 | 68 | To interact with Auto-RAG in your browser, run the following command: 69 | 70 | ```bash 71 | bash scripts/run_gui.sh 72 | ``` 73 | 74 | > [!Tip] 75 | > The interaction process between Auto-RAG and the retriever can be optionally displayed by adjusting a toggle. 76 | 77 | ## Run as a FlexRAG Assistant 78 | You can also run Auto-RAG as a FlexRAG assistant. To do this, execute the following command: 79 | 80 | ```bash 81 | ENCODER_PATH='intfloat/e5-base-v2' 82 | MODEL_NAME="" 83 | BASE_URL="http://127.0.0.1:8000/v1" 84 | 85 | 86 | python -m flexrag.entrypoints.run_assistant \ 87 | user_module=Auto-RAG \ 88 | name=nq \ 89 | split=test \ 90 | assistant_type=autorag \ 91 | autorag_config.model_name=$MODEL_NAME \ 92 | autorag_config.base_url=$BASE_URL \ 93 | autorag_config.database_path=wiki \ 94 | autorag_config.index_type=faiss \ 95 | autorag_config.query_encoder_config.encoder_type=hf \ 96 | autorag_config.query_encoder_config.hf_config.model_path=$ENCODER_PATH \ 97 | eval_config.metrics_type=[retrieval_success_rate,generation_f1,generation_em] \ 98 | eval_config.retrieval_success_rate_config.eval_field=text \ 99 | eval_config.response_preprocess.processor_type=[simplify_answer] \ 100 | log_interval=10 101 | ``` 102 | 103 | ## Experimental Results 104 | > [!Note] 105 | > Experimental results show that Auto-RAG outperforms all baselines across six benchmarks. 106 | 107 |
108 | img 109 |
110 |

111 | 112 |

113 | 114 | 115 | ## Licence 116 | This project is licensed under the Apache License, Version 2.0. See [LICENSE](LICENSE) for the full license text. 117 | 118 | ## Citation 119 | 120 | If this repository is useful for you, please cite as: 121 | 122 | ``` 123 | @article{yu2024autorag, 124 | title={Auto-RAG: Autonomous Retrieval-Augmented Generation for Large Language Models}, 125 | author={Tian Yu and Shaolei Zhang and Yang Feng}, 126 | year={2024}, 127 | eprint={2411.19443}, 128 | archivePrefix={arXiv}, 129 | primaryClass={cs.CL}, 130 | url={https://arxiv.org/abs/2411.19443}, 131 | } 132 | ``` 133 | 134 | If you have any questions, feel free to contact `yutian23s@ict.ac.cn`. 135 | -------------------------------------------------------------------------------- /scripts/deploy.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | MODEL_PATH="" 5 | 6 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m vllm.entrypoints.openai.api_server \ 7 | --model $MODEL_PATH \ 8 | --gpu-memory-utilization 0.9 \ 9 | --tensor-parallel 4 \ 10 | --max-model-len 8192 \ 11 | --port 8888 \ 12 | --host 0.0.0.0 13 | -------------------------------------------------------------------------------- /scripts/prepare_retriever.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -euo pipefail 4 | 5 | DEVICE_ID='[0,1,2,3]' 6 | ENCODER_PATH='intfloat/e5-base-v2' 7 | 8 | wget https://dl.fbaipublicfiles.com/dpr/wikipedia_split/psgs_w100.tsv.gz 9 | gunzip psgs_w100.tsv.gz 10 | 11 | python -m flexrag.entrypoints.prepare_index \ 12 | retriever_type=dense \ 13 | file_paths=[psgs_w100.tsv] \ 14 | id_field='id' \ 15 | saving_fields=[title,text] \ 16 | dense_config.database_path=wiki \ 17 | dense_config.encode_fields=[text] \ 18 | dense_config.passage_encoder_config.encoder_type=hf \ 19 | dense_config.passage_encoder_config.hf_config.model_path=$ENCODER_PATH \ 20 | dense_config.passage_encoder_config.hf_config.prompt='query: ' \ 21 | dense_config.passage_encoder_config.hf_config.normalize=True \ 22 | dense_config.passage_encoder_config.hf_config.device_id=$DEVICE_ID \ 23 | dense_config.index_type=faiss \ 24 | dense_config.faiss_config.batch_size=12288 \ 25 | dense_config.faiss_config.log_interval=100000 \ 26 | dense_config.batch_size=1024 \ 27 | dense_config.log_interval=100000 28 | -------------------------------------------------------------------------------- /scripts/run_evaluation.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | ENCODER_PATH='intfloat/e5-base-v2' 4 | MODEL_NAME="" 5 | BASE_URL="http://127.0.0.1:8000/v1" 6 | 7 | 8 | python -m flexrag.entrypoints.run_assistant \ 9 | user_module=Auto-RAG \ 10 | name=nq \ 11 | split=test \ 12 | assistant_type=autorag \ 13 | autorag_config.model_name=$MODEL_NAME \ 14 | autorag_config.base_url=$BASE_URL \ 15 | autorag_config.database_path=wiki \ 16 | autorag_config.index_type=faiss \ 17 | autorag_config.query_encoder_config.encoder_type=hf \ 18 | autorag_config.query_encoder_config.hf_config.model_path=$ENCODER_PATH \ 19 | eval_config.metrics_type=[retrieval_success_rate,generation_f1,generation_em] \ 20 | eval_config.retrieval_success_rate_config.eval_field=text \ 21 | eval_config.response_preprocess.processor_type=[simplify_answer] \ 22 | log_interval=10 -------------------------------------------------------------------------------- /scripts/run_gui.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | MODEL_NAME="" 3 | BASE_URL="" 4 | ENCODER_PATH='intfloat/e5-base-v2' 5 | 6 | 7 | CUDA_VISIBLE_DEVICES=4 python Auto-RAG/gui.py \ 8 | model_name=$MODEL_NAME \ 9 | base_url=$BASE_URL \ 10 | database_path=wiki \ 11 | index_type=faiss \ 12 | query_encoder_config.encoder_type=hf \ 13 | query_encoder_config.hf_config.model_path=$ENCODER_PATH 14 | --------------------------------------------------------------------------------