├── .env.example ├── .gitignore ├── README.md ├── compress.py ├── report.py ├── requirements.txt ├── run.py ├── subquery.py ├── tavily.py └── workflow.py /.env.example: -------------------------------------------------------------------------------- 1 | OPENAI_API_KEY="" 2 | TAVILY_API_KEY="" -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | ### Python 2 | # Byte-compiled / optimized / DLL files 3 | __pycache__/ 4 | *.py[cod] 5 | *$py.class 6 | 7 | # C extensions 8 | *.so 9 | 10 | # Distribution / packaging 11 | .Python 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | cover/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | db.sqlite3 63 | db.sqlite3-journal 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | 75 | # PyBuilder 76 | .pybuilder/ 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | # For a library or package, you might want to ignore these files since the code is 88 | # intended to run in multiple environments; otherwise, check them in: 89 | # .python-version 90 | 91 | # pipenv 92 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 93 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 94 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 95 | # install all needed dependencies. 96 | #Pipfile.lock 97 | 98 | # poetry 99 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 100 | # This is especially recommended for binary packages to ensure reproducibility, and is more 101 | # commonly ignored for libraries. 102 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 103 | #poetry.lock 104 | 105 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 106 | __pypackages__/ 107 | 108 | # Celery stuff 109 | celerybeat-schedule 110 | celerybeat.pid 111 | 112 | # SageMath parsed files 113 | *.sage.py 114 | 115 | # Environments 116 | .env 117 | .venv 118 | env/ 119 | venv/ 120 | ENV/ 121 | env.bak/ 122 | venv.bak/ 123 | 124 | # Spyder project settings 125 | .spyderproject 126 | .spyproject 127 | 128 | # Rope project settings 129 | .ropeproject 130 | 131 | # mkdocs documentation 132 | /site 133 | 134 | # mypy 135 | .mypy_cache/ 136 | .dmypy.json 137 | dmypy.json 138 | 139 | # Pyre type checker 140 | .pyre/ 141 | 142 | # pytype static type analyzer 143 | .pytype/ 144 | 145 | # Cython debug symbols 146 | cython_debug/ 147 | 148 | # PyCharm 149 | # JetBrains specific template is maintainted in a separate JetBrains.gitignore that can 150 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 151 | # and can be added to the global gitignore or merged into this file. For a more nuclear 152 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 153 | #.idea/ 154 | 155 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Llama Researcher 2 | 3 | In this tutorial, we'll create LLama-Researcher using LlamaIndex workflows, inspired by [GPT-Researcher.](https://github.com/assafelovic/gpt-researcher) 4 | 5 | Stack Used: 6 | 7 | - LlamaIndex workflows for orchestration 8 | - Tavily API as the search engine api 9 | - Other LlamaIndex abstractions like VectorStoreIndex, PostProcessors etc. 10 | 11 | Full tutorial 👇 12 | 13 | [![Llama-Researcher](https://img.youtube.com/vi/gHdQcoeNgMU/maxresdefault.jpg)](https://www.youtube.com/watch?v=gHdQcoeNgMU) 14 | 15 | ## How to use 16 | 17 | - Clone the repo 18 | 19 | ```bash 20 | git clone https://github.com/rsrohan99/Llama-Researcher.git 21 | cd Llama-Researcher 22 | ``` 23 | 24 | - Install dependencies 25 | 26 | ```bash 27 | pip install -r requirements.txt 28 | ``` 29 | 30 | - Create `.env` file and add `OPENAI_API_KEY` and `TAVILY_API_KEY` 31 | 32 | ```bash 33 | cp .env.example .env 34 | ``` 35 | 36 | - Run the workflow with the topic to research 37 | 38 | ```bash 39 | python run.py "topic to research" 40 | ``` 41 | -------------------------------------------------------------------------------- /compress.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from llama_index.core.schema import Document 4 | from llama_index.core.embeddings import BaseEmbedding 5 | from llama_index.core.text_splitter import SentenceSplitter 6 | from llama_index.core import VectorStoreIndex 7 | from llama_index.core.postprocessor import SimilarityPostprocessor 8 | 9 | 10 | async def get_compressed_context( 11 | query: str, docs: List[Document], embed_model: BaseEmbedding 12 | ) -> str: 13 | index = VectorStoreIndex.from_documents( 14 | docs, 15 | embed_model=embed_model, 16 | transformations=[SentenceSplitter()], 17 | ) 18 | 19 | retriever = index.as_retriever(similarity_top_k=100) 20 | 21 | nodes = retriever.retrieve(query) 22 | 23 | processor = SimilarityPostprocessor(similarity_cutoff=0.38) 24 | filtered_nodes = processor.postprocess_nodes(nodes) 25 | # print(filtered_nodes) 26 | print( 27 | f"\n> Filtered {len(filtered_nodes)} nodes from {len(nodes)} nodes for subquery: {query}\n" 28 | ) 29 | 30 | context = "" 31 | 32 | for node_with_score in filtered_nodes: 33 | node = node_with_score.node 34 | node_info = ( 35 | f"---\nSource: {node.metadata.get('source', 'Unknown')}\n" 36 | f"Title: {node.metadata.get('title', '')}\n" 37 | f"Content: {node.text}\n---\n" 38 | ) 39 | context += node_info + "\n" 40 | 41 | return context 42 | -------------------------------------------------------------------------------- /report.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime, timezone 2 | 3 | from llama_index.core.llms.llm import LLM 4 | from llama_index.core.prompts.base import PromptTemplate 5 | 6 | 7 | async def generate_report_from_context(query: str, context: str, llm: LLM) -> str: 8 | prompt = PromptTemplate( 9 | """Information: 10 | -------------------------------- 11 | {context} 12 | -------------------------------- 13 | Using the above information, answer the following query or task: "{question}" in a detailed report -- 14 | The report should focus on the answer to the query, should be well structured, informative, 15 | in-depth, and comprehensive, with facts and numbers if available and at least {total_words} words. 16 | You should strive to write the report as long as you can using all relevant and necessary information provided. 17 | 18 | Please follow all of the following guidelines in your report: 19 | - You MUST determine your own concrete and valid opinion based on the given information. Do NOT defer to general and meaningless conclusions. 20 | - You MUST write the report with markdown syntax and {report_format} format. 21 | - You MUST prioritize the relevance, reliability, and significance of the sources you use. Choose trusted sources over less reliable ones. 22 | - You must also prioritize new articles over older articles if the source can be trusted. 23 | - Use in-text citation references in {report_format} format and make it with markdown hyperlink placed at the end of the sentence or paragraph that references them like this: ([in-text citation](url)). 24 | - Don't forget to add a reference list at the end of the report in {report_format} format and full url links without hyperlinks. 25 | - You MUST write all used source urls at the end of the report as references, and make sure to not add duplicated sources, but only one reference for each. 26 | Every url should be hyperlinked: [url website](url) 27 | Additionally, you MUST include hyperlinks to the relevant URLs wherever they are referenced in the report: 28 | 29 | eg: Author, A. A. (Year, Month Date). Title of web page. Website Name. [url website](url) 30 | 31 | Please do your best, this is very important to my career. 32 | Assume that the current date is {date_today}. 33 | """ 34 | ) 35 | response = await llm.apredict( 36 | prompt, 37 | context=context, 38 | question=query, 39 | total_words=1000, 40 | report_format="APA", 41 | date_today=datetime.now(timezone.utc).strftime("%B %d, %Y"), 42 | ) 43 | 44 | print("\n> Done generating report\n") 45 | 46 | return response 47 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | llama-index-core==0.11.11 2 | llama-index-embeddings-openai==0.2.5 3 | llama-index-llms-openai==0.2.9 4 | llama-index-utils-workflow==0.2.1 5 | llama-parse==0.5.6 6 | markdown_pdf==1.3 7 | python-dotenv==1.0.1 -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import asyncio 3 | import subprocess 4 | 5 | from dotenv import load_dotenv 6 | 7 | from llama_index.utils.workflow import draw_all_possible_flows 8 | from llama_index.llms.openai import OpenAI 9 | from llama_index.embeddings.openai import OpenAIEmbedding 10 | 11 | from workflow import ResearchAssistantWorkflow 12 | 13 | 14 | async def main(): 15 | load_dotenv() 16 | llm = OpenAI(model="gpt-4o-mini") 17 | embed_model = OpenAIEmbedding(model="text-embedding-3-small") 18 | workflow = ResearchAssistantWorkflow( 19 | llm=llm, embed_model=embed_model, verbose=True, timeout=240.0 20 | ) 21 | # draw_all_possible_flows(workflow, filename="research_assistant_workflow.html") 22 | topic = sys.argv[1] 23 | report_file = await workflow.run(query=topic) 24 | subprocess.run(["open", report_file]) 25 | 26 | 27 | if __name__ == "__main__": 28 | asyncio.run(main()) 29 | -------------------------------------------------------------------------------- /subquery.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime, timezone 2 | 3 | from llama_index.core.llms.llm import LLM 4 | from llama_index.core.prompts.base import PromptTemplate 5 | 6 | SUB_QUERY_PROMPT = PromptTemplate( 7 | 'Write {max_iterations} google search queries to search online that form an objective opinion from the following task: "{task}"\n' 8 | f"Assume the current date is {datetime.now(timezone.utc).strftime('%B %d, %Y')} if required.\n" 9 | f"You must respond with the search queries separated by comma in the following format: query 1, query 2, query 3\n" 10 | "{max_iterations} google search queries for {task} (separated by comma): " 11 | ) 12 | 13 | 14 | async def get_sub_queries( 15 | query: str, 16 | llm: LLM, 17 | num_sub_queries: int = 3, 18 | ): 19 | """ 20 | Gets the sub queries 21 | Args: 22 | query: original query 23 | llm: LLM to generate sub queries 24 | Returns: 25 | sub_queries: List of sub queries 26 | 27 | """ 28 | response = await llm.apredict( 29 | SUB_QUERY_PROMPT, 30 | task=query, 31 | max_iterations=num_sub_queries, 32 | ) 33 | sub_queries = list( 34 | map(lambda x: x.strip().strip('"').strip("'"), response.split(",")) 35 | ) 36 | 37 | return sub_queries 38 | -------------------------------------------------------------------------------- /tavily.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import List 3 | import requests 4 | from dotenv import load_dotenv 5 | 6 | from llama_index.core.schema import Document 7 | 8 | 9 | async def get_docs_from_tavily_search(sub_query: str, visited_urls: set[str]): 10 | load_dotenv() 11 | api_key = os.getenv("TAVILY_API_KEY") 12 | base_url = "https://api.tavily.com/search" 13 | headers = { 14 | "Content-Type": "application/json", 15 | } 16 | data = { 17 | "query": sub_query, 18 | "api_key": api_key, 19 | "include_raw_content": True, 20 | } 21 | 22 | docs = [] 23 | print(f"\n> Searching Tavily for sub query: {sub_query}\n") 24 | response = requests.post(base_url, headers=headers, json=data) 25 | if response.status_code == 200: 26 | search_results = response.json().get("results", []) 27 | for search_result in search_results: 28 | url = search_result.get("url") 29 | if not search_result.get("raw_content"): 30 | continue 31 | if url not in visited_urls: 32 | visited_urls.add(url) 33 | docs.append( 34 | Document( 35 | text=search_result.get("raw_content"), 36 | metadata={ 37 | "source": url, 38 | "title": search_result.get("title"), 39 | }, 40 | ) 41 | ) 42 | print(f"\n> Found {len(docs)} docs from Tavily search on {sub_query}\n") 43 | return docs, visited_urls 44 | else: 45 | response.raise_for_status() 46 | -------------------------------------------------------------------------------- /workflow.py: -------------------------------------------------------------------------------- 1 | from typing import List, Any 2 | 3 | from llama_index.core.schema import Document 4 | from llama_index.core.embeddings import BaseEmbedding 5 | from llama_index.core.llms.llm import LLM 6 | from llama_index.core.workflow import ( 7 | step, 8 | Context, 9 | Workflow, 10 | Event, 11 | StartEvent, 12 | StopEvent, 13 | ) 14 | from markdown_pdf import MarkdownPdf, Section 15 | 16 | from subquery import get_sub_queries 17 | from tavily import get_docs_from_tavily_search 18 | from compress import get_compressed_context 19 | from report import generate_report_from_context 20 | 21 | 22 | class SubQueriesCreatedEvent(Event): 23 | sub_queries: List[str] 24 | 25 | 26 | class ToProcessSubQueryEvent(Event): 27 | sub_query: str 28 | 29 | 30 | class DocsScrapedEvent(Event): 31 | sub_query: str 32 | docs: List[Document] 33 | 34 | 35 | class ToCombineContextEvent(Event): 36 | sub_query: str 37 | context: str 38 | 39 | 40 | class ReportPromptCreatedEvent(Event): 41 | context: str 42 | 43 | 44 | class LLMResponseEvent(Event): 45 | response: str 46 | 47 | 48 | class ResearchAssistantWorkflow(Workflow): 49 | def __init__( 50 | self, 51 | *args: Any, 52 | llm: LLM, 53 | embed_model: BaseEmbedding, 54 | **kwargs: Any, 55 | ) -> None: 56 | super().__init__(*args, **kwargs) 57 | self.llm = llm 58 | self.embed_model = embed_model 59 | self.visited_urls: set[str] = set() 60 | 61 | @step 62 | async def create_sub_queries( 63 | self, ctx: Context, ev: StartEvent 64 | ) -> SubQueriesCreatedEvent: 65 | query = ev.query 66 | await ctx.set("query", query) 67 | sub_queries = await get_sub_queries(query, self.llm) 68 | await ctx.set("num_sub_queries", len(sub_queries)) 69 | return SubQueriesCreatedEvent(sub_queries=sub_queries) 70 | 71 | @step 72 | async def deligate_sub_queries( 73 | self, ctx: Context, ev: SubQueriesCreatedEvent 74 | ) -> ToProcessSubQueryEvent: 75 | for sub_query in ev.sub_queries: 76 | ctx.send_event(ToProcessSubQueryEvent(sub_query=sub_query)) 77 | return None 78 | 79 | @step 80 | async def get_docs_for_subquery( 81 | self, ev: ToProcessSubQueryEvent 82 | ) -> DocsScrapedEvent: 83 | sub_query = ev.sub_query 84 | docs, visited_urls = await get_docs_from_tavily_search( 85 | sub_query, self.visited_urls 86 | ) 87 | self.visited_urls = visited_urls 88 | return DocsScrapedEvent(sub_query=sub_query, docs=docs) 89 | 90 | @step(num_workers=3) 91 | async def compress_docs(self, ev: DocsScrapedEvent) -> ToCombineContextEvent: 92 | sub_query = ev.sub_query 93 | docs = ev.docs 94 | print(f"\n> Compressing docs for sub query: {sub_query}\n") 95 | compressed_context = await get_compressed_context( 96 | sub_query, docs, self.embed_model 97 | ) 98 | return ToCombineContextEvent(sub_query=sub_query, context=compressed_context) 99 | 100 | @step 101 | async def combine_contexts( 102 | self, ctx: Context, ev: ToCombineContextEvent 103 | ) -> ReportPromptCreatedEvent: 104 | events = ctx.collect_events( 105 | ev, [ToCombineContextEvent] * await ctx.get("num_sub_queries") 106 | ) 107 | if events is None: 108 | return None 109 | 110 | context = "" 111 | 112 | for event in events: 113 | context += ( 114 | f'Research findings for topic "{event.sub_query}":\n{event.context}\n\n' 115 | ) 116 | 117 | return ReportPromptCreatedEvent(context=context) 118 | 119 | @step 120 | async def write_report( 121 | self, ctx: Context, ev: ReportPromptCreatedEvent 122 | ) -> StopEvent: 123 | context = ev.context 124 | query = await ctx.get("query") 125 | print(f"\n> Writing report. This will take a few minutes...\n") 126 | report = await generate_report_from_context(query, context, self.llm) 127 | pdf = MarkdownPdf() 128 | pdf.add_section(Section(report, toc=False)) 129 | pdf.save("report.pdf") 130 | print("\n> Done writing report to report.pdf! Trying to open the file...\n") 131 | return StopEvent(result="report.pdf") 132 | --------------------------------------------------------------------------------