├── lightrag ├── kg │ ├── __init__.py │ └── neo4j_impl.py ├── __init__.py ├── base.py ├── storage.py ├── utils.py ├── lightrag.py ├── prompt.py └── llm.py ├── dickens └── inbox │ └── 2410.05779v2-LightRAG.pdf ├── requirements.txt ├── .gitignore ├── .pre-commit-config.yaml ├── examples ├── graph_visual_with_html.py ├── lightrag_openai_demo.py ├── lightrag_bedrock_demo.py ├── lightrag_ollama_demo.py ├── lightrag_hf_demo.py ├── generate_query.py ├── lightrag_siliconcloud_demo.py ├── lightrag_lmdeploy_demo.py ├── vram_management_demo.py ├── lightrag_openai_compatible_demo.py ├── batch_eval.py ├── graph_visual_with_neo4j.py ├── lightrag_azure_openai_demo.py └── lightrag_api_openai_compatible_demo.py ├── .github └── workflows │ └── linting.yaml ├── pyproject.toml ├── reproduce ├── Step_1.py ├── Step_1_openai_compatible.py ├── Step_3.py ├── Step_2.py ├── Step_0.py └── Step_3_openai_compatible.py ├── LICENSE ├── change-journal.md ├── test.py ├── test_neo4j.py ├── notebooks ├── design-docs │ ├── tagrag-interface-spec.md │ └── Streamlit Chat element └── taipy-lightrag-app.py ├── get_all_edges_nx.py ├── Dockerfile └── setup.py /lightrag/kg/__init__.py: -------------------------------------------------------------------------------- 1 | # print ("init package vars here. ......") 2 | 3 | 4 | -------------------------------------------------------------------------------- /dickens/inbox/2410.05779v2-LightRAG.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aiproductguy/LightRAG-gui/HEAD/dickens/inbox/2410.05779v2-LightRAG.pdf -------------------------------------------------------------------------------- /lightrag/__init__.py: -------------------------------------------------------------------------------- 1 | from .lightrag import LightRAG as LightRAG, QueryParam as QueryParam 2 | 3 | __version__ = "0.0.8" 4 | __author__ = "Zirui Guo" 5 | __url__ = "https://github.com/HKUDS/LightRAG" 6 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate 2 | aioboto3 3 | aiohttp 4 | graspologic 5 | hnswlib 6 | nano-vectordb 7 | neo4j 8 | networkx 9 | ollama 10 | openai 11 | pyvis 12 | tenacity 13 | tiktoken 14 | torch 15 | transformers 16 | xxhash 17 | # lmdeploy[all] 18 | PyPDF2 19 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | *.egg-info 3 | dickens/ 4 | !dickens/inbox/book.txt 5 | !dickens/inbox/2410.05779v2-LightRAG.pdf 6 | lib/ 7 | book.txt 8 | lightrag-dev/ 9 | .idea/ 10 | dist/ 11 | env/ 12 | local_neo4jWorkDir/ 13 | neo4jWorkDir/ 14 | ignore_this.txt 15 | .venv/ 16 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v5.0.0 4 | hooks: 5 | - id: trailing-whitespace 6 | - id: end-of-file-fixer 7 | - id: requirements-txt-fixer 8 | 9 | 10 | - repo: https://github.com/astral-sh/ruff-pre-commit 11 | rev: v0.6.4 12 | hooks: 13 | - id: ruff-format 14 | - id: ruff 15 | args: [--fix] 16 | 17 | 18 | - repo: https://github.com/mgedmin/check-manifest 19 | rev: "0.49" 20 | hooks: 21 | - id: check-manifest 22 | stages: [manual] 23 | -------------------------------------------------------------------------------- /examples/graph_visual_with_html.py: -------------------------------------------------------------------------------- 1 | import networkx as nx 2 | from pyvis.network import Network 3 | import random 4 | 5 | # Load the GraphML file 6 | G = nx.read_graphml("./dickens/graph_chunk_entity_relation.graphml") 7 | 8 | # Create a Pyvis network 9 | net = Network(height="100vh", notebook=True) 10 | 11 | # Convert NetworkX graph to Pyvis network 12 | net.from_nx(G) 13 | 14 | # Add colors to nodes 15 | for node in net.nodes: 16 | node["color"] = "#{:06x}".format(random.randint(0, 0xFFFFFF)) 17 | 18 | # Save and display the network 19 | net.show("knowledge_graph.html") 20 | -------------------------------------------------------------------------------- /.github/workflows/linting.yaml: -------------------------------------------------------------------------------- 1 | name: Linting and Formatting 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | pull_request: 8 | branches: 9 | - main 10 | 11 | jobs: 12 | lint-and-format: 13 | runs-on: ubuntu-latest 14 | 15 | steps: 16 | - name: Checkout code 17 | uses: actions/checkout@v2 18 | 19 | - name: Set up Python 20 | uses: actions/setup-python@v2 21 | with: 22 | python-version: '3.x' 23 | 24 | - name: Install dependencies 25 | run: | 26 | python -m pip install --upgrade pip 27 | pip install pre-commit 28 | 29 | - name: Run pre-commit 30 | run: pre-commit run --all-files 31 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "lightrag-app" 3 | version = "0.0.8" 4 | description = "A lightweight graph retrieval augmented generation framework" 5 | authors = ["The AI Product Guy "] 6 | license = "MIT" 7 | readme = "README.md" 8 | 9 | [tool.poetry.dependencies] 10 | python = ">=3.9,!=3.9.7,<3.12" 11 | streamlit = "^1.40.0" 12 | accelerate = "^0.25.0" 13 | aioboto3 = "^12.1.0" 14 | aiohttp = "^3.9.1" 15 | graspologic = "^3.3.0" 16 | hnswlib = "^0.8.0" 17 | nano-vectordb = "^0.0.2" 18 | neo4j = "^5.15.0" 19 | networkx = "^3.2.1" 20 | ollama = "^0.1.6" 21 | openai = "^1.3.7" 22 | pyvis = "^0.3.2" 23 | tenacity = "^8.2.3" 24 | tiktoken = "^0.8.0" 25 | torch = "^2.1.1" 26 | transformers = "^4.35.2" 27 | xxhash = "^3.4.1" 28 | pypdf2 = "^3.0.1" 29 | 30 | [tool.poetry.group.dev.dependencies] 31 | ipykernel = "^6.29.5" 32 | 33 | [build-system] 34 | requires = ["poetry-core"] 35 | build-backend = "poetry.core.masonry.api" 36 | -------------------------------------------------------------------------------- /reproduce/Step_1.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import time 4 | 5 | from lightrag import LightRAG 6 | 7 | 8 | def insert_text(rag, file_path): 9 | with open(file_path, mode="r") as f: 10 | unique_contexts = json.load(f) 11 | 12 | retries = 0 13 | max_retries = 3 14 | while retries < max_retries: 15 | try: 16 | rag.insert(unique_contexts) 17 | break 18 | except Exception as e: 19 | retries += 1 20 | print(f"Insertion failed, retrying ({retries}/{max_retries}), error: {e}") 21 | time.sleep(10) 22 | if retries == max_retries: 23 | print("Insertion failed after exceeding the maximum number of retries") 24 | 25 | 26 | cls = "agriculture" 27 | WORKING_DIR = "../{cls}" 28 | 29 | if not os.path.exists(WORKING_DIR): 30 | os.mkdir(WORKING_DIR) 31 | 32 | rag = LightRAG(working_dir=WORKING_DIR) 33 | 34 | insert_text(rag, f"../datasets/unique_contexts/{cls}_unique_contexts.json") 35 | -------------------------------------------------------------------------------- /examples/lightrag_openai_demo.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from lightrag import LightRAG, QueryParam 4 | from lightrag.llm import gpt_4o_mini_complete 5 | 6 | WORKING_DIR = "./dickens" 7 | 8 | if not os.path.exists(WORKING_DIR): 9 | os.mkdir(WORKING_DIR) 10 | 11 | rag = LightRAG( 12 | working_dir=WORKING_DIR, 13 | llm_model_func=gpt_4o_mini_complete, 14 | # llm_model_func=gpt_4o_complete 15 | ) 16 | 17 | 18 | with open("./dickens/inbox/book.txt", "r", encoding="utf-8") as f: 19 | rag.insert(f.read()) 20 | 21 | # Perform naive search 22 | print( 23 | rag.query("What are the top themes in this story?", param=QueryParam(mode="naive")) 24 | ) 25 | 26 | # Perform local search 27 | print( 28 | rag.query("What are the top themes in this story?", param=QueryParam(mode="local")) 29 | ) 30 | 31 | # Perform global search 32 | print( 33 | rag.query("What are the top themes in this story?", param=QueryParam(mode="global")) 34 | ) 35 | 36 | # Perform hybrid search 37 | print( 38 | rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid")) 39 | ) 40 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Gustavo Ye 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /examples/lightrag_bedrock_demo.py: -------------------------------------------------------------------------------- 1 | """ 2 | LightRAG meets Amazon Bedrock ⛰️ 3 | """ 4 | 5 | import os 6 | import logging 7 | 8 | from lightrag import LightRAG, QueryParam 9 | from lightrag.llm import bedrock_complete, bedrock_embedding 10 | from lightrag.utils import EmbeddingFunc 11 | 12 | logging.getLogger("aiobotocore").setLevel(logging.WARNING) 13 | 14 | WORKING_DIR = "./dickens" 15 | if not os.path.exists(WORKING_DIR): 16 | os.mkdir(WORKING_DIR) 17 | 18 | rag = LightRAG( 19 | working_dir=WORKING_DIR, 20 | llm_model_func=bedrock_complete, 21 | llm_model_name="Anthropic Claude 3 Haiku // Amazon Bedrock", 22 | embedding_func=EmbeddingFunc( 23 | embedding_dim=1024, max_token_size=8192, func=bedrock_embedding 24 | ), 25 | ) 26 | 27 | with open("./dickens/inbox/book.txt", "r", encoding="utf-8") as f: 28 | rag.insert(f.read()) 29 | 30 | for mode in ["naive", "local", "global", "hybrid"]: 31 | print("\n+-" + "-" * len(mode) + "-+") 32 | print(f"| {mode.capitalize()} |") 33 | print("+-" + "-" * len(mode) + "-+\n") 34 | print( 35 | rag.query("What are the top themes in this story?", param=QueryParam(mode=mode)) 36 | ) 37 | -------------------------------------------------------------------------------- /change-journal.md: -------------------------------------------------------------------------------- 1 | # Change Journal 2 | 3 | ## 2024-11-08 4 | - [x] clean up testing files and update README 5 | - move book.txt to dickens/inbox/ 6 | - [x] add Poetry pyproject.toml file for python 3.9-3.11 7 | - [x] add UI examples in ./ui-examples/ 8 | - add a simple chat UI for Import markdown file or URL and query the graph 9 | - @Streamlit Chat element - use this chat element to make @streamlit-import-query-lightrag.py into a detailed chat interface 10 | - [!] Fix streamlit ui glitchiness 11 | - [x] Add timestamps to chat ui 12 | - [x] add graph visualization 13 | - [!] blog about it https://blog.streamlit.io/how-to-build-a-llama-2-chatbot/ 14 | 15 | ## 2024-11-08 16 | - [?] secure api key handling 17 | - [ ] markdown messages in chat ui 18 | - [ ] trilingual PR release 19 | - [ ] hat feature, advanced settings 20 | - [?] apply the tagrag-interface-spec.md design spec 21 | 22 | 23 | ## 2024-11-10 24 | - [ ] Make a trilingual chatbot that accepts structured outputs 25 | 26 | - [ ] Implement GROQ-LLAMA3-8B-8192 as default AI model 27 | 28 | ## 2024-12-?? 29 | - [ ] make a better version of @streamlit-import-query-lightrag.py using taipy rather than streamlit as @taipy-import-query-lightrag.py 30 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os 2 | from lightrag import LightRAG, QueryParam 3 | from lightrag.llm import gpt_4o_mini_complete, gpt_4o_complete 4 | ######### 5 | # Uncomment the below two lines if running in a jupyter notebook to handle the async nature of rag.insert() 6 | # import nest_asyncio 7 | # nest_asyncio.apply() 8 | ######### 9 | 10 | WORKING_DIR = "./dickens" 11 | 12 | if not os.path.exists(WORKING_DIR): 13 | os.mkdir(WORKING_DIR) 14 | 15 | rag = LightRAG( 16 | working_dir=WORKING_DIR, 17 | llm_model_func=gpt_4o_mini_complete, # Use gpt_4o_mini_complete LLM model 18 | # llm_model_func=gpt_4o_complete # Optionally, use a stronger model 19 | ) 20 | 21 | with open("./dickens/inbox/book.txt") as f: 22 | rag.insert(f.read()) 23 | 24 | # Perform naive search 25 | print( 26 | rag.query("What are the top themes in this story?", param=QueryParam(mode="naive")) 27 | ) 28 | 29 | # Perform local search 30 | print( 31 | rag.query("What are the top themes in this story?", param=QueryParam(mode="local")) 32 | ) 33 | 34 | # Perform global search 35 | print( 36 | rag.query("What are the top themes in this story?", param=QueryParam(mode="global")) 37 | ) 38 | 39 | # Perform hybrid search 40 | print( 41 | rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid")) 42 | ) 43 | -------------------------------------------------------------------------------- /test_neo4j.py: -------------------------------------------------------------------------------- 1 | import os 2 | from lightrag import LightRAG, QueryParam 3 | from lightrag.llm import gpt_4o_mini_complete, gpt_4o_complete 4 | 5 | 6 | ######### 7 | # Uncomment the below two lines if running in a jupyter notebook to handle the async nature of rag.insert() 8 | # import nest_asyncio 9 | # nest_asyncio.apply() 10 | ######### 11 | 12 | WORKING_DIR = "./local_neo4jWorkDir" 13 | 14 | if not os.path.exists(WORKING_DIR): 15 | os.mkdir(WORKING_DIR) 16 | 17 | rag = LightRAG( 18 | working_dir=WORKING_DIR, 19 | llm_model_func=gpt_4o_mini_complete, # Use gpt_4o_mini_complete LLM model 20 | kg="Neo4JStorage", 21 | log_level="INFO", 22 | # llm_model_func=gpt_4o_complete # Optionally, use a stronger model 23 | ) 24 | 25 | with open("./dickens/inbox/book.txt") as f: 26 | rag.insert(f.read()) 27 | 28 | # Perform naive search 29 | print( 30 | rag.query("What are the top themes in this story?", param=QueryParam(mode="naive")) 31 | ) 32 | 33 | # Perform local search 34 | print( 35 | rag.query("What are the top themes in this story?", param=QueryParam(mode="local")) 36 | ) 37 | 38 | # Perform global search 39 | print( 40 | rag.query("What are the top themes in this story?", param=QueryParam(mode="global")) 41 | ) 42 | 43 | # Perform hybrid search 44 | print( 45 | rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid")) 46 | ) 47 | -------------------------------------------------------------------------------- /notebooks/design-docs/tagrag-interface-spec.md: -------------------------------------------------------------------------------- 1 | # Tagrag Interface Spec 2 | 3 | ## Markdown format 4 | --- 5 | user-prompt: `INPUT[{user-prompt}]` #{YYYY-MM-DD} 6 | 7 | ## ${user-prompt-summary} 8 | > [!inquiry]- [[#light-global/gpt-4o/eng/1]] #ds/${shahash12} 9 | > user-prompt: ${user-prompt} 10 | > rag: ${rag} // light-naive, light-local, light-global, light-hybrid 11 | > ai-model: gpt-4o 12 | > languages: eng 13 | > embedder: ollama-nomic-embed-text 14 | > template: [[Obsidian Researcher Note]] 15 | > user-prompt-summary: ${user-prompt-summary} 16 | > user-prompt-rewritten: ${user-prompt-rewritten} 17 | > prompt-hash: ${shahash12} 18 | 19 | > [!sources]- #toggle {sources-count} sources 20 | > ### Websearch (top 3 web sources) 21 | > [^1] {source-1} 22 | > [^2] {source-2} 23 | > [^3] {source-3} 24 | > ### Doc Graph (top 3 docs) [^4] 25 | > - {list-of-docs-used} 26 | 27 | > [!answer]+ Answered by {rag}@{ai-model} 🗓️{date} 28 | > subprompt-steps: ${subprompt-steps} 29 | > {answer} 30 | 31 | > [!footer]- {edit} | {rewrite} | {copy} | {import} 32 | 33 | ## ${user-prompt-summary} edit-2 34 | > [!inquiry]- [[#light-global/gpt-4o/eng/1]] #ds/${shahash12} 35 | > ... 36 | 37 | > [!sources]- #toggle {sources-count} sources 38 | > ... 39 | 40 | > [!answer]+ Answered by {rag}@{ai-model} 🗓️{date} 41 | > subprompt-steps: ${subprompt-steps} 42 | > {answer} 43 | 44 | > [!footer]- {edit} | {rewrite} | {copy} | {import} 45 | 46 | ... -------------------------------------------------------------------------------- /get_all_edges_nx.py: -------------------------------------------------------------------------------- 1 | import networkx as nx 2 | 3 | G = nx.read_graphml("./dickensTestEmbedcall/graph_chunk_entity_relation.graphml") 4 | 5 | 6 | def get_all_edges_and_nodes(G): 7 | # Get all edges and their properties 8 | edges_with_properties = [] 9 | for u, v, data in G.edges(data=True): 10 | edges_with_properties.append( 11 | { 12 | "start": u, 13 | "end": v, 14 | "label": data.get( 15 | "label", "" 16 | ), # Assuming 'label' is used for edge type 17 | "properties": data, 18 | "start_node_properties": G.nodes[u], 19 | "end_node_properties": G.nodes[v], 20 | } 21 | ) 22 | 23 | return edges_with_properties 24 | 25 | 26 | # Example usage 27 | if __name__ == "__main__": 28 | # Assume G is your NetworkX graph loaded from Neo4j 29 | 30 | all_edges = get_all_edges_and_nodes(G) 31 | 32 | # Print all edges and node properties 33 | for edge in all_edges: 34 | print(f"Edge Label: {edge['label']}") 35 | print(f"Edge Properties: {edge['properties']}") 36 | print(f"Start Node: {edge['start']}") 37 | print(f"Start Node Properties: {edge['start_node_properties']}") 38 | print(f"End Node: {edge['end']}") 39 | print(f"End Node Properties: {edge['end_node_properties']}") 40 | print("---") 41 | -------------------------------------------------------------------------------- /examples/lightrag_ollama_demo.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | from lightrag import LightRAG, QueryParam 4 | from lightrag.llm import ollama_model_complete, ollama_embedding 5 | from lightrag.utils import EmbeddingFunc 6 | 7 | WORKING_DIR = "./dickens" 8 | 9 | logging.basicConfig(format="%(levelname)s:%(message)s", level=logging.INFO) 10 | 11 | if not os.path.exists(WORKING_DIR): 12 | os.mkdir(WORKING_DIR) 13 | 14 | rag = LightRAG( 15 | working_dir=WORKING_DIR, 16 | llm_model_func=ollama_model_complete, 17 | llm_model_name="gemma2:2b", 18 | llm_model_max_async=4, 19 | llm_model_max_token_size=32768, 20 | llm_model_kwargs={"host": "http://localhost:11434", "options": {"num_ctx": 32768}}, 21 | embedding_func=EmbeddingFunc( 22 | embedding_dim=768, 23 | max_token_size=8192, 24 | func=lambda texts: ollama_embedding( 25 | texts, embed_model="nomic-embed-text", host="http://localhost:11434" 26 | ), 27 | ), 28 | ) 29 | 30 | with open("./dickens/inbox/book.txt", "r", encoding="utf-8") as f: 31 | rag.insert(f.read()) 32 | 33 | # Perform naive search 34 | print( 35 | rag.query("What are the top themes in this story?", param=QueryParam(mode="naive")) 36 | ) 37 | 38 | # Perform local search 39 | print( 40 | rag.query("What are the top themes in this story?", param=QueryParam(mode="local")) 41 | ) 42 | 43 | # Perform global search 44 | print( 45 | rag.query("What are the top themes in this story?", param=QueryParam(mode="global")) 46 | ) 47 | 48 | # Perform hybrid search 49 | print( 50 | rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid")) 51 | ) 52 | -------------------------------------------------------------------------------- /examples/lightrag_hf_demo.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from lightrag import LightRAG, QueryParam 4 | from lightrag.llm import hf_model_complete, hf_embedding 5 | from lightrag.utils import EmbeddingFunc 6 | from transformers import AutoModel, AutoTokenizer 7 | 8 | WORKING_DIR = "./dickens" 9 | 10 | if not os.path.exists(WORKING_DIR): 11 | os.mkdir(WORKING_DIR) 12 | 13 | rag = LightRAG( 14 | working_dir=WORKING_DIR, 15 | llm_model_func=hf_model_complete, 16 | llm_model_name="meta-llama/Llama-3.1-8B-Instruct", 17 | embedding_func=EmbeddingFunc( 18 | embedding_dim=384, 19 | max_token_size=5000, 20 | func=lambda texts: hf_embedding( 21 | texts, 22 | tokenizer=AutoTokenizer.from_pretrained( 23 | "sentence-transformers/all-MiniLM-L6-v2" 24 | ), 25 | embed_model=AutoModel.from_pretrained( 26 | "sentence-transformers/all-MiniLM-L6-v2" 27 | ), 28 | ), 29 | ), 30 | ) 31 | 32 | 33 | with open("./dickens/inbox/book.txt", "r", encoding="utf-8") as f: 34 | rag.insert(f.read()) 35 | 36 | # Perform naive search 37 | print( 38 | rag.query("What are the top themes in this story?", param=QueryParam(mode="naive")) 39 | ) 40 | 41 | # Perform local search 42 | print( 43 | rag.query("What are the top themes in this story?", param=QueryParam(mode="local")) 44 | ) 45 | 46 | # Perform global search 47 | print( 48 | rag.query("What are the top themes in this story?", param=QueryParam(mode="global")) 49 | ) 50 | 51 | # Perform hybrid search 52 | print( 53 | rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid")) 54 | ) 55 | -------------------------------------------------------------------------------- /examples/generate_query.py: -------------------------------------------------------------------------------- 1 | from openai import OpenAI 2 | 3 | # os.environ["OPENAI_API_KEY"] = "" 4 | 5 | 6 | def openai_complete_if_cache( 7 | model="gpt-4o-mini", prompt=None, system_prompt=None, history_messages=[], **kwargs 8 | ) -> str: 9 | openai_client = OpenAI() 10 | 11 | messages = [] 12 | if system_prompt: 13 | messages.append({"role": "system", "content": system_prompt}) 14 | messages.extend(history_messages) 15 | messages.append({"role": "user", "content": prompt}) 16 | 17 | response = openai_client.chat.completions.create( 18 | model=model, messages=messages, **kwargs 19 | ) 20 | return response.choices[0].message.content 21 | 22 | 23 | if __name__ == "__main__": 24 | description = "" 25 | prompt = f""" 26 | Given the following description of a dataset: 27 | 28 | {description} 29 | 30 | Please identify 5 potential users who would engage with this dataset. For each user, list 5 tasks they would perform with this dataset. Then, for each (user, task) combination, generate 5 questions that require a high-level understanding of the entire dataset. 31 | 32 | Output the results in the following structure: 33 | - User 1: [user description] 34 | - Task 1: [task description] 35 | - Question 1: 36 | - Question 2: 37 | - Question 3: 38 | - Question 4: 39 | - Question 5: 40 | - Task 2: [task description] 41 | ... 42 | - Task 5: [task description] 43 | - User 2: [user description] 44 | ... 45 | - User 5: [user description] 46 | ... 47 | """ 48 | 49 | result = openai_complete_if_cache(model="gpt-4o-mini", prompt=prompt) 50 | 51 | file_path = "./queries.txt" 52 | with open(file_path, "w") as file: 53 | file.write(result) 54 | 55 | print(f"Queries written to {file_path}") 56 | -------------------------------------------------------------------------------- /reproduce/Step_1_openai_compatible.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import time 4 | import numpy as np 5 | 6 | from lightrag import LightRAG 7 | from lightrag.utils import EmbeddingFunc 8 | from lightrag.llm import openai_complete_if_cache, openai_embedding 9 | 10 | 11 | ## For Upstage API 12 | # please check if embedding_dim=4096 in lightrag.py and llm.py in lightrag direcotry 13 | async def llm_model_func( 14 | prompt, system_prompt=None, history_messages=[], **kwargs 15 | ) -> str: 16 | return await openai_complete_if_cache( 17 | "solar-mini", 18 | prompt, 19 | system_prompt=system_prompt, 20 | history_messages=history_messages, 21 | api_key=os.getenv("UPSTAGE_API_KEY"), 22 | base_url="https://api.upstage.ai/v1/solar", 23 | **kwargs, 24 | ) 25 | 26 | 27 | async def embedding_func(texts: list[str]) -> np.ndarray: 28 | return await openai_embedding( 29 | texts, 30 | model="solar-embedding-1-large-query", 31 | api_key=os.getenv("UPSTAGE_API_KEY"), 32 | base_url="https://api.upstage.ai/v1/solar", 33 | ) 34 | 35 | 36 | ## /For Upstage API 37 | 38 | 39 | def insert_text(rag, file_path): 40 | with open(file_path, mode="r") as f: 41 | unique_contexts = json.load(f) 42 | 43 | retries = 0 44 | max_retries = 3 45 | while retries < max_retries: 46 | try: 47 | rag.insert(unique_contexts) 48 | break 49 | except Exception as e: 50 | retries += 1 51 | print(f"Insertion failed, retrying ({retries}/{max_retries}), error: {e}") 52 | time.sleep(10) 53 | if retries == max_retries: 54 | print("Insertion failed after exceeding the maximum number of retries") 55 | 56 | 57 | cls = "mix" 58 | WORKING_DIR = f"../{cls}" 59 | 60 | if not os.path.exists(WORKING_DIR): 61 | os.mkdir(WORKING_DIR) 62 | 63 | rag = LightRAG( 64 | working_dir=WORKING_DIR, 65 | llm_model_func=llm_model_func, 66 | embedding_func=EmbeddingFunc( 67 | embedding_dim=4096, max_token_size=8192, func=embedding_func 68 | ), 69 | ) 70 | 71 | insert_text(rag, f"../datasets/unique_contexts/{cls}_unique_contexts.json") 72 | -------------------------------------------------------------------------------- /examples/lightrag_siliconcloud_demo.py: -------------------------------------------------------------------------------- 1 | import os 2 | import asyncio 3 | from lightrag import LightRAG, QueryParam 4 | from lightrag.llm import openai_complete_if_cache, siliconcloud_embedding 5 | from lightrag.utils import EmbeddingFunc 6 | import numpy as np 7 | 8 | WORKING_DIR = "./dickens" 9 | 10 | if not os.path.exists(WORKING_DIR): 11 | os.mkdir(WORKING_DIR) 12 | 13 | 14 | async def llm_model_func( 15 | prompt, system_prompt=None, history_messages=[], **kwargs 16 | ) -> str: 17 | return await openai_complete_if_cache( 18 | "Qwen/Qwen2.5-7B-Instruct", 19 | prompt, 20 | system_prompt=system_prompt, 21 | history_messages=history_messages, 22 | api_key=os.getenv("SILICONFLOW_API_KEY"), 23 | base_url="https://api.siliconflow.cn/v1/", 24 | **kwargs, 25 | ) 26 | 27 | 28 | async def embedding_func(texts: list[str]) -> np.ndarray: 29 | return await siliconcloud_embedding( 30 | texts, 31 | model="netease-youdao/bce-embedding-base_v1", 32 | api_key=os.getenv("SILICONFLOW_API_KEY"), 33 | max_token_size=512, 34 | ) 35 | 36 | 37 | # function test 38 | async def test_funcs(): 39 | result = await llm_model_func("How are you?") 40 | print("llm_model_func: ", result) 41 | 42 | result = await embedding_func(["How are you?"]) 43 | print("embedding_func: ", result) 44 | 45 | 46 | asyncio.run(test_funcs()) 47 | 48 | 49 | rag = LightRAG( 50 | working_dir=WORKING_DIR, 51 | llm_model_func=llm_model_func, 52 | embedding_func=EmbeddingFunc( 53 | embedding_dim=768, max_token_size=512, func=embedding_func 54 | ), 55 | ) 56 | 57 | 58 | with open("./dickens/inbox/book.txt") as f: 59 | rag.insert(f.read()) 60 | 61 | # Perform naive search 62 | print( 63 | rag.query("What are the top themes in this story?", param=QueryParam(mode="naive")) 64 | ) 65 | 66 | # Perform local search 67 | print( 68 | rag.query("What are the top themes in this story?", param=QueryParam(mode="local")) 69 | ) 70 | 71 | # Perform global search 72 | print( 73 | rag.query("What are the top themes in this story?", param=QueryParam(mode="global")) 74 | ) 75 | 76 | # Perform hybrid search 77 | print( 78 | rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid")) 79 | ) 80 | -------------------------------------------------------------------------------- /reproduce/Step_3.py: -------------------------------------------------------------------------------- 1 | import re 2 | import json 3 | import asyncio 4 | from lightrag import LightRAG, QueryParam 5 | from tqdm import tqdm 6 | 7 | 8 | def extract_queries(file_path): 9 | with open(file_path, "r") as f: 10 | data = f.read() 11 | 12 | data = data.replace("**", "") 13 | 14 | queries = re.findall(r"- Question \d+: (.+)", data) 15 | 16 | return queries 17 | 18 | 19 | async def process_query(query_text, rag_instance, query_param): 20 | try: 21 | result = await rag_instance.aquery(query_text, param=query_param) 22 | return {"query": query_text, "result": result}, None 23 | except Exception as e: 24 | return None, {"query": query_text, "error": str(e)} 25 | 26 | 27 | def always_get_an_event_loop() -> asyncio.AbstractEventLoop: 28 | try: 29 | loop = asyncio.get_event_loop() 30 | except RuntimeError: 31 | loop = asyncio.new_event_loop() 32 | asyncio.set_event_loop(loop) 33 | return loop 34 | 35 | 36 | def run_queries_and_save_to_json( 37 | queries, rag_instance, query_param, output_file, error_file 38 | ): 39 | loop = always_get_an_event_loop() 40 | 41 | with open(output_file, "a", encoding="utf-8") as result_file, open( 42 | error_file, "a", encoding="utf-8" 43 | ) as err_file: 44 | result_file.write("[\n") 45 | first_entry = True 46 | 47 | for query_text in tqdm(queries, desc="Processing queries", unit="query"): 48 | result, error = loop.run_until_complete( 49 | process_query(query_text, rag_instance, query_param) 50 | ) 51 | 52 | if result: 53 | if not first_entry: 54 | result_file.write(",\n") 55 | json.dump(result, result_file, ensure_ascii=False, indent=4) 56 | first_entry = False 57 | elif error: 58 | json.dump(error, err_file, ensure_ascii=False, indent=4) 59 | err_file.write("\n") 60 | 61 | result_file.write("\n]") 62 | 63 | 64 | if __name__ == "__main__": 65 | cls = "agriculture" 66 | mode = "hybrid" 67 | WORKING_DIR = f"../{cls}" 68 | 69 | rag = LightRAG(working_dir=WORKING_DIR) 70 | query_param = QueryParam(mode=mode) 71 | 72 | queries = extract_queries(f"../datasets/questions/{cls}_questions.txt") 73 | run_queries_and_save_to_json( 74 | queries, rag, query_param, f"{cls}_result.json", f"{cls}_errors.json" 75 | ) 76 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM debian:bullseye-slim 2 | ENV JAVA_HOME=/opt/java/openjdk 3 | COPY --from=eclipse-temurin:17 $JAVA_HOME $JAVA_HOME 4 | ENV PATH="${JAVA_HOME}/bin:${PATH}" \ 5 | NEO4J_SHA256=7ce97bd9a4348af14df442f00b3dc5085b5983d6f03da643744838c7a1bc8ba7 \ 6 | NEO4J_TARBALL=neo4j-enterprise-5.24.2-unix.tar.gz \ 7 | NEO4J_EDITION=enterprise \ 8 | NEO4J_HOME="/var/lib/neo4j" \ 9 | LANG=C.UTF-8 10 | ARG NEO4J_URI=https://dist.neo4j.org/neo4j-enterprise-5.24.2-unix.tar.gz 11 | 12 | RUN addgroup --gid 7474 --system neo4j && adduser --uid 7474 --system --no-create-home --home "${NEO4J_HOME}" --ingroup neo4j neo4j 13 | 14 | COPY ./local-package/* /startup/ 15 | 16 | RUN apt update \ 17 | && apt-get install -y curl gcc git jq make procps tini wget \ 18 | && curl --fail --silent --show-error --location --remote-name ${NEO4J_URI} \ 19 | && echo "${NEO4J_SHA256} ${NEO4J_TARBALL}" | sha256sum -c --strict --quiet \ 20 | && tar --extract --file ${NEO4J_TARBALL} --directory /var/lib \ 21 | && mv /var/lib/neo4j-* "${NEO4J_HOME}" \ 22 | && rm ${NEO4J_TARBALL} \ 23 | && sed -i 's/Package Type:.*/Package Type: docker bullseye/' $NEO4J_HOME/packaging_info \ 24 | && mv /startup/neo4j-admin-report.sh "${NEO4J_HOME}"/bin/neo4j-admin-report \ 25 | && mv "${NEO4J_HOME}"/data /data \ 26 | && mv "${NEO4J_HOME}"/logs /logs \ 27 | && chown -R neo4j:neo4j /data \ 28 | && chmod -R 777 /data \ 29 | && chown -R neo4j:neo4j /logs \ 30 | && chmod -R 777 /logs \ 31 | && chown -R neo4j:neo4j "${NEO4J_HOME}" \ 32 | && chmod -R 777 "${NEO4J_HOME}" \ 33 | && chmod -R 755 "${NEO4J_HOME}/bin" \ 34 | && ln -s /data "${NEO4J_HOME}"/data \ 35 | && ln -s /logs "${NEO4J_HOME}"/logs \ 36 | && git clone https://github.com/ncopa/su-exec.git \ 37 | && cd su-exec \ 38 | && git checkout 4c3bb42b093f14da70d8ab924b487ccfbb1397af \ 39 | && echo d6c40440609a23483f12eb6295b5191e94baf08298a856bab6e15b10c3b82891 su-exec.c | sha256sum -c \ 40 | && echo 2a87af245eb125aca9305a0b1025525ac80825590800f047419dc57bba36b334 Makefile | sha256sum -c \ 41 | && make \ 42 | && mv /su-exec/su-exec /usr/bin/su-exec \ 43 | && apt-get -y purge --auto-remove curl gcc git make \ 44 | && rm -rf /var/lib/apt/lists/* /su-exec 45 | 46 | 47 | ENV PATH "${NEO4J_HOME}"/bin:$PATH 48 | 49 | WORKDIR "${NEO4J_HOME}" 50 | 51 | VOLUME /data /logs 52 | 53 | EXPOSE 7474 7473 7687 54 | 55 | ENTRYPOINT ["tini", "-g", "--", "/startup/docker-entrypoint.sh"] 56 | CMD ["neo4j"] -------------------------------------------------------------------------------- /reproduce/Step_2.py: -------------------------------------------------------------------------------- 1 | import json 2 | from openai import OpenAI 3 | from transformers import GPT2Tokenizer 4 | 5 | 6 | def openai_complete_if_cache( 7 | model="gpt-4o", prompt=None, system_prompt=None, history_messages=[], **kwargs 8 | ) -> str: 9 | openai_client = OpenAI() 10 | 11 | messages = [] 12 | if system_prompt: 13 | messages.append({"role": "system", "content": system_prompt}) 14 | messages.extend(history_messages) 15 | messages.append({"role": "user", "content": prompt}) 16 | 17 | response = openai_client.chat.completions.create( 18 | model=model, messages=messages, **kwargs 19 | ) 20 | return response.choices[0].message.content 21 | 22 | 23 | tokenizer = GPT2Tokenizer.from_pretrained("gpt2") 24 | 25 | 26 | def get_summary(context, tot_tokens=2000): 27 | tokens = tokenizer.tokenize(context) 28 | half_tokens = tot_tokens // 2 29 | 30 | start_tokens = tokens[1000 : 1000 + half_tokens] 31 | end_tokens = tokens[-(1000 + half_tokens) : 1000] 32 | 33 | summary_tokens = start_tokens + end_tokens 34 | summary = tokenizer.convert_tokens_to_string(summary_tokens) 35 | 36 | return summary 37 | 38 | 39 | clses = ["agriculture"] 40 | for cls in clses: 41 | with open(f"../datasets/unique_contexts/{cls}_unique_contexts.json", mode="r") as f: 42 | unique_contexts = json.load(f) 43 | 44 | summaries = [get_summary(context) for context in unique_contexts] 45 | 46 | total_description = "\n\n".join(summaries) 47 | 48 | prompt = f""" 49 | Given the following description of a dataset: 50 | 51 | {total_description} 52 | 53 | Please identify 5 potential users who would engage with this dataset. For each user, list 5 tasks they would perform with this dataset. Then, for each (user, task) combination, generate 5 questions that require a high-level understanding of the entire dataset. 54 | 55 | Output the results in the following structure: 56 | - User 1: [user description] 57 | - Task 1: [task description] 58 | - Question 1: 59 | - Question 2: 60 | - Question 3: 61 | - Question 4: 62 | - Question 5: 63 | - Task 2: [task description] 64 | ... 65 | - Task 5: [task description] 66 | - User 2: [user description] 67 | ... 68 | - User 5: [user description] 69 | ... 70 | """ 71 | 72 | result = openai_complete_if_cache(model="gpt-4o", prompt=prompt) 73 | 74 | file_path = f"../datasets/questions/{cls}_questions.txt" 75 | with open(file_path, "w") as file: 76 | file.write(result) 77 | 78 | print(f"{cls}_questions written to {file_path}") 79 | -------------------------------------------------------------------------------- /examples/lightrag_lmdeploy_demo.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from lightrag import LightRAG, QueryParam 4 | from lightrag.llm import lmdeploy_model_if_cache, hf_embedding 5 | from lightrag.utils import EmbeddingFunc 6 | from transformers import AutoModel, AutoTokenizer 7 | 8 | WORKING_DIR = "./dickens" 9 | 10 | if not os.path.exists(WORKING_DIR): 11 | os.mkdir(WORKING_DIR) 12 | 13 | 14 | async def lmdeploy_model_complete( 15 | prompt=None, system_prompt=None, history_messages=[], **kwargs 16 | ) -> str: 17 | model_name = kwargs["hashing_kv"].global_config["llm_model_name"] 18 | return await lmdeploy_model_if_cache( 19 | model_name, 20 | prompt, 21 | system_prompt=system_prompt, 22 | history_messages=history_messages, 23 | ## please specify chat_template if your local path does not follow original HF file name, 24 | ## or model_name is a pytorch model on huggingface.co, 25 | ## you can refer to https://github.com/InternLM/lmdeploy/blob/main/lmdeploy/model.py 26 | ## for a list of chat_template available in lmdeploy. 27 | chat_template="llama3", 28 | # model_format ='awq', # if you are using awq quantization model. 29 | # quant_policy=8, # if you want to use online kv cache, 4=kv int4, 8=kv int8. 30 | **kwargs, 31 | ) 32 | 33 | 34 | rag = LightRAG( 35 | working_dir=WORKING_DIR, 36 | llm_model_func=lmdeploy_model_complete, 37 | llm_model_name="meta-llama/Llama-3.1-8B-Instruct", # please use definite path for local model 38 | embedding_func=EmbeddingFunc( 39 | embedding_dim=384, 40 | max_token_size=5000, 41 | func=lambda texts: hf_embedding( 42 | texts, 43 | tokenizer=AutoTokenizer.from_pretrained( 44 | "sentence-transformers/all-MiniLM-L6-v2" 45 | ), 46 | embed_model=AutoModel.from_pretrained( 47 | "sentence-transformers/all-MiniLM-L6-v2" 48 | ), 49 | ), 50 | ), 51 | ) 52 | 53 | 54 | with open("./dickens/inbox/book.txt", "r", encoding="utf-8") as f: 55 | rag.insert(f.read()) 56 | 57 | # Perform naive search 58 | print( 59 | rag.query("What are the top themes in this story?", param=QueryParam(mode="naive")) 60 | ) 61 | 62 | # Perform local search 63 | print( 64 | rag.query("What are the top themes in this story?", param=QueryParam(mode="local")) 65 | ) 66 | 67 | # Perform global search 68 | print( 69 | rag.query("What are the top themes in this story?", param=QueryParam(mode="global")) 70 | ) 71 | 72 | # Perform hybrid search 73 | print( 74 | rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid")) 75 | ) 76 | -------------------------------------------------------------------------------- /reproduce/Step_0.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import glob 4 | import argparse 5 | 6 | 7 | def extract_unique_contexts(input_directory, output_directory): 8 | os.makedirs(output_directory, exist_ok=True) 9 | 10 | jsonl_files = glob.glob(os.path.join(input_directory, "*.jsonl")) 11 | print(f"Found {len(jsonl_files)} JSONL files.") 12 | 13 | for file_path in jsonl_files: 14 | filename = os.path.basename(file_path) 15 | name, ext = os.path.splitext(filename) 16 | output_filename = f"{name}_unique_contexts.json" 17 | output_path = os.path.join(output_directory, output_filename) 18 | 19 | unique_contexts_dict = {} 20 | 21 | print(f"Processing file: {filename}") 22 | 23 | try: 24 | with open(file_path, "r", encoding="utf-8") as infile: 25 | for line_number, line in enumerate(infile, start=1): 26 | line = line.strip() 27 | if not line: 28 | continue 29 | try: 30 | json_obj = json.loads(line) 31 | context = json_obj.get("context") 32 | if context and context not in unique_contexts_dict: 33 | unique_contexts_dict[context] = None 34 | except json.JSONDecodeError as e: 35 | print( 36 | f"JSON decoding error in file {filename} at line {line_number}: {e}" 37 | ) 38 | except FileNotFoundError: 39 | print(f"File not found: {filename}") 40 | continue 41 | except Exception as e: 42 | print(f"An error occurred while processing file {filename}: {e}") 43 | continue 44 | 45 | unique_contexts_list = list(unique_contexts_dict.keys()) 46 | print( 47 | f"There are {len(unique_contexts_list)} unique `context` entries in the file {filename}." 48 | ) 49 | 50 | try: 51 | with open(output_path, "w", encoding="utf-8") as outfile: 52 | json.dump(unique_contexts_list, outfile, ensure_ascii=False, indent=4) 53 | print(f"Unique `context` entries have been saved to: {output_filename}") 54 | except Exception as e: 55 | print(f"An error occurred while saving to the file {output_filename}: {e}") 56 | 57 | print("All files have been processed.") 58 | 59 | 60 | if __name__ == "__main__": 61 | parser = argparse.ArgumentParser() 62 | parser.add_argument("-i", "--input_dir", type=str, default="../datasets") 63 | parser.add_argument( 64 | "-o", "--output_dir", type=str, default="../datasets/unique_contexts" 65 | ) 66 | 67 | args = parser.parse_args() 68 | 69 | extract_unique_contexts(args.input_dir, args.output_dir) 70 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | from pathlib import Path 3 | 4 | 5 | # Reading the long description from README.md 6 | def read_long_description(): 7 | try: 8 | return Path("README.md").read_text(encoding="utf-8") 9 | except FileNotFoundError: 10 | return "A description of LightRAG is currently unavailable." 11 | 12 | 13 | # Retrieving metadata from __init__.py 14 | def retrieve_metadata(): 15 | vars2find = ["__author__", "__version__", "__url__"] 16 | vars2readme = {} 17 | try: 18 | with open("./lightrag/__init__.py") as f: 19 | for line in f.readlines(): 20 | for v in vars2find: 21 | if line.startswith(v): 22 | line = ( 23 | line.replace(" ", "") 24 | .replace('"', "") 25 | .replace("'", "") 26 | .strip() 27 | ) 28 | vars2readme[v] = line.split("=")[1] 29 | except FileNotFoundError: 30 | raise FileNotFoundError("Metadata file './lightrag/__init__.py' not found.") 31 | 32 | # Checking if all required variables are found 33 | missing_vars = [v for v in vars2find if v not in vars2readme] 34 | if missing_vars: 35 | raise ValueError( 36 | f"Missing required metadata variables in __init__.py: {missing_vars}" 37 | ) 38 | 39 | return vars2readme 40 | 41 | 42 | # Reading dependencies from requirements.txt 43 | def read_requirements(): 44 | deps = [] 45 | try: 46 | with open("./requirements.txt") as f: 47 | deps = [line.strip() for line in f if line.strip()] 48 | except FileNotFoundError: 49 | print( 50 | "Warning: 'requirements.txt' not found. No dependencies will be installed." 51 | ) 52 | return deps 53 | 54 | 55 | metadata = retrieve_metadata() 56 | long_description = read_long_description() 57 | requirements = read_requirements() 58 | 59 | setuptools.setup( 60 | name="lightrag-hku", 61 | url=metadata["__url__"], 62 | version=metadata["__version__"], 63 | author=metadata["__author__"], 64 | description="LightRAG: Simple and Fast Retrieval-Augmented Generation", 65 | long_description=long_description, 66 | long_description_content_type="text/markdown", 67 | packages=setuptools.find_packages( 68 | exclude=("tests*", "docs*") 69 | ), # Automatically find packages 70 | classifiers=[ 71 | "Development Status :: 4 - Beta", 72 | "Programming Language :: Python :: 3", 73 | "License :: OSI Approved :: MIT License", 74 | "Operating System :: OS Independent", 75 | "Intended Audience :: Developers", 76 | "Topic :: Software Development :: Libraries :: Python Modules", 77 | ], 78 | python_requires=">=3.9", 79 | install_requires=requirements, 80 | include_package_data=True, # Includes non-code files from MANIFEST.in 81 | project_urls={ # Additional project metadata 82 | "Documentation": metadata.get("__url__", ""), 83 | "Source": metadata.get("__url__", ""), 84 | "Tracker": f"{metadata.get('__url__', '')}/issues" 85 | if metadata.get("__url__") 86 | else "", 87 | }, 88 | ) 89 | -------------------------------------------------------------------------------- /examples/vram_management_demo.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | from lightrag import LightRAG, QueryParam 4 | from lightrag.llm import ollama_model_complete, ollama_embedding 5 | from lightrag.utils import EmbeddingFunc 6 | 7 | # Working directory and the directory path for text files 8 | WORKING_DIR = "./dickens" 9 | TEXT_FILES_DIR = "/llm/mt" 10 | 11 | # Create the working directory if it doesn't exist 12 | if not os.path.exists(WORKING_DIR): 13 | os.mkdir(WORKING_DIR) 14 | 15 | # Initialize LightRAG 16 | rag = LightRAG( 17 | working_dir=WORKING_DIR, 18 | llm_model_func=ollama_model_complete, 19 | llm_model_name="qwen2.5:3b-instruct-max-context", 20 | embedding_func=EmbeddingFunc( 21 | embedding_dim=768, 22 | max_token_size=8192, 23 | func=lambda texts: ollama_embedding(texts, embed_model="nomic-embed-text"), 24 | ), 25 | ) 26 | 27 | # Read all .txt files from the TEXT_FILES_DIR directory 28 | texts = [] 29 | for filename in os.listdir(TEXT_FILES_DIR): 30 | if filename.endswith(".txt"): 31 | file_path = os.path.join(TEXT_FILES_DIR, filename) 32 | with open(file_path, "r", encoding="utf-8") as file: 33 | texts.append(file.read()) 34 | 35 | 36 | # Batch insert texts into LightRAG with a retry mechanism 37 | def insert_texts_with_retry(rag, texts, retries=3, delay=5): 38 | for _ in range(retries): 39 | try: 40 | rag.insert(texts) 41 | return 42 | except Exception as e: 43 | print( 44 | f"Error occurred during insertion: {e}. Retrying in {delay} seconds..." 45 | ) 46 | time.sleep(delay) 47 | raise RuntimeError("Failed to insert texts after multiple retries.") 48 | 49 | 50 | insert_texts_with_retry(rag, texts) 51 | 52 | # Perform different types of queries and handle potential errors 53 | try: 54 | print( 55 | rag.query( 56 | "What are the top themes in this story?", param=QueryParam(mode="naive") 57 | ) 58 | ) 59 | except Exception as e: 60 | print(f"Error performing naive search: {e}") 61 | 62 | try: 63 | print( 64 | rag.query( 65 | "What are the top themes in this story?", param=QueryParam(mode="local") 66 | ) 67 | ) 68 | except Exception as e: 69 | print(f"Error performing local search: {e}") 70 | 71 | try: 72 | print( 73 | rag.query( 74 | "What are the top themes in this story?", param=QueryParam(mode="global") 75 | ) 76 | ) 77 | except Exception as e: 78 | print(f"Error performing global search: {e}") 79 | 80 | try: 81 | print( 82 | rag.query( 83 | "What are the top themes in this story?", param=QueryParam(mode="hybrid") 84 | ) 85 | ) 86 | except Exception as e: 87 | print(f"Error performing hybrid search: {e}") 88 | 89 | 90 | # Function to clear VRAM resources 91 | def clear_vram(): 92 | os.system("sudo nvidia-smi --gpu-reset") 93 | 94 | 95 | # Regularly clear VRAM to prevent overflow 96 | clear_vram_interval = 3600 # Clear once every hour 97 | start_time = time.time() 98 | 99 | while True: 100 | current_time = time.time() 101 | if current_time - start_time > clear_vram_interval: 102 | clear_vram() 103 | start_time = current_time 104 | time.sleep(60) # Check the time every minute 105 | -------------------------------------------------------------------------------- /examples/lightrag_openai_compatible_demo.py: -------------------------------------------------------------------------------- 1 | import os 2 | import asyncio 3 | from lightrag import LightRAG, QueryParam 4 | from lightrag.llm import openai_complete_if_cache, openai_embedding 5 | from lightrag.utils import EmbeddingFunc 6 | import numpy as np 7 | 8 | WORKING_DIR = "./dickens" 9 | 10 | if not os.path.exists(WORKING_DIR): 11 | os.mkdir(WORKING_DIR) 12 | 13 | 14 | async def llm_model_func( 15 | prompt, system_prompt=None, history_messages=[], **kwargs 16 | ) -> str: 17 | return await openai_complete_if_cache( 18 | "solar-mini", 19 | prompt, 20 | system_prompt=system_prompt, 21 | history_messages=history_messages, 22 | api_key=os.getenv("UPSTAGE_API_KEY"), 23 | base_url="https://api.upstage.ai/v1/solar", 24 | **kwargs, 25 | ) 26 | 27 | 28 | async def embedding_func(texts: list[str]) -> np.ndarray: 29 | return await openai_embedding( 30 | texts, 31 | model="solar-embedding-1-large-query", 32 | api_key=os.getenv("UPSTAGE_API_KEY"), 33 | base_url="https://api.upstage.ai/v1/solar", 34 | ) 35 | 36 | 37 | async def get_embedding_dim(): 38 | test_text = ["This is a test sentence."] 39 | embedding = await embedding_func(test_text) 40 | embedding_dim = embedding.shape[1] 41 | return embedding_dim 42 | 43 | 44 | # function test 45 | async def test_funcs(): 46 | result = await llm_model_func("How are you?") 47 | print("llm_model_func: ", result) 48 | 49 | result = await embedding_func(["How are you?"]) 50 | print("embedding_func: ", result) 51 | 52 | 53 | # asyncio.run(test_funcs()) 54 | 55 | 56 | async def main(): 57 | try: 58 | embedding_dimension = await get_embedding_dim() 59 | print(f"Detected embedding dimension: {embedding_dimension}") 60 | 61 | rag = LightRAG( 62 | working_dir=WORKING_DIR, 63 | llm_model_func=llm_model_func, 64 | embedding_func=EmbeddingFunc( 65 | embedding_dim=embedding_dimension, 66 | max_token_size=8192, 67 | func=embedding_func, 68 | ), 69 | ) 70 | 71 | with open("./dickens/inbox/book.txt", "r", encoding="utf-8") as f: 72 | await rag.ainsert(f.read()) 73 | 74 | # Perform naive search 75 | print( 76 | await rag.aquery( 77 | "What are the top themes in this story?", param=QueryParam(mode="naive") 78 | ) 79 | ) 80 | 81 | # Perform local search 82 | print( 83 | await rag.aquery( 84 | "What are the top themes in this story?", param=QueryParam(mode="local") 85 | ) 86 | ) 87 | 88 | # Perform global search 89 | print( 90 | await rag.aquery( 91 | "What are the top themes in this story?", 92 | param=QueryParam(mode="global"), 93 | ) 94 | ) 95 | 96 | # Perform hybrid search 97 | print( 98 | await rag.aquery( 99 | "What are the top themes in this story?", 100 | param=QueryParam(mode="hybrid"), 101 | ) 102 | ) 103 | except Exception as e: 104 | print(f"An error occurred: {e}") 105 | 106 | 107 | if __name__ == "__main__": 108 | asyncio.run(main()) 109 | -------------------------------------------------------------------------------- /reproduce/Step_3_openai_compatible.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import json 4 | import asyncio 5 | from lightrag import LightRAG, QueryParam 6 | from tqdm import tqdm 7 | from lightrag.llm import openai_complete_if_cache, openai_embedding 8 | from lightrag.utils import EmbeddingFunc 9 | import numpy as np 10 | 11 | 12 | ## For Upstage API 13 | # please check if embedding_dim=4096 in lightrag.py and llm.py in lightrag direcotry 14 | async def llm_model_func( 15 | prompt, system_prompt=None, history_messages=[], **kwargs 16 | ) -> str: 17 | return await openai_complete_if_cache( 18 | "solar-mini", 19 | prompt, 20 | system_prompt=system_prompt, 21 | history_messages=history_messages, 22 | api_key=os.getenv("UPSTAGE_API_KEY"), 23 | base_url="https://api.upstage.ai/v1/solar", 24 | **kwargs, 25 | ) 26 | 27 | 28 | async def embedding_func(texts: list[str]) -> np.ndarray: 29 | return await openai_embedding( 30 | texts, 31 | model="solar-embedding-1-large-query", 32 | api_key=os.getenv("UPSTAGE_API_KEY"), 33 | base_url="https://api.upstage.ai/v1/solar", 34 | ) 35 | 36 | 37 | ## /For Upstage API 38 | 39 | 40 | def extract_queries(file_path): 41 | with open(file_path, "r") as f: 42 | data = f.read() 43 | 44 | data = data.replace("**", "") 45 | 46 | queries = re.findall(r"- Question \d+: (.+)", data) 47 | 48 | return queries 49 | 50 | 51 | async def process_query(query_text, rag_instance, query_param): 52 | try: 53 | result = await rag_instance.aquery(query_text, param=query_param) 54 | return {"query": query_text, "result": result}, None 55 | except Exception as e: 56 | return None, {"query": query_text, "error": str(e)} 57 | 58 | 59 | def always_get_an_event_loop() -> asyncio.AbstractEventLoop: 60 | try: 61 | loop = asyncio.get_event_loop() 62 | except RuntimeError: 63 | loop = asyncio.new_event_loop() 64 | asyncio.set_event_loop(loop) 65 | return loop 66 | 67 | 68 | def run_queries_and_save_to_json( 69 | queries, rag_instance, query_param, output_file, error_file 70 | ): 71 | loop = always_get_an_event_loop() 72 | 73 | with open(output_file, "a", encoding="utf-8") as result_file, open( 74 | error_file, "a", encoding="utf-8" 75 | ) as err_file: 76 | result_file.write("[\n") 77 | first_entry = True 78 | 79 | for query_text in tqdm(queries, desc="Processing queries", unit="query"): 80 | result, error = loop.run_until_complete( 81 | process_query(query_text, rag_instance, query_param) 82 | ) 83 | 84 | if result: 85 | if not first_entry: 86 | result_file.write(",\n") 87 | json.dump(result, result_file, ensure_ascii=False, indent=4) 88 | first_entry = False 89 | elif error: 90 | json.dump(error, err_file, ensure_ascii=False, indent=4) 91 | err_file.write("\n") 92 | 93 | result_file.write("\n]") 94 | 95 | 96 | if __name__ == "__main__": 97 | cls = "mix" 98 | mode = "hybrid" 99 | WORKING_DIR = f"../{cls}" 100 | 101 | rag = LightRAG(working_dir=WORKING_DIR) 102 | rag = LightRAG( 103 | working_dir=WORKING_DIR, 104 | llm_model_func=llm_model_func, 105 | embedding_func=EmbeddingFunc( 106 | embedding_dim=4096, max_token_size=8192, func=embedding_func 107 | ), 108 | ) 109 | query_param = QueryParam(mode=mode) 110 | 111 | base_dir = "../datasets/questions" 112 | queries = extract_queries(f"{base_dir}/{cls}_questions.txt") 113 | run_queries_and_save_to_json( 114 | queries, rag, query_param, f"{base_dir}/result.json", f"{base_dir}/errors.json" 115 | ) 116 | -------------------------------------------------------------------------------- /examples/batch_eval.py: -------------------------------------------------------------------------------- 1 | import re 2 | import json 3 | import jsonlines 4 | 5 | from openai import OpenAI 6 | 7 | 8 | def batch_eval(query_file, result1_file, result2_file, output_file_path): 9 | client = OpenAI() 10 | 11 | with open(query_file, "r") as f: 12 | data = f.read() 13 | 14 | queries = re.findall(r"- Question \d+: (.+)", data) 15 | 16 | with open(result1_file, "r") as f: 17 | answers1 = json.load(f) 18 | answers1 = [i["result"] for i in answers1] 19 | 20 | with open(result2_file, "r") as f: 21 | answers2 = json.load(f) 22 | answers2 = [i["result"] for i in answers2] 23 | 24 | requests = [] 25 | for i, (query, answer1, answer2) in enumerate(zip(queries, answers1, answers2)): 26 | sys_prompt = """ 27 | ---Role--- 28 | You are an expert tasked with evaluating two answers to the same question based on three criteria: **Comprehensiveness**, **Diversity**, and **Empowerment**. 29 | """ 30 | 31 | prompt = f""" 32 | You will evaluate two answers to the same question based on three criteria: **Comprehensiveness**, **Diversity**, and **Empowerment**. 33 | 34 | - **Comprehensiveness**: How much detail does the answer provide to cover all aspects and details of the question? 35 | - **Diversity**: How varied and rich is the answer in providing different perspectives and insights on the question? 36 | - **Empowerment**: How well does the answer help the reader understand and make informed judgments about the topic? 37 | 38 | For each criterion, choose the better answer (either Answer 1 or Answer 2) and explain why. Then, select an overall winner based on these three categories. 39 | 40 | Here is the question: 41 | {query} 42 | 43 | Here are the two answers: 44 | 45 | **Answer 1:** 46 | {answer1} 47 | 48 | **Answer 2:** 49 | {answer2} 50 | 51 | Evaluate both answers using the three criteria listed above and provide detailed explanations for each criterion. 52 | 53 | Output your evaluation in the following JSON format: 54 | 55 | {{ 56 | "Comprehensiveness": {{ 57 | "Winner": "[Answer 1 or Answer 2]", 58 | "Explanation": "[Provide explanation here]" 59 | }}, 60 | "Empowerment": {{ 61 | "Winner": "[Answer 1 or Answer 2]", 62 | "Explanation": "[Provide explanation here]" 63 | }}, 64 | "Overall Winner": {{ 65 | "Winner": "[Answer 1 or Answer 2]", 66 | "Explanation": "[Summarize why this answer is the overall winner based on the three criteria]" 67 | }} 68 | }} 69 | """ 70 | 71 | request_data = { 72 | "custom_id": f"request-{i+1}", 73 | "method": "POST", 74 | "url": "/v1/chat/completions", 75 | "body": { 76 | "model": "gpt-4o-mini", 77 | "messages": [ 78 | {"role": "system", "content": sys_prompt}, 79 | {"role": "user", "content": prompt}, 80 | ], 81 | }, 82 | } 83 | 84 | requests.append(request_data) 85 | 86 | with jsonlines.open(output_file_path, mode="w") as writer: 87 | for request in requests: 88 | writer.write(request) 89 | 90 | print(f"Batch API requests written to {output_file_path}") 91 | 92 | batch_input_file = client.files.create( 93 | file=open(output_file_path, "rb"), purpose="batch" 94 | ) 95 | batch_input_file_id = batch_input_file.id 96 | 97 | batch = client.batches.create( 98 | input_file_id=batch_input_file_id, 99 | endpoint="/v1/chat/completions", 100 | completion_window="24h", 101 | metadata={"description": "nightly eval job"}, 102 | ) 103 | 104 | print(f"Batch {batch.id} has been created.") 105 | 106 | 107 | if __name__ == "__main__": 108 | batch_eval() 109 | -------------------------------------------------------------------------------- /lightrag/base.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import TypedDict, Union, Literal, Generic, TypeVar 3 | 4 | import numpy as np 5 | 6 | from .utils import EmbeddingFunc 7 | 8 | TextChunkSchema = TypedDict( 9 | "TextChunkSchema", 10 | {"tokens": int, "content": str, "full_doc_id": str, "chunk_order_index": int}, 11 | ) 12 | 13 | T = TypeVar("T") 14 | 15 | 16 | @dataclass 17 | class QueryParam: 18 | mode: Literal["local", "global", "hybrid", "naive"] = "global" 19 | only_need_context: bool = False 20 | response_type: str = "Multiple Paragraphs" 21 | # Number of top-k items to retrieve; corresponds to entities in "local" mode and relationships in "global" mode. 22 | top_k: int = 60 23 | # Number of tokens for the original chunks. 24 | max_token_for_text_unit: int = 4000 25 | # Number of tokens for the relationship descriptions 26 | max_token_for_global_context: int = 4000 27 | # Number of tokens for the entity descriptions 28 | max_token_for_local_context: int = 4000 29 | 30 | 31 | @dataclass 32 | class StorageNameSpace: 33 | namespace: str 34 | global_config: dict 35 | 36 | async def index_done_callback(self): 37 | """commit the storage operations after indexing""" 38 | pass 39 | 40 | async def query_done_callback(self): 41 | """commit the storage operations after querying""" 42 | pass 43 | 44 | 45 | @dataclass 46 | class BaseVectorStorage(StorageNameSpace): 47 | embedding_func: EmbeddingFunc 48 | meta_fields: set = field(default_factory=set) 49 | 50 | async def query(self, query: str, top_k: int) -> list[dict]: 51 | raise NotImplementedError 52 | 53 | async def upsert(self, data: dict[str, dict]): 54 | """Use 'content' field from value for embedding, use key as id. 55 | If embedding_func is None, use 'embedding' field from value 56 | """ 57 | raise NotImplementedError 58 | 59 | 60 | @dataclass 61 | class BaseKVStorage(Generic[T], StorageNameSpace): 62 | async def all_keys(self) -> list[str]: 63 | raise NotImplementedError 64 | 65 | async def get_by_id(self, id: str) -> Union[T, None]: 66 | raise NotImplementedError 67 | 68 | async def get_by_ids( 69 | self, ids: list[str], fields: Union[set[str], None] = None 70 | ) -> list[Union[T, None]]: 71 | raise NotImplementedError 72 | 73 | async def filter_keys(self, data: list[str]) -> set[str]: 74 | """return un-exist keys""" 75 | raise NotImplementedError 76 | 77 | async def upsert(self, data: dict[str, T]): 78 | raise NotImplementedError 79 | 80 | async def drop(self): 81 | raise NotImplementedError 82 | 83 | 84 | @dataclass 85 | class BaseGraphStorage(StorageNameSpace): 86 | async def has_node(self, node_id: str) -> bool: 87 | raise NotImplementedError 88 | 89 | async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: 90 | raise NotImplementedError 91 | 92 | async def node_degree(self, node_id: str) -> int: 93 | raise NotImplementedError 94 | 95 | async def edge_degree(self, src_id: str, tgt_id: str) -> int: 96 | raise NotImplementedError 97 | 98 | async def get_node(self, node_id: str) -> Union[dict, None]: 99 | raise NotImplementedError 100 | 101 | async def get_edge( 102 | self, source_node_id: str, target_node_id: str 103 | ) -> Union[dict, None]: 104 | raise NotImplementedError 105 | 106 | async def get_node_edges( 107 | self, source_node_id: str 108 | ) -> Union[list[tuple[str, str]], None]: 109 | raise NotImplementedError 110 | 111 | async def upsert_node(self, node_id: str, node_data: dict[str, str]): 112 | raise NotImplementedError 113 | 114 | async def upsert_edge( 115 | self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] 116 | ): 117 | raise NotImplementedError 118 | 119 | async def clustering(self, algorithm: str): 120 | raise NotImplementedError 121 | 122 | async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]: 123 | raise NotImplementedError("Node embedding is not used in lightrag.") 124 | -------------------------------------------------------------------------------- /examples/graph_visual_with_neo4j.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from lightrag.utils import xml_to_json 4 | from neo4j import GraphDatabase 5 | 6 | # Constants 7 | WORKING_DIR = "./dickens" 8 | BATCH_SIZE_NODES = 500 9 | BATCH_SIZE_EDGES = 100 10 | 11 | # Neo4j connection credentials 12 | NEO4J_URI = "bolt://localhost:7687" 13 | NEO4J_USERNAME = "neo4j" 14 | NEO4J_PASSWORD = "your_password" 15 | 16 | 17 | def convert_xml_to_json(xml_path, output_path): 18 | """Converts XML file to JSON and saves the output.""" 19 | if not os.path.exists(xml_path): 20 | print(f"Error: File not found - {xml_path}") 21 | return None 22 | 23 | json_data = xml_to_json(xml_path) 24 | if json_data: 25 | with open(output_path, "w", encoding="utf-8") as f: 26 | json.dump(json_data, f, ensure_ascii=False, indent=2) 27 | print(f"JSON file created: {output_path}") 28 | return json_data 29 | else: 30 | print("Failed to create JSON data") 31 | return None 32 | 33 | 34 | def process_in_batches(tx, query, data, batch_size): 35 | """Process data in batches and execute the given query.""" 36 | for i in range(0, len(data), batch_size): 37 | batch = data[i : i + batch_size] 38 | tx.run(query, {"nodes": batch} if "nodes" in query else {"edges": batch}) 39 | 40 | 41 | def main(): 42 | # Paths 43 | xml_file = os.path.join(WORKING_DIR, "graph_chunk_entity_relation.graphml") 44 | json_file = os.path.join(WORKING_DIR, "graph_data.json") 45 | 46 | # Convert XML to JSON 47 | json_data = convert_xml_to_json(xml_file, json_file) 48 | if json_data is None: 49 | return 50 | 51 | # Load nodes and edges 52 | nodes = json_data.get("nodes", []) 53 | edges = json_data.get("edges", []) 54 | 55 | # Neo4j queries 56 | create_nodes_query = """ 57 | UNWIND $nodes AS node 58 | MERGE (e:Entity {id: node.id}) 59 | SET e.entity_type = node.entity_type, 60 | e.description = node.description, 61 | e.source_id = node.source_id, 62 | e.displayName = node.id 63 | REMOVE e:Entity 64 | WITH e, node 65 | CALL apoc.create.addLabels(e, [node.entity_type]) YIELD node AS labeledNode 66 | RETURN count(*) 67 | """ 68 | 69 | create_edges_query = """ 70 | UNWIND $edges AS edge 71 | MATCH (source {id: edge.source}) 72 | MATCH (target {id: edge.target}) 73 | WITH source, target, edge, 74 | CASE 75 | WHEN edge.keywords CONTAINS 'lead' THEN 'lead' 76 | WHEN edge.keywords CONTAINS 'participate' THEN 'participate' 77 | WHEN edge.keywords CONTAINS 'uses' THEN 'uses' 78 | WHEN edge.keywords CONTAINS 'located' THEN 'located' 79 | WHEN edge.keywords CONTAINS 'occurs' THEN 'occurs' 80 | ELSE REPLACE(SPLIT(edge.keywords, ',')[0], '\"', '') 81 | END AS relType 82 | CALL apoc.create.relationship(source, relType, { 83 | weight: edge.weight, 84 | description: edge.description, 85 | keywords: edge.keywords, 86 | source_id: edge.source_id 87 | }, target) YIELD rel 88 | RETURN count(*) 89 | """ 90 | 91 | set_displayname_and_labels_query = """ 92 | MATCH (n) 93 | SET n.displayName = n.id 94 | WITH n 95 | CALL apoc.create.setLabels(n, [n.entity_type]) YIELD node 96 | RETURN count(*) 97 | """ 98 | 99 | # Create a Neo4j driver 100 | driver = GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USERNAME, NEO4J_PASSWORD)) 101 | 102 | try: 103 | # Execute queries in batches 104 | with driver.session() as session: 105 | # Insert nodes in batches 106 | session.execute_write( 107 | process_in_batches, create_nodes_query, nodes, BATCH_SIZE_NODES 108 | ) 109 | 110 | # Insert edges in batches 111 | session.execute_write( 112 | process_in_batches, create_edges_query, edges, BATCH_SIZE_EDGES 113 | ) 114 | 115 | # Set displayName and labels 116 | session.run(set_displayname_and_labels_query) 117 | 118 | except Exception as e: 119 | print(f"Error occurred: {e}") 120 | 121 | finally: 122 | driver.close() 123 | 124 | 125 | if __name__ == "__main__": 126 | main() 127 | -------------------------------------------------------------------------------- /examples/lightrag_azure_openai_demo.py: -------------------------------------------------------------------------------- 1 | import os 2 | import asyncio 3 | from lightrag import LightRAG, QueryParam 4 | from lightrag.utils import EmbeddingFunc 5 | import numpy as np 6 | from dotenv import load_dotenv 7 | import aiohttp 8 | import logging 9 | 10 | logging.basicConfig(level=logging.INFO) 11 | 12 | load_dotenv() 13 | 14 | AZURE_OPENAI_API_VERSION = os.getenv("AZURE_OPENAI_API_VERSION") 15 | AZURE_OPENAI_DEPLOYMENT = os.getenv("AZURE_OPENAI_DEPLOYMENT") 16 | AZURE_OPENAI_API_KEY = os.getenv("AZURE_OPENAI_API_KEY") 17 | AZURE_OPENAI_ENDPOINT = os.getenv("AZURE_OPENAI_ENDPOINT") 18 | 19 | AZURE_EMBEDDING_DEPLOYMENT = os.getenv("AZURE_EMBEDDING_DEPLOYMENT") 20 | AZURE_EMBEDDING_API_VERSION = os.getenv("AZURE_EMBEDDING_API_VERSION") 21 | 22 | WORKING_DIR = "./dickens" 23 | 24 | if os.path.exists(WORKING_DIR): 25 | import shutil 26 | 27 | shutil.rmtree(WORKING_DIR) 28 | 29 | os.mkdir(WORKING_DIR) 30 | 31 | 32 | async def llm_model_func( 33 | prompt, system_prompt=None, history_messages=[], **kwargs 34 | ) -> str: 35 | headers = { 36 | "Content-Type": "application/json", 37 | "api-key": AZURE_OPENAI_API_KEY, 38 | } 39 | endpoint = f"{AZURE_OPENAI_ENDPOINT}openai/deployments/{AZURE_OPENAI_DEPLOYMENT}/chat/completions?api-version={AZURE_OPENAI_API_VERSION}" 40 | 41 | messages = [] 42 | if system_prompt: 43 | messages.append({"role": "system", "content": system_prompt}) 44 | if history_messages: 45 | messages.extend(history_messages) 46 | messages.append({"role": "user", "content": prompt}) 47 | 48 | payload = { 49 | "messages": messages, 50 | "temperature": kwargs.get("temperature", 0), 51 | "top_p": kwargs.get("top_p", 1), 52 | "n": kwargs.get("n", 1), 53 | } 54 | 55 | async with aiohttp.ClientSession() as session: 56 | async with session.post(endpoint, headers=headers, json=payload) as response: 57 | if response.status != 200: 58 | raise ValueError( 59 | f"Request failed with status {response.status}: {await response.text()}" 60 | ) 61 | result = await response.json() 62 | return result["choices"][0]["message"]["content"] 63 | 64 | 65 | async def embedding_func(texts: list[str]) -> np.ndarray: 66 | headers = { 67 | "Content-Type": "application/json", 68 | "api-key": AZURE_OPENAI_API_KEY, 69 | } 70 | endpoint = f"{AZURE_OPENAI_ENDPOINT}openai/deployments/{AZURE_EMBEDDING_DEPLOYMENT}/embeddings?api-version={AZURE_EMBEDDING_API_VERSION}" 71 | 72 | payload = {"input": texts} 73 | 74 | async with aiohttp.ClientSession() as session: 75 | async with session.post(endpoint, headers=headers, json=payload) as response: 76 | if response.status != 200: 77 | raise ValueError( 78 | f"Request failed with status {response.status}: {await response.text()}" 79 | ) 80 | result = await response.json() 81 | embeddings = [item["embedding"] for item in result["data"]] 82 | return np.array(embeddings) 83 | 84 | 85 | async def test_funcs(): 86 | result = await llm_model_func("How are you?") 87 | print("Resposta do llm_model_func: ", result) 88 | 89 | result = await embedding_func(["How are you?"]) 90 | print("Resultado do embedding_func: ", result.shape) 91 | print("Dimensão da embedding: ", result.shape[1]) 92 | 93 | 94 | asyncio.run(test_funcs()) 95 | 96 | embedding_dimension = 3072 97 | 98 | rag = LightRAG( 99 | working_dir=WORKING_DIR, 100 | llm_model_func=llm_model_func, 101 | embedding_func=EmbeddingFunc( 102 | embedding_dim=embedding_dimension, 103 | max_token_size=8192, 104 | func=embedding_func, 105 | ), 106 | ) 107 | 108 | book1 = open("./book_1.txt", encoding="utf-8") 109 | book2 = open("./book_2.txt", encoding="utf-8") 110 | 111 | rag.insert([book1.read(), book2.read()]) 112 | 113 | query_text = "What are the main themes?" 114 | 115 | print("Result (Naive):") 116 | print(rag.query(query_text, param=QueryParam(mode="naive"))) 117 | 118 | print("\nResult (Local):") 119 | print(rag.query(query_text, param=QueryParam(mode="local"))) 120 | 121 | print("\nResult (Global):") 122 | print(rag.query(query_text, param=QueryParam(mode="global"))) 123 | 124 | print("\nResult (Hybrid):") 125 | print(rag.query(query_text, param=QueryParam(mode="hybrid"))) 126 | -------------------------------------------------------------------------------- /examples/lightrag_api_openai_compatible_demo.py: -------------------------------------------------------------------------------- 1 | from fastapi import FastAPI, HTTPException, File, UploadFile 2 | from pydantic import BaseModel 3 | import os 4 | from lightrag import LightRAG, QueryParam 5 | from lightrag.llm import openai_complete_if_cache, openai_embedding 6 | from lightrag.utils import EmbeddingFunc 7 | import numpy as np 8 | from typing import Optional 9 | import asyncio 10 | import nest_asyncio 11 | 12 | # Apply nest_asyncio to solve event loop issues 13 | nest_asyncio.apply() 14 | 15 | DEFAULT_RAG_DIR = "index_default" 16 | app = FastAPI(title="LightRAG API", description="API for RAG operations") 17 | 18 | # Configure working directory 19 | WORKING_DIR = os.environ.get("RAG_DIR", f"{DEFAULT_RAG_DIR}") 20 | print(f"WORKING_DIR: {WORKING_DIR}") 21 | LLM_MODEL = os.environ.get("LLM_MODEL", "gpt-4o-mini") 22 | print(f"LLM_MODEL: {LLM_MODEL}") 23 | EMBEDDING_MODEL = os.environ.get("EMBEDDING_MODEL", "text-embedding-3-large") 24 | print(f"EMBEDDING_MODEL: {EMBEDDING_MODEL}") 25 | EMBEDDING_MAX_TOKEN_SIZE = int(os.environ.get("EMBEDDING_MAX_TOKEN_SIZE", 8192)) 26 | print(f"EMBEDDING_MAX_TOKEN_SIZE: {EMBEDDING_MAX_TOKEN_SIZE}") 27 | 28 | if not os.path.exists(WORKING_DIR): 29 | os.mkdir(WORKING_DIR) 30 | 31 | 32 | # LLM model function 33 | 34 | 35 | async def llm_model_func( 36 | prompt, system_prompt=None, history_messages=[], **kwargs 37 | ) -> str: 38 | return await openai_complete_if_cache( 39 | LLM_MODEL, 40 | prompt, 41 | system_prompt=system_prompt, 42 | history_messages=history_messages, 43 | **kwargs, 44 | ) 45 | 46 | 47 | # Embedding function 48 | 49 | 50 | async def embedding_func(texts: list[str]) -> np.ndarray: 51 | return await openai_embedding( 52 | texts, 53 | model=EMBEDDING_MODEL, 54 | ) 55 | 56 | 57 | async def get_embedding_dim(): 58 | test_text = ["This is a test sentence."] 59 | embedding = await embedding_func(test_text) 60 | embedding_dim = embedding.shape[1] 61 | print(f"{embedding_dim=}") 62 | return embedding_dim 63 | 64 | 65 | # Initialize RAG instance 66 | rag = LightRAG( 67 | working_dir=WORKING_DIR, 68 | llm_model_func=llm_model_func, 69 | embedding_func=EmbeddingFunc(embedding_dim=asyncio.run(get_embedding_dim()), 70 | max_token_size=EMBEDDING_MAX_TOKEN_SIZE, 71 | func=embedding_func), 72 | ) 73 | 74 | 75 | # Data models 76 | 77 | 78 | class QueryRequest(BaseModel): 79 | query: str 80 | mode: str = "hybrid" 81 | only_need_context: bool = False 82 | 83 | 84 | class InsertRequest(BaseModel): 85 | text: str 86 | 87 | 88 | class Response(BaseModel): 89 | status: str 90 | data: Optional[str] = None 91 | message: Optional[str] = None 92 | 93 | 94 | # API routes 95 | 96 | 97 | @app.post("/query", response_model=Response) 98 | async def query_endpoint(request: QueryRequest): 99 | try: 100 | loop = asyncio.get_event_loop() 101 | result = await loop.run_in_executor( 102 | None, lambda: rag.query(request.query, 103 | param=QueryParam(mode=request.mode, only_need_context=request.only_need_context)) 104 | ) 105 | return Response(status="success", data=result) 106 | except Exception as e: 107 | raise HTTPException(status_code=500, detail=str(e)) 108 | 109 | 110 | @app.post("/insert", response_model=Response) 111 | async def insert_endpoint(request: InsertRequest): 112 | try: 113 | loop = asyncio.get_event_loop() 114 | await loop.run_in_executor(None, lambda: rag.insert(request.text)) 115 | return Response(status="success", message="Text inserted successfully") 116 | except Exception as e: 117 | raise HTTPException(status_code=500, detail=str(e)) 118 | 119 | 120 | @app.post("/insert_file", response_model=Response) 121 | async def insert_file(file: UploadFile = File(...)): 122 | try: 123 | file_content = await file.read() 124 | # Read file content 125 | try: 126 | content = file_content.decode("utf-8") 127 | except UnicodeDecodeError: 128 | # If UTF-8 decoding fails, try other encodings 129 | content = file_content.decode("gbk") 130 | # Insert file content 131 | loop = asyncio.get_event_loop() 132 | await loop.run_in_executor(None, lambda: rag.insert(content)) 133 | 134 | return Response( 135 | status="success", 136 | message=f"File content from {file.filename} inserted successfully", 137 | ) 138 | except Exception as e: 139 | raise HTTPException(status_code=500, detail=str(e)) 140 | 141 | 142 | @app.get("/health") 143 | async def health_check(): 144 | return {"status": "healthy"} 145 | 146 | 147 | if __name__ == "__main__": 148 | import uvicorn 149 | 150 | uvicorn.run(app, host="0.0.0.0", port=8020) 151 | 152 | # Usage example 153 | # To run the server, use the following command in your terminal: 154 | # python lightrag_api_openai_compatible_demo.py 155 | 156 | # Example requests: 157 | # 1. Query: 158 | # curl -X POST "http://127.0.0.1:8020/query" -H "Content-Type: application/json" -d '{"query": "your query here", "mode": "hybrid"}' 159 | 160 | # 2. Insert text: 161 | # curl -X POST "http://127.0.0.1:8020/insert" -H "Content-Type: application/json" -d '{"text": "your text here"}' 162 | 163 | # 3. Insert file: 164 | # curl -X POST "http://127.0.0.1:8020/insert_file" -H "Content-Type: application/json" -d '{"file_path": "path/to/your/file.txt"}' 165 | 166 | # 4. Health check: 167 | # curl -X GET "http://127.0.0.1:8020/health" 168 | -------------------------------------------------------------------------------- /lightrag/storage.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import html 3 | import os 4 | from dataclasses import dataclass 5 | from typing import Any, Union, cast 6 | import networkx as nx 7 | import numpy as np 8 | from nano_vectordb import NanoVectorDB 9 | 10 | from .utils import load_json, logger, write_json 11 | from .base import ( 12 | BaseGraphStorage, 13 | BaseKVStorage, 14 | BaseVectorStorage, 15 | ) 16 | 17 | 18 | @dataclass 19 | class JsonKVStorage(BaseKVStorage): 20 | def __post_init__(self): 21 | working_dir = self.global_config["working_dir"] 22 | self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json") 23 | self._data = load_json(self._file_name) or {} 24 | logger.info(f"Load KV {self.namespace} with {len(self._data)} data") 25 | 26 | async def all_keys(self) -> list[str]: 27 | return list(self._data.keys()) 28 | 29 | async def index_done_callback(self): 30 | write_json(self._data, self._file_name) 31 | 32 | async def get_by_id(self, id): 33 | return self._data.get(id, None) 34 | 35 | async def get_by_ids(self, ids, fields=None): 36 | if fields is None: 37 | return [self._data.get(id, None) for id in ids] 38 | return [ 39 | ( 40 | {k: v for k, v in self._data[id].items() if k in fields} 41 | if self._data.get(id, None) 42 | else None 43 | ) 44 | for id in ids 45 | ] 46 | 47 | async def filter_keys(self, data: list[str]) -> set[str]: 48 | return set([s for s in data if s not in self._data]) 49 | 50 | async def upsert(self, data: dict[str, dict]): 51 | left_data = {k: v for k, v in data.items() if k not in self._data} 52 | self._data.update(left_data) 53 | return left_data 54 | 55 | async def drop(self): 56 | self._data = {} 57 | 58 | 59 | @dataclass 60 | class NanoVectorDBStorage(BaseVectorStorage): 61 | cosine_better_than_threshold: float = 0.2 62 | 63 | def __post_init__(self): 64 | self._client_file_name = os.path.join( 65 | self.global_config["working_dir"], f"vdb_{self.namespace}.json" 66 | ) 67 | self._max_batch_size = self.global_config["embedding_batch_num"] 68 | self._client = NanoVectorDB( 69 | self.embedding_func.embedding_dim, storage_file=self._client_file_name 70 | ) 71 | self.cosine_better_than_threshold = self.global_config.get( 72 | "cosine_better_than_threshold", self.cosine_better_than_threshold 73 | ) 74 | 75 | async def upsert(self, data: dict[str, dict]): 76 | logger.info(f"Inserting {len(data)} vectors to {self.namespace}") 77 | if not len(data): 78 | logger.warning("You insert an empty data to vector DB") 79 | return [] 80 | list_data = [ 81 | { 82 | "__id__": k, 83 | **{k1: v1 for k1, v1 in v.items() if k1 in self.meta_fields}, 84 | } 85 | for k, v in data.items() 86 | ] 87 | contents = [v["content"] for v in data.values()] 88 | batches = [ 89 | contents[i : i + self._max_batch_size] 90 | for i in range(0, len(contents), self._max_batch_size) 91 | ] 92 | embeddings_list = await asyncio.gather( 93 | *[self.embedding_func(batch) for batch in batches] 94 | ) 95 | embeddings = np.concatenate(embeddings_list) 96 | for i, d in enumerate(list_data): 97 | d["__vector__"] = embeddings[i] 98 | results = self._client.upsert(datas=list_data) 99 | return results 100 | 101 | async def query(self, query: str, top_k=5): 102 | embedding = await self.embedding_func([query]) 103 | embedding = embedding[0] 104 | results = self._client.query( 105 | query=embedding, 106 | top_k=top_k, 107 | better_than_threshold=self.cosine_better_than_threshold, 108 | ) 109 | results = [ 110 | {**dp, "id": dp["__id__"], "distance": dp["__metrics__"]} for dp in results 111 | ] 112 | return results 113 | 114 | async def index_done_callback(self): 115 | self._client.save() 116 | 117 | 118 | @dataclass 119 | class NetworkXStorage(BaseGraphStorage): 120 | @staticmethod 121 | def load_nx_graph(file_name) -> nx.Graph: 122 | if os.path.exists(file_name): 123 | return nx.read_graphml(file_name) 124 | return None 125 | 126 | @staticmethod 127 | def write_nx_graph(graph: nx.Graph, file_name): 128 | logger.info( 129 | f"Writing graph with {graph.number_of_nodes()} nodes, {graph.number_of_edges()} edges" 130 | ) 131 | nx.write_graphml(graph, file_name) 132 | 133 | @staticmethod 134 | def stable_largest_connected_component(graph: nx.Graph) -> nx.Graph: 135 | """Refer to https://github.com/microsoft/graphrag/index/graph/utils/stable_lcc.py 136 | Return the largest connected component of the graph, with nodes and edges sorted in a stable way. 137 | """ 138 | from graspologic.utils import largest_connected_component 139 | 140 | graph = graph.copy() 141 | graph = cast(nx.Graph, largest_connected_component(graph)) 142 | node_mapping = { 143 | node: html.unescape(node.upper().strip()) for node in graph.nodes() 144 | } # type: ignore 145 | graph = nx.relabel_nodes(graph, node_mapping) 146 | return NetworkXStorage._stabilize_graph(graph) 147 | 148 | @staticmethod 149 | def _stabilize_graph(graph: nx.Graph) -> nx.Graph: 150 | """Refer to https://github.com/microsoft/graphrag/index/graph/utils/stable_lcc.py 151 | Ensure an undirected graph with the same relationships will always be read the same way. 152 | """ 153 | fixed_graph = nx.DiGraph() if graph.is_directed() else nx.Graph() 154 | 155 | sorted_nodes = graph.nodes(data=True) 156 | sorted_nodes = sorted(sorted_nodes, key=lambda x: x[0]) 157 | 158 | fixed_graph.add_nodes_from(sorted_nodes) 159 | edges = list(graph.edges(data=True)) 160 | 161 | if not graph.is_directed(): 162 | 163 | def _sort_source_target(edge): 164 | source, target, edge_data = edge 165 | if source > target: 166 | temp = source 167 | source = target 168 | target = temp 169 | return source, target, edge_data 170 | 171 | edges = [_sort_source_target(edge) for edge in edges] 172 | 173 | def _get_edge_key(source: Any, target: Any) -> str: 174 | return f"{source} -> {target}" 175 | 176 | edges = sorted(edges, key=lambda x: _get_edge_key(x[0], x[1])) 177 | 178 | fixed_graph.add_edges_from(edges) 179 | return fixed_graph 180 | 181 | def __post_init__(self): 182 | self._graphml_xml_file = os.path.join( 183 | self.global_config["working_dir"], f"graph_{self.namespace}.graphml" 184 | ) 185 | preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file) 186 | if preloaded_graph is not None: 187 | logger.info( 188 | f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges" 189 | ) 190 | self._graph = preloaded_graph or nx.Graph() 191 | self._node_embed_algorithms = { 192 | "node2vec": self._node2vec_embed, 193 | } 194 | 195 | async def index_done_callback(self): 196 | NetworkXStorage.write_nx_graph(self._graph, self._graphml_xml_file) 197 | 198 | async def has_node(self, node_id: str) -> bool: 199 | return self._graph.has_node(node_id) 200 | 201 | async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: 202 | return self._graph.has_edge(source_node_id, target_node_id) 203 | 204 | async def get_node(self, node_id: str) -> Union[dict, None]: 205 | return self._graph.nodes.get(node_id) 206 | 207 | async def node_degree(self, node_id: str) -> int: 208 | return self._graph.degree(node_id) 209 | 210 | async def edge_degree(self, src_id: str, tgt_id: str) -> int: 211 | return self._graph.degree(src_id) + self._graph.degree(tgt_id) 212 | 213 | async def get_edge( 214 | self, source_node_id: str, target_node_id: str 215 | ) -> Union[dict, None]: 216 | return self._graph.edges.get((source_node_id, target_node_id)) 217 | 218 | async def get_node_edges(self, source_node_id: str): 219 | if self._graph.has_node(source_node_id): 220 | return list(self._graph.edges(source_node_id)) 221 | return None 222 | 223 | async def upsert_node(self, node_id: str, node_data: dict[str, str]): 224 | self._graph.add_node(node_id, **node_data) 225 | 226 | async def upsert_edge( 227 | self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] 228 | ): 229 | self._graph.add_edge(source_node_id, target_node_id, **edge_data) 230 | 231 | async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]: 232 | if algorithm not in self._node_embed_algorithms: 233 | raise ValueError(f"Node embedding algorithm {algorithm} not supported") 234 | return await self._node_embed_algorithms[algorithm]() 235 | 236 | # @TODO: NOT USED 237 | async def _node2vec_embed(self): 238 | from graspologic import embed 239 | 240 | embeddings, nodes = embed.node2vec_embed( 241 | self._graph, 242 | **self.global_config["node2vec_params"], 243 | ) 244 | 245 | nodes_ids = [self._graph.nodes[node_id]["id"] for node_id in nodes] 246 | return embeddings, nodes_ids 247 | -------------------------------------------------------------------------------- /lightrag/utils.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import html 3 | import io 4 | import csv 5 | import json 6 | import logging 7 | import os 8 | import re 9 | from dataclasses import dataclass 10 | from functools import wraps 11 | from hashlib import md5 12 | from typing import Any, Union, List 13 | import xml.etree.ElementTree as ET 14 | 15 | import numpy as np 16 | import tiktoken 17 | 18 | ENCODER = None 19 | 20 | logger = logging.getLogger("lightrag") 21 | 22 | 23 | def set_logger(log_file: str): 24 | logger.setLevel(logging.DEBUG) 25 | 26 | file_handler = logging.FileHandler(log_file) 27 | file_handler.setLevel(logging.DEBUG) 28 | 29 | formatter = logging.Formatter( 30 | "%(asctime)s - %(name)s - %(levelname)s - %(message)s" 31 | ) 32 | file_handler.setFormatter(formatter) 33 | 34 | if not logger.handlers: 35 | logger.addHandler(file_handler) 36 | 37 | 38 | @dataclass 39 | class EmbeddingFunc: 40 | embedding_dim: int 41 | max_token_size: int 42 | func: callable 43 | 44 | async def __call__(self, *args, **kwargs) -> np.ndarray: 45 | return await self.func(*args, **kwargs) 46 | 47 | 48 | def locate_json_string_body_from_string(content: str) -> Union[str, None]: 49 | """Locate the JSON string body from a string""" 50 | maybe_json_str = re.search(r"{.*}", content, re.DOTALL) 51 | if maybe_json_str is not None: 52 | return maybe_json_str.group(0) 53 | else: 54 | return None 55 | 56 | 57 | def convert_response_to_json(response: str) -> dict: 58 | json_str = locate_json_string_body_from_string(response) 59 | assert json_str is not None, f"Unable to parse JSON from response: {response}" 60 | try: 61 | data = json.loads(json_str) 62 | return data 63 | except json.JSONDecodeError as e: 64 | logger.error(f"Failed to parse JSON: {json_str}") 65 | raise e from None 66 | 67 | 68 | def compute_args_hash(*args): 69 | return md5(str(args).encode()).hexdigest() 70 | 71 | 72 | def compute_mdhash_id(content, prefix: str = ""): 73 | return prefix + md5(content.encode()).hexdigest() 74 | 75 | 76 | def limit_async_func_call(max_size: int, waitting_time: float = 0.0001): 77 | """Add restriction of maximum async calling times for a async func""" 78 | 79 | def final_decro(func): 80 | """Not using async.Semaphore to aovid use nest-asyncio""" 81 | __current_size = 0 82 | 83 | @wraps(func) 84 | async def wait_func(*args, **kwargs): 85 | nonlocal __current_size 86 | while __current_size >= max_size: 87 | await asyncio.sleep(waitting_time) 88 | __current_size += 1 89 | result = await func(*args, **kwargs) 90 | __current_size -= 1 91 | return result 92 | 93 | return wait_func 94 | 95 | return final_decro 96 | 97 | 98 | def wrap_embedding_func_with_attrs(**kwargs): 99 | """Wrap a function with attributes""" 100 | 101 | def final_decro(func) -> EmbeddingFunc: 102 | new_func = EmbeddingFunc(**kwargs, func=func) 103 | return new_func 104 | 105 | return final_decro 106 | 107 | 108 | def load_json(file_name): 109 | if not os.path.exists(file_name): 110 | return None 111 | with open(file_name, encoding="utf-8") as f: 112 | return json.load(f) 113 | 114 | 115 | def write_json(json_obj, file_name): 116 | with open(file_name, "w", encoding="utf-8") as f: 117 | json.dump(json_obj, f, indent=2, ensure_ascii=False) 118 | 119 | 120 | def encode_string_by_tiktoken(content: str, model_name: str = "gpt-4o"): 121 | global ENCODER 122 | if ENCODER is None: 123 | ENCODER = tiktoken.encoding_for_model(model_name) 124 | tokens = ENCODER.encode(content) 125 | return tokens 126 | 127 | 128 | def decode_tokens_by_tiktoken(tokens: list[int], model_name: str = "gpt-4o"): 129 | global ENCODER 130 | if ENCODER is None: 131 | ENCODER = tiktoken.encoding_for_model(model_name) 132 | content = ENCODER.decode(tokens) 133 | return content 134 | 135 | 136 | def pack_user_ass_to_openai_messages(*args: str): 137 | roles = ["user", "assistant"] 138 | return [ 139 | {"role": roles[i % 2], "content": content} for i, content in enumerate(args) 140 | ] 141 | 142 | 143 | def split_string_by_multi_markers(content: str, markers: list[str]) -> list[str]: 144 | """Split a string by multiple markers""" 145 | if not markers: 146 | return [content] 147 | results = re.split("|".join(re.escape(marker) for marker in markers), content) 148 | return [r.strip() for r in results if r.strip()] 149 | 150 | 151 | # Refer the utils functions of the official GraphRAG implementation: 152 | # https://github.com/microsoft/graphrag 153 | def clean_str(input: Any) -> str: 154 | """Clean an input string by removing HTML escapes, control characters, and other unwanted characters.""" 155 | # If we get non-string input, just give it back 156 | if not isinstance(input, str): 157 | return input 158 | 159 | result = html.unescape(input.strip()) 160 | # https://stackoverflow.com/questions/4324790/removing-control-characters-from-a-string-in-python 161 | return re.sub(r"[\x00-\x1f\x7f-\x9f]", "", result) 162 | 163 | 164 | def is_float_regex(value): 165 | return bool(re.match(r"^[-+]?[0-9]*\.?[0-9]+$", value)) 166 | 167 | 168 | def truncate_list_by_token_size(list_data: list, key: callable, max_token_size: int): 169 | """Truncate a list of data by token size""" 170 | if max_token_size <= 0: 171 | return [] 172 | tokens = 0 173 | for i, data in enumerate(list_data): 174 | tokens += len(encode_string_by_tiktoken(key(data))) 175 | if tokens > max_token_size: 176 | return list_data[:i] 177 | return list_data 178 | 179 | 180 | def list_of_list_to_csv(data: List[List[str]]) -> str: 181 | output = io.StringIO() 182 | writer = csv.writer(output) 183 | writer.writerows(data) 184 | return output.getvalue() 185 | 186 | 187 | def csv_string_to_list(csv_string: str) -> List[List[str]]: 188 | output = io.StringIO(csv_string) 189 | reader = csv.reader(output) 190 | return [row for row in reader] 191 | 192 | 193 | def save_data_to_file(data, file_name): 194 | with open(file_name, "w", encoding="utf-8") as f: 195 | json.dump(data, f, ensure_ascii=False, indent=4) 196 | 197 | 198 | def xml_to_json(xml_file): 199 | try: 200 | tree = ET.parse(xml_file) 201 | root = tree.getroot() 202 | 203 | # Print the root element's tag and attributes to confirm the file has been correctly loaded 204 | print(f"Root element: {root.tag}") 205 | print(f"Root attributes: {root.attrib}") 206 | 207 | data = {"nodes": [], "edges": []} 208 | 209 | # Use namespace 210 | namespace = {"": "http://graphml.graphdrawing.org/xmlns"} 211 | 212 | for node in root.findall(".//node", namespace): 213 | node_data = { 214 | "id": node.get("id").strip('"'), 215 | "entity_type": node.find("./data[@key='d0']", namespace).text.strip('"') 216 | if node.find("./data[@key='d0']", namespace) is not None 217 | else "", 218 | "description": node.find("./data[@key='d1']", namespace).text 219 | if node.find("./data[@key='d1']", namespace) is not None 220 | else "", 221 | "source_id": node.find("./data[@key='d2']", namespace).text 222 | if node.find("./data[@key='d2']", namespace) is not None 223 | else "", 224 | } 225 | data["nodes"].append(node_data) 226 | 227 | for edge in root.findall(".//edge", namespace): 228 | edge_data = { 229 | "source": edge.get("source").strip('"'), 230 | "target": edge.get("target").strip('"'), 231 | "weight": float(edge.find("./data[@key='d3']", namespace).text) 232 | if edge.find("./data[@key='d3']", namespace) is not None 233 | else 0.0, 234 | "description": edge.find("./data[@key='d4']", namespace).text 235 | if edge.find("./data[@key='d4']", namespace) is not None 236 | else "", 237 | "keywords": edge.find("./data[@key='d5']", namespace).text 238 | if edge.find("./data[@key='d5']", namespace) is not None 239 | else "", 240 | "source_id": edge.find("./data[@key='d6']", namespace).text 241 | if edge.find("./data[@key='d6']", namespace) is not None 242 | else "", 243 | } 244 | data["edges"].append(edge_data) 245 | 246 | # Print the number of nodes and edges found 247 | print(f"Found {len(data['nodes'])} nodes and {len(data['edges'])} edges") 248 | 249 | return data 250 | except ET.ParseError as e: 251 | print(f"Error parsing XML file: {e}") 252 | return None 253 | except Exception as e: 254 | print(f"An error occurred: {e}") 255 | return None 256 | 257 | 258 | def process_combine_contexts(hl, ll): 259 | header = None 260 | list_hl = csv_string_to_list(hl.strip()) 261 | list_ll = csv_string_to_list(ll.strip()) 262 | 263 | if list_hl: 264 | header = list_hl[0] 265 | list_hl = list_hl[1:] 266 | if list_ll: 267 | header = list_ll[0] 268 | list_ll = list_ll[1:] 269 | if header is None: 270 | return "" 271 | 272 | if list_hl: 273 | list_hl = [",".join(item[1:]) for item in list_hl if item] 274 | if list_ll: 275 | list_ll = [",".join(item[1:]) for item in list_ll if item] 276 | 277 | combined_sources_set = set(filter(None, list_hl + list_ll)) 278 | 279 | combined_sources = [",\t".join(header)] 280 | 281 | for i, item in enumerate(combined_sources_set, start=1): 282 | combined_sources.append(f"{i},\t{item}") 283 | 284 | combined_sources = "\n".join(combined_sources) 285 | 286 | return combined_sources 287 | -------------------------------------------------------------------------------- /lightrag/kg/neo4j_impl.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import os 3 | from dataclasses import dataclass 4 | from typing import Any, Union, Tuple, List, Dict 5 | import inspect 6 | from lightrag.utils import logger 7 | from ..base import BaseGraphStorage 8 | from neo4j import ( 9 | AsyncGraphDatabase, 10 | exceptions as neo4jExceptions, 11 | AsyncDriver, 12 | AsyncManagedTransaction, 13 | ) 14 | 15 | 16 | from tenacity import ( 17 | retry, 18 | stop_after_attempt, 19 | wait_exponential, 20 | retry_if_exception_type, 21 | ) 22 | 23 | 24 | @dataclass 25 | class Neo4JStorage(BaseGraphStorage): 26 | @staticmethod 27 | def load_nx_graph(file_name): 28 | print("no preloading of graph with neo4j in production") 29 | 30 | def __init__(self, namespace, global_config): 31 | super().__init__(namespace=namespace, global_config=global_config) 32 | self._driver = None 33 | self._driver_lock = asyncio.Lock() 34 | URI = os.environ["NEO4J_URI"] 35 | USERNAME = os.environ["NEO4J_USERNAME"] 36 | PASSWORD = os.environ["NEO4J_PASSWORD"] 37 | self._driver: AsyncDriver = AsyncGraphDatabase.driver( 38 | URI, auth=(USERNAME, PASSWORD) 39 | ) 40 | return None 41 | 42 | def __post_init__(self): 43 | self._node_embed_algorithms = { 44 | "node2vec": self._node2vec_embed, 45 | } 46 | 47 | async def close(self): 48 | if self._driver: 49 | await self._driver.close() 50 | self._driver = None 51 | 52 | async def __aexit__(self, exc_type, exc, tb): 53 | if self._driver: 54 | await self._driver.close() 55 | 56 | async def index_done_callback(self): 57 | print("KG successfully indexed.") 58 | 59 | async def has_node(self, node_id: str) -> bool: 60 | entity_name_label = node_id.strip('"') 61 | 62 | async with self._driver.session() as session: 63 | query = ( 64 | f"MATCH (n:`{entity_name_label}`) RETURN count(n) > 0 AS node_exists" 65 | ) 66 | result = await session.run(query) 67 | single_result = await result.single() 68 | logger.debug( 69 | f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{single_result["node_exists"]}' 70 | ) 71 | return single_result["node_exists"] 72 | 73 | async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: 74 | entity_name_label_source = source_node_id.strip('"') 75 | entity_name_label_target = target_node_id.strip('"') 76 | 77 | async with self._driver.session() as session: 78 | query = ( 79 | f"MATCH (a:`{entity_name_label_source}`)-[r]-(b:`{entity_name_label_target}`) " 80 | "RETURN COUNT(r) > 0 AS edgeExists" 81 | ) 82 | result = await session.run(query) 83 | single_result = await result.single() 84 | logger.debug( 85 | f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{single_result["edgeExists"]}' 86 | ) 87 | return single_result["edgeExists"] 88 | 89 | def close(self): 90 | self._driver.close() 91 | 92 | async def get_node(self, node_id: str) -> Union[dict, None]: 93 | async with self._driver.session() as session: 94 | entity_name_label = node_id.strip('"') 95 | query = f"MATCH (n:`{entity_name_label}`) RETURN n" 96 | result = await session.run(query) 97 | record = await result.single() 98 | if record: 99 | node = record["n"] 100 | node_dict = dict(node) 101 | logger.debug( 102 | f"{inspect.currentframe().f_code.co_name}: query: {query}, result: {node_dict}" 103 | ) 104 | return node_dict 105 | return None 106 | 107 | async def node_degree(self, node_id: str) -> int: 108 | entity_name_label = node_id.strip('"') 109 | 110 | async with self._driver.session() as session: 111 | query = f""" 112 | MATCH (n:`{entity_name_label}`) 113 | RETURN COUNT{{ (n)--() }} AS totalEdgeCount 114 | """ 115 | result = await session.run(query) 116 | record = await result.single() 117 | if record: 118 | edge_count = record["totalEdgeCount"] 119 | logger.debug( 120 | f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{edge_count}" 121 | ) 122 | return edge_count 123 | else: 124 | return None 125 | 126 | async def edge_degree(self, src_id: str, tgt_id: str) -> int: 127 | entity_name_label_source = src_id.strip('"') 128 | entity_name_label_target = tgt_id.strip('"') 129 | src_degree = await self.node_degree(entity_name_label_source) 130 | trg_degree = await self.node_degree(entity_name_label_target) 131 | 132 | # Convert None to 0 for addition 133 | src_degree = 0 if src_degree is None else src_degree 134 | trg_degree = 0 if trg_degree is None else trg_degree 135 | 136 | degrees = int(src_degree) + int(trg_degree) 137 | logger.debug( 138 | f"{inspect.currentframe().f_code.co_name}:query:src_Degree+trg_degree:result:{degrees}" 139 | ) 140 | return degrees 141 | 142 | async def get_edge( 143 | self, source_node_id: str, target_node_id: str 144 | ) -> Union[dict, None]: 145 | entity_name_label_source = source_node_id.strip('"') 146 | entity_name_label_target = target_node_id.strip('"') 147 | """ 148 | Find all edges between nodes of two given labels 149 | 150 | Args: 151 | source_node_label (str): Label of the source nodes 152 | target_node_label (str): Label of the target nodes 153 | 154 | Returns: 155 | list: List of all relationships/edges found 156 | """ 157 | async with self._driver.session() as session: 158 | query = f""" 159 | MATCH (start:`{entity_name_label_source}`)-[r]->(end:`{entity_name_label_target}`) 160 | RETURN properties(r) as edge_properties 161 | LIMIT 1 162 | """.format( 163 | entity_name_label_source=entity_name_label_source, 164 | entity_name_label_target=entity_name_label_target, 165 | ) 166 | 167 | result = await session.run(query) 168 | record = await result.single() 169 | if record: 170 | result = dict(record["edge_properties"]) 171 | logger.debug( 172 | f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{result}" 173 | ) 174 | return result 175 | else: 176 | return None 177 | 178 | async def get_node_edges(self, source_node_id: str) -> List[Tuple[str, str]]: 179 | node_label = source_node_id.strip('"') 180 | 181 | """ 182 | Retrieves all edges (relationships) for a particular node identified by its label. 183 | :return: List of dictionaries containing edge information 184 | """ 185 | query = f"""MATCH (n:`{node_label}`) 186 | OPTIONAL MATCH (n)-[r]-(connected) 187 | RETURN n, r, connected""" 188 | async with self._driver.session() as session: 189 | results = await session.run(query) 190 | edges = [] 191 | async for record in results: 192 | source_node = record["n"] 193 | connected_node = record["connected"] 194 | 195 | source_label = ( 196 | list(source_node.labels)[0] if source_node.labels else None 197 | ) 198 | target_label = ( 199 | list(connected_node.labels)[0] 200 | if connected_node and connected_node.labels 201 | else None 202 | ) 203 | 204 | if source_label and target_label: 205 | edges.append((source_label, target_label)) 206 | 207 | return edges 208 | 209 | @retry( 210 | stop=stop_after_attempt(3), 211 | wait=wait_exponential(multiplier=1, min=4, max=10), 212 | retry=retry_if_exception_type( 213 | ( 214 | neo4jExceptions.ServiceUnavailable, 215 | neo4jExceptions.TransientError, 216 | neo4jExceptions.WriteServiceUnavailable, 217 | ) 218 | ), 219 | ) 220 | async def upsert_node(self, node_id: str, node_data: Dict[str, Any]): 221 | """ 222 | Upsert a node in the Neo4j database. 223 | 224 | Args: 225 | node_id: The unique identifier for the node (used as label) 226 | node_data: Dictionary of node properties 227 | """ 228 | label = node_id.strip('"') 229 | properties = node_data 230 | 231 | async def _do_upsert(tx: AsyncManagedTransaction): 232 | query = f""" 233 | MERGE (n:`{label}`) 234 | SET n += $properties 235 | """ 236 | await tx.run(query, properties=properties) 237 | logger.debug( 238 | f"Upserted node with label '{label}' and properties: {properties}" 239 | ) 240 | 241 | try: 242 | async with self._driver.session() as session: 243 | await session.execute_write(_do_upsert) 244 | except Exception as e: 245 | logger.error(f"Error during upsert: {str(e)}") 246 | raise 247 | 248 | @retry( 249 | stop=stop_after_attempt(3), 250 | wait=wait_exponential(multiplier=1, min=4, max=10), 251 | retry=retry_if_exception_type( 252 | ( 253 | neo4jExceptions.ServiceUnavailable, 254 | neo4jExceptions.TransientError, 255 | neo4jExceptions.WriteServiceUnavailable, 256 | ) 257 | ), 258 | ) 259 | async def upsert_edge( 260 | self, source_node_id: str, target_node_id: str, edge_data: Dict[str, Any] 261 | ): 262 | """ 263 | Upsert an edge and its properties between two nodes identified by their labels. 264 | 265 | Args: 266 | source_node_id (str): Label of the source node (used as identifier) 267 | target_node_id (str): Label of the target node (used as identifier) 268 | edge_data (dict): Dictionary of properties to set on the edge 269 | """ 270 | source_node_label = source_node_id.strip('"') 271 | target_node_label = target_node_id.strip('"') 272 | edge_properties = edge_data 273 | 274 | async def _do_upsert_edge(tx: AsyncManagedTransaction): 275 | query = f""" 276 | MATCH (source:`{source_node_label}`) 277 | WITH source 278 | MATCH (target:`{target_node_label}`) 279 | MERGE (source)-[r:DIRECTED]->(target) 280 | SET r += $properties 281 | RETURN r 282 | """ 283 | await tx.run(query, properties=edge_properties) 284 | logger.debug( 285 | f"Upserted edge from '{source_node_label}' to '{target_node_label}' with properties: {edge_properties}" 286 | ) 287 | 288 | try: 289 | async with self._driver.session() as session: 290 | await session.execute_write(_do_upsert_edge) 291 | except Exception as e: 292 | logger.error(f"Error during edge upsert: {str(e)}") 293 | raise 294 | 295 | async def _node2vec_embed(self): 296 | print("Implemented but never called.") 297 | -------------------------------------------------------------------------------- /lightrag/lightrag.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import os 3 | from dataclasses import asdict, dataclass, field 4 | from datetime import datetime 5 | from functools import partial 6 | from typing import Type, cast 7 | 8 | from .llm import ( 9 | gpt_4o_mini_complete, 10 | openai_embedding, 11 | ) 12 | from .operate import ( 13 | chunking_by_token_size, 14 | extract_entities, 15 | local_query, 16 | global_query, 17 | hybrid_query, 18 | naive_query, 19 | ) 20 | 21 | from .storage import ( 22 | JsonKVStorage, 23 | NanoVectorDBStorage, 24 | NetworkXStorage, 25 | ) 26 | 27 | from .kg.neo4j_impl import Neo4JStorage 28 | # future KG integrations 29 | 30 | # from .kg.ArangoDB_impl import ( 31 | # GraphStorage as ArangoDBStorage 32 | # ) 33 | 34 | 35 | from .utils import ( 36 | EmbeddingFunc, 37 | compute_mdhash_id, 38 | limit_async_func_call, 39 | convert_response_to_json, 40 | logger, 41 | set_logger, 42 | ) 43 | from .base import ( 44 | BaseGraphStorage, 45 | BaseKVStorage, 46 | BaseVectorStorage, 47 | StorageNameSpace, 48 | QueryParam, 49 | ) 50 | 51 | 52 | def always_get_an_event_loop() -> asyncio.AbstractEventLoop: 53 | try: 54 | return asyncio.get_event_loop() 55 | 56 | except RuntimeError: 57 | logger.info("Creating a new event loop in main thread.") 58 | loop = asyncio.new_event_loop() 59 | asyncio.set_event_loop(loop) 60 | 61 | return loop 62 | 63 | 64 | 65 | @dataclass 66 | class LightRAG: 67 | working_dir: str = field( 68 | default_factory=lambda: f"./lightrag_cache_{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}" 69 | ) 70 | 71 | kg: str = field(default="NetworkXStorage") 72 | 73 | current_log_level = logger.level 74 | log_level: str = field(default=current_log_level) 75 | 76 | # text chunking 77 | chunk_token_size: int = 1200 78 | chunk_overlap_token_size: int = 100 79 | tiktoken_model_name: str = "gpt-4o-mini" 80 | 81 | # entity extraction 82 | entity_extract_max_gleaning: int = 1 83 | entity_summary_to_max_tokens: int = 500 84 | 85 | # node embedding 86 | node_embedding_algorithm: str = "node2vec" 87 | node2vec_params: dict = field( 88 | default_factory=lambda: { 89 | "dimensions": 1536, 90 | "num_walks": 10, 91 | "walk_length": 40, 92 | "window_size": 2, 93 | "iterations": 3, 94 | "random_seed": 3, 95 | } 96 | ) 97 | 98 | # embedding_func: EmbeddingFunc = field(default_factory=lambda:hf_embedding) 99 | embedding_func: EmbeddingFunc = field(default_factory=lambda: openai_embedding) 100 | embedding_batch_num: int = 32 101 | embedding_func_max_async: int = 16 102 | 103 | # LLM 104 | llm_model_func: callable = gpt_4o_mini_complete # hf_model_complete# 105 | llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct" #'meta-llama/Llama-3.2-1B'#'google/gemma-2-2b-it' 106 | llm_model_max_token_size: int = 32768 107 | llm_model_max_async: int = 16 108 | llm_model_kwargs: dict = field(default_factory=dict) 109 | 110 | # storage 111 | key_string_value_json_storage_cls: Type[BaseKVStorage] = JsonKVStorage 112 | vector_db_storage_cls: Type[BaseVectorStorage] = NanoVectorDBStorage 113 | vector_db_storage_cls_kwargs: dict = field(default_factory=dict) 114 | enable_llm_cache: bool = True 115 | 116 | # extension 117 | addon_params: dict = field(default_factory=dict) 118 | convert_response_to_json_func: callable = convert_response_to_json 119 | 120 | def __post_init__(self): 121 | log_file = os.path.join(self.working_dir, "lightrag.log") 122 | set_logger(log_file) 123 | logger.setLevel(self.log_level) 124 | 125 | logger.info(f"Logger initialized for working directory: {self.working_dir}") 126 | 127 | _print_config = ",\n ".join([f"{k} = {v}" for k, v in asdict(self).items()]) 128 | logger.debug(f"LightRAG init with param:\n {_print_config}\n") 129 | 130 | # @TODO: should move all storage setup here to leverage initial start params attached to self. 131 | self.graph_storage_cls: Type[BaseGraphStorage] = self._get_storage_class()[ 132 | self.kg 133 | ] 134 | 135 | if not os.path.exists(self.working_dir): 136 | logger.info(f"Creating working directory {self.working_dir}") 137 | os.makedirs(self.working_dir) 138 | 139 | self.full_docs = self.key_string_value_json_storage_cls( 140 | namespace="full_docs", global_config=asdict(self) 141 | ) 142 | 143 | self.text_chunks = self.key_string_value_json_storage_cls( 144 | namespace="text_chunks", global_config=asdict(self) 145 | ) 146 | 147 | self.llm_response_cache = ( 148 | self.key_string_value_json_storage_cls( 149 | namespace="llm_response_cache", global_config=asdict(self) 150 | ) 151 | if self.enable_llm_cache 152 | else None 153 | ) 154 | self.chunk_entity_relation_graph = self.graph_storage_cls( 155 | namespace="chunk_entity_relation", global_config=asdict(self) 156 | ) 157 | 158 | self.embedding_func = limit_async_func_call(self.embedding_func_max_async)( 159 | self.embedding_func 160 | ) 161 | 162 | self.entities_vdb = self.vector_db_storage_cls( 163 | namespace="entities", 164 | global_config=asdict(self), 165 | embedding_func=self.embedding_func, 166 | meta_fields={"entity_name"}, 167 | ) 168 | self.relationships_vdb = self.vector_db_storage_cls( 169 | namespace="relationships", 170 | global_config=asdict(self), 171 | embedding_func=self.embedding_func, 172 | meta_fields={"src_id", "tgt_id"}, 173 | ) 174 | self.chunks_vdb = self.vector_db_storage_cls( 175 | namespace="chunks", 176 | global_config=asdict(self), 177 | embedding_func=self.embedding_func, 178 | ) 179 | 180 | self.llm_model_func = limit_async_func_call(self.llm_model_max_async)( 181 | partial( 182 | self.llm_model_func, 183 | hashing_kv=self.llm_response_cache, 184 | **self.llm_model_kwargs, 185 | ) 186 | ) 187 | 188 | def _get_storage_class(self) -> Type[BaseGraphStorage]: 189 | return { 190 | "Neo4JStorage": Neo4JStorage, 191 | "NetworkXStorage": NetworkXStorage, 192 | # "ArangoDBStorage": ArangoDBStorage 193 | } 194 | 195 | def insert(self, string_or_strings): 196 | loop = always_get_an_event_loop() 197 | return loop.run_until_complete(self.ainsert(string_or_strings)) 198 | 199 | async def ainsert(self, string_or_strings): 200 | try: 201 | if isinstance(string_or_strings, str): 202 | string_or_strings = [string_or_strings] 203 | 204 | new_docs = { 205 | compute_mdhash_id(c.strip(), prefix="doc-"): {"content": c.strip()} 206 | for c in string_or_strings 207 | } 208 | _add_doc_keys = await self.full_docs.filter_keys(list(new_docs.keys())) 209 | new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys} 210 | if not len(new_docs): 211 | logger.warning("All docs are already in the storage") 212 | return 213 | logger.info(f"[New Docs] inserting {len(new_docs)} docs") 214 | 215 | inserting_chunks = {} 216 | for doc_key, doc in new_docs.items(): 217 | chunks = { 218 | compute_mdhash_id(dp["content"], prefix="chunk-"): { 219 | **dp, 220 | "full_doc_id": doc_key, 221 | } 222 | for dp in chunking_by_token_size( 223 | doc["content"], 224 | overlap_token_size=self.chunk_overlap_token_size, 225 | max_token_size=self.chunk_token_size, 226 | tiktoken_model=self.tiktoken_model_name, 227 | ) 228 | } 229 | inserting_chunks.update(chunks) 230 | _add_chunk_keys = await self.text_chunks.filter_keys( 231 | list(inserting_chunks.keys()) 232 | ) 233 | inserting_chunks = { 234 | k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys 235 | } 236 | if not len(inserting_chunks): 237 | logger.warning("All chunks are already in the storage") 238 | return 239 | logger.info(f"[New Chunks] inserting {len(inserting_chunks)} chunks") 240 | 241 | await self.chunks_vdb.upsert(inserting_chunks) 242 | 243 | logger.info("[Entity Extraction]...") 244 | maybe_new_kg = await extract_entities( 245 | inserting_chunks, 246 | knowledge_graph_inst=self.chunk_entity_relation_graph, 247 | entity_vdb=self.entities_vdb, 248 | relationships_vdb=self.relationships_vdb, 249 | global_config=asdict(self), 250 | ) 251 | if maybe_new_kg is None: 252 | logger.warning("No new entities and relationships found") 253 | return 254 | self.chunk_entity_relation_graph = maybe_new_kg 255 | 256 | await self.full_docs.upsert(new_docs) 257 | await self.text_chunks.upsert(inserting_chunks) 258 | finally: 259 | await self._insert_done() 260 | 261 | async def _insert_done(self): 262 | tasks = [] 263 | for storage_inst in [ 264 | self.full_docs, 265 | self.text_chunks, 266 | self.llm_response_cache, 267 | self.entities_vdb, 268 | self.relationships_vdb, 269 | self.chunks_vdb, 270 | self.chunk_entity_relation_graph, 271 | ]: 272 | if storage_inst is None: 273 | continue 274 | tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback()) 275 | await asyncio.gather(*tasks) 276 | 277 | def query(self, query: str, param: QueryParam = QueryParam()): 278 | loop = always_get_an_event_loop() 279 | return loop.run_until_complete(self.aquery(query, param)) 280 | 281 | async def aquery(self, query: str, param: QueryParam = QueryParam()): 282 | if param.mode == "local": 283 | response = await local_query( 284 | query, 285 | self.chunk_entity_relation_graph, 286 | self.entities_vdb, 287 | self.relationships_vdb, 288 | self.text_chunks, 289 | param, 290 | asdict(self), 291 | ) 292 | elif param.mode == "global": 293 | response = await global_query( 294 | query, 295 | self.chunk_entity_relation_graph, 296 | self.entities_vdb, 297 | self.relationships_vdb, 298 | self.text_chunks, 299 | param, 300 | asdict(self), 301 | ) 302 | elif param.mode == "hybrid": 303 | response = await hybrid_query( 304 | query, 305 | self.chunk_entity_relation_graph, 306 | self.entities_vdb, 307 | self.relationships_vdb, 308 | self.text_chunks, 309 | param, 310 | asdict(self), 311 | ) 312 | elif param.mode == "naive": 313 | response = await naive_query( 314 | query, 315 | self.chunks_vdb, 316 | self.text_chunks, 317 | param, 318 | asdict(self), 319 | ) 320 | else: 321 | raise ValueError(f"Unknown mode {param.mode}") 322 | await self._query_done() 323 | return response 324 | 325 | async def _query_done(self): 326 | tasks = [] 327 | for storage_inst in [self.llm_response_cache]: 328 | if storage_inst is None: 329 | continue 330 | tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback()) 331 | await asyncio.gather(*tasks) 332 | -------------------------------------------------------------------------------- /notebooks/taipy-lightrag-app.py: -------------------------------------------------------------------------------- 1 | # Add to state initialization 2 | state.insert_text = "" # For paste dialog 3 | state.website_url = "" # For website dialog 4 | state.upload_content = None # For file upload 5 | state.download_format = "markdown" # For download options 6 | state.status_message = "" 7 | state.query_info = "" 8 | state.show_query_details = False 9 | 10 | # Insert Dialog Handlers 11 | def handle_paste_insert(state): 12 | """Handle text insertion from paste.""" 13 | if not state.insert_text: 14 | state.status_message = "Please enter some text to insert" 15 | return 16 | 17 | try: 18 | with get_event_loop_context() as loop: 19 | loop.run_until_complete(state.rag.ainsert(state.insert_text)) 20 | state.status_message = "Content inserted successfully!" 21 | state.insert_text = "" # Clear the input 22 | state.show_insert = False # Close dialog 23 | except Exception as e: 24 | logger.error(f"Error inserting content: {str(e)}") 25 | state.status_message = f"Error inserting content: {str(e)}" 26 | 27 | def handle_file_upload(state, file): 28 | """Handle file upload insertion.""" 29 | try: 30 | content = file.read() 31 | if isinstance(content, bytes): 32 | content = content.decode('utf-8') 33 | 34 | with get_event_loop_context() as loop: 35 | loop.run_until_complete(state.rag.ainsert(content)) 36 | state.status_message = "File inserted successfully!" 37 | state.show_insert = False 38 | except Exception as e: 39 | logger.error(f"Error inserting file: {str(e)}") 40 | state.status_message = f"Error inserting file: {str(e)}" 41 | 42 | def handle_website_insert(state): 43 | """Handle website content insertion.""" 44 | if not state.website_url: 45 | state.status_message = "Please enter a URL" 46 | return 47 | 48 | try: 49 | response = requests.get(state.website_url) 50 | response.raise_for_status() 51 | with get_event_loop_context() as loop: 52 | loop.run_until_complete(state.rag.ainsert(response.text)) 53 | state.status_message = "Website content inserted successfully!" 54 | state.website_url = "" 55 | state.show_insert = False 56 | except Exception as e: 57 | logger.error(f"Error inserting website content: {str(e)}") 58 | state.status_message = f"Error inserting website content: {str(e)}" 59 | 60 | def insert_test_book(state): 61 | """Insert test book content.""" 62 | try: 63 | with open("dickens/inbox/book.txt", "r", encoding="utf-8") as f: 64 | content = f.read() 65 | with get_event_loop_context() as loop: 66 | loop.run_until_complete(state.rag.ainsert(content)) 67 | state.status_message = "Test book inserted successfully!" 68 | state.show_insert = False 69 | except Exception as e: 70 | logger.error(f"Error inserting test book: {str(e)}") 71 | state.status_message = f"Error inserting test book: {str(e)}" 72 | 73 | def insert_test_paper(state): 74 | """Insert test paper content.""" 75 | try: 76 | with open("dickens/inbox/2410.05779v2-LightRAG.pdf", "rb") as f: 77 | pdf_reader = PyPDF2.PdfReader(f) 78 | content = [] 79 | for page in pdf_reader.pages: 80 | text = page.extract_text() 81 | if text.strip(): 82 | content.append(text) 83 | 84 | if not content: 85 | state.status_message = "No text could be extracted from the PDF" 86 | return 87 | 88 | combined_content = "\n\n".join(content) 89 | with get_event_loop_context() as loop: 90 | loop.run_until_complete(state.rag.ainsert(combined_content)) 91 | state.status_message = "Test paper inserted successfully!" 92 | state.show_insert = False 93 | except Exception as e: 94 | logger.error(f"Error inserting test paper: {str(e)}") 95 | state.status_message = f"Error inserting test paper: {str(e)}" 96 | 97 | # Knowledge Graph Stats Handler 98 | def show_kg_stats(state): 99 | """Show knowledge graph statistics.""" 100 | try: 101 | if state.rag is None: 102 | state.status_message = "Knowledge Graph not initialized yet" 103 | return 104 | 105 | graph = state.rag.chunk_entity_relation_graph._graph 106 | if graph is None: 107 | state.status_message = "Knowledge Graph is empty" 108 | return 109 | 110 | # Calculate stats 111 | nodes = graph.number_of_nodes() 112 | edges = graph.number_of_edges() 113 | avg_degree = round(sum(dict(graph.degree()).values()) / nodes, 2) if nodes > 0 else 0 114 | 115 | # Create degree distribution for plotting 116 | degrees = dict(graph.degree()) 117 | degree_dist = {} 118 | for d in degrees.values(): 119 | degree_dist[d] = degree_dist.get(d, 0) + 1 120 | 121 | # Create plot 122 | fig = go.Figure(data=[ 123 | go.Bar(x=list(degree_dist.keys()), y=list(degree_dist.values())) 124 | ]) 125 | fig.update_layout( 126 | title="Node Degree Distribution", 127 | xaxis_title="Degree", 128 | yaxis_title="Count" 129 | ) 130 | 131 | # Update state with stats and plot 132 | state.kg_stats = { 133 | "nodes": nodes, 134 | "edges": edges, 135 | "avg_degree": avg_degree, 136 | "plot": fig 137 | } 138 | state.show_kg_stats = True 139 | 140 | except Exception as e: 141 | logger.error(f"Error getting graph stats: {str(e)}") 142 | state.status_message = f"Error getting graph stats: {str(e)}" 143 | 144 | # Download Handlers 145 | def handle_chat_download(state): 146 | """Handle chat history download.""" 147 | if not state.messages: 148 | state.status_message = "No messages to download yet! Start a conversation first." 149 | return 150 | 151 | try: 152 | # Create markdown content 153 | md_lines = [ 154 | "# LightRAG Chat Session\n", 155 | f"*Exported on {strftime('%Y-%m-%d %H:%M:%S')}*\n", 156 | "\n## Settings\n", 157 | f"- Search Mode: {state.settings['search_mode']}", 158 | f"- LLM Model: {state.settings['llm_model']}", 159 | f"- Embedding Model: {state.settings['embedding_model']}", 160 | f"- Temperature: {state.settings['temperature']}", 161 | f"- System Message: {state.settings['system_message']}\n", 162 | "\n## Conversation\n" 163 | ] 164 | 165 | for msg in state.messages: 166 | role = "User" if msg["role"] == "user" else "Assistant" 167 | md_lines.append(f"\n### {role} ({msg['metadata'].get('timestamp', 'N/A')})") 168 | md_lines.append(f"\n{msg['content']}\n") 169 | 170 | if msg["role"] == "assistant" and "metadata" in msg: 171 | metadata = msg["metadata"] 172 | if "query_info" in metadata: 173 | md_lines.append(f"\n> {metadata['query_info']}") 174 | if "error" in metadata: 175 | md_lines.append(f"\n> ⚠️ Error: {metadata['error']}") 176 | 177 | # Save to file 178 | filename = f"chat_session_{strftime('%Y%m%d_%H%M%S')}.md" 179 | with open(filename, "w", encoding="utf-8") as f: 180 | f.write("\n".join(md_lines)) 181 | 182 | state.status_message = f"Chat history saved to {filename}" 183 | state.show_download = False 184 | 185 | except Exception as e: 186 | logger.error(f"Error downloading chat history: {str(e)}") 187 | state.status_message = f"Error downloading chat history: {str(e)}" 188 | 189 | def handle_records_download(state): 190 | """Handle inserted records download.""" 191 | try: 192 | if state.rag is None: 193 | state.status_message = "No records available. Initialize RAG first." 194 | return 195 | 196 | records = state.rag.get_all_records() 197 | if not records: 198 | state.status_message = "No records found to download." 199 | return 200 | 201 | # Save to file 202 | filename = f"lightrag_records_{strftime('%Y%m%d_%H%M%S')}.json" 203 | with open(filename, "w", encoding="utf-8") as f: 204 | json.dump(records, f, indent=2) 205 | 206 | state.status_message = f"Records saved to {filename}" 207 | state.show_download = False 208 | 209 | except Exception as e: 210 | logger.error(f"Error downloading records: {str(e)}") 211 | state.status_message = f"Error downloading records: {str(e)}" 212 | 213 | # Add these functions for message handling 214 | def handle_prompt(state): 215 | """Handle user prompt and generate response.""" 216 | if not state.current_prompt: 217 | return 218 | 219 | timestamp = strftime("%Y-%m-%d %H:%M:%S") 220 | date_short = strftime("%Y%m%d") 221 | prompt_hash = xxhash.xxh64(state.current_prompt.encode()).hexdigest()[:8] 222 | 223 | # Add user message 224 | state.messages.append({ 225 | "role": "user", 226 | "content": state.current_prompt, 227 | "metadata": { 228 | "timestamp": timestamp 229 | } 230 | }) 231 | 232 | # Update UI to show processing 233 | state.status_message = "Searching and generating response..." 234 | 235 | try: 236 | # Generate response 237 | query_param = QueryParam(mode=state.settings["search_mode"]) 238 | with get_event_loop_context() as loop: 239 | response = loop.run_until_complete(state.rag.aquery(state.current_prompt, param=query_param)) 240 | 241 | # Create metadata 242 | timestamp = strftime("%Y-%m-%d %H:%M:%S") 243 | query_info = f"{state.settings['search_mode']}@{state.settings['llm_model']} #ds/{prompt_hash}/{date_short}" 244 | 245 | # Add assistant message 246 | state.messages.append({ 247 | "role": "assistant", 248 | "content": response, 249 | "metadata": { 250 | "timestamp": timestamp, 251 | "search_mode": state.settings["search_mode"], 252 | "llm_model": state.settings["llm_model"], 253 | "embedding_model": state.settings["embedding_model"], 254 | "temperature": state.settings["temperature"], 255 | "prompt_hash": prompt_hash, 256 | "query_info": query_info 257 | } 258 | }) 259 | 260 | # Update query info for display 261 | state.query_info = f""" 262 | **Query Details:** 263 | - Search Mode: {state.settings['search_mode']} 264 | - LLM Model: {state.settings['llm_model']} 265 | - Embedding Model: {state.settings['embedding_model']} 266 | - Temperature: {state.settings['temperature']} 267 | - Timestamp: {timestamp} 268 | - Prompt Hash: {prompt_hash} 269 | """ 270 | 271 | except Exception as e: 272 | error_msg = f"Error generating response: {str(e)}" 273 | logger.error(error_msg) 274 | 275 | # Add error message 276 | state.messages.append({ 277 | "role": "assistant", 278 | "content": "I apologize, but I encountered an error while processing your request.", 279 | "metadata": { 280 | "timestamp": timestamp, 281 | "search_mode": state.settings["search_mode"], 282 | "llm_model": state.settings["llm_model"], 283 | "embedding_model": state.settings["embedding_model"], 284 | "error": str(e) 285 | } 286 | }) 287 | 288 | finally: 289 | # Clear prompt and status 290 | state.current_prompt = "" 291 | state.status_message = "" 292 | 293 | def toggle_query_details(state): 294 | """Toggle the visibility of query details.""" 295 | state.show_query_details = not state.show_query_details 296 | 297 | # Update the layout to include dialogs 298 | layout = """ 299 | 300 | 301 | 302 | <|dialog|open={show_insert}| 303 | ### Insert Records 304 | 305 | <|tabs| 306 | <|tab|label=Paste| 307 | <|{insert_text}|text_area|label=Paste text or markdown content:|> 308 | <|Insert|button|on_action=handle_paste_insert|> 309 | |> 310 | 311 | <|tab|label=Upload| 312 | <|Upload file|file_selector|on_change=handle_file_upload|extensions=.txt,.md|> 313 | |> 314 | 315 | <|tab|label=Website| 316 | <|{website_url}|input|label=Website URL:|> 317 | <|Insert|button|on_action=handle_website_insert|> 318 | |> 319 | 320 | <|tab|label=Test Documents| 321 | <|Insert A Christmas Carol|button|on_action=insert_test_book|> 322 | <|Insert LightRAG Paper|button|on_action=insert_test_paper|> 323 | |> 324 | |> 325 | |> 326 | 327 | 328 | <|dialog|open={show_kg_stats}| 329 | ### Knowledge Graph Statistics 330 | 331 | <|{kg_stats}|chart|type=plotly|> 332 | 333 | **Basic Stats:** 334 | - Nodes: <|{kg_stats["nodes"]}|> 335 | - Edges: <|{kg_stats["edges"]}|> 336 | - Average Degree: <|{kg_stats["avg_degree"]}|> 337 | |> 338 | 339 | 340 | <|dialog|open={show_download}| 341 | ### Download Options 342 | 343 | <|tabs| 344 | <|tab|label=Chat History| 345 | Download the current chat session as a markdown file. 346 | <|Download Chat|button|on_action=handle_chat_download|> 347 | |> 348 | 349 | <|tab|label=Inserted Records| 350 | Download all inserted records as a JSON file. 351 | <|Download Records|button|on_action=handle_records_download|> 352 | |> 353 | |> 354 | |> 355 | """ 356 | 357 | # Add CSS classes for message styling 358 | def message_class(message): 359 | """Return CSS class based on message role.""" 360 | return f"message-{message['role']}" 361 | -------------------------------------------------------------------------------- /lightrag/prompt.py: -------------------------------------------------------------------------------- 1 | GRAPH_FIELD_SEP = "" 2 | 3 | PROMPTS = {} 4 | 5 | PROMPTS["DEFAULT_TUPLE_DELIMITER"] = "<|>" 6 | PROMPTS["DEFAULT_RECORD_DELIMITER"] = "##" 7 | PROMPTS["DEFAULT_COMPLETION_DELIMITER"] = "<|COMPLETE|>" 8 | PROMPTS["process_tickers"] = ["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"] 9 | 10 | PROMPTS["DEFAULT_ENTITY_TYPES"] = ["organization", "person", "geo", "event"] 11 | 12 | PROMPTS["entity_extraction"] = """-Goal- 13 | Given a text document that is potentially relevant to this activity and a list of entity types, identify all entities of those types from the text and all relationships among the identified entities. 14 | 15 | -Steps- 16 | 1. Identify all entities. For each identified entity, extract the following information: 17 | - entity_name: Name of the entity, capitalized 18 | - entity_type: One of the following types: [{entity_types}] 19 | - entity_description: Comprehensive description of the entity's attributes and activities 20 | Format each entity as ("entity"{tuple_delimiter}{tuple_delimiter}{tuple_delimiter} 21 | 22 | 2. From the entities identified in step 1, identify all pairs of (source_entity, target_entity) that are *clearly related* to each other. 23 | For each pair of related entities, extract the following information: 24 | - source_entity: name of the source entity, as identified in step 1 25 | - target_entity: name of the target entity, as identified in step 1 26 | - relationship_description: explanation as to why you think the source entity and the target entity are related to each other 27 | - relationship_strength: a numeric score indicating strength of the relationship between the source entity and target entity 28 | - relationship_keywords: one or more high-level key words that summarize the overarching nature of the relationship, focusing on concepts or themes rather than specific details 29 | Format each relationship as ("relationship"{tuple_delimiter}{tuple_delimiter}{tuple_delimiter}{tuple_delimiter}{tuple_delimiter}) 30 | 31 | 3. Identify high-level key words that summarize the main concepts, themes, or topics of the entire text. These should capture the overarching ideas present in the document. 32 | Format the content-level key words as ("content_keywords"{tuple_delimiter}) 33 | 34 | 4. Return output in English as a single list of all the entities and relationships identified in steps 1 and 2. Use **{record_delimiter}** as the list delimiter. 35 | 36 | 5. When finished, output {completion_delimiter} 37 | 38 | ###################### 39 | -Examples- 40 | ###################### 41 | Example 1: 42 | 43 | Entity_types: [person, technology, mission, organization, location] 44 | Text: 45 | while Alex clenched his jaw, the buzz of frustration dull against the backdrop of Taylor's authoritarian certainty. It was this competitive undercurrent that kept him alert, the sense that his and Jordan's shared commitment to discovery was an unspoken rebellion against Cruz's narrowing vision of control and order. 46 | 47 | Then Taylor did something unexpected. They paused beside Jordan and, for a moment, observed the device with something akin to reverence. “If this tech can be understood..." Taylor said, their voice quieter, "It could change the game for us. For all of us.” 48 | 49 | The underlying dismissal earlier seemed to falter, replaced by a glimpse of reluctant respect for the gravity of what lay in their hands. Jordan looked up, and for a fleeting heartbeat, their eyes locked with Taylor's, a wordless clash of wills softening into an uneasy truce. 50 | 51 | It was a small transformation, barely perceptible, but one that Alex noted with an inward nod. They had all been brought here by different paths 52 | ################ 53 | Output: 54 | ("entity"{tuple_delimiter}"Alex"{tuple_delimiter}"person"{tuple_delimiter}"Alex is a character who experiences frustration and is observant of the dynamics among other characters."){record_delimiter} 55 | ("entity"{tuple_delimiter}"Taylor"{tuple_delimiter}"person"{tuple_delimiter}"Taylor is portrayed with authoritarian certainty and shows a moment of reverence towards a device, indicating a change in perspective."){record_delimiter} 56 | ("entity"{tuple_delimiter}"Jordan"{tuple_delimiter}"person"{tuple_delimiter}"Jordan shares a commitment to discovery and has a significant interaction with Taylor regarding a device."){record_delimiter} 57 | ("entity"{tuple_delimiter}"Cruz"{tuple_delimiter}"person"{tuple_delimiter}"Cruz is associated with a vision of control and order, influencing the dynamics among other characters."){record_delimiter} 58 | ("entity"{tuple_delimiter}"The Device"{tuple_delimiter}"technology"{tuple_delimiter}"The Device is central to the story, with potential game-changing implications, and is revered by Taylor."){record_delimiter} 59 | ("relationship"{tuple_delimiter}"Alex"{tuple_delimiter}"Taylor"{tuple_delimiter}"Alex is affected by Taylor's authoritarian certainty and observes changes in Taylor's attitude towards the device."{tuple_delimiter}"power dynamics, perspective shift"{tuple_delimiter}7){record_delimiter} 60 | ("relationship"{tuple_delimiter}"Alex"{tuple_delimiter}"Jordan"{tuple_delimiter}"Alex and Jordan share a commitment to discovery, which contrasts with Cruz's vision."{tuple_delimiter}"shared goals, rebellion"{tuple_delimiter}6){record_delimiter} 61 | ("relationship"{tuple_delimiter}"Taylor"{tuple_delimiter}"Jordan"{tuple_delimiter}"Taylor and Jordan interact directly regarding the device, leading to a moment of mutual respect and an uneasy truce."{tuple_delimiter}"conflict resolution, mutual respect"{tuple_delimiter}8){record_delimiter} 62 | ("relationship"{tuple_delimiter}"Jordan"{tuple_delimiter}"Cruz"{tuple_delimiter}"Jordan's commitment to discovery is in rebellion against Cruz's vision of control and order."{tuple_delimiter}"ideological conflict, rebellion"{tuple_delimiter}5){record_delimiter} 63 | ("relationship"{tuple_delimiter}"Taylor"{tuple_delimiter}"The Device"{tuple_delimiter}"Taylor shows reverence towards the device, indicating its importance and potential impact."{tuple_delimiter}"reverence, technological significance"{tuple_delimiter}9){record_delimiter} 64 | ("content_keywords"{tuple_delimiter}"power dynamics, ideological conflict, discovery, rebellion"){completion_delimiter} 65 | ############################# 66 | Example 2: 67 | 68 | Entity_types: [person, technology, mission, organization, location] 69 | Text: 70 | They were no longer mere operatives; they had become guardians of a threshold, keepers of a message from a realm beyond stars and stripes. This elevation in their mission could not be shackled by regulations and established protocols—it demanded a new perspective, a new resolve. 71 | 72 | Tension threaded through the dialogue of beeps and static as communications with Washington buzzed in the background. The team stood, a portentous air enveloping them. It was clear that the decisions they made in the ensuing hours could redefine humanity's place in the cosmos or condemn them to ignorance and potential peril. 73 | 74 | Their connection to the stars solidified, the group moved to address the crystallizing warning, shifting from passive recipients to active participants. Mercer's latter instincts gained precedence— the team's mandate had evolved, no longer solely to observe and report but to interact and prepare. A metamorphosis had begun, and Operation: Dulce hummed with the newfound frequency of their daring, a tone set not by the earthly 75 | ############# 76 | Output: 77 | ("entity"{tuple_delimiter}"Washington"{tuple_delimiter}"location"{tuple_delimiter}"Washington is a location where communications are being received, indicating its importance in the decision-making process."){record_delimiter} 78 | ("entity"{tuple_delimiter}"Operation: Dulce"{tuple_delimiter}"mission"{tuple_delimiter}"Operation: Dulce is described as a mission that has evolved to interact and prepare, indicating a significant shift in objectives and activities."){record_delimiter} 79 | ("entity"{tuple_delimiter}"The team"{tuple_delimiter}"organization"{tuple_delimiter}"The team is portrayed as a group of individuals who have transitioned from passive observers to active participants in a mission, showing a dynamic change in their role."){record_delimiter} 80 | ("relationship"{tuple_delimiter}"The team"{tuple_delimiter}"Washington"{tuple_delimiter}"The team receives communications from Washington, which influences their decision-making process."{tuple_delimiter}"decision-making, external influence"{tuple_delimiter}7){record_delimiter} 81 | ("relationship"{tuple_delimiter}"The team"{tuple_delimiter}"Operation: Dulce"{tuple_delimiter}"The team is directly involved in Operation: Dulce, executing its evolved objectives and activities."{tuple_delimiter}"mission evolution, active participation"{tuple_delimiter}9){completion_delimiter} 82 | ("content_keywords"{tuple_delimiter}"mission evolution, decision-making, active participation, cosmic significance"){completion_delimiter} 83 | ############################# 84 | Example 3: 85 | 86 | Entity_types: [person, role, technology, organization, event, location, concept] 87 | Text: 88 | their voice slicing through the buzz of activity. "Control may be an illusion when facing an intelligence that literally writes its own rules," they stated stoically, casting a watchful eye over the flurry of data. 89 | 90 | "It's like it's learning to communicate," offered Sam Rivera from a nearby interface, their youthful energy boding a mix of awe and anxiety. "This gives talking to strangers' a whole new meaning." 91 | 92 | Alex surveyed his team—each face a study in concentration, determination, and not a small measure of trepidation. "This might well be our first contact," he acknowledged, "And we need to be ready for whatever answers back." 93 | 94 | Together, they stood on the edge of the unknown, forging humanity's response to a message from the heavens. The ensuing silence was palpable—a collective introspection about their role in this grand cosmic play, one that could rewrite human history. 95 | 96 | The encrypted dialogue continued to unfold, its intricate patterns showing an almost uncanny anticipation 97 | ############# 98 | Output: 99 | ("entity"{tuple_delimiter}"Sam Rivera"{tuple_delimiter}"person"{tuple_delimiter}"Sam Rivera is a member of a team working on communicating with an unknown intelligence, showing a mix of awe and anxiety."){record_delimiter} 100 | ("entity"{tuple_delimiter}"Alex"{tuple_delimiter}"person"{tuple_delimiter}"Alex is the leader of a team attempting first contact with an unknown intelligence, acknowledging the significance of their task."){record_delimiter} 101 | ("entity"{tuple_delimiter}"Control"{tuple_delimiter}"concept"{tuple_delimiter}"Control refers to the ability to manage or govern, which is challenged by an intelligence that writes its own rules."){record_delimiter} 102 | ("entity"{tuple_delimiter}"Intelligence"{tuple_delimiter}"concept"{tuple_delimiter}"Intelligence here refers to an unknown entity capable of writing its own rules and learning to communicate."){record_delimiter} 103 | ("entity"{tuple_delimiter}"First Contact"{tuple_delimiter}"event"{tuple_delimiter}"First Contact is the potential initial communication between humanity and an unknown intelligence."){record_delimiter} 104 | ("entity"{tuple_delimiter}"Humanity's Response"{tuple_delimiter}"event"{tuple_delimiter}"Humanity's Response is the collective action taken by Alex's team in response to a message from an unknown intelligence."){record_delimiter} 105 | ("relationship"{tuple_delimiter}"Sam Rivera"{tuple_delimiter}"Intelligence"{tuple_delimiter}"Sam Rivera is directly involved in the process of learning to communicate with the unknown intelligence."{tuple_delimiter}"communication, learning process"{tuple_delimiter}9){record_delimiter} 106 | ("relationship"{tuple_delimiter}"Alex"{tuple_delimiter}"First Contact"{tuple_delimiter}"Alex leads the team that might be making the First Contact with the unknown intelligence."{tuple_delimiter}"leadership, exploration"{tuple_delimiter}10){record_delimiter} 107 | ("relationship"{tuple_delimiter}"Alex"{tuple_delimiter}"Humanity's Response"{tuple_delimiter}"Alex and his team are the key figures in Humanity's Response to the unknown intelligence."{tuple_delimiter}"collective action, cosmic significance"{tuple_delimiter}8){record_delimiter} 108 | ("relationship"{tuple_delimiter}"Control"{tuple_delimiter}"Intelligence"{tuple_delimiter}"The concept of Control is challenged by the Intelligence that writes its own rules."{tuple_delimiter}"power dynamics, autonomy"{tuple_delimiter}7){record_delimiter} 109 | ("content_keywords"{tuple_delimiter}"first contact, control, communication, cosmic significance"){completion_delimiter} 110 | ############################# 111 | -Real Data- 112 | ###################### 113 | Entity_types: {entity_types} 114 | Text: {input_text} 115 | ###################### 116 | Output: 117 | """ 118 | 119 | PROMPTS[ 120 | "summarize_entity_descriptions" 121 | ] = """You are a helpful assistant responsible for generating a comprehensive summary of the data provided below. 122 | Given one or two entities, and a list of descriptions, all related to the same entity or group of entities. 123 | Please concatenate all of these into a single, comprehensive description. Make sure to include information collected from all the descriptions. 124 | If the provided descriptions are contradictory, please resolve the contradictions and provide a single, coherent summary. 125 | Make sure it is written in third person, and include the entity names so we the have full context. 126 | 127 | ####### 128 | -Data- 129 | Entities: {entity_name} 130 | Description List: {description_list} 131 | ####### 132 | Output: 133 | """ 134 | 135 | PROMPTS[ 136 | "entiti_continue_extraction" 137 | ] = """MANY entities were missed in the last extraction. Add them below using the same format: 138 | """ 139 | 140 | PROMPTS[ 141 | "entiti_if_loop_extraction" 142 | ] = """It appears some entities may have still been missed. Answer YES | NO if there are still entities that need to be added. 143 | """ 144 | 145 | PROMPTS["fail_response"] = "Sorry, I'm not able to provide an answer to that question." 146 | 147 | PROMPTS["rag_response"] = """---Role--- 148 | 149 | You are a helpful assistant responding to questions about data in the tables provided. 150 | 151 | 152 | ---Goal--- 153 | 154 | Generate a response of the target length and format that responds to the user's question, summarizing all information in the input data tables appropriate for the response length and format, and incorporating any relevant general knowledge. 155 | If you don't know the answer, just say so. Do not make anything up. 156 | Do not include information where the supporting evidence for it is not provided. 157 | 158 | ---Target response length and format--- 159 | 160 | {response_type} 161 | 162 | ---Data tables--- 163 | 164 | {context_data} 165 | 166 | Add sections and commentary to the response as appropriate for the length and format. Style the response in markdown. 167 | """ 168 | 169 | PROMPTS["keywords_extraction"] = """---Role--- 170 | 171 | You are a helpful assistant tasked with identifying both high-level and low-level keywords in the user's query. 172 | 173 | ---Goal--- 174 | 175 | Given the query, list both high-level and low-level keywords. High-level keywords focus on overarching concepts or themes, while low-level keywords focus on specific entities, details, or concrete terms. 176 | 177 | ---Instructions--- 178 | 179 | - Output the keywords in JSON format. 180 | - The JSON should have two keys: 181 | - "high_level_keywords" for overarching concepts or themes. 182 | - "low_level_keywords" for specific entities or details. 183 | 184 | ###################### 185 | -Examples- 186 | ###################### 187 | Example 1: 188 | 189 | Query: "How does international trade influence global economic stability?" 190 | ################ 191 | Output: 192 | {{ 193 | "high_level_keywords": ["International trade", "Global economic stability", "Economic impact"], 194 | "low_level_keywords": ["Trade agreements", "Tariffs", "Currency exchange", "Imports", "Exports"] 195 | }} 196 | ############################# 197 | Example 2: 198 | 199 | Query: "What are the environmental consequences of deforestation on biodiversity?" 200 | ################ 201 | Output: 202 | {{ 203 | "high_level_keywords": ["Environmental consequences", "Deforestation", "Biodiversity loss"], 204 | "low_level_keywords": ["Species extinction", "Habitat destruction", "Carbon emissions", "Rainforest", "Ecosystem"] 205 | }} 206 | ############################# 207 | Example 3: 208 | 209 | Query: "What is the role of education in reducing poverty?" 210 | ################ 211 | Output: 212 | {{ 213 | "high_level_keywords": ["Education", "Poverty reduction", "Socioeconomic development"], 214 | "low_level_keywords": ["School access", "Literacy rates", "Job training", "Income inequality"] 215 | }} 216 | ############################# 217 | -Real Data- 218 | ###################### 219 | Query: {query} 220 | ###################### 221 | Output: 222 | 223 | """ 224 | 225 | PROMPTS["naive_rag_response"] = """You're a helpful assistant 226 | Below are the knowledge you know: 227 | {content_data} 228 | --- 229 | If you don't know the answer or if the provided knowledge do not contain sufficient information to provide an answer, just say so. Do not make anything up. 230 | Generate a response of the target length and format that responds to the user's question, summarizing all information in the input data tables appropriate for the response length and format, and incorporating any relevant general knowledge. 231 | If you don't know the answer, just say so. Do not make anything up. 232 | Do not include information where the supporting evidence for it is not provided. 233 | ---Target response length and format--- 234 | {response_type} 235 | """ 236 | -------------------------------------------------------------------------------- /lightrag/llm.py: -------------------------------------------------------------------------------- 1 | import os 2 | import copy 3 | from functools import lru_cache 4 | import json 5 | import aioboto3 6 | import aiohttp 7 | import numpy as np 8 | import ollama 9 | 10 | from openai import ( 11 | AsyncOpenAI, 12 | APIConnectionError, 13 | RateLimitError, 14 | Timeout, 15 | AsyncAzureOpenAI, 16 | ) 17 | 18 | import base64 19 | import struct 20 | 21 | from tenacity import ( 22 | retry, 23 | stop_after_attempt, 24 | wait_exponential, 25 | retry_if_exception_type, 26 | ) 27 | from transformers import AutoTokenizer, AutoModelForCausalLM 28 | import torch 29 | from pydantic import BaseModel, Field 30 | from typing import List, Dict, Callable, Any 31 | from .base import BaseKVStorage 32 | from .utils import compute_args_hash, wrap_embedding_func_with_attrs 33 | 34 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 35 | 36 | 37 | @retry( 38 | stop=stop_after_attempt(3), 39 | wait=wait_exponential(multiplier=1, min=4, max=10), 40 | retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)), 41 | ) 42 | async def openai_complete_if_cache( 43 | model, 44 | prompt, 45 | system_prompt=None, 46 | history_messages=[], 47 | base_url=None, 48 | api_key=None, 49 | **kwargs, 50 | ) -> str: 51 | if api_key: 52 | os.environ["OPENAI_API_KEY"] = api_key 53 | 54 | openai_async_client = ( 55 | AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url) 56 | ) 57 | hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None) 58 | messages = [] 59 | if system_prompt: 60 | messages.append({"role": "system", "content": system_prompt}) 61 | messages.extend(history_messages) 62 | messages.append({"role": "user", "content": prompt}) 63 | if hashing_kv is not None: 64 | args_hash = compute_args_hash(model, messages) 65 | if_cache_return = await hashing_kv.get_by_id(args_hash) 66 | if if_cache_return is not None: 67 | return if_cache_return["return"] 68 | 69 | response = await openai_async_client.chat.completions.create( 70 | model=model, messages=messages, **kwargs 71 | ) 72 | 73 | if hashing_kv is not None: 74 | await hashing_kv.upsert( 75 | {args_hash: {"return": response.choices[0].message.content, "model": model}} 76 | ) 77 | return response.choices[0].message.content 78 | 79 | 80 | @retry( 81 | stop=stop_after_attempt(3), 82 | wait=wait_exponential(multiplier=1, min=4, max=10), 83 | retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)), 84 | ) 85 | async def azure_openai_complete_if_cache( 86 | model, 87 | prompt, 88 | system_prompt=None, 89 | history_messages=[], 90 | base_url=None, 91 | api_key=None, 92 | **kwargs, 93 | ): 94 | if api_key: 95 | os.environ["AZURE_OPENAI_API_KEY"] = api_key 96 | if base_url: 97 | os.environ["AZURE_OPENAI_ENDPOINT"] = base_url 98 | 99 | openai_async_client = AsyncAzureOpenAI( 100 | azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"), 101 | api_key=os.getenv("AZURE_OPENAI_API_KEY"), 102 | api_version=os.getenv("AZURE_OPENAI_API_VERSION"), 103 | ) 104 | 105 | hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None) 106 | messages = [] 107 | if system_prompt: 108 | messages.append({"role": "system", "content": system_prompt}) 109 | messages.extend(history_messages) 110 | if prompt is not None: 111 | messages.append({"role": "user", "content": prompt}) 112 | if hashing_kv is not None: 113 | args_hash = compute_args_hash(model, messages) 114 | if_cache_return = await hashing_kv.get_by_id(args_hash) 115 | if if_cache_return is not None: 116 | return if_cache_return["return"] 117 | 118 | response = await openai_async_client.chat.completions.create( 119 | model=model, messages=messages, **kwargs 120 | ) 121 | 122 | if hashing_kv is not None: 123 | await hashing_kv.upsert( 124 | {args_hash: {"return": response.choices[0].message.content, "model": model}} 125 | ) 126 | return response.choices[0].message.content 127 | 128 | 129 | class BedrockError(Exception): 130 | """Generic error for issues related to Amazon Bedrock""" 131 | 132 | 133 | @retry( 134 | stop=stop_after_attempt(5), 135 | wait=wait_exponential(multiplier=1, max=60), 136 | retry=retry_if_exception_type((BedrockError)), 137 | ) 138 | async def bedrock_complete_if_cache( 139 | model, 140 | prompt, 141 | system_prompt=None, 142 | history_messages=[], 143 | aws_access_key_id=None, 144 | aws_secret_access_key=None, 145 | aws_session_token=None, 146 | **kwargs, 147 | ) -> str: 148 | os.environ["AWS_ACCESS_KEY_ID"] = os.environ.get( 149 | "AWS_ACCESS_KEY_ID", aws_access_key_id 150 | ) 151 | os.environ["AWS_SECRET_ACCESS_KEY"] = os.environ.get( 152 | "AWS_SECRET_ACCESS_KEY", aws_secret_access_key 153 | ) 154 | os.environ["AWS_SESSION_TOKEN"] = os.environ.get( 155 | "AWS_SESSION_TOKEN", aws_session_token 156 | ) 157 | 158 | # Fix message history format 159 | messages = [] 160 | for history_message in history_messages: 161 | message = copy.copy(history_message) 162 | message["content"] = [{"text": message["content"]}] 163 | messages.append(message) 164 | 165 | # Add user prompt 166 | messages.append({"role": "user", "content": [{"text": prompt}]}) 167 | 168 | # Initialize Converse API arguments 169 | args = {"modelId": model, "messages": messages} 170 | 171 | # Define system prompt 172 | if system_prompt: 173 | args["system"] = [{"text": system_prompt}] 174 | 175 | # Map and set up inference parameters 176 | inference_params_map = { 177 | "max_tokens": "maxTokens", 178 | "top_p": "topP", 179 | "stop_sequences": "stopSequences", 180 | } 181 | if inference_params := list( 182 | set(kwargs) & set(["max_tokens", "temperature", "top_p", "stop_sequences"]) 183 | ): 184 | args["inferenceConfig"] = {} 185 | for param in inference_params: 186 | args["inferenceConfig"][inference_params_map.get(param, param)] = ( 187 | kwargs.pop(param) 188 | ) 189 | 190 | hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None) 191 | if hashing_kv is not None: 192 | args_hash = compute_args_hash(model, messages) 193 | if_cache_return = await hashing_kv.get_by_id(args_hash) 194 | if if_cache_return is not None: 195 | return if_cache_return["return"] 196 | 197 | # Call model via Converse API 198 | session = aioboto3.Session() 199 | async with session.client("bedrock-runtime") as bedrock_async_client: 200 | try: 201 | response = await bedrock_async_client.converse(**args, **kwargs) 202 | except Exception as e: 203 | raise BedrockError(e) 204 | 205 | if hashing_kv is not None: 206 | await hashing_kv.upsert( 207 | { 208 | args_hash: { 209 | "return": response["output"]["message"]["content"][0]["text"], 210 | "model": model, 211 | } 212 | } 213 | ) 214 | 215 | return response["output"]["message"]["content"][0]["text"] 216 | 217 | 218 | @lru_cache(maxsize=1) 219 | def initialize_hf_model(model_name): 220 | hf_tokenizer = AutoTokenizer.from_pretrained( 221 | model_name, device_map="auto", trust_remote_code=True 222 | ) 223 | hf_model = AutoModelForCausalLM.from_pretrained( 224 | model_name, device_map="auto", trust_remote_code=True 225 | ) 226 | if hf_tokenizer.pad_token is None: 227 | hf_tokenizer.pad_token = hf_tokenizer.eos_token 228 | 229 | return hf_model, hf_tokenizer 230 | 231 | 232 | async def hf_model_if_cache( 233 | model, prompt, system_prompt=None, history_messages=[], **kwargs 234 | ) -> str: 235 | model_name = model 236 | hf_model, hf_tokenizer = initialize_hf_model(model_name) 237 | hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None) 238 | messages = [] 239 | if system_prompt: 240 | messages.append({"role": "system", "content": system_prompt}) 241 | messages.extend(history_messages) 242 | messages.append({"role": "user", "content": prompt}) 243 | 244 | if hashing_kv is not None: 245 | args_hash = compute_args_hash(model, messages) 246 | if_cache_return = await hashing_kv.get_by_id(args_hash) 247 | if if_cache_return is not None: 248 | return if_cache_return["return"] 249 | input_prompt = "" 250 | try: 251 | input_prompt = hf_tokenizer.apply_chat_template( 252 | messages, tokenize=False, add_generation_prompt=True 253 | ) 254 | except Exception: 255 | try: 256 | ori_message = copy.deepcopy(messages) 257 | if messages[0]["role"] == "system": 258 | messages[1]["content"] = ( 259 | "" 260 | + messages[0]["content"] 261 | + "\n" 262 | + messages[1]["content"] 263 | ) 264 | messages = messages[1:] 265 | input_prompt = hf_tokenizer.apply_chat_template( 266 | messages, tokenize=False, add_generation_prompt=True 267 | ) 268 | except Exception: 269 | len_message = len(ori_message) 270 | for msgid in range(len_message): 271 | input_prompt = ( 272 | input_prompt 273 | + "<" 274 | + ori_message[msgid]["role"] 275 | + ">" 276 | + ori_message[msgid]["content"] 277 | + "\n" 280 | ) 281 | 282 | input_ids = hf_tokenizer( 283 | input_prompt, return_tensors="pt", padding=True, truncation=True 284 | ).to("cuda") 285 | inputs = {k: v.to(hf_model.device) for k, v in input_ids.items()} 286 | output = hf_model.generate( 287 | **input_ids, max_new_tokens=512, num_return_sequences=1, early_stopping=True 288 | ) 289 | response_text = hf_tokenizer.decode( 290 | output[0][len(inputs["input_ids"][0]) :], skip_special_tokens=True 291 | ) 292 | if hashing_kv is not None: 293 | await hashing_kv.upsert({args_hash: {"return": response_text, "model": model}}) 294 | return response_text 295 | 296 | 297 | async def ollama_model_if_cache( 298 | model, prompt, system_prompt=None, history_messages=[], **kwargs 299 | ) -> str: 300 | kwargs.pop("max_tokens", None) 301 | kwargs.pop("response_format", None) 302 | host = kwargs.pop("host", None) 303 | timeout = kwargs.pop("timeout", None) 304 | 305 | ollama_client = ollama.AsyncClient(host=host, timeout=timeout) 306 | messages = [] 307 | if system_prompt: 308 | messages.append({"role": "system", "content": system_prompt}) 309 | 310 | hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None) 311 | messages.extend(history_messages) 312 | messages.append({"role": "user", "content": prompt}) 313 | if hashing_kv is not None: 314 | args_hash = compute_args_hash(model, messages) 315 | if_cache_return = await hashing_kv.get_by_id(args_hash) 316 | if if_cache_return is not None: 317 | return if_cache_return["return"] 318 | 319 | response = await ollama_client.chat(model=model, messages=messages, **kwargs) 320 | 321 | result = response["message"]["content"] 322 | 323 | if hashing_kv is not None: 324 | await hashing_kv.upsert({args_hash: {"return": result, "model": model}}) 325 | 326 | return result 327 | 328 | 329 | @lru_cache(maxsize=1) 330 | def initialize_lmdeploy_pipeline( 331 | model, 332 | tp=1, 333 | chat_template=None, 334 | log_level="WARNING", 335 | model_format="hf", 336 | quant_policy=0, 337 | ): 338 | from lmdeploy import pipeline, ChatTemplateConfig, TurbomindEngineConfig 339 | 340 | lmdeploy_pipe = pipeline( 341 | model_path=model, 342 | backend_config=TurbomindEngineConfig( 343 | tp=tp, model_format=model_format, quant_policy=quant_policy 344 | ), 345 | chat_template_config=ChatTemplateConfig(model_name=chat_template) 346 | if chat_template 347 | else None, 348 | log_level="WARNING", 349 | ) 350 | return lmdeploy_pipe 351 | 352 | 353 | async def lmdeploy_model_if_cache( 354 | model, 355 | prompt, 356 | system_prompt=None, 357 | history_messages=[], 358 | chat_template=None, 359 | model_format="hf", 360 | quant_policy=0, 361 | **kwargs, 362 | ) -> str: 363 | """ 364 | Args: 365 | model (str): The path to the model. 366 | It could be one of the following options: 367 | - i) A local directory path of a turbomind model which is 368 | converted by `lmdeploy convert` command or download 369 | from ii) and iii). 370 | - ii) The model_id of a lmdeploy-quantized model hosted 371 | inside a model repo on huggingface.co, such as 372 | "InternLM/internlm-chat-20b-4bit", 373 | "lmdeploy/llama2-chat-70b-4bit", etc. 374 | - iii) The model_id of a model hosted inside a model repo 375 | on huggingface.co, such as "internlm/internlm-chat-7b", 376 | "Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat" 377 | and so on. 378 | chat_template (str): needed when model is a pytorch model on 379 | huggingface.co, such as "internlm-chat-7b", 380 | "Qwen-7B-Chat ", "Baichuan2-7B-Chat" and so on, 381 | and when the model name of local path did not match the original model name in HF. 382 | tp (int): tensor parallel 383 | prompt (Union[str, List[str]]): input texts to be completed. 384 | do_preprocess (bool): whether pre-process the messages. Default to 385 | True, which means chat_template will be applied. 386 | skip_special_tokens (bool): Whether or not to remove special tokens 387 | in the decoding. Default to be True. 388 | do_sample (bool): Whether or not to use sampling, use greedy decoding otherwise. 389 | Default to be False, which means greedy decoding will be applied. 390 | """ 391 | try: 392 | import lmdeploy 393 | from lmdeploy import version_info, GenerationConfig 394 | except Exception: 395 | raise ImportError("Please install lmdeploy before intialize lmdeploy backend.") 396 | 397 | kwargs.pop("response_format", None) 398 | max_new_tokens = kwargs.pop("max_tokens", 512) 399 | tp = kwargs.pop("tp", 1) 400 | skip_special_tokens = kwargs.pop("skip_special_tokens", True) 401 | do_preprocess = kwargs.pop("do_preprocess", True) 402 | do_sample = kwargs.pop("do_sample", False) 403 | gen_params = kwargs 404 | 405 | version = version_info 406 | if do_sample is not None and version < (0, 6, 0): 407 | raise RuntimeError( 408 | "`do_sample` parameter is not supported by lmdeploy until " 409 | f"v0.6.0, but currently using lmdeloy {lmdeploy.__version__}" 410 | ) 411 | else: 412 | do_sample = True 413 | gen_params.update(do_sample=do_sample) 414 | 415 | lmdeploy_pipe = initialize_lmdeploy_pipeline( 416 | model=model, 417 | tp=tp, 418 | chat_template=chat_template, 419 | model_format=model_format, 420 | quant_policy=quant_policy, 421 | log_level="WARNING", 422 | ) 423 | 424 | messages = [] 425 | if system_prompt: 426 | messages.append({"role": "system", "content": system_prompt}) 427 | 428 | hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None) 429 | messages.extend(history_messages) 430 | messages.append({"role": "user", "content": prompt}) 431 | if hashing_kv is not None: 432 | args_hash = compute_args_hash(model, messages) 433 | if_cache_return = await hashing_kv.get_by_id(args_hash) 434 | if if_cache_return is not None: 435 | return if_cache_return["return"] 436 | 437 | gen_config = GenerationConfig( 438 | skip_special_tokens=skip_special_tokens, 439 | max_new_tokens=max_new_tokens, 440 | **gen_params, 441 | ) 442 | 443 | response = "" 444 | async for res in lmdeploy_pipe.generate( 445 | messages, 446 | gen_config=gen_config, 447 | do_preprocess=do_preprocess, 448 | stream_response=False, 449 | session_id=1, 450 | ): 451 | response += res.response 452 | 453 | if hashing_kv is not None: 454 | await hashing_kv.upsert({args_hash: {"return": response, "model": model}}) 455 | return response 456 | 457 | 458 | async def gpt_4o_complete( 459 | prompt, system_prompt=None, history_messages=[], **kwargs 460 | ) -> str: 461 | return await openai_complete_if_cache( 462 | "gpt-4o", 463 | prompt, 464 | system_prompt=system_prompt, 465 | history_messages=history_messages, 466 | **kwargs, 467 | ) 468 | 469 | 470 | async def gpt_4o_mini_complete( 471 | prompt, system_prompt=None, history_messages=[], **kwargs 472 | ) -> str: 473 | return await openai_complete_if_cache( 474 | "gpt-4o-mini", 475 | prompt, 476 | system_prompt=system_prompt, 477 | history_messages=history_messages, 478 | **kwargs, 479 | ) 480 | 481 | 482 | async def azure_openai_complete( 483 | prompt, system_prompt=None, history_messages=[], **kwargs 484 | ) -> str: 485 | return await azure_openai_complete_if_cache( 486 | "conversation-4o-mini", 487 | prompt, 488 | system_prompt=system_prompt, 489 | history_messages=history_messages, 490 | **kwargs, 491 | ) 492 | 493 | 494 | async def bedrock_complete( 495 | prompt, system_prompt=None, history_messages=[], **kwargs 496 | ) -> str: 497 | return await bedrock_complete_if_cache( 498 | "anthropic.claude-3-haiku-20240307-v1:0", 499 | prompt, 500 | system_prompt=system_prompt, 501 | history_messages=history_messages, 502 | **kwargs, 503 | ) 504 | 505 | 506 | async def hf_model_complete( 507 | prompt, system_prompt=None, history_messages=[], **kwargs 508 | ) -> str: 509 | model_name = kwargs["hashing_kv"].global_config["llm_model_name"] 510 | return await hf_model_if_cache( 511 | model_name, 512 | prompt, 513 | system_prompt=system_prompt, 514 | history_messages=history_messages, 515 | **kwargs, 516 | ) 517 | 518 | 519 | async def ollama_model_complete( 520 | prompt, system_prompt=None, history_messages=[], **kwargs 521 | ) -> str: 522 | model_name = kwargs["hashing_kv"].global_config["llm_model_name"] 523 | return await ollama_model_if_cache( 524 | model_name, 525 | prompt, 526 | system_prompt=system_prompt, 527 | history_messages=history_messages, 528 | **kwargs, 529 | ) 530 | 531 | 532 | @wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192) 533 | @retry( 534 | stop=stop_after_attempt(3), 535 | wait=wait_exponential(multiplier=1, min=4, max=60), 536 | retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)), 537 | ) 538 | async def openai_embedding( 539 | texts: list[str], 540 | model: str = "text-embedding-3-small", 541 | base_url: str = None, 542 | api_key: str = None, 543 | ) -> np.ndarray: 544 | if api_key: 545 | os.environ["OPENAI_API_KEY"] = api_key 546 | 547 | openai_async_client = ( 548 | AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url) 549 | ) 550 | response = await openai_async_client.embeddings.create( 551 | model=model, input=texts, encoding_format="float" 552 | ) 553 | return np.array([dp.embedding for dp in response.data]) 554 | 555 | 556 | @wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192) 557 | @retry( 558 | stop=stop_after_attempt(3), 559 | wait=wait_exponential(multiplier=1, min=4, max=10), 560 | retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)), 561 | ) 562 | async def azure_openai_embedding( 563 | texts: list[str], 564 | model: str = "text-embedding-3-small", 565 | base_url: str = None, 566 | api_key: str = None, 567 | ) -> np.ndarray: 568 | if api_key: 569 | os.environ["AZURE_OPENAI_API_KEY"] = api_key 570 | if base_url: 571 | os.environ["AZURE_OPENAI_ENDPOINT"] = base_url 572 | 573 | openai_async_client = AsyncAzureOpenAI( 574 | azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"), 575 | api_key=os.getenv("AZURE_OPENAI_API_KEY"), 576 | api_version=os.getenv("AZURE_OPENAI_API_VERSION"), 577 | ) 578 | 579 | response = await openai_async_client.embeddings.create( 580 | model=model, input=texts, encoding_format="float" 581 | ) 582 | return np.array([dp.embedding for dp in response.data]) 583 | 584 | 585 | @retry( 586 | stop=stop_after_attempt(3), 587 | wait=wait_exponential(multiplier=1, min=4, max=60), 588 | retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)), 589 | ) 590 | async def siliconcloud_embedding( 591 | texts: list[str], 592 | model: str = "netease-youdao/bce-embedding-base_v1", 593 | base_url: str = "https://api.siliconflow.cn/v1/embeddings", 594 | max_token_size: int = 512, 595 | api_key: str = None, 596 | ) -> np.ndarray: 597 | if api_key and not api_key.startswith("Bearer "): 598 | api_key = "Bearer " + api_key 599 | 600 | headers = {"Authorization": api_key, "Content-Type": "application/json"} 601 | 602 | truncate_texts = [text[0:max_token_size] for text in texts] 603 | 604 | payload = {"model": model, "input": truncate_texts, "encoding_format": "base64"} 605 | 606 | base64_strings = [] 607 | async with aiohttp.ClientSession() as session: 608 | async with session.post(base_url, headers=headers, json=payload) as response: 609 | content = await response.json() 610 | if "code" in content: 611 | raise ValueError(content) 612 | base64_strings = [item["embedding"] for item in content["data"]] 613 | 614 | embeddings = [] 615 | for string in base64_strings: 616 | decode_bytes = base64.b64decode(string) 617 | n = len(decode_bytes) // 4 618 | float_array = struct.unpack("<" + "f" * n, decode_bytes) 619 | embeddings.append(float_array) 620 | return np.array(embeddings) 621 | 622 | 623 | # @wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192) 624 | # @retry( 625 | # stop=stop_after_attempt(3), 626 | # wait=wait_exponential(multiplier=1, min=4, max=10), 627 | # retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)), # TODO: fix exceptions 628 | # ) 629 | async def bedrock_embedding( 630 | texts: list[str], 631 | model: str = "amazon.titan-embed-text-v2:0", 632 | aws_access_key_id=None, 633 | aws_secret_access_key=None, 634 | aws_session_token=None, 635 | ) -> np.ndarray: 636 | os.environ["AWS_ACCESS_KEY_ID"] = os.environ.get( 637 | "AWS_ACCESS_KEY_ID", aws_access_key_id 638 | ) 639 | os.environ["AWS_SECRET_ACCESS_KEY"] = os.environ.get( 640 | "AWS_SECRET_ACCESS_KEY", aws_secret_access_key 641 | ) 642 | os.environ["AWS_SESSION_TOKEN"] = os.environ.get( 643 | "AWS_SESSION_TOKEN", aws_session_token 644 | ) 645 | 646 | session = aioboto3.Session() 647 | async with session.client("bedrock-runtime") as bedrock_async_client: 648 | if (model_provider := model.split(".")[0]) == "amazon": 649 | embed_texts = [] 650 | for text in texts: 651 | if "v2" in model: 652 | body = json.dumps( 653 | { 654 | "inputText": text, 655 | # 'dimensions': embedding_dim, 656 | "embeddingTypes": ["float"], 657 | } 658 | ) 659 | elif "v1" in model: 660 | body = json.dumps({"inputText": text}) 661 | else: 662 | raise ValueError(f"Model {model} is not supported!") 663 | 664 | response = await bedrock_async_client.invoke_model( 665 | modelId=model, 666 | body=body, 667 | accept="application/json", 668 | contentType="application/json", 669 | ) 670 | 671 | response_body = await response.get("body").json() 672 | 673 | embed_texts.append(response_body["embedding"]) 674 | elif model_provider == "cohere": 675 | body = json.dumps( 676 | {"texts": texts, "input_type": "search_document", "truncate": "NONE"} 677 | ) 678 | 679 | response = await bedrock_async_client.invoke_model( 680 | model=model, 681 | body=body, 682 | accept="application/json", 683 | contentType="application/json", 684 | ) 685 | 686 | response_body = json.loads(response.get("body").read()) 687 | 688 | embed_texts = response_body["embeddings"] 689 | else: 690 | raise ValueError(f"Model provider '{model_provider}' is not supported!") 691 | 692 | return np.array(embed_texts) 693 | 694 | 695 | async def hf_embedding(texts: list[str], tokenizer, embed_model) -> np.ndarray: 696 | input_ids = tokenizer( 697 | texts, return_tensors="pt", padding=True, truncation=True 698 | ).input_ids 699 | with torch.no_grad(): 700 | outputs = embed_model(input_ids) 701 | embeddings = outputs.last_hidden_state.mean(dim=1) 702 | return embeddings.detach().numpy() 703 | 704 | 705 | async def ollama_embedding(texts: list[str], embed_model, **kwargs) -> np.ndarray: 706 | embed_text = [] 707 | ollama_client = ollama.Client(**kwargs) 708 | for text in texts: 709 | data = ollama_client.embeddings(model=embed_model, prompt=text) 710 | embed_text.append(data["embedding"]) 711 | 712 | return embed_text 713 | 714 | 715 | class Model(BaseModel): 716 | """ 717 | This is a Pydantic model class named 'Model' that is used to define a custom language model. 718 | 719 | Attributes: 720 | gen_func (Callable[[Any], str]): A callable function that generates the response from the language model. 721 | The function should take any argument and return a string. 722 | kwargs (Dict[str, Any]): A dictionary that contains the arguments to pass to the callable function. 723 | This could include parameters such as the model name, API key, etc. 724 | 725 | Example usage: 726 | Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_1"]}) 727 | 728 | In this example, 'openai_complete_if_cache' is the callable function that generates the response from the OpenAI model. 729 | The 'kwargs' dictionary contains the model name and API key to be passed to the function. 730 | """ 731 | 732 | gen_func: Callable[[Any], str] = Field( 733 | ..., 734 | description="A function that generates the response from the llm. The response must be a string", 735 | ) 736 | kwargs: Dict[str, Any] = Field( 737 | ..., 738 | description="The arguments to pass to the callable function. Eg. the api key, model name, etc", 739 | ) 740 | 741 | class Config: 742 | arbitrary_types_allowed = True 743 | 744 | 745 | class MultiModel: 746 | """ 747 | Distributes the load across multiple language models. Useful for circumventing low rate limits with certain api providers especially if you are on the free tier. 748 | Could also be used for spliting across diffrent models or providers. 749 | 750 | Attributes: 751 | models (List[Model]): A list of language models to be used. 752 | 753 | Usage example: 754 | ```python 755 | models = [ 756 | Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_1"]}), 757 | Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_2"]}), 758 | Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_3"]}), 759 | Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_4"]}), 760 | Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_5"]}), 761 | ] 762 | multi_model = MultiModel(models) 763 | rag = LightRAG( 764 | llm_model_func=multi_model.llm_model_func 765 | / ..other args 766 | ) 767 | ``` 768 | """ 769 | 770 | def __init__(self, models: List[Model]): 771 | self._models = models 772 | self._current_model = 0 773 | 774 | def _next_model(self): 775 | self._current_model = (self._current_model + 1) % len(self._models) 776 | return self._models[self._current_model] 777 | 778 | async def llm_model_func( 779 | self, prompt, system_prompt=None, history_messages=[], **kwargs 780 | ) -> str: 781 | kwargs.pop("model", None) # stop from overwriting the custom model name 782 | next_model = self._next_model() 783 | args = dict( 784 | prompt=prompt, 785 | system_prompt=system_prompt, 786 | history_messages=history_messages, 787 | **kwargs, 788 | **next_model.kwargs, 789 | ) 790 | 791 | return await next_model.gen_func(**args) 792 | 793 | 794 | if __name__ == "__main__": 795 | import asyncio 796 | 797 | async def main(): 798 | result = await gpt_4o_mini_complete("How are you?") 799 | print(result) 800 | 801 | asyncio.run(main()) 802 | -------------------------------------------------------------------------------- /notebooks/design-docs/Streamlit Chat element: -------------------------------------------------------------------------------- 1 | #Chat elements 2 | Streamlit provides a few commands to help you build conversational apps. These chat elements are designed to be used in conjunction with each other, but you can also use them separately. 3 | 4 | st.chat_message lets you insert a chat message container into the app so you can display messages from the user or the app. Chat containers can contain other Streamlit elements, including charts, tables, text, and more. st.chat_input lets you display a chat input widget so the user can type in a message. Remember to check out st.status to display output from long-running processes and external API calls. 5 | 6 | screenshot 7 | Chat input 8 | Display a chat input widget. 9 | 10 | prompt = st.chat_input("Say something") 11 | if prompt: 12 | st.write(f"The user has sent: {prompt}") 13 | screenshot 14 | Chat message 15 | Insert a chat message container. 16 | 17 | import numpy as np 18 | with st.chat_message("user"): 19 | st.write("Hello 👋") 20 | st.line_chart(np.random.randn(30, 3)) 21 | screenshot 22 | Status container 23 | Display output of long-running tasks in a container. 24 | 25 | with st.status('Running'): 26 | do_something_slow() 27 | st.write_stream 28 | Write generators or streams to the app with a typewriter effect. 29 | 30 | st.write_stream(my_generator) 31 | st.write_stream(my_llm_stream) 32 | 33 | 34 | 35 | # st.chat_input 36 | Streamlit Version 37 | Version 1.40.0 38 | Display a chat input widget. 39 | 40 | Function signature[source] 41 | st.chat_input(placeholder="Your message", *, key=None, max_chars=None, disabled=False, on_submit=None, args=None, kwargs=None) 42 | 43 | Parameters 44 | placeholder (str) 45 | 46 | A placeholder text shown when the chat input is empty. Defaults to "Your message". For accessibility reasons, you should not use an empty string. 47 | 48 | key (str or int) 49 | 50 | An optional string or integer to use as the unique key for the widget. If this is omitted, a key will be generated for the widget based on its content. No two widgets may have the same key. 51 | 52 | max_chars (int or None) 53 | 54 | The maximum number of characters that can be entered. If None (default), there will be no maximum. 55 | 56 | disabled (bool) 57 | 58 | Whether the chat input should be disabled. Defaults to False. 59 | 60 | on_submit (callable) 61 | 62 | An optional callback invoked when the chat input's value is submitted. 63 | 64 | args (tuple) 65 | 66 | An optional tuple of args to pass to the callback. 67 | 68 | kwargs (dict) 69 | 70 | An optional dict of kwargs to pass to the callback. 71 | 72 | Returns 73 | (str or None) 74 | 75 | The current (non-empty) value of the text input widget on the last run of the app. Otherwise, None. 76 | 77 | Examples 78 | When st.chat_input is used in the main body of an app, it will be pinned to the bottom of the page. 79 | 80 | import streamlit as st 81 | 82 | prompt = st.chat_input("Say something") 83 | if prompt: 84 | st.write(f"User has sent the following prompt: {prompt}") 85 | Copy 86 | 87 | Built with Streamlit 🎈 88 | Fullscreen 89 | open_in_new 90 | The chat input can also be used inline by nesting it inside any layout container (container, columns, tabs, sidebar, etc) or fragment. Create chat interfaces embedded next to other content or have multiple chatbots! 91 | 92 | import streamlit as st 93 | 94 | with st.sidebar: 95 | messages = st.container(height=300) 96 | if prompt := st.chat_input("Say something"): 97 | messages.chat_message("user").write(prompt) 98 | messages.chat_message("assistant").write(f"Echo: {prompt}") 99 | 100 | 101 | --- 102 | st.chat_message 103 | Streamlit Version 104 | Version 1.40.0 105 | Insert a chat message container. 106 | 107 | To add elements to the returned container, you can use with notation (preferred) or just call methods directly on the returned object. See the examples below. 108 | 109 | Function signature[source] 110 | st.chat_message(name, *, avatar=None) 111 | 112 | Parameters 113 | name ("user", "assistant", "ai", "human", or str) 114 | 115 | The name of the message author. Can be "human"/"user" or "ai"/"assistant" to enable preset styling and avatars. 116 | 117 | Currently, the name is not shown in the UI but is only set as an accessibility label. For accessibility reasons, you should not use an empty string. 118 | 119 | avatar (Anything supported by st.image, str, or None) 120 | 121 | The avatar shown next to the message. 122 | 123 | If avatar is None (default), the icon will be determined from name as follows: 124 | 125 | If name is "user" or "human", the message will have a default user icon. 126 | If name is "ai" or "assistant", the message will have a default bot icon. 127 | For all other values of name, the message will show the first letter of the name. 128 | In addition to the types supported by st.image (like URLs or numpy arrays), the following strings are valid: 129 | 130 | A single-character emoji. For example, you can set avatar="🧑‍💻" or avatar="🦖". Emoji short codes are not supported. 131 | 132 | An icon from the Material Symbols library (rounded style) in the format ":material/icon_name:" where "icon_name" is the name of the icon in snake case. 133 | 134 | For example, icon=":material/thumb_up:" will display the Thumb Up icon. Find additional icons in the Material Symbols font library. 135 | 136 | Returns 137 | (Container) 138 | 139 | A single container that can hold multiple elements. 140 | 141 | Examples 142 | You can use with notation to insert any element into an expander 143 | 144 | import streamlit as st 145 | import numpy as np 146 | 147 | with st.chat_message("user"): 148 | st.write("Hello 👋") 149 | st.line_chart(np.random.randn(30, 3)) 150 | 151 | --- 152 | st.status 153 | Streamlit Version 154 | Version 1.40.0 155 | Insert a status container to display output from long-running tasks. 156 | 157 | Inserts a container into your app that is typically used to show the status and details of a process or task. The container can hold multiple elements and can be expanded or collapsed by the user similar to st.expander. When collapsed, all that is visible is the status icon and label. 158 | 159 | The label, state, and expanded state can all be updated by calling .update() on the returned object. To add elements to the returned container, you can use with notation (preferred) or just call methods directly on the returned object. 160 | 161 | By default, st.status() initializes in the "running" state. When called using with notation, it automatically updates to the "complete" state at the end of the "with" block. See examples below for more details. 162 | 163 | Function signature[source] 164 | st.status(label, *, expanded=False, state="running") 165 | 166 | Parameters 167 | label (str) 168 | 169 | The initial label of the status container. The label can optionally contain GitHub-flavored Markdown of the following types: Bold, Italics, Strikethroughs, Inline Code, Links, and Images. Images display like icons, with a max height equal to the font height. 170 | 171 | Unsupported Markdown elements are unwrapped so only their children (text contents) render. Display unsupported elements as literal characters by backslash-escaping them. E.g., "1\. Not an ordered list". 172 | 173 | See the body parameter of st.markdown for additional, supported Markdown directives. 174 | 175 | expanded (bool) 176 | 177 | If True, initializes the status container in "expanded" state. Defaults to False (collapsed). 178 | 179 | state ("running", "complete", or "error") 180 | 181 | The initial state of the status container which determines which icon is shown: 182 | 183 | running (default): A spinner icon is shown. 184 | complete: A checkmark icon is shown. 185 | error: An error icon is shown. 186 | Returns 187 | (StatusContainer) 188 | 189 | A mutable status container that can hold multiple elements. The label, state, and expanded state can be updated after creation via .update(). 190 | 191 | Examples 192 | You can use the with notation to insert any element into an status container: 193 | 194 | import time 195 | import streamlit as st 196 | 197 | with st.status("Downloading data..."): 198 | st.write("Searching for data...") 199 | time.sleep(2) 200 | st.write("Found URL.") 201 | time.sleep(1) 202 | st.write("Downloading data...") 203 | time.sleep(1) 204 | 205 | st.button("Rerun") 206 | 207 | --- 208 | 209 | # st.status 210 | Streamlit Version 211 | Version 1.40.0 212 | Insert a status container to display output from long-running tasks. 213 | 214 | Inserts a container into your app that is typically used to show the status and details of a process or task. The container can hold multiple elements and can be expanded or collapsed by the user similar to st.expander. When collapsed, all that is visible is the status icon and label. 215 | 216 | The label, state, and expanded state can all be updated by calling .update() on the returned object. To add elements to the returned container, you can use with notation (preferred) or just call methods directly on the returned object. 217 | 218 | By default, st.status() initializes in the "running" state. When called using with notation, it automatically updates to the "complete" state at the end of the "with" block. See examples below for more details. 219 | 220 | Function signature[source] 221 | st.status(label, *, expanded=False, state="running") 222 | 223 | Parameters 224 | label (str) 225 | 226 | The initial label of the status container. The label can optionally contain GitHub-flavored Markdown of the following types: Bold, Italics, Strikethroughs, Inline Code, Links, and Images. Images display like icons, with a max height equal to the font height. 227 | 228 | Unsupported Markdown elements are unwrapped so only their children (text contents) render. Display unsupported elements as literal characters by backslash-escaping them. E.g., "1\. Not an ordered list". 229 | 230 | See the body parameter of st.markdown for additional, supported Markdown directives. 231 | 232 | expanded (bool) 233 | 234 | If True, initializes the status container in "expanded" state. Defaults to False (collapsed). 235 | 236 | state ("running", "complete", or "error") 237 | 238 | The initial state of the status container which determines which icon is shown: 239 | 240 | running (default): A spinner icon is shown. 241 | complete: A checkmark icon is shown. 242 | error: An error icon is shown. 243 | Returns 244 | (StatusContainer) 245 | 246 | A mutable status container that can hold multiple elements. The label, state, and expanded state can be updated after creation via .update(). 247 | 248 | Examples 249 | You can use the with notation to insert any element into an status container: 250 | 251 | import time 252 | import streamlit as st 253 | 254 | with st.status("Downloading data..."): 255 | st.write("Searching for data...") 256 | time.sleep(2) 257 | st.write("Found URL.") 258 | time.sleep(1) 259 | st.write("Downloading data...") 260 | time.sleep(1) 261 | 262 | st.button("Rerun") 263 | 264 | --- 265 | # Add statefulness to apps 266 | What is State? 267 | We define access to a Streamlit app in a browser tab as a session. For each browser tab that connects to the Streamlit server, a new session is created. Streamlit reruns your script from top to bottom every time you interact with your app. Each reruns takes place in a blank slate: no variables are shared between runs. 268 | 269 | Session State is a way to share variables between reruns, for each user session. In addition to the ability to store and persist state, Streamlit also exposes the ability to manipulate state using Callbacks. Session state also persists across pages inside a multipage app. 270 | 271 | In this guide, we will illustrate the usage of Session State and Callbacks as we build a stateful Counter app. 272 | 273 | For details on the Session State and Callbacks API, please refer to our Session State API Reference Guide. 274 | 275 | Also, check out this Session State basics tutorial video by Streamlit Developer Advocate Dr. Marisa Smith to get started: 276 | 277 | 278 | Build a Counter 279 | Let's call our script counter.py. It initializes a count variable and has a button to increment the value stored in the count variable: 280 | 281 | import streamlit as st 282 | 283 | st.title('Counter Example') 284 | count = 0 285 | 286 | increment = st.button('Increment') 287 | if increment: 288 | count += 1 289 | 290 | st.write('Count = ', count) 291 | No matter how many times we press the Increment button in the above app, the count remains at 1. Let's understand why: 292 | 293 | Each time we press the Increment button, Streamlit reruns counter.py from top to bottom, and with every run, count gets initialized to 0 . 294 | Pressing Increment subsequently adds 1 to 0, thus count=1 no matter how many times we press Increment. 295 | As we'll see later, we can avoid this issue by storing count as a Session State variable. By doing so, we're indicating to Streamlit that it should maintain the value stored inside a Session State variable across app reruns. 296 | 297 | Let's learn more about the API to use Session State. 298 | 299 | Initialization 300 | The Session State API follows a field-based API, which is very similar to Python dictionaries: 301 | 302 | import streamlit as st 303 | 304 | # Check if 'key' already exists in session_state 305 | # If not, then initialize it 306 | if 'key' not in st.session_state: 307 | st.session_state['key'] = 'value' 308 | 309 | # Session State also supports the attribute based syntax 310 | if 'key' not in st.session_state: 311 | st.session_state.key = 'value' 312 | Reads and updates 313 | Read the value of an item in Session State by passing the item to st.write : 314 | 315 | import streamlit as st 316 | 317 | if 'key' not in st.session_state: 318 | st.session_state['key'] = 'value' 319 | 320 | # Reads 321 | st.write(st.session_state.key) 322 | 323 | # Outputs: value 324 | Update an item in Session State by assigning it a value: 325 | 326 | import streamlit as st 327 | 328 | if 'key' not in st.session_state: 329 | st.session_state['key'] = 'value' 330 | 331 | # Updates 332 | st.session_state.key = 'value2' # Attribute API 333 | st.session_state['key'] = 'value2' # Dictionary like API 334 | Streamlit throws an exception if an uninitialized variable is accessed: 335 | 336 | import streamlit as st 337 | 338 | st.write(st.session_state['value']) 339 | 340 | # Throws an exception! 341 | state-uninitialized-exception 342 | Let's now take a look at a few examples that illustrate how to add Session State to our Counter app. 343 | 344 | Example 1: Add Session State 345 | Now that we've got a hang of the Session State API, let's update our Counter app to use Session State: 346 | 347 | import streamlit as st 348 | 349 | st.title('Counter Example') 350 | if 'count' not in st.session_state: 351 | st.session_state.count = 0 352 | 353 | increment = st.button('Increment') 354 | if increment: 355 | st.session_state.count += 1 356 | 357 | st.write('Count = ', st.session_state.count) 358 | As you can see in the above example, pressing the Increment button updates the count each time. 359 | 360 | Example 2: Session State and Callbacks 361 | Now that we've built a basic Counter app using Session State, let's move on to something a little more complex. The next example uses Callbacks with Session State. 362 | 363 | Callbacks: A callback is a Python function which gets called when an input widget changes. Callbacks can be used with widgets using the parameters on_change (or on_click), args, and kwargs. The full Callbacks API can be found in our Session State API Reference Guide. 364 | 365 | import streamlit as st 366 | 367 | st.title('Counter Example using Callbacks') 368 | if 'count' not in st.session_state: 369 | st.session_state.count = 0 370 | 371 | def increment_counter(): 372 | st.session_state.count += 1 373 | 374 | st.button('Increment', on_click=increment_counter) 375 | 376 | st.write('Count = ', st.session_state.count) 377 | Now, pressing the Increment button updates the count each time by calling the increment_counter() function. 378 | 379 | Example 3: Use args and kwargs in Callbacks 380 | Callbacks also support passing arguments using the args parameter in a widget: 381 | 382 | import streamlit as st 383 | 384 | st.title('Counter Example using Callbacks with args') 385 | if 'count' not in st.session_state: 386 | st.session_state.count = 0 387 | 388 | increment_value = st.number_input('Enter a value', value=0, step=1) 389 | 390 | def increment_counter(increment_value): 391 | st.session_state.count += increment_value 392 | 393 | increment = st.button('Increment', on_click=increment_counter, 394 | args=(increment_value, )) 395 | 396 | st.write('Count = ', st.session_state.count) 397 | Additionally, we can also use the kwargs parameter in a widget to pass named arguments to the callback function as shown below: 398 | 399 | import streamlit as st 400 | 401 | st.title('Counter Example using Callbacks with kwargs') 402 | if 'count' not in st.session_state: 403 | st.session_state.count = 0 404 | 405 | def increment_counter(increment_value=0): 406 | st.session_state.count += increment_value 407 | 408 | def decrement_counter(decrement_value=0): 409 | st.session_state.count -= decrement_value 410 | 411 | st.button('Increment', on_click=increment_counter, 412 | kwargs=dict(increment_value=5)) 413 | 414 | st.button('Decrement', on_click=decrement_counter, 415 | kwargs=dict(decrement_value=1)) 416 | 417 | st.write('Count = ', st.session_state.count) 418 | Example 4: Forms and Callbacks 419 | Say we now want to not only increment the count, but also store when it was last updated. We illustrate doing this using Callbacks and st.form: 420 | 421 | import streamlit as st 422 | import datetime 423 | 424 | st.title('Counter Example') 425 | if 'count' not in st.session_state: 426 | st.session_state.count = 0 427 | st.session_state.last_updated = datetime.time(0,0) 428 | 429 | def update_counter(): 430 | st.session_state.count += st.session_state.increment_value 431 | st.session_state.last_updated = st.session_state.update_time 432 | 433 | with st.form(key='my_form'): 434 | st.time_input(label='Enter the time', value=datetime.datetime.now().time(), key='update_time') 435 | st.number_input('Enter a value', value=0, step=1, key='increment_value') 436 | submit = st.form_submit_button(label='Update', on_click=update_counter) 437 | 438 | st.write('Current Count = ', st.session_state.count) 439 | st.write('Last Updated = ', st.session_state.last_updated) 440 | Advanced concepts 441 | Session State and Widget State association 442 | Session State provides the functionality to store variables across reruns. Widget state (i.e. the value of a widget) is also stored in a session. 443 | 444 | For simplicity, we have unified this information in one place. i.e. the Session State. This convenience feature makes it super easy to read or write to the widget's state anywhere in the app's code. Session State variables mirror the widget value using the key argument. 445 | 446 | We illustrate this with the following example. Let's say we have an app with a slider to represent temperature in Celsius. We can set and get the value of the temperature widget by using the Session State API, as follows: 447 | 448 | import streamlit as st 449 | 450 | if "celsius" not in st.session_state: 451 | # set the initial default value of the slider widget 452 | st.session_state.celsius = 50.0 453 | 454 | st.slider( 455 | "Temperature in Celsius", 456 | min_value=-100.0, 457 | max_value=100.0, 458 | key="celsius" 459 | ) 460 | 461 | # This will get the value of the slider widget 462 | st.write(st.session_state.celsius) 463 | There is a limitation to setting widget values using the Session State API. 464 | 465 | priority_high 466 | Important 467 | Streamlit does not allow setting widget values via the Session State API for st.button and st.file_uploader. 468 | 469 | The following example will raise a StreamlitAPIException on trying to set the state of st.button via the Session State API: 470 | 471 | import streamlit as st 472 | 473 | if 'my_button' not in st.session_state: 474 | st.session_state.my_button = True 475 | # Streamlit will raise an Exception on trying to set the state of button 476 | 477 | st.button('Submit', key='my_button') 478 | state-button-exception 479 | Serializable Session State 480 | Serialization refers to the process of converting an object or data structure into a format that can be persisted and shared, and allowing you to recover the data’s original structure. Python’s built-in pickle module serializes Python objects to a byte stream ("pickling") and deserializes the stream into an object ("unpickling"). 481 | 482 | By default, Streamlit’s Session State allows you to persist any Python object for the duration of the session, irrespective of the object’s pickle-serializability. This property lets you store Python primitives such as integers, floating-point numbers, complex numbers and booleans, dataframes, and even lambdas returned by functions. However, some execution environments may require serializing all data in Session State, so it may be useful to detect incompatibility during development, or when the execution environment will stop supporting it in the future. 483 | 484 | To that end, Streamlit provides a runner.enforceSerializableSessionState configuration option that, when set to true, only allows pickle-serializable objects in Session State. To enable the option, either create a global or project config file with the following or use it as a command-line flag: 485 | 486 | # .streamlit/config.toml 487 | [runner] 488 | enforceSerializableSessionState = true 489 | By "pickle-serializable", we mean calling pickle.dumps(obj) should not raise a PicklingError exception. When the config option is enabled, adding unserializable data to session state should result in an exception. E.g., 490 | 491 | import streamlit as st 492 | 493 | def unserializable_data(): 494 | return lambda x: x 495 | 496 | #👇 results in an exception when enforceSerializableSessionState is on 497 | st.session_state.unserializable = unserializable_data() 498 | UnserializableSessionStateError 499 | priority_high 500 | Warning 501 | When runner.enforceSerializableSessionState is set to true, Session State implicitly uses the pickle module, which is known to be insecure. Ensure all data saved and retrieved from Session State is trusted because it is possible to construct malicious pickle data that will execute arbitrary code during unpickling. Never load data that could have come from an untrusted source in an unsafe mode or that could have been tampered with. Only load data you trust. 502 | 503 | Caveats and limitations 504 | Here are some limitations to keep in mind when using Session State: 505 | 506 | Session State exists for as long as the tab is open and connected to the Streamlit server. As soon as you close the tab, everything stored in Session State is lost. 507 | Session State is not persisted. If the Streamlit server crashes, then everything stored in Session State gets wiped 508 | For caveats and limitations with the Session State API, please see the API limitations. 509 | 510 | --- 511 | # Session State 512 | Session State is a way to share variables between reruns, for each user session. In addition to the ability to store and persist state, Streamlit also exposes the ability to manipulate state using Callbacks. Session state also persists across apps inside a multipage app. 513 | 514 | Check out this Session State basics tutorial video by Streamlit Developer Advocate Dr. Marisa Smith to get started: 515 | 516 | 517 | Initialize values in Session State 518 | The Session State API follows a field-based API, which is very similar to Python dictionaries: 519 | 520 | # Initialization 521 | if 'key' not in st.session_state: 522 | st.session_state['key'] = 'value' 523 | 524 | # Session State also supports attribute based syntax 525 | if 'key' not in st.session_state: 526 | st.session_state.key = 'value' 527 | Reads and updates 528 | Read the value of an item in Session State and display it by passing to st.write : 529 | 530 | # Read 531 | st.write(st.session_state.key) 532 | 533 | # Outputs: value 534 | Update an item in Session State by assigning it a value: 535 | 536 | st.session_state.key = 'value2' # Attribute API 537 | st.session_state['key'] = 'value2' # Dictionary like API 538 | Curious about what is in Session State? Use st.write or magic: 539 | 540 | st.write(st.session_state) 541 | 542 | # With magic: 543 | st.session_state 544 | Streamlit throws a handy exception if an uninitialized variable is accessed: 545 | 546 | st.write(st.session_state['value']) 547 | 548 | # Throws an exception! 549 | state-uninitialized-exception 550 | Delete items 551 | Delete items in Session State using the syntax to delete items in any Python dictionary: 552 | 553 | # Delete a single key-value pair 554 | del st.session_state[key] 555 | 556 | # Delete all the items in Session state 557 | for key in st.session_state.keys(): 558 | del st.session_state[key] 559 | Session State can also be cleared by going to Settings → Clear Cache, followed by Rerunning the app. 560 | 561 | state-clear-cache 562 | Session State and Widget State association 563 | Every widget with a key is automatically added to Session State: 564 | 565 | st.text_input("Your name", key="name") 566 | 567 | # This exists now: 568 | st.session_state.name 569 | Use Callbacks to update Session State 570 | A callback is a python function which gets called when an input widget changes. 571 | 572 | Order of execution: When updating Session state in response to events, a callback function gets executed first, and then the app is executed from top to bottom. 573 | 574 | Callbacks can be used with widgets using the parameters on_change (or on_click), args, and kwargs: 575 | 576 | Parameters 577 | 578 | on_change or on_click - The function name to be used as a callback 579 | args (tuple) - List of arguments to be passed to the callback function 580 | kwargs (dict) - Named arguments to be passed to the callback function 581 | Widgets which support the on_change event: 582 | 583 | st.checkbox 584 | st.color_picker 585 | st.date_input 586 | st.data_editor 587 | st.file_uploader 588 | st.multiselect 589 | st.number_input 590 | st.radio 591 | st.select_slider 592 | st.selectbox 593 | st.slider 594 | st.text_area 595 | st.text_input 596 | st.time_input 597 | st.toggle 598 | Widgets which support the on_click event: 599 | 600 | st.button 601 | st.download_button 602 | st.form_submit_button 603 | To add a callback, define a callback function above the widget declaration and pass it to the widget via the on_change (or on_click ) parameter. 604 | 605 | Forms and Callbacks 606 | Widgets inside a form can have their values be accessed and set via the Session State API. st.form_submit_button can have a callback associated with it. The callback gets executed upon clicking on the submit button. For example: 607 | 608 | def form_callback(): 609 | st.write(st.session_state.my_slider) 610 | st.write(st.session_state.my_checkbox) 611 | 612 | with st.form(key='my_form'): 613 | slider_input = st.slider('My slider', 0, 10, 5, key='my_slider') 614 | checkbox_input = st.checkbox('Yes or No', key='my_checkbox') 615 | submit_button = st.form_submit_button(label='Submit', on_click=form_callback) 616 | Serializable Session State 617 | Serialization refers to the process of converting an object or data structure into a format that can be persisted and shared, and allowing you to recover the data’s original structure. Python’s built-in pickle module serializes Python objects to a byte stream ("pickling") and deserializes the stream into an object ("unpickling"). 618 | 619 | By default, Streamlit’s Session State allows you to persist any Python object for the duration of the session, irrespective of the object’s pickle-serializability. This property lets you store Python primitives such as integers, floating-point numbers, complex numbers and booleans, dataframes, and even lambdas returned by functions. However, some execution environments may require serializing all data in Session State, so it may be useful to detect incompatibility during development, or when the execution environment will stop supporting it in the future. 620 | 621 | To that end, Streamlit provides a runner.enforceSerializableSessionState configuration option that, when set to true, only allows pickle-serializable objects in Session State. To enable the option, either create a global or project config file with the following or use it as a command-line flag: 622 | 623 | # .streamlit/config.toml 624 | [runner] 625 | enforceSerializableSessionState = true 626 | By "pickle-serializable", we mean calling pickle.dumps(obj) should not raise a PicklingError exception. When the config option is enabled, adding unserializable data to session state should result in an exception. E.g., 627 | 628 | import streamlit as st 629 | 630 | def unserializable_data(): 631 | return lambda x: x 632 | 633 | #👇 results in an exception when enforceSerializableSessionState is on 634 | st.session_state.unserializable = unserializable_data() 635 | UnserializableSessionStateError 636 | priority_high 637 | Warning 638 | When runner.enforceSerializableSessionState is set to true, Session State implicitly uses the pickle module, which is known to be insecure. Ensure all data saved and retrieved from Session State is trusted because it is possible to construct malicious pickle data that will execute arbitrary code during unpickling. Never load data that could have come from an untrusted source in an unsafe mode or that could have been tampered with. Only load data you trust. 639 | 640 | Caveats and limitations 641 | Only the st.form_submit_button has a callback in forms. Other widgets inside a form are not allowed to have callbacks. 642 | 643 | on_change and on_click events are only supported on input type widgets. 644 | 645 | Modifying the value of a widget via the Session state API, after instantiating it, is not allowed and will raise a StreamlitAPIException. For example: 646 | 647 | slider = st.slider( 648 | label='My Slider', min_value=1, 649 | max_value=10, value=5, key='my_slider') 650 | 651 | st.session_state.my_slider = 7 652 | 653 | # Throws an exception! 654 | state-modified-instantiated-exception 655 | Setting the widget state via the Session State API and using the value parameter in the widget declaration is not recommended, and will throw a warning on the first run. For example: 656 | 657 | st.session_state.my_slider = 7 658 | 659 | slider = st.slider( 660 | label='Choose a Value', min_value=1, 661 | max_value=10, value=5, key='my_slider') 662 | state-value-api-exception 663 | Setting the state of button-like widgets: st.button, st.download_button, and st.file_uploader via the Session State API is not allowed. Such type of widgets are by default False and have ephemeral True states which are only valid for a single run. For example: 664 | 665 | if 'my_button' not in st.session_state: 666 | st.session_state.my_button = True 667 | 668 | st.button('My button', key='my_button') 669 | 670 | # Throws an exception! 671 | --------------------------------------------------------------------------------