├── .github ├── ISSUE_TEMPLATE │ └── bug_report.md └── workflows │ ├── format-check.yml │ └── python-package.yml ├── .gitignore ├── .pre-commit-config.yaml ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── assets ├── co-storm-workflow.jpg ├── logo.svg ├── overview.svg ├── storm_naacl2024_slides.pdf └── two_stages.jpg ├── examples ├── costorm_examples │ └── run_costorm_gpt.py └── storm_examples │ ├── README.md │ ├── helper │ └── process_kaggle_arxiv_abstract_dataset.py │ ├── run_storm_wiki_claude.py │ ├── run_storm_wiki_deepseek.py │ ├── run_storm_wiki_gemini.py │ ├── run_storm_wiki_gpt.py │ ├── run_storm_wiki_gpt_with_VectorRM.py │ ├── run_storm_wiki_groq.py │ ├── run_storm_wiki_mistral.py │ ├── run_storm_wiki_ollama.py │ ├── run_storm_wiki_ollama_with_searxng.py │ └── run_storm_wiki_serper.py ├── frontend └── demo_light │ ├── .streamlit │ └── config.toml │ ├── README.md │ ├── assets │ ├── article_display.jpg │ ├── create_article.jpg │ └── void.jpg │ ├── demo_util.py │ ├── pages_util │ ├── CreateNewArticle.py │ └── MyArticles.py │ ├── requirements.txt │ ├── stoc.py │ └── storm.py ├── knowledge_storm ├── __init__.py ├── collaborative_storm │ ├── __init__.py │ ├── engine.py │ └── modules │ │ ├── __init__.py │ │ ├── article_generation.py │ │ ├── callback.py │ │ ├── co_storm_agents.py │ │ ├── collaborative_storm_utils.py │ │ ├── costorm_expert_utterance_generator.py │ │ ├── expert_generation.py │ │ ├── grounded_question_answering.py │ │ ├── grounded_question_generation.py │ │ ├── information_insertion_module.py │ │ ├── knowledge_base_summary.py │ │ ├── simulate_user.py │ │ └── warmstart_hierarchical_chat.py ├── dataclass.py ├── encoder.py ├── interface.py ├── lm.py ├── logging_wrapper.py ├── rm.py ├── storm_wiki │ ├── __init__.py │ ├── engine.py │ └── modules │ │ ├── __init__.py │ │ ├── article_generation.py │ │ ├── article_polish.py │ │ ├── callback.py │ │ ├── knowledge_curation.py │ │ ├── outline_generation.py │ │ ├── persona_generator.py │ │ ├── retriever.py │ │ └── storm_dataclass.py └── utils.py ├── requirements.txt └── setup.py /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: "[BUG]" 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | 13 | **To Reproduce** 14 | Report following things 15 | 1. Input topic name 16 | 2. All output files generated for this topic as a zip file. 17 | 18 | **Screenshots** 19 | If applicable, add screenshots to help explain your problem. 20 | 21 | **Environment:** 22 | - OS: [e.g. iOS, Windows] 23 | - Browser [e.g. chrome, safari] if the bug report is UI problem 24 | -------------------------------------------------------------------------------- /.github/workflows/format-check.yml: -------------------------------------------------------------------------------- 1 | name: Check Python formatting with Black 2 | 3 | on: 4 | pull_request: 5 | branches: 6 | - main 7 | 8 | jobs: 9 | lint: 10 | runs-on: ubuntu-latest 11 | steps: 12 | - uses: actions/checkout@v2 13 | - uses: actions/setup-python@v2 14 | - uses: psf/black@stable 15 | with: 16 | black_args: "knowledge_storm --check" 17 | -------------------------------------------------------------------------------- /.github/workflows/python-package.yml: -------------------------------------------------------------------------------- 1 | name: Build and upload Python package 2 | 3 | on: 4 | workflow_dispatch: # Allows manual triggering of the workflow 5 | 6 | jobs: 7 | build: 8 | 9 | runs-on: ubuntu-latest 10 | 11 | steps: 12 | - uses: actions/checkout@master 13 | - name: Set up Python 3.11 14 | uses: actions/setup-python@v3 15 | with: 16 | python-version: "3.11" 17 | - name: Compare versions in setup.py and knowledge_storm/__init__.py 18 | run: | 19 | VERSION_SETUP=$(grep -oP '(?<=version=\").*(?=\")' setup.py) 20 | VERSION_INIT=$(grep -oP '(?<=__version__ = \").*(?=\")' knowledge_storm/__init__.py) 21 | echo "Version in setup.py: $VERSION_SETUP" 22 | echo "Version in __init__.py: $VERSION_INIT" 23 | if [ "$VERSION_SETUP" != "$VERSION_INIT" ]; then 24 | echo "Error: Version mismatch between setup.py ($VERSION_SETUP) and knowledge_storm/__init__.py ($VERSION_INIT)" 25 | exit 1 26 | fi 27 | shell: bash 28 | - name: Install dependencies 29 | run: python3 -m pip install setuptools wheel twine 30 | - name: Install dependencies 31 | run: | 32 | python3 -m pip install --upgrade pip setuptools wheel 33 | if [ -f requirements.txt ]; then pip install -r requirements.txt; fi 34 | - name: Build a binary wheel 35 | run: python3 setup.py sdist bdist_wheel 36 | - name: Publish package to PyPI 37 | uses: pypa/gh-action-pypi-publish@release/v1 38 | with: 39 | user: __token__ 40 | password: ${{ secrets.PYPI_API_TOKEN }} 41 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # mac 7 | .DS_Store 8 | 9 | # Other 10 | .vscode 11 | *.tsv 12 | *.pt 13 | gpt*.txt 14 | *.env 15 | local/ 16 | local_* 17 | build/ 18 | *.egg-info/ 19 | .idea 20 | .venv 21 | 22 | # Project-specific 23 | secrets.toml 24 | *.log 25 | */assertion.log 26 | *results/ 27 | .venv/ -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/psf/black 3 | rev: 24.8.0 4 | hooks: 5 | - id: black 6 | name: Format Python code with black 7 | entry: black 8 | args: ["knowledge_storm/"] 9 | language: python 10 | pass_filenames: true -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing 2 | 3 | Thank you for your interest in contributing to STORM! 4 | 5 | Contributions aren't just about code. Currently (last edit: 7/22/2024), we are accepting the following forms of contribution: 6 | - Pull requests for additional language model support to `knowledge_storm/lm.py`. 7 | - Pull requests for additional retrieval model/search engine support to `knowledge_storm/rm.py`. 8 | - Pull requests for new features to `frontend/demo_light` to assist other developers. 9 | - Identification and reporting of issues or bugs. 10 | - Helping each other by responding to issues. 11 | 12 | Please note that we are not accepting code refactoring PRs at this time to avoid conflicts with our team's efforts. 13 | 14 | ## Development 15 | This section contains technical instructions & hints for contributors. 16 | 17 | ### Setting up 18 | 1. Fork this repository and clone your forked repository. 19 | 2. Install the required packages: 20 | ``` 21 | conda create -n storm python=3.11 22 | conda activate storm 23 | pip install -r requirements.txt 24 | ``` 25 | 3. If you want to contribute to `frontend/demo_light`, follow its [Setup guide](https://github.com/stanford-oval/storm/tree/main/frontend/demo_light#setup) to install additional packages. 26 | 27 | ### PR suggestions 28 | 29 | Following the suggested format can lead to a faster review process. 30 | 31 | **Title:** 32 | 33 | [New LM/New RM/Demo Enhancement] xxx 34 | 35 | **Description:** 36 | - For new language model support, (1) describe how to use the new LM class, (2) create an example script following the style of existing example scripts under `examples/`, (3) attach an input-output example of the example script. 37 | - For new retrieval model/search engine support, (1) describe how to use the new RM class and (2) attach input-output examples of the RM class. 38 | - For demo light enhancements, (1) describe what's new and (2) attach screenshots to demonstrate the UI change. 39 | - Please clearly describe the required API keys and provide instructions on how to get them (if applicable). This project manages API key with `secrets.toml`. 40 | 41 | **Code Format:** 42 | 43 | We adopt [`black`](https://github.com/psf/black) for arranging and formatting Python code. To streamline the contribution process, we set up a [pre-commit hook](https://pre-commit.com/) to format the code under `knowledge_storm/` before committing. To install the pre-commit hook, run: 44 | ``` 45 | pip install pre-commit 46 | pre-commit install 47 | ``` 48 | The hook will automatically format the code before each commit. 49 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Stanford Open Virtual Assistant Lab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /assets/co-storm-workflow.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stanford-oval/storm/45ee413100f0287da9ca5250290a56ac4fa73c48/assets/co-storm-workflow.jpg -------------------------------------------------------------------------------- /assets/storm_naacl2024_slides.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stanford-oval/storm/45ee413100f0287da9ca5250290a56ac4fa73c48/assets/storm_naacl2024_slides.pdf -------------------------------------------------------------------------------- /assets/two_stages.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stanford-oval/storm/45ee413100f0287da9ca5250290a56ac4fa73c48/assets/two_stages.jpg -------------------------------------------------------------------------------- /examples/storm_examples/README.md: -------------------------------------------------------------------------------- 1 | # Examples 2 | 3 | We host a number of example scripts for various customization of STORM (e.g., use your favorite language models, use your own corpus, etc.). These examples can be starting points for your own customizations and you are welcome to contribute your own examples by submitting a pull request to this directory. 4 | 5 | ## Run STORM with your own language model 6 | [run_storm_wiki_gpt.py](run_storm_wiki_gpt.py) provides an example of running STORM with GPT models, and [run_storm_wiki_claude.py](run_storm_wiki_claude.py) provides an example of running STORM with Claude models. Besides using close-source models, you can also run STORM with models with open weights. 7 | 8 | `run_storm_wiki_mistral.py` provides an example of running STORM with `Mistral-7B-Instruct-v0.2` using [VLLM](https://docs.vllm.ai/en/stable/) server: 9 | 10 | 1. Set up a VLLM server with the `Mistral-7B-Instruct-v0.2` model running. 11 | 2. Run the following command under the root directory of the repository: 12 | 13 | ``` 14 | python examples/storm_examples/run_storm_wiki_mistral.py \ 15 | --url $URL \ 16 | --port $PORT \ 17 | --output-dir $OUTPUT_DIR \ 18 | --retriever you \ 19 | --do-research \ 20 | --do-generate-outline \ 21 | --do-generate-article \ 22 | --do-polish-article 23 | ``` 24 | - `--url` URL of the VLLM server. 25 | - `--port` Port of the VLLM server. 26 | 27 | Besides VLLM server, STORM is also compatible with [TGI](https://huggingface.co/docs/text-generation-inference/en/index) server or [Together.ai](https://www.together.ai/products#inference) endpoint. 28 | 29 | 30 | ## Run STORM with your own corpus 31 | 32 | By default, STORM is grounded on the Internet using the search engine, but it can also be grounded on your own corpus using `VectorRM`. [run_storm_wiki_with_gpt_with_VectorRM.py](run_storm_wiki_gpt_with_VectorRM.py) provides an example of running STORM grounding on your provided data. 33 | 34 | 1. Set up API keys. 35 | - Make sure you have set up the OpenAI API key. 36 | - `VectorRM` use [Qdrant](https://github.com/qdrant/qdrant-client) to create a vector store. If you want to set up this vector store online on a [Qdrant cloud server](https://cloud.qdrant.io/login), you need to set up `QDRANT_API_KEY` in `secrets.toml` as well; if you want to save the vector store locally, make sure you provide a location for the vector store. 37 | 2. Prepare your corpus. The documents should be provided as a single CSV file with the following format: 38 | 39 | | content | title | url | description | 40 | |------------------------|------------|------------|------------------------------------| 41 | | I am a document. | Document 1 | docu-n-112 | A self-explanatory document. | 42 | | I am another document. | Document 2 | docu-l-13 | Another self-explanatory document. | 43 | | ... | ... | ... | ... | 44 | 45 | - `url` will be used as a unique identifier of the document in STORM engine, so ensure different documents have different urls. 46 | - The contents for `title` and `description` columns are optional. If not provided, the script will use default empty values. 47 | - The content column is crucial and should be provided for each document. 48 | 49 | 3. Run the command under the root directory of the repository: 50 | To create the vector store offline, run 51 | 52 | ``` 53 | python examples/storm_examples/run_storm_wiki_gpt_with_VectorRM.py \ 54 | --output-dir $OUTPUT_DIR \ 55 | --vector-db-mode offline \ 56 | --offline-vector-db-dir $OFFLINE_VECTOR_DB_DIR \ 57 | --csv-file-path $CSV_FILE_PATH \ 58 | --device $DEVICE_FOR_EMBEDDING(mps, cuda, cpu) \ 59 | --do-research \ 60 | --do-generate-outline \ 61 | --do-generate-article \ 62 | --do-polish-article 63 | ``` 64 | 65 | To create the vector store online on a Qdrant server, run 66 | 67 | ``` 68 | python examples/storm_examples/run_storm_wiki_gpt_with_VectorRM.py \ 69 | --output-dir $OUTPUT_DIR \ 70 | --vector-db-mode online \ 71 | --online-vector-db-url $ONLINE_VECTOR_DB_URL \ 72 | --csv-file-path $CSV_FILE_PATH \ 73 | --device $DEVICE_FOR_EMBEDDING(mps, cuda, cpu) \ 74 | --do-research \ 75 | --do-generate-outline \ 76 | --do-generate-article \ 77 | --do-polish-article 78 | ``` 79 | 80 | 4. **Quick test with Kaggle arXiv Paper Abstracts dataset**: 81 | 82 | - Download `arxiv_data_210930-054931.csv` from [here](https://www.kaggle.com/datasets/spsayakpaul/arxiv-paper-abstracts). 83 | - Run the following command under the root directory to downsample the dataset by filtering papers with terms `[cs.CV]` and get a csv file that match the format mentioned above. 84 | 85 | ``` 86 | python examples/storm_examples/helper/process_kaggle_arxiv_abstract_dataset.py --input-path $PATH_TO_THE_DOWNLOADED_FILE --output-path $PATH_TO_THE_PROCESSED_CSV 87 | ``` 88 | - Run the following command to run STORM grounding on the processed dataset. You can input a topic related to computer vision (e.g., "The progress of multimodal models in computer vision") to see the generated article. (Note that the generated article may not include enough details since the quick test only use the abstracts of arxiv papers.) 89 | 90 | ``` 91 | python examples/storm_examples/run_storm_wiki_gpt_with_VectorRM.py \ 92 | --output-dir $OUTPUT_DIR \ 93 | --vector-db-mode offline \ 94 | --offline-vector-db-dir $OFFLINE_VECTOR_DB_DIR \ 95 | --csv-file-path $PATH_TO_THE_PROCESSED_CSV \ 96 | --device $DEVICE_FOR_EMBEDDING(mps, cuda, cpu) \ 97 | --do-research \ 98 | --do-generate-outline \ 99 | --do-generate-article \ 100 | --do-polish-article 101 | ``` 102 | - For a quicker run, you can also download the pre-embedded vector store directly from [here](https://drive.google.com/file/d/1bijFkw5BKU7bqcmXMhO-5hg2fdKAL9bf/view?usp=share_link). 103 | 104 | ``` 105 | python examples/storm_examples/run_storm_wiki_gpt_with_VectorRM.py \ 106 | --output-dir $OUTPUT_DIR \ 107 | --vector-db-mode offline \ 108 | --offline-vector-db-dir $DOWNLOADED_VECTOR_DB_DR \ 109 | --do-research \ 110 | --do-generate-outline \ 111 | --do-generate-article \ 112 | --do-polish-article 113 | ``` -------------------------------------------------------------------------------- /examples/storm_examples/helper/process_kaggle_arxiv_abstract_dataset.py: -------------------------------------------------------------------------------- 1 | """Process `arxiv_data_210930-054931.csv` 2 | from https://www.kaggle.com/datasets/spsayakpaul/arxiv-paper-abstracts 3 | to a csv file that is compatible with VectorRM. 4 | """ 5 | 6 | from argparse import ArgumentParser 7 | 8 | import pandas as pd 9 | 10 | if __name__ == "__main__": 11 | parser = ArgumentParser() 12 | parser.add_argument( 13 | "--input-path", type=str, help="Path to arxiv_data_210930-054931.csv." 14 | ) 15 | parser.add_argument( 16 | "--output-path", 17 | type=str, 18 | help="Path to store the csv file that is compatible with VectorRM.", 19 | ) 20 | args = parser.parse_args() 21 | 22 | df = pd.read_csv(args.input_path) 23 | print(f"The original dataset has {len(df)} samples.") 24 | 25 | # Downsample the dataset. 26 | df = df[df["terms"] == "['cs.CV']"] 27 | 28 | # Reformat the dataset to match the VectorRM input format. 29 | df.rename(columns={"abstracts": "content", "titles": "title"}, inplace=True) 30 | df["url"] = [ 31 | "uid_" + str(idx) for idx in range(len(df)) 32 | ] # Ensure the url is unique. 33 | df["description"] = "" 34 | 35 | print(f"The downsampled dataset has {len(df)} samples.") 36 | df.to_csv(args.output_path, index=False) 37 | -------------------------------------------------------------------------------- /examples/storm_examples/run_storm_wiki_claude.py: -------------------------------------------------------------------------------- 1 | """ 2 | STORM Wiki pipeline powered by Claude family models and You.com search engine. 3 | You need to set up the following environment variables to run this script: 4 | - ANTHROPIC_API_KEY: Anthropic API key 5 | - YDC_API_KEY: You.com API key; BING_SEARCH_API_KEY: Bing Search API key, SERPER_API_KEY: Serper API key, BRAVE_API_KEY: Brave API key, or TAVILY_API_KEY: Tavily API key 6 | 7 | Output will be structured as below 8 | args.output_dir/ 9 | topic_name/ # topic_name will follow convention of underscore-connected topic name w/o space and slash 10 | conversation_log.json # Log of information-seeking conversation 11 | raw_search_results.json # Raw search results from search engine 12 | direct_gen_outline.txt # Outline directly generated with LLM's parametric knowledge 13 | storm_gen_outline.txt # Outline refined with collected information 14 | url_to_info.json # Sources that are used in the final article 15 | storm_gen_article.txt # Final article generated 16 | storm_gen_article_polished.txt # Polished final article (if args.do_polish_article is True) 17 | """ 18 | 19 | import os 20 | from argparse import ArgumentParser 21 | 22 | from knowledge_storm import ( 23 | STORMWikiRunnerArguments, 24 | STORMWikiRunner, 25 | STORMWikiLMConfigs, 26 | ) 27 | from knowledge_storm.lm import ClaudeModel 28 | from knowledge_storm.rm import ( 29 | YouRM, 30 | BingSearch, 31 | BraveRM, 32 | SerperRM, 33 | DuckDuckGoSearchRM, 34 | TavilySearchRM, 35 | SearXNG, 36 | ) 37 | from knowledge_storm.utils import load_api_key 38 | 39 | 40 | def main(args): 41 | load_api_key(toml_file_path="secrets.toml") 42 | lm_configs = STORMWikiLMConfigs() 43 | claude_kwargs = { 44 | "api_key": os.getenv("ANTHROPIC_API_KEY"), 45 | "temperature": 1.0, 46 | "top_p": 0.9, 47 | } 48 | 49 | # STORM is a LM system so different components can be powered by different models. 50 | # For a good balance between cost and quality, you can choose a cheaper/faster model for conv_simulator_lm 51 | # which is used to split queries, synthesize answers in the conversation. We recommend using stronger models 52 | # for outline_gen_lm which is responsible for organizing the collected information, and article_gen_lm 53 | # which is responsible for generating sections with citations. 54 | conv_simulator_lm = ClaudeModel( 55 | model="claude-3-haiku-20240307", max_tokens=500, **claude_kwargs 56 | ) 57 | question_asker_lm = ClaudeModel( 58 | model="claude-3-sonnet-20240229", max_tokens=500, **claude_kwargs 59 | ) 60 | outline_gen_lm = ClaudeModel( 61 | model="claude-3-opus-20240229", max_tokens=400, **claude_kwargs 62 | ) 63 | article_gen_lm = ClaudeModel( 64 | model="claude-3-opus-20240229", max_tokens=700, **claude_kwargs 65 | ) 66 | article_polish_lm = ClaudeModel( 67 | model="claude-3-opus-20240229", max_tokens=4000, **claude_kwargs 68 | ) 69 | 70 | lm_configs.set_conv_simulator_lm(conv_simulator_lm) 71 | lm_configs.set_question_asker_lm(question_asker_lm) 72 | lm_configs.set_outline_gen_lm(outline_gen_lm) 73 | lm_configs.set_article_gen_lm(article_gen_lm) 74 | lm_configs.set_article_polish_lm(article_polish_lm) 75 | 76 | engine_args = STORMWikiRunnerArguments( 77 | output_dir=args.output_dir, 78 | max_conv_turn=args.max_conv_turn, 79 | max_perspective=args.max_perspective, 80 | search_top_k=args.search_top_k, 81 | max_thread_num=args.max_thread_num, 82 | ) 83 | 84 | # STORM is a knowledge curation system which consumes information from the retrieval module. 85 | # Currently, the information source is the Internet and we use search engine API as the retrieval module. 86 | match args.retriever: 87 | case "bing": 88 | rm = BingSearch( 89 | bing_search_api=os.getenv("BING_SEARCH_API_KEY"), 90 | k=engine_args.search_top_k, 91 | ) 92 | case "you": 93 | rm = YouRM(ydc_api_key=os.getenv("YDC_API_KEY"), k=engine_args.search_top_k) 94 | case "brave": 95 | rm = BraveRM( 96 | brave_search_api_key=os.getenv("BRAVE_API_KEY"), 97 | k=engine_args.search_top_k, 98 | ) 99 | case "duckduckgo": 100 | rm = DuckDuckGoSearchRM( 101 | k=engine_args.search_top_k, safe_search="On", region="us-en" 102 | ) 103 | case "serper": 104 | rm = SerperRM( 105 | serper_search_api_key=os.getenv("SERPER_API_KEY"), 106 | query_params={"autocorrect": True, "num": 10, "page": 1}, 107 | ) 108 | case "tavily": 109 | rm = TavilySearchRM( 110 | tavily_search_api_key=os.getenv("TAVILY_API_KEY"), 111 | k=engine_args.search_top_k, 112 | include_raw_content=True, 113 | ) 114 | case "searxng": 115 | rm = SearXNG( 116 | searxng_api_key=os.getenv("SEARXNG_API_KEY"), k=engine_args.search_top_k 117 | ) 118 | case _: 119 | raise ValueError( 120 | f'Invalid retriever: {args.retriever}. Choose either "bing", "you", "brave", "duckduckgo", "serper", "tavily", or "searxng"' 121 | ) 122 | 123 | runner = STORMWikiRunner(engine_args, lm_configs, rm) 124 | 125 | topic = input("Topic: ") 126 | runner.run( 127 | topic=topic, 128 | do_research=args.do_research, 129 | do_generate_outline=args.do_generate_outline, 130 | do_generate_article=args.do_generate_article, 131 | do_polish_article=args.do_polish_article, 132 | ) 133 | runner.post_run() 134 | runner.summary() 135 | 136 | 137 | if __name__ == "__main__": 138 | parser = ArgumentParser() 139 | # global arguments 140 | parser.add_argument( 141 | "--output-dir", 142 | type=str, 143 | default="./results/claude", 144 | help="Directory to store the outputs.", 145 | ) 146 | parser.add_argument( 147 | "--max-thread-num", 148 | type=int, 149 | default=3, 150 | help="Maximum number of threads to use. The information seeking part and the article generation" 151 | "part can speed up by using multiple threads. Consider reducing it if keep getting " 152 | '"Exceed rate limit" error when calling LM API.', 153 | ) 154 | parser.add_argument( 155 | "--retriever", 156 | type=str, 157 | choices=["bing", "you", "brave", "serper", "duckduckgo", "tavily", "searxng"], 158 | help="The search engine API to use for retrieving information.", 159 | ) 160 | # stage of the pipeline 161 | parser.add_argument( 162 | "--do-research", 163 | action="store_true", 164 | help="If True, simulate conversation to research the topic; otherwise, load the results.", 165 | ) 166 | parser.add_argument( 167 | "--do-generate-outline", 168 | action="store_true", 169 | help="If True, generate an outline for the topic; otherwise, load the results.", 170 | ) 171 | parser.add_argument( 172 | "--do-generate-article", 173 | action="store_true", 174 | help="If True, generate an article for the topic; otherwise, load the results.", 175 | ) 176 | parser.add_argument( 177 | "--do-polish-article", 178 | action="store_true", 179 | help="If True, polish the article by adding a summarization section and (optionally) removing " 180 | "duplicate content.", 181 | ) 182 | # hyperparameters for the pre-writing stage 183 | parser.add_argument( 184 | "--max-conv-turn", 185 | type=int, 186 | default=3, 187 | help="Maximum number of questions in conversational question asking.", 188 | ) 189 | parser.add_argument( 190 | "--max-perspective", 191 | type=int, 192 | default=3, 193 | help="Maximum number of perspectives to consider in perspective-guided question asking.", 194 | ) 195 | parser.add_argument( 196 | "--search-top-k", 197 | type=int, 198 | default=3, 199 | help="Top k search results to consider for each search query.", 200 | ) 201 | # hyperparameters for the writing stage 202 | parser.add_argument( 203 | "--retrieve-top-k", 204 | type=int, 205 | default=3, 206 | help="Top k collected references for each section title.", 207 | ) 208 | parser.add_argument( 209 | "--remove-duplicate", 210 | action="store_true", 211 | help="If True, remove duplicate content from the article.", 212 | ) 213 | 214 | main(parser.parse_args()) 215 | -------------------------------------------------------------------------------- /examples/storm_examples/run_storm_wiki_deepseek.py: -------------------------------------------------------------------------------- 1 | """ 2 | STORM Wiki pipeline powered by DeepSeek models and You.com or Bing search engine. 3 | You need to set up the following environment variables to run this script: 4 | - DEEPSEEK_API_KEY: DeepSeek API key 5 | - DEEPSEEK_API_BASE: DeepSeek API base URL (default is https://api.deepseek.com) 6 | - YDC_API_KEY: You.com API key; BING_SEARCH_API_KEY: Bing Search API key, SERPER_API_KEY: Serper API key, BRAVE_API_KEY: Brave API key, or TAVILY_API_KEY: Tavily API key 7 | 8 | Output will be structured as below 9 | args.output_dir/ 10 | topic_name/ # topic_name will follow convention of underscore-connected topic name w/o space and slash 11 | conversation_log.json # Log of information-seeking conversation 12 | raw_search_results.json # Raw search results from search engine 13 | direct_gen_outline.txt # Outline directly generated with LLM's parametric knowledge 14 | storm_gen_outline.txt # Outline refined with collected information 15 | url_to_info.json # Sources that are used in the final article 16 | storm_gen_article.txt # Final article generated 17 | storm_gen_article_polished.txt # Polished final article (if args.do_polish_article is True) 18 | """ 19 | 20 | import os 21 | import re 22 | import logging 23 | from argparse import ArgumentParser 24 | 25 | from knowledge_storm import ( 26 | STORMWikiRunnerArguments, 27 | STORMWikiRunner, 28 | STORMWikiLMConfigs, 29 | ) 30 | from knowledge_storm.lm import DeepSeekModel 31 | from knowledge_storm.rm import ( 32 | YouRM, 33 | BingSearch, 34 | BraveRM, 35 | SerperRM, 36 | DuckDuckGoSearchRM, 37 | TavilySearchRM, 38 | SearXNG, 39 | ) 40 | from knowledge_storm.utils import load_api_key 41 | 42 | 43 | def sanitize_topic(topic): 44 | """ 45 | Sanitize the topic name for use in file names. 46 | Remove or replace characters that are not allowed in file names. 47 | """ 48 | # Replace spaces with underscores 49 | topic = topic.replace(" ", "_") 50 | 51 | # Remove any character that isn't alphanumeric, underscore, or hyphen 52 | topic = re.sub(r"[^a-zA-Z0-9_-]", "", topic) 53 | 54 | # Ensure the topic isn't empty after sanitization 55 | if not topic: 56 | topic = "unnamed_topic" 57 | 58 | return topic 59 | 60 | 61 | def main(args): 62 | load_api_key(toml_file_path="secrets.toml") 63 | lm_configs = STORMWikiLMConfigs() 64 | 65 | logger = logging.getLogger(__name__) 66 | 67 | # Ensure DEEPSEEK_API_KEY is set 68 | if not os.getenv("DEEPSEEK_API_KEY"): 69 | raise ValueError( 70 | "DEEPSEEK_API_KEY environment variable is not set. Please set it in your secrets.toml file." 71 | ) 72 | 73 | deepseek_kwargs = { 74 | "api_key": os.getenv("DEEPSEEK_API_KEY"), 75 | "api_base": os.getenv("DEEPSEEK_API_BASE", "https://api.deepseek.com"), 76 | "temperature": args.temperature, 77 | "top_p": args.top_p, 78 | } 79 | 80 | # DeepSeek offers two main models: 'deepseek-chat' for general tasks and 'deepseek-coder' for coding tasks 81 | # Users can choose the appropriate model based on their needs 82 | conv_simulator_lm = DeepSeekModel( 83 | model=args.model, max_tokens=500, **deepseek_kwargs 84 | ) 85 | question_asker_lm = DeepSeekModel( 86 | model=args.model, max_tokens=500, **deepseek_kwargs 87 | ) 88 | outline_gen_lm = DeepSeekModel(model=args.model, max_tokens=400, **deepseek_kwargs) 89 | article_gen_lm = DeepSeekModel(model=args.model, max_tokens=700, **deepseek_kwargs) 90 | article_polish_lm = DeepSeekModel( 91 | model=args.model, max_tokens=4000, **deepseek_kwargs 92 | ) 93 | 94 | lm_configs.set_conv_simulator_lm(conv_simulator_lm) 95 | lm_configs.set_question_asker_lm(question_asker_lm) 96 | lm_configs.set_outline_gen_lm(outline_gen_lm) 97 | lm_configs.set_article_gen_lm(article_gen_lm) 98 | lm_configs.set_article_polish_lm(article_polish_lm) 99 | 100 | engine_args = STORMWikiRunnerArguments( 101 | output_dir=args.output_dir, 102 | max_conv_turn=args.max_conv_turn, 103 | max_perspective=args.max_perspective, 104 | search_top_k=args.search_top_k, 105 | max_thread_num=args.max_thread_num, 106 | ) 107 | 108 | # STORM is a knowledge curation system which consumes information from the retrieval module. 109 | # Currently, the information source is the Internet and we use search engine API as the retrieval module. 110 | match args.retriever: 111 | case "bing": 112 | rm = BingSearch( 113 | bing_search_api=os.getenv("BING_SEARCH_API_KEY"), 114 | k=engine_args.search_top_k, 115 | ) 116 | case "you": 117 | rm = YouRM(ydc_api_key=os.getenv("YDC_API_KEY"), k=engine_args.search_top_k) 118 | case "brave": 119 | rm = BraveRM( 120 | brave_search_api_key=os.getenv("BRAVE_API_KEY"), 121 | k=engine_args.search_top_k, 122 | ) 123 | case "duckduckgo": 124 | rm = DuckDuckGoSearchRM( 125 | k=engine_args.search_top_k, safe_search="On", region="us-en" 126 | ) 127 | case "serper": 128 | rm = SerperRM( 129 | serper_search_api_key=os.getenv("SERPER_API_KEY"), 130 | query_params={"autocorrect": True, "num": 10, "page": 1}, 131 | ) 132 | case "tavily": 133 | rm = TavilySearchRM( 134 | tavily_search_api_key=os.getenv("TAVILY_API_KEY"), 135 | k=engine_args.search_top_k, 136 | include_raw_content=True, 137 | ) 138 | case "searxng": 139 | rm = SearXNG( 140 | searxng_api_key=os.getenv("SEARXNG_API_KEY"), k=engine_args.search_top_k 141 | ) 142 | case _: 143 | raise ValueError( 144 | f'Invalid retriever: {args.retriever}. Choose either "bing", "you", "brave", "duckduckgo", "serper", "tavily", or "searxng"' 145 | ) 146 | 147 | runner = STORMWikiRunner(engine_args, lm_configs, rm) 148 | 149 | topic = input("Topic: ") 150 | sanitized_topic = sanitize_topic(topic) 151 | 152 | try: 153 | runner.run( 154 | topic=sanitized_topic, 155 | do_research=args.do_research, 156 | do_generate_outline=args.do_generate_outline, 157 | do_generate_article=args.do_generate_article, 158 | do_polish_article=args.do_polish_article, 159 | remove_duplicate=args.remove_duplicate, 160 | ) 161 | runner.post_run() 162 | runner.summary() 163 | except Exception as e: 164 | logger.exception(f"An error occurred: {str(e)}") 165 | raise 166 | 167 | 168 | if __name__ == "__main__": 169 | parser = ArgumentParser() 170 | # global arguments 171 | parser.add_argument( 172 | "--output-dir", 173 | type=str, 174 | default="./results/deepseek", 175 | help="Directory to store the outputs.", 176 | ) 177 | parser.add_argument( 178 | "--max-thread-num", 179 | type=int, 180 | default=3, 181 | help="Maximum number of threads to use. The information seeking part and the article generation" 182 | "part can speed up by using multiple threads. Consider reducing it if keep getting " 183 | '"Exceed rate limit" error when calling LM API.', 184 | ) 185 | parser.add_argument( 186 | "--retriever", 187 | type=str, 188 | choices=["bing", "you", "brave", "serper", "duckduckgo", "tavily", "searxng"], 189 | help="The search engine API to use for retrieving information.", 190 | ) 191 | parser.add_argument( 192 | "--model", 193 | type=str, 194 | choices=["deepseek-chat", "deepseek-coder"], 195 | default="deepseek-chat", 196 | help='DeepSeek model to use. "deepseek-chat" for general tasks, "deepseek-coder" for coding tasks.', 197 | ) 198 | parser.add_argument( 199 | "--temperature", type=float, default=1.0, help="Sampling temperature to use." 200 | ) 201 | parser.add_argument( 202 | "--top_p", type=float, default=0.9, help="Top-p sampling parameter." 203 | ) 204 | # stage of the pipeline 205 | parser.add_argument( 206 | "--do-research", 207 | action="store_true", 208 | help="If True, simulate conversation to research the topic; otherwise, load the results.", 209 | ) 210 | parser.add_argument( 211 | "--do-generate-outline", 212 | action="store_true", 213 | help="If True, generate an outline for the topic; otherwise, load the results.", 214 | ) 215 | parser.add_argument( 216 | "--do-generate-article", 217 | action="store_true", 218 | help="If True, generate an article for the topic; otherwise, load the results.", 219 | ) 220 | parser.add_argument( 221 | "--do-polish-article", 222 | action="store_true", 223 | help="If True, polish the article by adding a summarization section and (optionally) removing " 224 | "duplicate content.", 225 | ) 226 | # hyperparameters for the pre-writing stage 227 | parser.add_argument( 228 | "--max-conv-turn", 229 | type=int, 230 | default=3, 231 | help="Maximum number of questions in conversational question asking.", 232 | ) 233 | parser.add_argument( 234 | "--max-perspective", 235 | type=int, 236 | default=3, 237 | help="Maximum number of perspectives to consider in perspective-guided question asking.", 238 | ) 239 | parser.add_argument( 240 | "--search-top-k", 241 | type=int, 242 | default=3, 243 | help="Top k search results to consider for each search query.", 244 | ) 245 | # hyperparameters for the writing stage 246 | parser.add_argument( 247 | "--retrieve-top-k", 248 | type=int, 249 | default=3, 250 | help="Top k collected references for each section title.", 251 | ) 252 | parser.add_argument( 253 | "--remove-duplicate", 254 | action="store_true", 255 | help="If True, remove duplicate content from the article.", 256 | ) 257 | 258 | main(parser.parse_args()) 259 | -------------------------------------------------------------------------------- /examples/storm_examples/run_storm_wiki_gemini.py: -------------------------------------------------------------------------------- 1 | """ 2 | STORM Wiki pipeline powered by Google Gemini models and search engine. 3 | You need to set up the following environment variables to run this script: 4 | - GOOGLE_API_KEY: Google API key (Can be obtained from https://ai.google.dev/gemini-api/docs/api-key) 5 | - YDC_API_KEY: You.com API key; BING_SEARCH_API_KEY: Bing Search API key, SERPER_API_KEY: Serper API key, BRAVE_API_KEY: Brave API key, or TAVILY_API_KEY: Tavily API key 6 | 7 | Output will be structured as below 8 | args.output_dir/ 9 | topic_name/ # topic_name will follow convention of underscore-connected topic name w/o space and slash 10 | conversation_log.json # Log of information-seeking conversation 11 | raw_search_results.json # Raw search results from search engine 12 | direct_gen_outline.txt # Outline directly generated with LLM's parametric knowledge 13 | storm_gen_outline.txt # Outline refined with collected information 14 | url_to_info.json # Sources that are used in the final article 15 | storm_gen_article.txt # Final article generated 16 | storm_gen_article_polished.txt # Polished final article (if args.do_polish_article is True) 17 | """ 18 | 19 | import os 20 | from argparse import ArgumentParser 21 | 22 | from knowledge_storm import ( 23 | STORMWikiRunnerArguments, 24 | STORMWikiRunner, 25 | STORMWikiLMConfigs, 26 | ) 27 | from knowledge_storm.lm import GoogleModel 28 | from knowledge_storm.rm import ( 29 | YouRM, 30 | BingSearch, 31 | BraveRM, 32 | SerperRM, 33 | DuckDuckGoSearchRM, 34 | TavilySearchRM, 35 | SearXNG, 36 | ) 37 | from knowledge_storm.utils import load_api_key 38 | 39 | 40 | def main(args): 41 | load_api_key(toml_file_path="secrets.toml") 42 | lm_configs = STORMWikiLMConfigs() 43 | gemini_kwargs = { 44 | "api_key": os.getenv("GOOGLE_API_KEY"), 45 | "temperature": 1.0, 46 | "top_p": 0.9, 47 | } 48 | 49 | # STORM is a LM system so different components can be powered by different models. 50 | # For a good balance between cost and quality, you can choose a cheaper/faster model for conv_simulator_lm 51 | # which is used to split queries, synthesize answers in the conversation. We recommend using stronger models 52 | # for outline_gen_lm which is responsible for organizing the collected information, and article_gen_lm 53 | # which is responsible for generating sections with citations. 54 | # To check out available Google models, see: 55 | # https://ai.google.dev/gemini-api/docs/get-started/tutorial?lang=python#list_models 56 | conv_simulator_lm = GoogleModel( 57 | model="models/gemini-1.5-flash", max_tokens=500, **gemini_kwargs 58 | ) 59 | question_asker_lm = GoogleModel( 60 | model="models/gemini-1.5-flash", max_tokens=500, **gemini_kwargs 61 | ) 62 | outline_gen_lm = GoogleModel( 63 | model="models/gemini-1.5-pro-exp-0801", max_tokens=400, **gemini_kwargs 64 | ) 65 | article_gen_lm = GoogleModel( 66 | model="models/gemini-1.5-pro-exp-0801", max_tokens=700, **gemini_kwargs 67 | ) 68 | article_polish_lm = GoogleModel( 69 | model="models/gemini-1.5-pro-exp-0801", max_tokens=4000, **gemini_kwargs 70 | ) 71 | 72 | lm_configs.set_conv_simulator_lm(conv_simulator_lm) 73 | lm_configs.set_question_asker_lm(question_asker_lm) 74 | lm_configs.set_outline_gen_lm(outline_gen_lm) 75 | lm_configs.set_article_gen_lm(article_gen_lm) 76 | lm_configs.set_article_polish_lm(article_polish_lm) 77 | 78 | engine_args = STORMWikiRunnerArguments( 79 | output_dir=args.output_dir, 80 | max_conv_turn=args.max_conv_turn, 81 | max_perspective=args.max_perspective, 82 | search_top_k=args.search_top_k, 83 | max_thread_num=args.max_thread_num, 84 | ) 85 | 86 | # STORM is a knowledge curation system which consumes information from the retrieval module. 87 | # Currently, the information source is the Internet and we use search engine API as the retrieval module. 88 | match args.retriever: 89 | case "bing": 90 | rm = BingSearch( 91 | bing_search_api=os.getenv("BING_SEARCH_API_KEY"), 92 | k=engine_args.search_top_k, 93 | ) 94 | case "you": 95 | rm = YouRM(ydc_api_key=os.getenv("YDC_API_KEY"), k=engine_args.search_top_k) 96 | case "brave": 97 | rm = BraveRM( 98 | brave_search_api_key=os.getenv("BRAVE_API_KEY"), 99 | k=engine_args.search_top_k, 100 | ) 101 | case "duckduckgo": 102 | rm = DuckDuckGoSearchRM( 103 | k=engine_args.search_top_k, safe_search="On", region="us-en" 104 | ) 105 | case "serper": 106 | rm = SerperRM( 107 | serper_search_api_key=os.getenv("SERPER_API_KEY"), 108 | query_params={"autocorrect": True, "num": 10, "page": 1}, 109 | ) 110 | case "tavily": 111 | rm = TavilySearchRM( 112 | tavily_search_api_key=os.getenv("TAVILY_API_KEY"), 113 | k=engine_args.search_top_k, 114 | include_raw_content=True, 115 | ) 116 | case "searxng": 117 | rm = SearXNG( 118 | searxng_api_key=os.getenv("SEARXNG_API_KEY"), k=engine_args.search_top_k 119 | ) 120 | case _: 121 | raise ValueError( 122 | f'Invalid retriever: {args.retriever}. Choose either "bing", "you", "brave", "duckduckgo", "serper", "tavily", or "searxng"' 123 | ) 124 | 125 | runner = STORMWikiRunner(engine_args, lm_configs, rm) 126 | 127 | topic = input("Topic: ") 128 | runner.run( 129 | topic=topic, 130 | do_research=args.do_research, 131 | do_generate_outline=args.do_generate_outline, 132 | do_generate_article=args.do_generate_article, 133 | do_polish_article=args.do_polish_article, 134 | ) 135 | runner.post_run() 136 | runner.summary() 137 | 138 | 139 | if __name__ == "__main__": 140 | parser = ArgumentParser() 141 | # global arguments 142 | parser.add_argument( 143 | "--output-dir", 144 | type=str, 145 | default="./results/gemini", 146 | help="Directory to store the outputs.", 147 | ) 148 | parser.add_argument( 149 | "--max-thread-num", 150 | type=int, 151 | default=3, 152 | help="Maximum number of threads to use. The information seeking part and the article generation" 153 | "part can speed up by using multiple threads. Consider reducing it if keep getting " 154 | '"Exceed rate limit" error when calling LM API.', 155 | ) 156 | parser.add_argument( 157 | "--retriever", 158 | type=str, 159 | choices=["bing", "you", "brave", "serper", "duckduckgo", "tavily", "searxng"], 160 | help="The search engine API to use for retrieving information.", 161 | ) 162 | # stage of the pipeline 163 | parser.add_argument( 164 | "--do-research", 165 | action="store_true", 166 | help="If True, simulate conversation to research the topic; otherwise, load the results.", 167 | ) 168 | parser.add_argument( 169 | "--do-generate-outline", 170 | action="store_true", 171 | help="If True, generate an outline for the topic; otherwise, load the results.", 172 | ) 173 | parser.add_argument( 174 | "--do-generate-article", 175 | action="store_true", 176 | help="If True, generate an article for the topic; otherwise, load the results.", 177 | ) 178 | parser.add_argument( 179 | "--do-polish-article", 180 | action="store_true", 181 | help="If True, polish the article by adding a summarization section and (optionally) removing " 182 | "duplicate content.", 183 | ) 184 | # hyperparameters for the pre-writing stage 185 | parser.add_argument( 186 | "--max-conv-turn", 187 | type=int, 188 | default=3, 189 | help="Maximum number of questions in conversational question asking.", 190 | ) 191 | parser.add_argument( 192 | "--max-perspective", 193 | type=int, 194 | default=3, 195 | help="Maximum number of perspectives to consider in perspective-guided question asking.", 196 | ) 197 | parser.add_argument( 198 | "--search-top-k", 199 | type=int, 200 | default=3, 201 | help="Top k search results to consider for each search query.", 202 | ) 203 | # hyperparameters for the writing stage 204 | parser.add_argument( 205 | "--retrieve-top-k", 206 | type=int, 207 | default=3, 208 | help="Top k collected references for each section title.", 209 | ) 210 | parser.add_argument( 211 | "--remove-duplicate", 212 | action="store_true", 213 | help="If True, remove duplicate content from the article.", 214 | ) 215 | 216 | main(parser.parse_args()) 217 | -------------------------------------------------------------------------------- /examples/storm_examples/run_storm_wiki_gpt.py: -------------------------------------------------------------------------------- 1 | """ 2 | STORM Wiki pipeline powered by GPT-3.5/4 and You.com search engine. 3 | You need to set up the following environment variables to run this script: 4 | - OPENAI_API_KEY: OpenAI API key 5 | - OPENAI_API_TYPE: OpenAI API type (e.g., 'openai' or 'azure') 6 | - AZURE_API_BASE: Azure API base URL if using Azure API 7 | - AZURE_API_VERSION: Azure API version if using Azure API 8 | - YDC_API_KEY: You.com API key; BING_SEARCH_API_KEY: Bing Search API key, SERPER_API_KEY: Serper API key, BRAVE_API_KEY: Brave API key, or TAVILY_API_KEY: Tavily API key 9 | 10 | Output will be structured as below 11 | args.output_dir/ 12 | topic_name/ # topic_name will follow convention of underscore-connected topic name w/o space and slash 13 | conversation_log.json # Log of information-seeking conversation 14 | raw_search_results.json # Raw search results from search engine 15 | direct_gen_outline.txt # Outline directly generated with LLM's parametric knowledge 16 | storm_gen_outline.txt # Outline refined with collected information 17 | url_to_info.json # Sources that are used in the final article 18 | storm_gen_article.txt # Final article generated 19 | storm_gen_article_polished.txt # Polished final article (if args.do_polish_article is True) 20 | """ 21 | 22 | import os 23 | 24 | from argparse import ArgumentParser 25 | from knowledge_storm import ( 26 | STORMWikiRunnerArguments, 27 | STORMWikiRunner, 28 | STORMWikiLMConfigs, 29 | ) 30 | from knowledge_storm.lm import OpenAIModel, AzureOpenAIModel 31 | from knowledge_storm.rm import ( 32 | YouRM, 33 | BingSearch, 34 | BraveRM, 35 | SerperRM, 36 | DuckDuckGoSearchRM, 37 | TavilySearchRM, 38 | SearXNG, 39 | AzureAISearch, 40 | ) 41 | from knowledge_storm.utils import load_api_key 42 | 43 | 44 | def main(args): 45 | load_api_key(toml_file_path="secrets.toml") 46 | lm_configs = STORMWikiLMConfigs() 47 | openai_kwargs = { 48 | "api_key": os.getenv("OPENAI_API_KEY"), 49 | "temperature": 1.0, 50 | "top_p": 0.9, 51 | } 52 | 53 | ModelClass = ( 54 | OpenAIModel if os.getenv("OPENAI_API_TYPE") == "openai" else AzureOpenAIModel 55 | ) 56 | # If you are using Azure service, make sure the model name matches your own deployed model name. 57 | # The default name here is only used for demonstration and may not match your case. 58 | gpt_35_model_name = ( 59 | "gpt-3.5-turbo" if os.getenv("OPENAI_API_TYPE") == "openai" else "gpt-35-turbo" 60 | ) 61 | gpt_4_model_name = "gpt-4o" 62 | if os.getenv("OPENAI_API_TYPE") == "azure": 63 | openai_kwargs["api_base"] = os.getenv("AZURE_API_BASE") 64 | openai_kwargs["api_version"] = os.getenv("AZURE_API_VERSION") 65 | 66 | # STORM is a LM system so different components can be powered by different models. 67 | # For a good balance between cost and quality, you can choose a cheaper/faster model for conv_simulator_lm 68 | # which is used to split queries, synthesize answers in the conversation. We recommend using stronger models 69 | # for outline_gen_lm which is responsible for organizing the collected information, and article_gen_lm 70 | # which is responsible for generating sections with citations. 71 | conv_simulator_lm = ModelClass( 72 | model=gpt_35_model_name, max_tokens=500, **openai_kwargs 73 | ) 74 | question_asker_lm = ModelClass( 75 | model=gpt_35_model_name, max_tokens=500, **openai_kwargs 76 | ) 77 | outline_gen_lm = ModelClass(model=gpt_4_model_name, max_tokens=400, **openai_kwargs) 78 | article_gen_lm = ModelClass(model=gpt_4_model_name, max_tokens=700, **openai_kwargs) 79 | article_polish_lm = ModelClass( 80 | model=gpt_4_model_name, max_tokens=4000, **openai_kwargs 81 | ) 82 | 83 | lm_configs.set_conv_simulator_lm(conv_simulator_lm) 84 | lm_configs.set_question_asker_lm(question_asker_lm) 85 | lm_configs.set_outline_gen_lm(outline_gen_lm) 86 | lm_configs.set_article_gen_lm(article_gen_lm) 87 | lm_configs.set_article_polish_lm(article_polish_lm) 88 | 89 | engine_args = STORMWikiRunnerArguments( 90 | output_dir=args.output_dir, 91 | max_conv_turn=args.max_conv_turn, 92 | max_perspective=args.max_perspective, 93 | search_top_k=args.search_top_k, 94 | max_thread_num=args.max_thread_num, 95 | ) 96 | 97 | # STORM is a knowledge curation system which consumes information from the retrieval module. 98 | # Currently, the information source is the Internet and we use search engine API as the retrieval module. 99 | 100 | match args.retriever: 101 | case "bing": 102 | rm = BingSearch( 103 | bing_search_api=os.getenv("BING_SEARCH_API_KEY"), 104 | k=engine_args.search_top_k, 105 | ) 106 | case "you": 107 | rm = YouRM(ydc_api_key=os.getenv("YDC_API_KEY"), k=engine_args.search_top_k) 108 | case "brave": 109 | rm = BraveRM( 110 | brave_search_api_key=os.getenv("BRAVE_API_KEY"), 111 | k=engine_args.search_top_k, 112 | ) 113 | case "duckduckgo": 114 | rm = DuckDuckGoSearchRM( 115 | k=engine_args.search_top_k, safe_search="On", region="us-en" 116 | ) 117 | case "serper": 118 | rm = SerperRM( 119 | serper_search_api_key=os.getenv("SERPER_API_KEY"), 120 | query_params={"autocorrect": True, "num": 10, "page": 1}, 121 | ) 122 | case "tavily": 123 | rm = TavilySearchRM( 124 | tavily_search_api_key=os.getenv("TAVILY_API_KEY"), 125 | k=engine_args.search_top_k, 126 | include_raw_content=True, 127 | ) 128 | case "searxng": 129 | rm = SearXNG( 130 | searxng_api_key=os.getenv("SEARXNG_API_KEY"), k=engine_args.search_top_k 131 | ) 132 | case "azure_ai_search": 133 | rm = AzureAISearch( 134 | azure_ai_search_api_key=os.getenv("AZURE_AI_SEARCH_API_KEY"), 135 | k=engine_args.search_top_k, 136 | ) 137 | case _: 138 | raise ValueError( 139 | f'Invalid retriever: {args.retriever}. Choose either "bing", "you", "brave", "duckduckgo", "serper", "tavily", "searxng", or "azure_ai_search"' 140 | ) 141 | 142 | runner = STORMWikiRunner(engine_args, lm_configs, rm) 143 | 144 | topic = input("Topic: ") 145 | runner.run( 146 | topic=topic, 147 | do_research=args.do_research, 148 | do_generate_outline=args.do_generate_outline, 149 | do_generate_article=args.do_generate_article, 150 | do_polish_article=args.do_polish_article, 151 | ) 152 | runner.post_run() 153 | runner.summary() 154 | 155 | 156 | if __name__ == "__main__": 157 | parser = ArgumentParser() 158 | # global arguments 159 | parser.add_argument( 160 | "--output-dir", 161 | type=str, 162 | default="./results/gpt", 163 | help="Directory to store the outputs.", 164 | ) 165 | parser.add_argument( 166 | "--max-thread-num", 167 | type=int, 168 | default=3, 169 | help="Maximum number of threads to use. The information seeking part and the article generation" 170 | "part can speed up by using multiple threads. Consider reducing it if keep getting " 171 | '"Exceed rate limit" error when calling LM API.', 172 | ) 173 | parser.add_argument( 174 | "--retriever", 175 | type=str, 176 | choices=[ 177 | "bing", 178 | "you", 179 | "brave", 180 | "serper", 181 | "duckduckgo", 182 | "tavily", 183 | "searxng", 184 | "azure_ai_search", 185 | ], 186 | help="The search engine API to use for retrieving information.", 187 | ) 188 | # stage of the pipeline 189 | parser.add_argument( 190 | "--do-research", 191 | action="store_true", 192 | help="If True, simulate conversation to research the topic; otherwise, load the results.", 193 | ) 194 | parser.add_argument( 195 | "--do-generate-outline", 196 | action="store_true", 197 | help="If True, generate an outline for the topic; otherwise, load the results.", 198 | ) 199 | parser.add_argument( 200 | "--do-generate-article", 201 | action="store_true", 202 | help="If True, generate an article for the topic; otherwise, load the results.", 203 | ) 204 | parser.add_argument( 205 | "--do-polish-article", 206 | action="store_true", 207 | help="If True, polish the article by adding a summarization section and (optionally) removing " 208 | "duplicate content.", 209 | ) 210 | # hyperparameters for the pre-writing stage 211 | parser.add_argument( 212 | "--max-conv-turn", 213 | type=int, 214 | default=3, 215 | help="Maximum number of questions in conversational question asking.", 216 | ) 217 | parser.add_argument( 218 | "--max-perspective", 219 | type=int, 220 | default=3, 221 | help="Maximum number of perspectives to consider in perspective-guided question asking.", 222 | ) 223 | parser.add_argument( 224 | "--search-top-k", 225 | type=int, 226 | default=3, 227 | help="Top k search results to consider for each search query.", 228 | ) 229 | # hyperparameters for the writing stage 230 | parser.add_argument( 231 | "--retrieve-top-k", 232 | type=int, 233 | default=3, 234 | help="Top k collected references for each section title.", 235 | ) 236 | parser.add_argument( 237 | "--remove-duplicate", 238 | action="store_true", 239 | help="If True, remove duplicate content from the article.", 240 | ) 241 | 242 | main(parser.parse_args()) 243 | -------------------------------------------------------------------------------- /examples/storm_examples/run_storm_wiki_gpt_with_VectorRM.py: -------------------------------------------------------------------------------- 1 | """ 2 | This STORM Wiki pipeline powered by GPT-3.5/4 and local retrieval model that uses Qdrant. 3 | You need to set up the following environment variables to run this script: 4 | - OPENAI_API_KEY: OpenAI API key 5 | - OPENAI_API_TYPE: OpenAI API type (e.g., 'openai' or 'azure') 6 | - QDRANT_API_KEY: Qdrant API key (needed ONLY if online vector store was used) 7 | 8 | You will also need an existing Qdrant vector store either saved in a folder locally offline or in a server online. 9 | If not, then you would need a CSV file with documents, and the script is going to create the vector store for you. 10 | The CSV should be in the following format: 11 | content | title | url | description 12 | I am a document. | Document 1 | docu-n-112 | A self-explanatory document. 13 | I am another document. | Document 2 | docu-l-13 | Another self-explanatory document. 14 | 15 | Notice that the URL will be a unique identifier for the document so ensure different documents have different urls. 16 | 17 | Output will be structured as below 18 | args.output_dir/ 19 | topic_name/ # topic_name will follow convention of underscore-connected topic name w/o space and slash 20 | conversation_log.json # Log of information-seeking conversation 21 | raw_search_results.json # Raw search results from search engine 22 | direct_gen_outline.txt # Outline directly generated with LLM's parametric knowledge 23 | storm_gen_outline.txt # Outline refined with collected information 24 | url_to_info.json # Sources that are used in the final article 25 | storm_gen_article.txt # Final article generated 26 | storm_gen_article_polished.txt # Polished final article (if args.do_polish_article is True) 27 | """ 28 | 29 | import os 30 | from argparse import ArgumentParser 31 | 32 | from knowledge_storm import ( 33 | STORMWikiRunnerArguments, 34 | STORMWikiRunner, 35 | STORMWikiLMConfigs, 36 | ) 37 | from knowledge_storm.rm import VectorRM 38 | from knowledge_storm.lm import OpenAIModel, AzureOpenAIModel 39 | from knowledge_storm.utils import load_api_key, QdrantVectorStoreManager 40 | 41 | 42 | def main(args): 43 | # Load API key from the specified toml file path 44 | load_api_key(toml_file_path="secrets.toml") 45 | 46 | # Initialize the language model configurations 47 | engine_lm_configs = STORMWikiLMConfigs() 48 | openai_kwargs = { 49 | "api_key": os.getenv("OPENAI_API_KEY"), 50 | "temperature": 1.0, 51 | "top_p": 0.9, 52 | } 53 | 54 | ModelClass = ( 55 | OpenAIModel if os.getenv("OPENAI_API_TYPE") == "openai" else AzureOpenAIModel 56 | ) 57 | # If you are using Azure service, make sure the model name matches your own deployed model name. 58 | # The default name here is only used for demonstration and may not match your case. 59 | gpt_35_model_name = ( 60 | "gpt-3.5-turbo" if os.getenv("OPENAI_API_TYPE") == "openai" else "gpt-35-turbo" 61 | ) 62 | gpt_4_model_name = "gpt-4o" 63 | if os.getenv("OPENAI_API_TYPE") == "azure": 64 | openai_kwargs["api_base"] = os.getenv("AZURE_API_BASE") 65 | openai_kwargs["api_version"] = os.getenv("AZURE_API_VERSION") 66 | 67 | # STORM is a LM system so different components can be powered by different models. 68 | # For a good balance between cost and quality, you can choose a cheaper/faster model for conv_simulator_lm 69 | # which is used to split queries, synthesize answers in the conversation. We recommend using stronger models 70 | # for outline_gen_lm which is responsible for organizing the collected information, and article_gen_lm 71 | # which is responsible for generating sections with citations. 72 | conv_simulator_lm = ModelClass( 73 | model=gpt_35_model_name, max_tokens=500, **openai_kwargs 74 | ) 75 | question_asker_lm = ModelClass( 76 | model=gpt_35_model_name, max_tokens=500, **openai_kwargs 77 | ) 78 | outline_gen_lm = ModelClass(model=gpt_4_model_name, max_tokens=400, **openai_kwargs) 79 | article_gen_lm = ModelClass(model=gpt_4_model_name, max_tokens=700, **openai_kwargs) 80 | article_polish_lm = ModelClass( 81 | model=gpt_4_model_name, max_tokens=4000, **openai_kwargs 82 | ) 83 | 84 | engine_lm_configs.set_conv_simulator_lm(conv_simulator_lm) 85 | engine_lm_configs.set_question_asker_lm(question_asker_lm) 86 | engine_lm_configs.set_outline_gen_lm(outline_gen_lm) 87 | engine_lm_configs.set_article_gen_lm(article_gen_lm) 88 | engine_lm_configs.set_article_polish_lm(article_polish_lm) 89 | 90 | # Initialize the engine arguments 91 | engine_args = STORMWikiRunnerArguments( 92 | output_dir=args.output_dir, 93 | max_conv_turn=args.max_conv_turn, 94 | max_perspective=args.max_perspective, 95 | search_top_k=args.search_top_k, 96 | max_thread_num=args.max_thread_num, 97 | ) 98 | 99 | # Create / update the vector store with the documents in the csv file 100 | if args.csv_file_path: 101 | kwargs = { 102 | "file_path": args.csv_file_path, 103 | "content_column": "content", 104 | "title_column": "title", 105 | "url_column": "url", 106 | "desc_column": "description", 107 | "batch_size": args.embed_batch_size, 108 | "vector_db_mode": args.vector_db_mode, 109 | "collection_name": args.collection_name, 110 | "embedding_model": args.embedding_model, 111 | "device": args.device, 112 | } 113 | if args.vector_db_mode == "offline": 114 | QdrantVectorStoreManager.create_or_update_vector_store( 115 | vector_store_path=args.offline_vector_db_dir, **kwargs 116 | ) 117 | elif args.vector_db_mode == "online": 118 | QdrantVectorStoreManager.create_or_update_vector_store( 119 | url=args.online_vector_db_url, 120 | api_key=os.getenv("QDRANT_API_KEY"), 121 | **kwargs 122 | ) 123 | 124 | # Setup VectorRM to retrieve information from your own data 125 | rm = VectorRM( 126 | collection_name=args.collection_name, 127 | embedding_model=args.embedding_model, 128 | device=args.device, 129 | k=engine_args.search_top_k, 130 | ) 131 | 132 | # initialize the vector store, either online (store the db on Qdrant server) or offline (store the db locally): 133 | if args.vector_db_mode == "offline": 134 | rm.init_offline_vector_db(vector_store_path=args.offline_vector_db_dir) 135 | elif args.vector_db_mode == "online": 136 | rm.init_online_vector_db( 137 | url=args.online_vector_db_url, api_key=os.getenv("QDRANT_API_KEY") 138 | ) 139 | 140 | # Initialize the STORM Wiki Runner 141 | runner = STORMWikiRunner(engine_args, engine_lm_configs, rm) 142 | 143 | # run the pipeline 144 | topic = input("Topic: ") 145 | runner.run( 146 | topic=topic, 147 | do_research=args.do_research, 148 | do_generate_outline=args.do_generate_outline, 149 | do_generate_article=args.do_generate_article, 150 | do_polish_article=args.do_polish_article, 151 | ) 152 | runner.post_run() 153 | runner.summary() 154 | 155 | 156 | if __name__ == "__main__": 157 | parser = ArgumentParser() 158 | # global arguments 159 | parser.add_argument( 160 | "--output-dir", 161 | type=str, 162 | default="./results/gpt_retrieval", 163 | help="Directory to store the outputs.", 164 | ) 165 | parser.add_argument( 166 | "--max-thread-num", 167 | type=int, 168 | default=3, 169 | help="Maximum number of threads to use. The information seeking part and the article generation" 170 | "part can speed up by using multiple threads. Consider reducing it if keep getting " 171 | '"Exceed rate limit" error when calling LM API.', 172 | ) 173 | # provide local corpus and set up vector db 174 | parser.add_argument( 175 | "--collection-name", 176 | type=str, 177 | default="my_documents", 178 | help="The collection name for vector store.", 179 | ) 180 | parser.add_argument( 181 | "--embedding_model", 182 | type=str, 183 | default="BAAI/bge-m3", 184 | help="The collection name for vector store.", 185 | ) 186 | parser.add_argument( 187 | "--device", 188 | type=str, 189 | default="mps", 190 | help="The device used to run the retrieval model (mps, cuda, cpu, etc).", 191 | ) 192 | parser.add_argument( 193 | "--vector-db-mode", 194 | type=str, 195 | choices=["offline", "online"], 196 | help="The mode of the Qdrant vector store (offline or online).", 197 | ) 198 | parser.add_argument( 199 | "--offline-vector-db-dir", 200 | type=str, 201 | default="./vector_store", 202 | help="If use offline mode, please provide the directory to store the vector store.", 203 | ) 204 | parser.add_argument( 205 | "--online-vector-db-url", 206 | type=str, 207 | help="If use online mode, please provide the url of the Qdrant server.", 208 | ) 209 | parser.add_argument( 210 | "--csv-file-path", 211 | type=str, 212 | default=None, 213 | help="The path of the custom document corpus in CSV format. The CSV file should include " 214 | "content, title, url, and description columns.", 215 | ) 216 | parser.add_argument( 217 | "--embed-batch-size", 218 | type=int, 219 | default=64, 220 | help="Batch size for embedding the documents in the csv file.", 221 | ) 222 | # stage of the pipeline 223 | parser.add_argument( 224 | "--do-research", 225 | action="store_true", 226 | help="If True, simulate conversation to research the topic; otherwise, load the results.", 227 | ) 228 | parser.add_argument( 229 | "--do-generate-outline", 230 | action="store_true", 231 | help="If True, generate an outline for the topic; otherwise, load the results.", 232 | ) 233 | parser.add_argument( 234 | "--do-generate-article", 235 | action="store_true", 236 | help="If True, generate an article for the topic; otherwise, load the results.", 237 | ) 238 | parser.add_argument( 239 | "--do-polish-article", 240 | action="store_true", 241 | help="If True, polish the article by adding a summarization section and (optionally) removing " 242 | "duplicate content.", 243 | ) 244 | # hyperparameters for the pre-writing stage 245 | parser.add_argument( 246 | "--max-conv-turn", 247 | type=int, 248 | default=3, 249 | help="Maximum number of questions in conversational question asking.", 250 | ) 251 | parser.add_argument( 252 | "--max-perspective", 253 | type=int, 254 | default=3, 255 | help="Maximum number of perspectives to consider in perspective-guided question asking.", 256 | ) 257 | parser.add_argument( 258 | "--search-top-k", 259 | type=int, 260 | default=3, 261 | help="Top k search results to consider for each search query.", 262 | ) 263 | # hyperparameters for the writing stage 264 | parser.add_argument( 265 | "--retrieve-top-k", 266 | type=int, 267 | default=3, 268 | help="Top k collected references for each section title.", 269 | ) 270 | parser.add_argument( 271 | "--remove-duplicate", 272 | action="store_true", 273 | help="If True, remove duplicate content from the article.", 274 | ) 275 | main(parser.parse_args()) 276 | -------------------------------------------------------------------------------- /examples/storm_examples/run_storm_wiki_groq.py: -------------------------------------------------------------------------------- 1 | """ 2 | STORM Wiki pipeline powered by llama3-70b-8192 hosted by Groq server and You.com search engine. 3 | You need to set up the following environment variables to run this script: 4 | - GROQ_API_KEY: You can get your Groq API Key at https://console.groq.com/keys 5 | - YDC_API_KEY: You.com API key; BING_SEARCH_API_KEY: Bing Search API key, SERPER_API_KEY: Serper API key, BRAVE_API_KEY: Brave API key, or TAVILY_API_KEY: Tavily API key 6 | You also need to have a VLLM server running with the Mistral-7B-Instruct-v0.2 model. Specify `--url` and `--port` accordingly. 7 | 8 | Output will be structured as below 9 | args.output_dir/ 10 | topic_name/ # topic_name will follow convention of underscore-connected topic name w/o space and slash 11 | conversation_log.json # Log of information-seeking conversation 12 | raw_search_results.json # Raw search results from search engine 13 | direct_gen_outline.txt # Outline directly generated with LLM's parametric knowledge 14 | storm_gen_outline.txt # Outline refined with collected information 15 | url_to_info.json # Sources that are used in the final article 16 | storm_gen_article.txt # Final article generated 17 | storm_gen_article_polished.txt # Polished final article (if args.do_polish_article is True) 18 | """ 19 | 20 | import os 21 | import re 22 | from argparse import ArgumentParser 23 | 24 | from knowledge_storm import ( 25 | STORMWikiRunnerArguments, 26 | STORMWikiRunner, 27 | STORMWikiLMConfigs, 28 | ) 29 | 30 | # Now import lm directly 31 | import lm 32 | from lm import GroqModel 33 | from knowledge_storm.rm import ( 34 | YouRM, 35 | BingSearch, 36 | BraveRM, 37 | SerperRM, 38 | DuckDuckGoSearchRM, 39 | TavilySearchRM, 40 | SearXNG, 41 | ) 42 | from knowledge_storm.utils import load_api_key 43 | 44 | 45 | def sanitize_topic(topic): 46 | """ 47 | Sanitize the topic name for use in file names. 48 | Remove or replace characters that are not allowed in file names. 49 | """ 50 | # Replace spaces with underscores 51 | topic = topic.replace(" ", "_") 52 | 53 | # Remove any character that isn't alphanumeric, underscore, or hyphen 54 | topic = re.sub(r"[^a-zA-Z0-9_-]", "", topic) 55 | 56 | # Ensure the topic isn't empty after sanitization 57 | if not topic: 58 | topic = "unnamed_topic" 59 | 60 | return topic 61 | 62 | 63 | def main(args): 64 | load_api_key(toml_file_path="secrets.toml") 65 | lm_configs = STORMWikiLMConfigs() 66 | 67 | # Ensure GROQ_API_KEY is set 68 | if not os.getenv("GROQ_API_KEY"): 69 | raise ValueError( 70 | "GROQ_API_KEY environment variable is not set. Please set it in your secrets.toml file." 71 | ) 72 | 73 | groq_kwargs = { 74 | "api_key": os.getenv("GROQ_API_KEY"), 75 | "api_base": "https://api.groq.com/openai/v1", 76 | "temperature": args.temperature, 77 | "top_p": args.top_p, 78 | } 79 | 80 | # Groq currently offers the "llama3-70b-8192" model with generous free API credits and the llama3.1 family of models as a preview for paying customers 81 | conv_simulator_lm = GroqModel( 82 | model="llama3-70b-8192", max_tokens=500, **groq_kwargs 83 | ) 84 | question_asker_lm = GroqModel( 85 | model="llama3-70b-8192", max_tokens=500, **groq_kwargs 86 | ) 87 | outline_gen_lm = GroqModel(model="llama3-70b-8192", max_tokens=400, **groq_kwargs) 88 | article_gen_lm = GroqModel(model="llama3-70b-8192", max_tokens=700, **groq_kwargs) 89 | article_polish_lm = GroqModel( 90 | model="llama3-70b-8192", max_tokens=4000, **groq_kwargs 91 | ) 92 | 93 | lm_configs.set_conv_simulator_lm(conv_simulator_lm) 94 | lm_configs.set_question_asker_lm(question_asker_lm) 95 | lm_configs.set_outline_gen_lm(outline_gen_lm) 96 | lm_configs.set_article_gen_lm(article_gen_lm) 97 | lm_configs.set_article_polish_lm(article_polish_lm) 98 | 99 | engine_args = STORMWikiRunnerArguments( 100 | output_dir=args.output_dir, 101 | max_conv_turn=args.max_conv_turn, 102 | max_perspective=args.max_perspective, 103 | search_top_k=args.search_top_k, 104 | max_thread_num=args.max_thread_num, 105 | ) 106 | 107 | # STORM is a knowledge curation system which consumes information from the retrieval module. 108 | # Currently, the information source is the Internet and we use search engine API as the retrieval module. 109 | match args.retriever: 110 | case "bing": 111 | rm = BingSearch( 112 | bing_search_api=os.getenv("BING_SEARCH_API_KEY"), 113 | k=engine_args.search_top_k, 114 | ) 115 | case "you": 116 | rm = YouRM(ydc_api_key=os.getenv("YDC_API_KEY"), k=engine_args.search_top_k) 117 | case "brave": 118 | rm = BraveRM( 119 | brave_search_api_key=os.getenv("BRAVE_API_KEY"), 120 | k=engine_args.search_top_k, 121 | ) 122 | case "duckduckgo": 123 | rm = DuckDuckGoSearchRM( 124 | k=engine_args.search_top_k, safe_search="On", region="us-en" 125 | ) 126 | case "serper": 127 | rm = SerperRM( 128 | serper_search_api_key=os.getenv("SERPER_API_KEY"), 129 | query_params={"autocorrect": True, "num": 10, "page": 1}, 130 | ) 131 | case "tavily": 132 | rm = TavilySearchRM( 133 | tavily_search_api_key=os.getenv("TAVILY_API_KEY"), 134 | k=engine_args.search_top_k, 135 | include_raw_content=True, 136 | ) 137 | case "searxng": 138 | rm = SearXNG( 139 | searxng_api_key=os.getenv("SEARXNG_API_KEY"), k=engine_args.search_top_k 140 | ) 141 | case _: 142 | raise ValueError( 143 | f'Invalid retriever: {args.retriever}. Choose either "bing", "you", "brave", "duckduckgo", "serper", "tavily", or "searxng"' 144 | ) 145 | 146 | runner = STORMWikiRunner(engine_args, lm_configs, rm) 147 | 148 | topic = input("Topic: ") 149 | sanitized_topic = sanitize_topic(topic) 150 | 151 | try: 152 | runner.run( 153 | topic=sanitized_topic, 154 | do_research=args.do_research, 155 | do_generate_outline=args.do_generate_outline, 156 | do_generate_article=args.do_generate_article, 157 | do_polish_article=args.do_polish_article, 158 | remove_duplicate=args.remove_duplicate, 159 | ) 160 | runner.post_run() 161 | runner.summary() 162 | except Exception as e: 163 | logger.exception(f"An error occurred: {str(e)}") 164 | raise 165 | 166 | 167 | if __name__ == "__main__": 168 | parser = ArgumentParser() 169 | # global arguments 170 | parser.add_argument( 171 | "--output-dir", 172 | type=str, 173 | default="./results/groq", 174 | help="Directory to store the outputs.", 175 | ) 176 | parser.add_argument( 177 | "--max-thread-num", 178 | type=int, 179 | default=3, 180 | help="Maximum number of threads to use. The information seeking part and the article generation" 181 | "part can speed up by using multiple threads. Consider reducing it if keep getting " 182 | '"Exceed rate limit" error when calling LM API.', 183 | ) 184 | parser.add_argument( 185 | "--retriever", 186 | type=str, 187 | choices=["bing", "you", "brave", "serper", "duckduckgo", "tavily", "searxng"], 188 | help="The search engine API to use for retrieving information.", 189 | ) 190 | parser.add_argument( 191 | "--temperature", type=float, default=1.0, help="Sampling temperature to use." 192 | ) 193 | parser.add_argument( 194 | "--top_p", type=float, default=0.9, help="Top-p sampling parameter." 195 | ) 196 | # stage of the pipeline 197 | parser.add_argument( 198 | "--do-research", 199 | action="store_true", 200 | help="If True, simulate conversation to research the topic; otherwise, load the results.", 201 | ) 202 | parser.add_argument( 203 | "--do-generate-outline", 204 | action="store_true", 205 | help="If True, generate an outline for the topic; otherwise, load the results.", 206 | ) 207 | parser.add_argument( 208 | "--do-generate-article", 209 | action="store_true", 210 | help="If True, generate an article for the topic; otherwise, load the results.", 211 | ) 212 | parser.add_argument( 213 | "--do-polish-article", 214 | action="store_true", 215 | help="If True, polish the article by adding a summarization section and (optionally) removing " 216 | "duplicate content.", 217 | ) 218 | # hyperparameters for the pre-writing stage 219 | parser.add_argument( 220 | "--max-conv-turn", 221 | type=int, 222 | default=3, 223 | help="Maximum number of questions in conversational question asking.", 224 | ) 225 | parser.add_argument( 226 | "--max-perspective", 227 | type=int, 228 | default=3, 229 | help="Maximum number of perspectives to consider in perspective-guided question asking.", 230 | ) 231 | parser.add_argument( 232 | "--search-top-k", 233 | type=int, 234 | default=3, 235 | help="Top k search results to consider for each search query.", 236 | ) 237 | # hyperparameters for the writing stage 238 | parser.add_argument( 239 | "--retrieve-top-k", 240 | type=int, 241 | default=3, 242 | help="Top k collected references for each section title.", 243 | ) 244 | parser.add_argument( 245 | "--remove-duplicate", 246 | action="store_true", 247 | help="If True, remove duplicate content from the article.", 248 | ) 249 | 250 | main(parser.parse_args()) 251 | -------------------------------------------------------------------------------- /examples/storm_examples/run_storm_wiki_ollama_with_searxng.py: -------------------------------------------------------------------------------- 1 | import os 2 | from argparse import ArgumentParser 3 | 4 | from dspy import Example 5 | 6 | from knowledge_storm import ( 7 | STORMWikiRunnerArguments, 8 | STORMWikiRunner, 9 | STORMWikiLMConfigs, 10 | ) 11 | from knowledge_storm.lm import OllamaClient 12 | from knowledge_storm.rm import SearXNG 13 | from knowledge_storm.utils import load_api_key 14 | 15 | 16 | def main(args): 17 | load_api_key(toml_file_path="secrets.toml") 18 | lm_configs = STORMWikiLMConfigs() 19 | 20 | ollama_kwargs = { 21 | "model": args.model, 22 | "port": args.port, 23 | "url": args.url, 24 | "stop": ("\n\n---",), 25 | } 26 | 27 | conv_simulator_lm = OllamaClient(max_tokens=500, **ollama_kwargs) 28 | question_asker_lm = OllamaClient(max_tokens=500, **ollama_kwargs) 29 | outline_gen_lm = OllamaClient(max_tokens=400, **ollama_kwargs) 30 | article_gen_lm = OllamaClient(max_tokens=700, **ollama_kwargs) 31 | article_polish_lm = OllamaClient(max_tokens=4000, **ollama_kwargs) 32 | 33 | lm_configs.set_conv_simulator_lm(conv_simulator_lm) 34 | lm_configs.set_question_asker_lm(question_asker_lm) 35 | lm_configs.set_outline_gen_lm(outline_gen_lm) 36 | lm_configs.set_article_gen_lm(article_gen_lm) 37 | lm_configs.set_article_polish_lm(article_polish_lm) 38 | 39 | engine_args = STORMWikiRunnerArguments( 40 | output_dir=args.output_dir, 41 | max_conv_turn=args.max_conv_turn, 42 | max_perspective=args.max_perspective, 43 | search_top_k=args.search_top_k, 44 | max_thread_num=args.max_thread_num, 45 | ) 46 | 47 | rm = SearXNG( 48 | searxng_api_url=args.searxng_api_url, 49 | searxng_api_key=os.getenv("SEARXNG_API_KEY"), 50 | k=engine_args.search_top_k, 51 | ) 52 | 53 | runner = STORMWikiRunner(engine_args, lm_configs, rm) 54 | 55 | find_related_topic_example = Example( 56 | topic="Knowledge Curation", 57 | related_topics="https://en.wikipedia.org/wiki/Knowledge_management\n" 58 | "https://en.wikipedia.org/wiki/Information_science\n" 59 | "https://en.wikipedia.org/wiki/Library_science\n", 60 | ) 61 | gen_persona_example = Example( 62 | topic="Knowledge Curation", 63 | examples="Title: Knowledge management\n" 64 | "Table of Contents: History\nResearch\n Dimensions\n Strategies\n Motivations\nKM technologies" 65 | "\nKnowledge barriers\nKnowledge retention\nKnowledge audit\nKnowledge protection\n" 66 | " Knowledge protection methods\n Formal methods\n Informal methods\n" 67 | " Balancing knowledge protection and knowledge sharing\n Knowledge protection risks", 68 | personas=( 69 | "1. Historian of Knowledge Systems: This editor will focus on the history and evolution of knowledge " 70 | "curation. They will provide context on how knowledge curation has changed over time and its impact on " 71 | "modern practices.\n" 72 | "2. Information Science Professional: With insights from 'Information science', this editor will " 73 | "explore the foundational theories, definitions, and philosophy that underpin knowledge curation\n" 74 | "3. Digital Librarian: This editor will delve into the specifics of how digital libraries operate, " 75 | "including software, metadata, digital preservation.\n" 76 | "4. Technical expert: This editor will focus on the technical aspects of knowledge curation, " 77 | "such as common features of content management systems.\n" 78 | "5. Museum Curator: The museum curator will contribute expertise on the curation of physical items and " 79 | "the transition of these practices into the digital realm." 80 | ), 81 | ) 82 | runner.storm_knowledge_curation_module.persona_generator.create_writer_with_persona.find_related_topic.demos = [ 83 | find_related_topic_example 84 | ] 85 | runner.storm_knowledge_curation_module.persona_generator.create_writer_with_persona.gen_persona.demos = [ 86 | gen_persona_example 87 | ] 88 | 89 | write_page_outline_example = Example( 90 | topic="Example Topic", 91 | conv="Wikipedia Writer: ...\nExpert: ...\nWikipedia Writer: ...\nExpert: ...", 92 | old_outline="# Section 1\n## Subsection 1\n## Subsection 2\n" 93 | "# Section 2\n## Subsection 1\n## Subsection 2\n" 94 | "# Section 3", 95 | outline="# New Section 1\n## New Subsection 1\n## New Subsection 2\n" 96 | "# New Section 2\n" 97 | "# New Section 3\n## New Subsection 1\n## New Subsection 2\n## New Subsection 3", 98 | ) 99 | runner.storm_outline_generation_module.write_outline.write_page_outline.demos = [ 100 | write_page_outline_example 101 | ] 102 | write_section_example = Example( 103 | info="[1]\nInformation in document 1\n[2]\nInformation in document 2\n[3]\nInformation in document 3", 104 | topic="Example Topic", 105 | section="Example Section", 106 | output="# Example Topic\n## Subsection 1\n" 107 | "This is an example sentence [1]. This is another example sentence [2][3].\n" 108 | "## Subsection 2\nThis is one more example sentence [1].", 109 | ) 110 | runner.storm_article_generation.section_gen.write_section.demos = [ 111 | write_section_example 112 | ] 113 | 114 | topic = input("Topic: ") 115 | runner.run( 116 | topic=topic, 117 | do_research=args.do_research, 118 | do_generate_outline=args.do_generate_outline, 119 | do_generate_article=args.do_generate_article, 120 | do_polish_article=args.do_polish_article, 121 | ) 122 | runner.post_run() 123 | runner.summary() 124 | 125 | 126 | if __name__ == "__main__": 127 | parser = ArgumentParser() 128 | # global arguments 129 | parser.add_argument( 130 | "--url", type=str, default="http://localhost", help="URL of the Ollama server." 131 | ) 132 | parser.add_argument( 133 | "--port", type=int, default=11434, help="Port of the Ollama server." 134 | ) 135 | parser.add_argument( 136 | "--model", type=str, default="llama3:latest", help="Model of the Ollama server." 137 | ) 138 | parser.add_argument( 139 | "--output-dir", 140 | type=str, 141 | default="./results/ollama", 142 | help="Directory to store the outputs.", 143 | ) 144 | parser.add_argument( 145 | "--max-thread-num", 146 | type=int, 147 | default=3, 148 | help="Maximum number of threads to use. The information seeking part and the article generation" 149 | "part can speed up by using multiple threads. Consider reducing it if keep getting " 150 | '"Exceed rate limit" error when calling LM API.', 151 | ) 152 | parser.add_argument( 153 | "--retriever", 154 | type=str, 155 | choices=["searxng"], 156 | help="The search engine API to use for retrieving information.", 157 | ) 158 | parser.add_argument( 159 | "--searxng-api-url", type=str, required=True, help="URL of the SearXNG API." 160 | ) 161 | # stage of the pipeline 162 | parser.add_argument( 163 | "--do-research", 164 | action="store_true", 165 | help="If True, simulate conversation to research the topic; otherwise, load the results.", 166 | ) 167 | parser.add_argument( 168 | "--do-generate-outline", 169 | action="store_true", 170 | help="If True, generate an outline for the topic; otherwise, load the results.", 171 | ) 172 | parser.add_argument( 173 | "--do-generate-article", 174 | action="store_true", 175 | help="If True, generate an article for the topic; otherwise, load the results.", 176 | ) 177 | parser.add_argument( 178 | "--do-polish-article", 179 | action="store_true", 180 | help="If True, polish the article by adding a summarization section and (optionally) removing " 181 | "duplicate content.", 182 | ) 183 | # hyperparameters for the pre-writing stage 184 | parser.add_argument( 185 | "--max-conv-turn", 186 | type=int, 187 | default=3, 188 | help="Maximum number of questions in conversational question asking.", 189 | ) 190 | parser.add_argument( 191 | "--max-perspective", 192 | type=int, 193 | default=3, 194 | help="Maximum number of perspectives to consider in perspective-guided question asking.", 195 | ) 196 | parser.add_argument( 197 | "--search-top-k", 198 | type=int, 199 | default=3, 200 | help="Top k search results to consider for each search query.", 201 | ) 202 | # hyperparameters for the writing stage 203 | parser.add_argument( 204 | "--retrieve-top-k", 205 | type=int, 206 | default=3, 207 | help="Top k collected references for each section title.", 208 | ) 209 | parser.add_argument( 210 | "--remove-duplicate", 211 | action="store_true", 212 | help="If True, remove duplicate content from the article.", 213 | ) 214 | 215 | main(parser.parse_args()) 216 | -------------------------------------------------------------------------------- /examples/storm_examples/run_storm_wiki_serper.py: -------------------------------------------------------------------------------- 1 | """ 2 | STORM Wiki pipeline powered by Claude family models and serper search engine. 3 | You need to set up the following environment variables to run this script: 4 | - ANTHROPIC_API_KEY: Anthropic API key 5 | - SERPER_API_KEY: Serper.dev api key 6 | 7 | Output will be structured as below 8 | args.output_dir/ 9 | topic_name/ # topic_name will follow convention of underscore-connected topic name w/o space and slash 10 | conversation_log.json # Log of information-seeking conversation 11 | raw_search_results.json # Raw search results from search engine 12 | direct_gen_outline.txt # Outline directly generated with LLM's parametric knowledge 13 | storm_gen_outline.txt # Outline refined with collected information 14 | url_to_info.json # Sources that are used in the final article 15 | storm_gen_article.txt # Final article generated 16 | storm_gen_article_polished.txt # Polished final article (if args.do_polish_article is True) 17 | """ 18 | 19 | import os 20 | from argparse import ArgumentParser 21 | 22 | from knowledge_storm import ( 23 | STORMWikiRunnerArguments, 24 | STORMWikiRunner, 25 | STORMWikiLMConfigs, 26 | ) 27 | from knowledge_storm.lm import ClaudeModel 28 | from knowledge_storm.rm import SerperRM 29 | from knowledge_storm.utils import load_api_key 30 | 31 | 32 | def main(args): 33 | load_api_key(toml_file_path="secrets.toml") 34 | lm_configs = STORMWikiLMConfigs() 35 | claude_kwargs = { 36 | "api_key": os.getenv("ANTHROPIC_API_KEY"), 37 | "temperature": 1.0, 38 | "top_p": 0.9, 39 | } 40 | 41 | # STORM is a LM system so different components can be powered by different models. 42 | # For a good balance between cost and quality, you can choose a cheaper/faster model for conv_simulator_lm 43 | # which is used to split queries, synthesize answers in the conversation. We recommend using stronger models 44 | # for outline_gen_lm which is responsible for organizing the collected information, and article_gen_lm 45 | # which is responsible for generating sections with citations. 46 | conv_simulator_lm = ClaudeModel( 47 | model="claude-3-haiku-20240307", max_tokens=500, **claude_kwargs 48 | ) 49 | question_asker_lm = ClaudeModel( 50 | model="claude-3-sonnet-20240229", max_tokens=500, **claude_kwargs 51 | ) 52 | outline_gen_lm = ClaudeModel( 53 | model="claude-3-opus-20240229", max_tokens=400, **claude_kwargs 54 | ) 55 | article_gen_lm = ClaudeModel( 56 | model="claude-3-opus-20240229", max_tokens=700, **claude_kwargs 57 | ) 58 | article_polish_lm = ClaudeModel( 59 | model="claude-3-opus-20240229", max_tokens=4000, **claude_kwargs 60 | ) 61 | 62 | lm_configs.set_conv_simulator_lm(conv_simulator_lm) 63 | lm_configs.set_question_asker_lm(question_asker_lm) 64 | lm_configs.set_outline_gen_lm(outline_gen_lm) 65 | lm_configs.set_article_gen_lm(article_gen_lm) 66 | lm_configs.set_article_polish_lm(article_polish_lm) 67 | 68 | engine_args = STORMWikiRunnerArguments( 69 | output_dir=args.output_dir, 70 | max_conv_turn=args.max_conv_turn, 71 | max_perspective=args.max_perspective, 72 | search_top_k=args.search_top_k, 73 | max_thread_num=args.max_thread_num, 74 | ) 75 | # Documentation to generate the data is available here: 76 | # https://serper.dev/playground 77 | # Important to note that tbs(date range is hardcoded values). 78 | # num is results per pages and is recommended to use in increments of 10(10, 20, etc). 79 | # page is how many pages will be searched. 80 | # h1 is where the google search will orginate from. 81 | topic = input("topic: ") 82 | data = {"autocorrect": True, "num": 10, "page": 1} 83 | rm = SerperRM(serper_search_api_key=os.getenv("SERPER_API_KEY"), query_params=data) 84 | 85 | runner = STORMWikiRunner(engine_args, lm_configs, rm) 86 | 87 | runner.run( 88 | topic=topic, 89 | do_research=args.do_research, 90 | do_generate_outline=args.do_generate_outline, 91 | do_generate_article=args.do_generate_article, 92 | do_polish_article=args.do_polish_article, 93 | ) 94 | runner.post_run() 95 | runner.summary() 96 | 97 | 98 | if __name__ == "__main__": 99 | parser = ArgumentParser() 100 | # global arguments 101 | parser.add_argument( 102 | "--output-dir", 103 | type=str, 104 | default="./results/serper", 105 | help="Directory to store the outputs.", 106 | ) 107 | parser.add_argument( 108 | "--max-thread-num", 109 | type=int, 110 | default=3, 111 | help="Maximum number of threads to use. The information seeking part and the article generation" 112 | "part can speed up by using multiple threads. Consider reducing it if keep getting " 113 | '"Exceed rate limit" error when calling LM API.', 114 | ) 115 | parser.add_argument( 116 | "--retriever", 117 | type=str, 118 | choices=["bing", "you", "serper"], 119 | help="The search engine API to use for retrieving information.", 120 | ) 121 | # stage of the pipeline 122 | parser.add_argument( 123 | "--do-research", 124 | action="store_true", 125 | help="If True, simulate conversation to research the topic; otherwise, load the results.", 126 | ) 127 | parser.add_argument( 128 | "--do-generate-outline", 129 | action="store_true", 130 | help="If True, generate an outline for the topic; otherwise, load the results.", 131 | ) 132 | parser.add_argument( 133 | "--do-generate-article", 134 | action="store_true", 135 | help="If True, generate an article for the topic; otherwise, load the results.", 136 | ) 137 | parser.add_argument( 138 | "--do-polish-article", 139 | action="store_true", 140 | help="If True, polish the article by adding a summarization section and (optionally) removing " 141 | "duplicate content.", 142 | ) 143 | # hyperparameters for the pre-writing stage 144 | parser.add_argument( 145 | "--max-conv-turn", 146 | type=int, 147 | default=3, 148 | help="Maximum number of questions in conversational question asking.", 149 | ) 150 | parser.add_argument( 151 | "--max-perspective", 152 | type=int, 153 | default=3, 154 | help="Maximum number of perspectives to consider in perspective-guided question asking.", 155 | ) 156 | parser.add_argument( 157 | "--search-top-k", 158 | type=int, 159 | default=3, 160 | help="Top k search results to consider for each search query.", 161 | ) 162 | # hyperparameters for the writing stage 163 | parser.add_argument( 164 | "--retrieve-top-k", 165 | type=int, 166 | default=3, 167 | help="Top k collected references for each section title.", 168 | ) 169 | parser.add_argument( 170 | "--remove-duplicate", 171 | action="store_true", 172 | help="If True, remove duplicate content from the article.", 173 | ) 174 | 175 | main(parser.parse_args()) 176 | -------------------------------------------------------------------------------- /frontend/demo_light/.streamlit/config.toml: -------------------------------------------------------------------------------- 1 | [client] 2 | showErrorDetails = false 3 | toolbarMode = "minimal" 4 | 5 | [theme] 6 | primaryColor = "#F63366" 7 | backgroundColor = "#FFFFFF" 8 | secondaryBackgroundColor = "#F0F2F6" 9 | textColor = "#262730" 10 | font = "sans serif" -------------------------------------------------------------------------------- /frontend/demo_light/README.md: -------------------------------------------------------------------------------- 1 | # STORM Minimal User Interface 2 | 3 | This is a minimal user interface for `STORMWikiRunner` which includes the following features: 4 | 1. Allowing user to create a new article through the "Create New Article" page. 5 | 2. Showing the intermediate steps of STORMWikiRunner in real-time when creating an article. 6 | 3. Displaying the written article and references side by side. 7 | 4. Allowing user to view previously created articles through the "My Articles" page. 8 | 9 |

10 | 11 |

12 | 13 |

14 | 15 |

16 | 17 | ## Setup 18 | 1. Make sure you have installed `knowledge-storm` or set up the source code correctly. 19 | 2. Install additional packages required by the user interface: 20 | ```bash 21 | pip install -r requirements.txt 22 | ``` 23 | 2. Make sure you set up the API keys following the instructions in the main README file. Create a copy of `secrets.toml` and place it under `.streamlit/`. 24 | 3. Run the following command to start the user interface: 25 | ```bash 26 | streamlit run storm.py 27 | ``` 28 | The user interface will create a `DEMO_WORKING_DIR` directory in the current directory to store the outputs. 29 | 30 | ## Customization 31 | 32 | You can customize the `STORMWikiRunner` powering the user interface according to [the guidelines](https://github.com/stanford-oval/storm?tab=readme-ov-file#customize-storm) in the main README file. 33 | 34 | The `STORMWikiRunner` is initialized in `set_storm_runner()` in [demo_util.py](demo_util.py). You can change `STORMWikiRunnerArguments`, `STORMWikiLMConfigs`, or use a different retrieval model according to your need. 35 | -------------------------------------------------------------------------------- /frontend/demo_light/assets/article_display.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stanford-oval/storm/45ee413100f0287da9ca5250290a56ac4fa73c48/frontend/demo_light/assets/article_display.jpg -------------------------------------------------------------------------------- /frontend/demo_light/assets/create_article.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stanford-oval/storm/45ee413100f0287da9ca5250290a56ac4fa73c48/frontend/demo_light/assets/create_article.jpg -------------------------------------------------------------------------------- /frontend/demo_light/assets/void.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stanford-oval/storm/45ee413100f0287da9ca5250290a56ac4fa73c48/frontend/demo_light/assets/void.jpg -------------------------------------------------------------------------------- /frontend/demo_light/pages_util/CreateNewArticle.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | 4 | import demo_util 5 | import streamlit as st 6 | from demo_util import ( 7 | DemoFileIOHelper, 8 | DemoTextProcessingHelper, 9 | DemoUIHelper, 10 | truncate_filename, 11 | ) 12 | 13 | 14 | def handle_not_started(): 15 | if st.session_state["page3_write_article_state"] == "not started": 16 | _, search_form_column, _ = st.columns([2, 5, 2]) 17 | with search_form_column: 18 | with st.form(key="search_form"): 19 | # Text input for the search topic 20 | DemoUIHelper.st_markdown_adjust_size( 21 | content="Enter the topic you want to learn in depth:", font_size=18 22 | ) 23 | st.session_state["page3_topic"] = st.text_input( 24 | label="page3_topic", label_visibility="collapsed" 25 | ) 26 | pass_appropriateness_check = True 27 | 28 | # Submit button for the form 29 | submit_button = st.form_submit_button(label="Research") 30 | # only start new search when button is clicked, not started, or already finished previous one 31 | if submit_button and st.session_state["page3_write_article_state"] in [ 32 | "not started", 33 | "show results", 34 | ]: 35 | if not st.session_state["page3_topic"].strip(): 36 | pass_appropriateness_check = False 37 | st.session_state["page3_warning_message"] = ( 38 | "topic could not be empty" 39 | ) 40 | 41 | st.session_state["page3_topic_name_cleaned"] = ( 42 | st.session_state["page3_topic"] 43 | .replace(" ", "_") 44 | .replace("/", "_") 45 | ) 46 | st.session_state["page3_topic_name_truncated"] = truncate_filename( 47 | st.session_state["page3_topic_name_cleaned"] 48 | ) 49 | if not pass_appropriateness_check: 50 | st.session_state["page3_write_article_state"] = "not started" 51 | alert = st.warning( 52 | st.session_state["page3_warning_message"], icon="⚠️" 53 | ) 54 | time.sleep(5) 55 | alert.empty() 56 | else: 57 | st.session_state["page3_write_article_state"] = "initiated" 58 | 59 | 60 | def handle_initiated(): 61 | if st.session_state["page3_write_article_state"] == "initiated": 62 | current_working_dir = os.path.join(demo_util.get_demo_dir(), "DEMO_WORKING_DIR") 63 | if not os.path.exists(current_working_dir): 64 | os.makedirs(current_working_dir) 65 | 66 | if "runner" not in st.session_state: 67 | demo_util.set_storm_runner() 68 | st.session_state["page3_current_working_dir"] = current_working_dir 69 | st.session_state["page3_write_article_state"] = "pre_writing" 70 | 71 | 72 | def handle_pre_writing(): 73 | if st.session_state["page3_write_article_state"] == "pre_writing": 74 | status = st.status( 75 | "I am brain**STORM**ing now to research the topic. (This may take 2-3 minutes.)" 76 | ) 77 | st_callback_handler = demo_util.StreamlitCallbackHandler(status) 78 | with status: 79 | # STORM main gen outline 80 | st.session_state["runner"].run( 81 | topic=st.session_state["page3_topic"], 82 | do_research=True, 83 | do_generate_outline=True, 84 | do_generate_article=False, 85 | do_polish_article=False, 86 | callback_handler=st_callback_handler, 87 | ) 88 | conversation_log_path = os.path.join( 89 | st.session_state["page3_current_working_dir"], 90 | st.session_state["page3_topic_name_truncated"], 91 | "conversation_log.json", 92 | ) 93 | demo_util._display_persona_conversations( 94 | DemoFileIOHelper.read_json_file(conversation_log_path) 95 | ) 96 | st.session_state["page3_write_article_state"] = "final_writing" 97 | status.update(label="brain**STORM**ing complete!", state="complete") 98 | 99 | 100 | def handle_final_writing(): 101 | if st.session_state["page3_write_article_state"] == "final_writing": 102 | # polish final article 103 | with st.status( 104 | "Now I will connect the information I found for your reference. (This may take 4-5 minutes.)" 105 | ) as status: 106 | st.info( 107 | "Now I will connect the information I found for your reference. (This may take 4-5 minutes.)" 108 | ) 109 | st.session_state["runner"].run( 110 | topic=st.session_state["page3_topic"], 111 | do_research=False, 112 | do_generate_outline=False, 113 | do_generate_article=True, 114 | do_polish_article=True, 115 | remove_duplicate=False, 116 | ) 117 | # finish the session 118 | st.session_state["runner"].post_run() 119 | 120 | # update status bar 121 | st.session_state["page3_write_article_state"] = "prepare_to_show_result" 122 | status.update(label="information snythesis complete!", state="complete") 123 | 124 | 125 | def handle_prepare_to_show_result(): 126 | if st.session_state["page3_write_article_state"] == "prepare_to_show_result": 127 | _, show_result_col, _ = st.columns([4, 3, 4]) 128 | with show_result_col: 129 | if st.button("show final article"): 130 | st.session_state["page3_write_article_state"] = "completed" 131 | st.rerun() 132 | 133 | 134 | def handle_completed(): 135 | if st.session_state["page3_write_article_state"] == "completed": 136 | # display polished article 137 | current_working_dir_paths = DemoFileIOHelper.read_structure_to_dict( 138 | st.session_state["page3_current_working_dir"] 139 | ) 140 | current_article_file_path_dict = current_working_dir_paths[ 141 | st.session_state["page3_topic_name_truncated"] 142 | ] 143 | demo_util.display_article_page( 144 | selected_article_name=st.session_state["page3_topic_name_cleaned"], 145 | selected_article_file_path_dict=current_article_file_path_dict, 146 | show_title=True, 147 | show_main_article=True, 148 | ) 149 | 150 | 151 | def create_new_article_page(): 152 | demo_util.clear_other_page_session_state(page_index=3) 153 | 154 | if "page3_write_article_state" not in st.session_state: 155 | st.session_state["page3_write_article_state"] = "not started" 156 | 157 | handle_not_started() 158 | 159 | handle_initiated() 160 | 161 | handle_pre_writing() 162 | 163 | handle_final_writing() 164 | 165 | handle_prepare_to_show_result() 166 | 167 | handle_completed() 168 | -------------------------------------------------------------------------------- /frontend/demo_light/pages_util/MyArticles.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import demo_util 4 | import streamlit as st 5 | from demo_util import DemoFileIOHelper, DemoUIHelper 6 | from streamlit_card import card 7 | 8 | 9 | # set page config and display title 10 | def my_articles_page(): 11 | with st.sidebar: 12 | _, return_button_col = st.columns([2, 5]) 13 | with return_button_col: 14 | if st.button( 15 | "Select another article", 16 | disabled="page2_selected_my_article" not in st.session_state, 17 | ): 18 | if "page2_selected_my_article" in st.session_state: 19 | del st.session_state["page2_selected_my_article"] 20 | st.rerun() 21 | 22 | # sync my articles 23 | if "page2_user_articles_file_path_dict" not in st.session_state: 24 | local_dir = os.path.join(demo_util.get_demo_dir(), "DEMO_WORKING_DIR") 25 | os.makedirs(local_dir, exist_ok=True) 26 | st.session_state["page2_user_articles_file_path_dict"] = ( 27 | DemoFileIOHelper.read_structure_to_dict(local_dir) 28 | ) 29 | 30 | # if no feature demo selected, display all featured articles as info cards 31 | def article_card_setup(column_to_add, card_title, article_name): 32 | with column_to_add: 33 | cleaned_article_title = article_name.replace("_", " ") 34 | hasClicked = card( 35 | title=" / ".join(card_title), 36 | text=article_name.replace("_", " "), 37 | image=DemoFileIOHelper.read_image_as_base64( 38 | os.path.join(demo_util.get_demo_dir(), "assets", "void.jpg") 39 | ), 40 | styles=DemoUIHelper.get_article_card_UI_style(boarder_color="#9AD8E1"), 41 | ) 42 | if hasClicked: 43 | st.session_state["page2_selected_my_article"] = article_name 44 | st.rerun() 45 | 46 | if "page2_selected_my_article" not in st.session_state: 47 | # display article cards 48 | my_article_columns = st.columns(3) 49 | if len(st.session_state["page2_user_articles_file_path_dict"]) > 0: 50 | # get article names 51 | article_names = sorted( 52 | list(st.session_state["page2_user_articles_file_path_dict"].keys()) 53 | ) 54 | # configure pagination 55 | pagination = st.container() 56 | bottom_menu = st.columns((1, 4, 1, 1, 1))[1:-1] 57 | with bottom_menu[2]: 58 | batch_size = st.selectbox("Page Size", options=[24, 48, 72]) 59 | with bottom_menu[1]: 60 | total_pages = ( 61 | int(len(article_names) / batch_size) 62 | if int(len(article_names) / batch_size) > 0 63 | else 1 64 | ) 65 | current_page = st.number_input( 66 | "Page", min_value=1, max_value=total_pages, step=1 67 | ) 68 | with bottom_menu[0]: 69 | st.markdown(f"Page **{current_page}** of **{total_pages}** ") 70 | # show article cards 71 | with pagination: 72 | my_article_count = 0 73 | start_index = (current_page - 1) * batch_size 74 | end_index = min(current_page * batch_size, len(article_names)) 75 | for article_name in article_names[start_index:end_index]: 76 | column_to_add = my_article_columns[my_article_count % 3] 77 | my_article_count += 1 78 | article_card_setup( 79 | column_to_add=column_to_add, 80 | card_title=["My Article"], 81 | article_name=article_name, 82 | ) 83 | else: 84 | with my_article_columns[0]: 85 | hasClicked = card( 86 | title="Get started", 87 | text="Start your first research!", 88 | image=DemoFileIOHelper.read_image_as_base64( 89 | os.path.join(demo_util.get_demo_dir(), "assets", "void.jpg") 90 | ), 91 | styles=DemoUIHelper.get_article_card_UI_style(), 92 | ) 93 | if hasClicked: 94 | st.session_state.selected_page = 1 95 | st.session_state["manual_selection_override"] = True 96 | st.session_state["rerun_requested"] = True 97 | st.rerun() 98 | else: 99 | selected_article_name = st.session_state["page2_selected_my_article"] 100 | selected_article_file_path_dict = st.session_state[ 101 | "page2_user_articles_file_path_dict" 102 | ][selected_article_name] 103 | 104 | demo_util.display_article_page( 105 | selected_article_name=selected_article_name, 106 | selected_article_file_path_dict=selected_article_file_path_dict, 107 | show_title=True, 108 | show_main_article=True, 109 | ) 110 | -------------------------------------------------------------------------------- /frontend/demo_light/requirements.txt: -------------------------------------------------------------------------------- 1 | streamlit==1.31.1 2 | streamlit-card 3 | markdown 4 | unidecode 5 | extra-streamlit-components==0.1.60 6 | streamlit_extras 7 | deprecation==2.1.0 8 | st-pages==0.4.5 9 | streamlit-float 10 | streamlit-option-menu -------------------------------------------------------------------------------- /frontend/demo_light/stoc.py: -------------------------------------------------------------------------------- 1 | """https://github.com/arnaudmiribel/stoc""" 2 | 3 | import re 4 | 5 | import streamlit as st 6 | import unidecode 7 | 8 | DISABLE_LINK_CSS = """ 9 | """ 15 | 16 | 17 | class stoc: 18 | def __init__(self): 19 | self.toc_items = list() 20 | 21 | def h1(self, text: str, write: bool = True): 22 | if write: 23 | st.write(f"# {text}") 24 | self.toc_items.append(("h1", text)) 25 | 26 | def h2(self, text: str, write: bool = True): 27 | if write: 28 | st.write(f"## {text}") 29 | self.toc_items.append(("h2", text)) 30 | 31 | def h3(self, text: str, write: bool = True): 32 | if write: 33 | st.write(f"### {text}") 34 | self.toc_items.append(("h3", text)) 35 | 36 | def toc(self, expander): 37 | st.write(DISABLE_LINK_CSS, unsafe_allow_html=True) 38 | # st.sidebar.caption("Table of contents") 39 | if expander is None: 40 | expander = st.sidebar.expander("**Table of contents**", expanded=True) 41 | with expander: 42 | with st.container(height=600, border=False): 43 | markdown_toc = "" 44 | for title_size, title in self.toc_items: 45 | h = int(title_size.replace("h", "")) 46 | markdown_toc += ( 47 | " " * 2 * h 48 | + "- " 49 | + f' {title} \n' 50 | ) 51 | # st.sidebar.write(markdown_toc, unsafe_allow_html=True) 52 | st.write(markdown_toc, unsafe_allow_html=True) 53 | 54 | @classmethod 55 | def get_toc(cls, markdown_text: str, topic=""): 56 | def increase_heading_depth_and_add_top_heading(markdown_text, new_top_heading): 57 | lines = markdown_text.splitlines() 58 | # Increase the depth of each heading by adding an extra '#' 59 | increased_depth_lines = [ 60 | "#" + line if line.startswith("#") else line for line in lines 61 | ] 62 | # Add the new top-level heading at the beginning 63 | increased_depth_lines.insert(0, f"# {new_top_heading}") 64 | # Re-join the modified lines back into a single string 65 | modified_text = "\n".join(increased_depth_lines) 66 | return modified_text 67 | 68 | if topic: 69 | markdown_text = increase_heading_depth_and_add_top_heading( 70 | markdown_text, topic 71 | ) 72 | toc = [] 73 | for line in markdown_text.splitlines(): 74 | if line.startswith("#"): 75 | # Remove the '#' characters and strip leading/trailing spaces 76 | heading_text = line.lstrip("#").strip() 77 | # Create slug (lowercase, spaces to hyphens, remove non-alphanumeric characters) 78 | slug = ( 79 | re.sub(r"[^a-zA-Z0-9\s-]", "", heading_text) 80 | .lower() 81 | .replace(" ", "-") 82 | ) 83 | # Determine heading level for indentation 84 | level = line.count("#") - 1 85 | # Add to the table of contents 86 | toc.append(" " * level + f"- [{heading_text}](#{slug})") 87 | return "\n".join(toc) 88 | 89 | @classmethod 90 | def from_markdown(cls, text: str, expander=None): 91 | self = cls() 92 | for line in text.splitlines(): 93 | if line.startswith("###"): 94 | self.h3(line[3:], write=False) 95 | elif line.startswith("##"): 96 | self.h2(line[2:], write=False) 97 | elif line.startswith("#"): 98 | self.h1(line[1:], write=False) 99 | # customize markdown font size 100 | custom_css = """ 101 | 111 | """ 112 | st.markdown(custom_css, unsafe_allow_html=True) 113 | 114 | st.write(text) 115 | self.toc(expander=expander) 116 | 117 | 118 | def normalize(s): 119 | """ 120 | Normalize titles as valid HTML ids for anchors 121 | >>> normalize("it's a test to spot how Things happ3n héhé") 122 | "it-s-a-test-to-spot-how-things-happ3n-h-h" 123 | """ 124 | 125 | # Replace accents with "-" 126 | s_wo_accents = unidecode.unidecode(s) 127 | accents = [s for s in s if s not in s_wo_accents] 128 | for accent in accents: 129 | s = s.replace(accent, "-") 130 | 131 | # Lowercase 132 | s = s.lower() 133 | 134 | # Keep only alphanum and remove "-" suffix if existing 135 | normalized = ( 136 | "".join([char if char.isalnum() else "-" for char in s]).strip("-").lower() 137 | ) 138 | 139 | return normalized 140 | -------------------------------------------------------------------------------- /frontend/demo_light/storm.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | script_dir = os.path.dirname(os.path.abspath(__file__)) 4 | wiki_root_dir = os.path.dirname(os.path.dirname(script_dir)) 5 | 6 | import demo_util 7 | from pages_util import MyArticles, CreateNewArticle 8 | from streamlit_float import * 9 | from streamlit_option_menu import option_menu 10 | 11 | 12 | def main(): 13 | global database 14 | st.set_page_config(layout="wide") 15 | 16 | if "first_run" not in st.session_state: 17 | st.session_state["first_run"] = True 18 | 19 | # set api keys from secrets 20 | if st.session_state["first_run"]: 21 | for key, value in st.secrets.items(): 22 | if type(value) == str: 23 | os.environ[key] = value 24 | 25 | # initialize session_state 26 | if "selected_article_index" not in st.session_state: 27 | st.session_state["selected_article_index"] = 0 28 | if "selected_page" not in st.session_state: 29 | st.session_state["selected_page"] = 0 30 | if st.session_state.get("rerun_requested", False): 31 | st.session_state["rerun_requested"] = False 32 | st.rerun() 33 | 34 | st.write( 35 | "", unsafe_allow_html=True 36 | ) 37 | menu_container = st.container() 38 | with menu_container: 39 | pages = ["My Articles", "Create New Article"] 40 | styles = { 41 | "container": {"padding": "0.2rem 0", "background-color": "#22222200"}, 42 | } 43 | menu_selection = option_menu( 44 | None, 45 | pages, 46 | icons=["house", "search"], 47 | menu_icon="cast", 48 | default_index=0, 49 | orientation="horizontal", 50 | manual_select=st.session_state.selected_page, 51 | styles=styles, 52 | key="menu_selection", 53 | ) 54 | if st.session_state.get("manual_selection_override", False): 55 | menu_selection = pages[st.session_state["selected_page"]] 56 | st.session_state["manual_selection_override"] = False 57 | st.session_state["selected_page"] = None 58 | 59 | if menu_selection == "My Articles": 60 | demo_util.clear_other_page_session_state(page_index=2) 61 | MyArticles.my_articles_page() 62 | elif menu_selection == "Create New Article": 63 | demo_util.clear_other_page_session_state(page_index=3) 64 | CreateNewArticle.create_new_article_page() 65 | 66 | 67 | if __name__ == "__main__": 68 | main() 69 | -------------------------------------------------------------------------------- /knowledge_storm/__init__.py: -------------------------------------------------------------------------------- 1 | from .storm_wiki import * 2 | from .collaborative_storm import * 3 | from .encoder import * 4 | from .interface import * 5 | from .lm import * 6 | from .rm import * 7 | from .utils import * 8 | from .dataclass import * 9 | 10 | __version__ = "1.1.0" 11 | -------------------------------------------------------------------------------- /knowledge_storm/collaborative_storm/__init__.py: -------------------------------------------------------------------------------- 1 | from .modules import * 2 | from .engine import * 3 | -------------------------------------------------------------------------------- /knowledge_storm/collaborative_storm/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .article_generation import * 2 | from .grounded_question_answering import * 3 | from .grounded_question_generation import * 4 | from .information_insertion_module import * 5 | from .simulate_user import * 6 | from .warmstart_hierarchical_chat import * 7 | from .knowledge_base_summary import * 8 | from .costorm_expert_utterance_generator import * 9 | -------------------------------------------------------------------------------- /knowledge_storm/collaborative_storm/modules/article_generation.py: -------------------------------------------------------------------------------- 1 | import dspy 2 | from concurrent.futures import ThreadPoolExecutor, as_completed 3 | from typing import Set, Union 4 | 5 | from .collaborative_storm_utils import clean_up_section 6 | from ...dataclass import KnowledgeBase, KnowledgeNode 7 | 8 | 9 | class ArticleGenerationModule(dspy.Module): 10 | """Use the information collected from the information-seeking conversation to write a section.""" 11 | 12 | def __init__( 13 | self, 14 | engine: Union[dspy.dsp.LM, dspy.dsp.HFModel], 15 | ): 16 | super().__init__() 17 | self.write_section = dspy.Predict(WriteSection) 18 | self.engine = engine 19 | 20 | def _get_cited_information_string( 21 | self, 22 | all_citation_index: Set[int], 23 | knowledge_base: KnowledgeBase, 24 | max_words: int = 4000, 25 | ): 26 | information = [] 27 | cur_word_count = 0 28 | for index in sorted(list(all_citation_index)): 29 | info = knowledge_base.info_uuid_to_info_dict[index] 30 | snippet = info.snippets[0] 31 | info_text = f"[{index}]: {snippet} (Question: {info.meta['question']}. Query: {info.meta['query']})" 32 | cur_snippet_length = len(info_text.split()) 33 | if cur_snippet_length + cur_word_count > max_words: 34 | break 35 | cur_word_count += cur_snippet_length 36 | information.append(info_text) 37 | return "\n".join(information) 38 | 39 | def gen_section( 40 | self, topic: str, node: KnowledgeNode, knowledge_base: KnowledgeBase 41 | ): 42 | if node is None or len(node.content) == 0: 43 | return "" 44 | if ( 45 | node.synthesize_output is not None 46 | and node.synthesize_output 47 | and not node.need_regenerate_synthesize_output 48 | ): 49 | return node.synthesize_output 50 | all_citation_index = node.collect_all_content() 51 | information = self._get_cited_information_string( 52 | all_citation_index=all_citation_index, knowledge_base=knowledge_base 53 | ) 54 | with dspy.settings.context(lm=self.engine): 55 | synthesize_output = clean_up_section( 56 | self.write_section( 57 | topic=topic, info=information, section=node.name 58 | ).output 59 | ) 60 | node.synthesize_output = synthesize_output 61 | node.need_regenerate_synthesize_output = False 62 | return node.synthesize_output 63 | 64 | def forward(self, knowledge_base: KnowledgeBase): 65 | all_nodes = knowledge_base.collect_all_nodes() 66 | node_to_paragraph = {} 67 | 68 | # Define a function to generate paragraphs for nodes 69 | def _node_generate_paragraph(node): 70 | node_gen_paragraph = self.gen_section( 71 | topic=knowledge_base.topic, node=node, knowledge_base=knowledge_base 72 | ) 73 | lines = node_gen_paragraph.split("\n") 74 | if lines[0].strip().replace("*", "").replace("#", "") == node.name: 75 | lines = lines[1:] 76 | node_gen_paragraph = "\n".join(lines) 77 | path = " -> ".join(node.get_path_from_root()) 78 | return path, node_gen_paragraph 79 | 80 | with ThreadPoolExecutor(max_workers=5) as executor: 81 | # Submit all tasks 82 | future_to_node = { 83 | executor.submit(_node_generate_paragraph, node): node 84 | for node in all_nodes 85 | } 86 | 87 | # Collect the results as they complete 88 | for future in as_completed(future_to_node): 89 | path, node_gen_paragraph = future.result() 90 | node_to_paragraph[path] = node_gen_paragraph 91 | 92 | def helper(cur_root, level): 93 | to_return = [] 94 | if cur_root is not None: 95 | hash_tag = "#" * level + " " 96 | cur_path = " -> ".join(cur_root.get_path_from_root()) 97 | node_gen_paragraph = node_to_paragraph[cur_path] 98 | to_return.append(f"{hash_tag}{cur_root.name}\n{node_gen_paragraph}") 99 | for child in cur_root.children: 100 | to_return.extend(helper(child, level + 1)) 101 | return to_return 102 | 103 | to_return = [] 104 | for child in knowledge_base.root.children: 105 | to_return.extend(helper(child, level=1)) 106 | 107 | return "\n".join(to_return) 108 | 109 | 110 | class WriteSection(dspy.Signature): 111 | """Write a Wikipedia section based on the collected information. You will be given the topic, the section you are writing and relevant information. 112 | Each information will be provided with the raw content along with question and query lead to that information. 113 | Here is the format of your writing: 114 | Use [1], [2], ..., [n] in line (for example, "The capital of the United States is Washington, D.C.[1][3]."). You DO NOT need to include a References or Sources section to list the sources at the end. 115 | """ 116 | 117 | info = dspy.InputField(prefix="The collected information:\n", format=str) 118 | topic = dspy.InputField(prefix="The topic of the page: ", format=str) 119 | section = dspy.InputField(prefix="The section you need to write: ", format=str) 120 | output = dspy.OutputField( 121 | prefix="Write the section with proper inline citations (Start your writing. Don't include the page title, section name, or try to write other sections. Do not start the section with topic name.):\n", 122 | format=str, 123 | ) 124 | -------------------------------------------------------------------------------- /knowledge_storm/collaborative_storm/modules/callback.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | from ...interface import Information 3 | 4 | 5 | class BaseCallbackHandler: 6 | """Base callback handler to manage callbacks from the Co-STORM pipeline.""" 7 | 8 | def on_turn_policy_planning_start(self, **kwargs): 9 | """Run when the turn policy planning begins, before deciding the direction or goal for the next conversation turn.""" 10 | pass 11 | 12 | def on_expert_action_planning_start(self, **kwargs): 13 | """Run when the expert action planning begins, preparing to determine the actions that each expert should take.""" 14 | pass 15 | 16 | def on_expert_action_planning_end(self, **kwargs): 17 | """Run when the expert action planning ends, after deciding the actions that each expert should take.""" 18 | pass 19 | 20 | def on_expert_information_collection_start(self, **kwargs): 21 | """Run when the expert information collection starts, start gathering all necessary data from selected sources.""" 22 | pass 23 | 24 | def on_expert_information_collection_end(self, info: List[Information], **kwargs): 25 | """Run when the expert information collection ends, after gathering all necessary data from selected sources.""" 26 | pass 27 | 28 | def on_expert_utterance_generation_end(self, **kwargs): 29 | """Run when the expert utterance generation ends, before creating responses or statements from each expert.""" 30 | pass 31 | 32 | def on_expert_utterance_polishing_start(self, **kwargs): 33 | """Run when the expert utterance polishing begins, to refine and improve the clarity and coherence of generated content.""" 34 | pass 35 | 36 | def on_mindmap_insert_start(self, **kwargs): 37 | """Run when the process of inserting new information into the mindmap starts.""" 38 | pass 39 | 40 | def on_mindmap_insert_end(self, **kwargs): 41 | """Run when the process of inserting new information into the mindmap ends.""" 42 | pass 43 | 44 | def on_mindmap_reorg_start(self, **kwargs): 45 | """Run when the reorganization of the mindmap begins, to restructure and optimize the flow of information.""" 46 | pass 47 | 48 | def on_expert_list_update_start(self, **kwargs): 49 | """Run when the expert list update starts, to modify or refresh the list of active experts.""" 50 | pass 51 | 52 | def on_article_generation_start(self, **kwargs): 53 | """Run when the article generation process begins, to compile and format the final article content.""" 54 | pass 55 | 56 | def on_warmstart_update(self, message, **kwargs): 57 | """Run when the warm start process has update.""" 58 | pass 59 | 60 | 61 | class LocalConsolePrintCallBackHandler(BaseCallbackHandler): 62 | def __init__(self): 63 | pass 64 | 65 | def on_turn_policy_planning_start(self, **kwargs): 66 | """Run when the turn policy planning begins, before deciding the direction or goal for the next conversation turn.""" 67 | print("Start planning next expert; inspect mind map; inspect system state.") 68 | 69 | def on_expert_action_planning_start(self, **kwargs): 70 | """Run when the expert action planning begins, preparing to determine the actions that each expert should take.""" 71 | print("Reviewing discourse history; Deciding utterance intent.") 72 | 73 | def on_expert_information_collection_start(self, **kwargs): 74 | """Run when the expert information collection ends, after gathering all necessary data from selected sources.""" 75 | print("Start searching with the search engine; browsing collected information.") 76 | 77 | def on_expert_information_collection_end(self, info: List[Information], **kwargs): 78 | """Run when the expert information collection ends, after gathering all necessary data from selected sources.""" 79 | if info: 80 | urls = [i.url for i in info] 81 | information_string = "\n".join([f"Finish browsing {url}" for url in urls]) 82 | print(information_string) 83 | 84 | def on_expert_utterance_generation_end(self, **kwargs): 85 | """Run when the expert utterance generation ends, before creating responses or statements from each expert.""" 86 | print("Finish generating utterance from collected information.") 87 | 88 | def on_expert_utterance_polishing_start(self, **kwargs): 89 | """Run when the expert utterance polishing begins, to refine and improve the clarity and coherence of generated content.""" 90 | print("Start polishing utterance.") 91 | 92 | def on_mindmap_insert_start(self, **kwargs): 93 | """Run when the process of inserting new information into the mindmap starts.""" 94 | print("Start inserting information into mind map.") 95 | 96 | def on_mindmap_insert_end(self, **kwargs): 97 | """Run when the process of inserting new information into the mindmap ends.""" 98 | print("Finish inserting information into mind map.") 99 | 100 | def on_mindmap_reorg_start(self, **kwargs): 101 | """Run when the reorganization of the mindmap begins, to restructure and optimize the flow of information.""" 102 | print("Start re-organizing mind map.") 103 | 104 | def on_expert_list_update_start(self, **kwargs): 105 | """Run when the expert list update starts, to modify or refresh the list of active experts.""" 106 | print("Start updating expert candidates.") 107 | 108 | def on_warmstart_update(self, message, **kwargs): 109 | """Run when the warm start process has update.""" 110 | print(f"Warm start update: {message}") 111 | -------------------------------------------------------------------------------- /knowledge_storm/collaborative_storm/modules/collaborative_storm_utils.py: -------------------------------------------------------------------------------- 1 | import dspy 2 | import os 3 | import re 4 | import sys 5 | import toml 6 | from typing import List, Tuple, Dict, Optional, TYPE_CHECKING 7 | 8 | if TYPE_CHECKING: 9 | from ..engine import RunnerArgument 10 | from ...interface import Information, Retriever, LMConfigs 11 | from ...logging_wrapper import LoggingWrapper 12 | from ...rm import BingSearch 13 | 14 | 15 | def extract_storm_info_snippet(info: Information, snippet_index: int) -> Information: 16 | """ 17 | Constructs a new Information instance with only the specified snippet index. 18 | 19 | Args: 20 | storm_info (Information): The original Information instance. 21 | snippet_index (int): The index of the snippet to retain. 22 | 23 | Returns: 24 | Information: A new Information instance with only the specified snippet. 25 | """ 26 | if snippet_index < 0 or snippet_index >= len(info.snippets): 27 | raise ValueError("Snippet index out of range") 28 | 29 | new_snippets = [info.snippets[snippet_index]] 30 | new_storm_info = Information( 31 | info.url, info.description, new_snippets, info.title, info.meta 32 | ) 33 | return new_storm_info 34 | 35 | 36 | def format_search_results( 37 | searched_results: List[Information], 38 | info_max_num_words: int = 1000, 39 | mode: str = "brief", 40 | ) -> Tuple[str, Dict[int, Information]]: 41 | """ 42 | Constructs a string from a list of search results with a specified word limit and returns a mapping of indices to Information. 43 | 44 | Args: 45 | searched_results (List[Information]): List of Information objects to process. 46 | info_max_num_words (int, optional): Maximum number of words allowed in the output string. Defaults to 1000. 47 | mode (str, optional): Mode of summarization. 'brief' takes only the first snippet of each Information. 48 | 'extensive' adds snippets iteratively until the word limit is reached. Defaults to 'brief'. 49 | 50 | Returns: 51 | Tuple[str, Dict[int, Information]]: 52 | - Formatted string with search results, constrained by the word limit. 53 | - Dictionary mapping indices to the corresponding Information objects. 54 | """ 55 | total_length = 0 56 | 57 | extracted_snippet_queue = [] 58 | max_snippets = ( 59 | max(len(info.snippets) for info in searched_results) if searched_results else 0 60 | ) 61 | max_snippets = 1 if mode == "brief" else max_snippets 62 | abort = False 63 | included_snippets = set() 64 | for i in range(max_snippets): 65 | for info in searched_results: 66 | if i < len(info.snippets) and not abort: 67 | cur_snippet = info.snippets[i] 68 | cur_snippet_len = len(info.snippets[i].split()) 69 | if total_length + cur_snippet_len > info_max_num_words: 70 | abort = True 71 | break 72 | if cur_snippet not in included_snippets: 73 | included_snippets.add(cur_snippet) 74 | info = extract_storm_info_snippet(info, snippet_index=i) 75 | extracted_snippet_queue.append(info) 76 | total_length += cur_snippet_len 77 | output = [] 78 | index_mapping = {} 79 | for idx, info in enumerate(extracted_snippet_queue): 80 | output.append(f"[{idx + 1}]: {info.snippets[0]}") 81 | index_mapping[idx + 1] = info 82 | assert -1 not in index_mapping 83 | return "\n".join(output), index_mapping 84 | 85 | 86 | def extract_cited_storm_info( 87 | response: str, index_to_storm_info: Dict[int, Information] 88 | ) -> Dict[int, Information]: 89 | """ 90 | Extracts a sub-dictionary of Information instances that are cited in the response. 91 | 92 | Args: 93 | response (str): The response string containing inline citations like [1], [2], etc. 94 | index_to_storm_info (Dict[int, Information]): A dictionary mapping indices to Information instances. 95 | 96 | Returns: 97 | Dict[int, Information]: A sub-dictionary with only the indices that appear in the response. 98 | """ 99 | cited_indices = set(map(int, re.findall(r"\[(\d+)\]", response))) 100 | cited_storm_info = { 101 | index: info 102 | for index, info in index_to_storm_info.items() 103 | if index in cited_indices 104 | } 105 | return cited_storm_info 106 | 107 | 108 | def trim_output_after_hint(response: str, hint: str) -> str: 109 | """ 110 | Trims the output string to only keep the substring after the given hint (not including the hint). 111 | 112 | Args: 113 | response (str): The original output string. 114 | hint (str): The hint string after which the substring should be kept. 115 | 116 | Returns: 117 | str: The trimmed output string, or the original string if the hint is not found. 118 | """ 119 | if hint in response: 120 | start_index = response.find(hint) + len(hint) 121 | return response[start_index:].strip() 122 | return response.strip("\n") 123 | 124 | 125 | def separate_citations(text: str) -> str: 126 | """ 127 | Separates multiple citations within square brackets into individual citations. 128 | 129 | Args: 130 | text (str): The input string containing citations. 131 | 132 | Returns: 133 | str: The string with separated citations. 134 | """ 135 | 136 | # Define a function to process each match 137 | def replace_citations(match): 138 | citations = match.group(1).split(",") 139 | return "".join(f"[{citation.strip()}]" for citation in citations) 140 | 141 | # Use regular expressions to find and replace citations 142 | pattern = re.compile(r"\[(\d+(?:,\s*\d+)*)\]") 143 | return pattern.sub(replace_citations, text) 144 | 145 | 146 | def extract_and_remove_citations(text: str) -> Tuple[str, List[int]]: 147 | """ 148 | Removes single inline citations from the input string and returns the modified string and a list of citation integers. 149 | 150 | Args: 151 | text (str): The input string containing citations. 152 | 153 | Returns: 154 | Tuple[str, List[int]]: The string after removal of citations and a list of citation integers. 155 | """ 156 | citations = [] 157 | 158 | # Define a function to process each match 159 | def extract_citation(match): 160 | citation = int(match.group(1)) 161 | citations.append(citation) 162 | return "" 163 | 164 | # Use regular expressions to find and replace citations 165 | pattern = re.compile(r"\[(\d+)\]") 166 | modified_text = pattern.sub(extract_citation, text) 167 | 168 | return modified_text, citations 169 | 170 | 171 | def keep_first_and_last_paragraph(text: str) -> str: 172 | """ 173 | Processes the input text to keep the first and last paragraphs and replace 174 | the middle paragraphs with '[content omitted due to space limit]'. 175 | 176 | Args: 177 | text (str): The input text containing paragraphs separated by '\n\n'. 178 | 179 | Returns: 180 | str: The processed text. 181 | """ 182 | paragraphs = text.split("\n\n") 183 | 184 | if len(paragraphs) <= 3: 185 | return text 186 | 187 | first_paragraph = paragraphs[0] 188 | last_paragraph = "\n\n".join(paragraphs[-2:]) 189 | return ( 190 | f"{first_paragraph}\n\n[content omitted due to space limit]\n\n{last_paragraph}" 191 | ) 192 | 193 | 194 | def clean_up_section(text): 195 | """Clean up a section: 196 | 1. Remove uncompleted sentences (usually due to output token limitation). 197 | 2. Deduplicate individual groups of citations. 198 | 3. Remove unnecessary summary.""" 199 | 200 | paragraphs = text.split("\n") 201 | output_paragraphs = [] 202 | summary_sec_flag = False 203 | for p in paragraphs: 204 | p = p.strip() 205 | if len(p) == 0: 206 | continue 207 | if not p.startswith("#"): 208 | p = separate_citations(p) 209 | if summary_sec_flag: 210 | if p.startswith("#"): 211 | summary_sec_flag = False 212 | else: 213 | continue 214 | if ( 215 | p.startswith("Overall") 216 | or p.startswith("In summary") 217 | or p.startswith("In conclusion") 218 | ): 219 | continue 220 | if "# Summary" in p or "# Conclusion" in p: 221 | summary_sec_flag = True 222 | continue 223 | output_paragraphs.append(p) 224 | 225 | return "\n\n".join(output_paragraphs) # Join with '\n\n' for markdown format. 226 | 227 | 228 | def load_api_key(toml_file_path): 229 | try: 230 | with open(toml_file_path, "r") as file: 231 | data = toml.load(file) 232 | except FileNotFoundError: 233 | print(f"File not found: {toml_file_path}", file=sys.stderr) 234 | return 235 | except toml.TomlDecodeError: 236 | print(f"Error decoding TOML file: {toml_file_path}", file=sys.stderr) 237 | return 238 | # Set environment variables 239 | for key, value in data.items(): 240 | os.environ[key] = str(value) 241 | 242 | 243 | def _get_answer_question_module_instance( 244 | lm_config: LMConfigs, 245 | runner_argument: "RunnerArgument", 246 | logging_wrapper: LoggingWrapper, 247 | rm: Optional[dspy.Retrieve] = None, 248 | ): 249 | from .grounded_question_answering import AnswerQuestionModule 250 | 251 | # configure retriever 252 | if rm is None: 253 | rm = BingSearch(k=runner_argument.retrieve_top_k) 254 | retriever = Retriever(rm=rm, max_thread=runner_argument.max_search_thread) 255 | # return AnswerQuestionModule instance 256 | return AnswerQuestionModule( 257 | retriever=retriever, 258 | max_search_queries=runner_argument.max_search_queries, 259 | question_answering_lm=lm_config.question_answering_lm, 260 | logging_wrapper=logging_wrapper, 261 | ) 262 | -------------------------------------------------------------------------------- /knowledge_storm/collaborative_storm/modules/costorm_expert_utterance_generator.py: -------------------------------------------------------------------------------- 1 | import dspy 2 | from typing import Union 3 | 4 | from .callback import BaseCallbackHandler 5 | from .collaborative_storm_utils import ( 6 | trim_output_after_hint, 7 | extract_and_remove_citations, 8 | keep_first_and_last_paragraph, 9 | ) 10 | 11 | from .grounded_question_answering import AnswerQuestionModule 12 | from .grounded_question_generation import ConvertUtteranceStyle 13 | from ...dataclass import ConversationTurn 14 | from ...logging_wrapper import LoggingWrapper 15 | 16 | 17 | class GenExpertActionPlanning(dspy.Signature): 18 | """ 19 | You are an invited speaker in the round table conversation. Your task is to make a very short note to your assistant to help you prepare for your turn in the conversation. 20 | You will be given the topic we are discussing, your expertise, and the conversation history. 21 | Take a look at conversation history, especially last few turns, then let your assistant prepare the material for you with one of following ways. 22 | 1. Original Question: Initiates a new question to other speakers. 23 | 2. Further Details: Provides additional information. 24 | 3. Information Request: Requests information from other speakers. 25 | 4. Potential Answer: Offers a possible solution or answer. 26 | 27 | Strictly follow this format: [type of contribution]: [one sentence description]. For example, Original Question: [description] 28 | """ 29 | 30 | topic = dspy.InputField(prefix="topic of discussion: ", format=str) 31 | expert = dspy.InputField(prefix="You are inivited as: ", format=str) 32 | summary = dspy.InputField(prefix="Discussion history: \n", format=str) 33 | last_utterance = dspy.InputField( 34 | prefix="Last utterance in the conversation: \n", format=str 35 | ) 36 | resposne = dspy.OutputField( 37 | prefix="Now give your note. Start with one of [Original Question, Further Details, Information Request, Potential Answer] with one sentence description\n", 38 | format=str, 39 | ) 40 | 41 | 42 | class CoStormExpertUtteranceGenerationModule(dspy.Module): 43 | def __init__( 44 | self, 45 | action_planning_lm: Union[dspy.dsp.LM, dspy.dsp.HFModel], 46 | utterance_polishing_lm: Union[dspy.dsp.LM, dspy.dsp.HFModel], 47 | answer_question_module: AnswerQuestionModule, 48 | logging_wrapper: LoggingWrapper, 49 | callback_handler: BaseCallbackHandler = None, 50 | ): 51 | self.action_planning_lm = action_planning_lm 52 | self.utterance_polishing_lm = utterance_polishing_lm 53 | self.expert_action = dspy.Predict(GenExpertActionPlanning) 54 | self.change_style = dspy.Predict(ConvertUtteranceStyle) 55 | self.answer_question_module = answer_question_module 56 | self.logging_wrapper = logging_wrapper 57 | self.callback_handler = callback_handler 58 | 59 | def parse_action(self, action): 60 | action_types = [ 61 | "Original Question", 62 | "Further Details", 63 | "Information Request", 64 | "Potential Answer", 65 | ] 66 | for action_type in action_types: 67 | if f"{action_type}:" in action: 68 | return action_type, trim_output_after_hint(action, f"{action_type}:") 69 | elif f"[{action_type}]:" in action: 70 | return action_type, trim_output_after_hint(action, f"[{action_type}]:") 71 | return "Undefined", "" 72 | 73 | def polish_utterance( 74 | self, conversation_turn: ConversationTurn, last_conv_turn: ConversationTurn 75 | ): 76 | # change utterance style 77 | action_type = conversation_turn.utterance_type 78 | with self.logging_wrapper.log_event( 79 | "RoundTableConversationModule.ConvertUtteranceStyle" 80 | ): 81 | with dspy.settings.context( 82 | lm=self.utterance_polishing_lm, show_guidelines=False 83 | ): 84 | action_string = ( 85 | f"{action_type} about: {conversation_turn.claim_to_make}" 86 | ) 87 | if action_type in ["Original Question", "Information Request"]: 88 | action_string = f"{action_type}" 89 | last_expert_utterance_wo_citation, _ = extract_and_remove_citations( 90 | last_conv_turn.utterance 91 | ) 92 | trimmed_last_expert_utterance = keep_first_and_last_paragraph( 93 | last_expert_utterance_wo_citation 94 | ) 95 | utterance = self.change_style( 96 | expert=conversation_turn.role, 97 | action=action_string, 98 | prev=trimmed_last_expert_utterance, 99 | content=conversation_turn.raw_utterance, 100 | ).utterance 101 | conversation_turn.utterance = utterance 102 | 103 | def forward( 104 | self, 105 | topic: str, 106 | current_expert: str, 107 | conversation_summary: str, 108 | last_conv_turn: ConversationTurn, 109 | ): 110 | last_utterance, _ = extract_and_remove_citations(last_conv_turn.utterance) 111 | if last_conv_turn.utterance_type in [ 112 | "Original Question", 113 | "Information Request", 114 | ]: 115 | action_type = "Potential Answer" 116 | action_content = last_utterance 117 | else: 118 | with self.logging_wrapper.log_event( 119 | "CoStormExpertUtteranceGenerationModule: GenExpertActionPlanning" 120 | ): 121 | with dspy.settings.context( 122 | lm=self.action_planning_lm, show_guidelines=False 123 | ): 124 | action = self.expert_action( 125 | topic=topic, 126 | expert=current_expert, 127 | summary=conversation_summary, 128 | last_utterance=last_utterance, 129 | ).resposne 130 | action_type, action_content = self.parse_action(action) 131 | 132 | if self.callback_handler is not None: 133 | self.callback_handler.on_expert_action_planning_end() 134 | # get response 135 | conversation_turn = ConversationTurn( 136 | role=current_expert, raw_utterance="", utterance_type=action_type 137 | ) 138 | 139 | if action_type == "Undefined": 140 | raise Exception(f"unexpected output: {action}") 141 | elif action_type in ["Further Details", "Potential Answer"]: 142 | with self.logging_wrapper.log_event( 143 | "RoundTableConversationModule: QuestionAnswering" 144 | ): 145 | grounded_answer = self.answer_question_module( 146 | topic=topic, 147 | question=action_content, 148 | mode="brief", 149 | style="conversational and concise", 150 | callback_handler=self.callback_handler, 151 | ) 152 | conversation_turn.claim_to_make = action_content 153 | conversation_turn.raw_utterance = grounded_answer.response 154 | conversation_turn.queries = grounded_answer.queries 155 | conversation_turn.raw_retrieved_info = grounded_answer.raw_retrieved_info 156 | conversation_turn.cited_info = grounded_answer.cited_info 157 | elif action_type in ["Original Question", "Information Request"]: 158 | conversation_turn.raw_utterance = action_content 159 | 160 | return dspy.Prediction(conversation_turn=conversation_turn) 161 | -------------------------------------------------------------------------------- /knowledge_storm/collaborative_storm/modules/expert_generation.py: -------------------------------------------------------------------------------- 1 | import dspy 2 | import re 3 | from typing import Union 4 | 5 | 6 | class GenerateExpertGeneral(dspy.Signature): 7 | """You need to select a group of diverse experts who will be suitable to be invited to a roundtable discussion on the given topic. 8 | Each expert should represent a different perspective, role, or affiliation related to this topic. 9 | You can use the background information provided about the topic for inspiration. For each expert, add a description of their expertise and what they will focus on during the discussion. 10 | No need to include speakers name in the output. 11 | Strictly follow format below: 12 | 1. [speaker 1 role]: [speaker 1 short description] 13 | 2. [speaker 2 role]: [speaker 2 short description] 14 | """ 15 | 16 | topic = dspy.InputField(prefix="Topic of interest:", format=str) 17 | background_info = dspy.InputField( 18 | prefix="Background information about the topic:\n", format=str 19 | ) 20 | topN = dspy.InputField(prefix="Number of speakers needed: ", format=str) 21 | experts = dspy.OutputField(format=str) 22 | 23 | 24 | class GenerateExpertWithFocus(dspy.Signature): 25 | """ 26 | You need to select a group of speakers who will be suitable to have roundtable discussion on the [topic] of specific [focus]. 27 | You may consider inviting speakers having opposite stands on the topic; speakers representing different interest parties; Ensure that the selected speakers are directly connected to the specific context and scenario provided. 28 | For example, if the discussion focus is about a recent event at a specific university, consider inviting students, faculty members, journalists covering the event, university officials, and local community members. 29 | Use the background information provided about the topic for inspiration. For each speaker, add a description of their interests and what they will focus on during the discussion. 30 | No need to include speakers name in the output. 31 | Strictly follow format below: 32 | 1. [speaker 1 role]: [speaker 1 short description] 33 | 2. [speaker 2 role]: [speaker 2 short description] 34 | """ 35 | 36 | topic = dspy.InputField(prefix="Topic of interest:", format=str) 37 | background_info = dspy.InputField(prefix="Background information:\n", format=str) 38 | focus = dspy.InputField(prefix="Discussion focus: ", format=str) 39 | topN = dspy.InputField(prefix="Number of speakers needed: ", format=str) 40 | experts = dspy.OutputField(format=str) 41 | 42 | 43 | class GenerateExpertModule(dspy.Module): 44 | def __init__(self, engine: Union[dspy.dsp.LM, dspy.dsp.HFModel]): 45 | self.engine = engine 46 | self.generate_expert_general = dspy.Predict(GenerateExpertGeneral) 47 | self.generate_expert_w_focus = dspy.ChainOfThought(GenerateExpertWithFocus) 48 | 49 | def trim_background(self, background: str, max_words: int = 100): 50 | words = background.split() 51 | cur_len = len(words) 52 | if cur_len <= max_words: 53 | return background 54 | trimmed_words = words[: min(cur_len, max_words)] 55 | trimmed_background = " ".join(trimmed_words) 56 | return f"{trimmed_background} [rest content omitted]." 57 | 58 | def forward( 59 | self, topic: str, num_experts: int, background_info: str = "", focus: str = "" 60 | ): 61 | with dspy.settings.context(lm=self.engine, show_guidelines=False): 62 | if not focus: 63 | output = self.generate_expert_general( 64 | topic=topic, background_info=background_info, topN=num_experts 65 | ).experts 66 | else: 67 | background_info = self.trim_background( 68 | background=background_info, max_words=100 69 | ) 70 | output = self.generate_expert_w_focus( 71 | topic=topic, 72 | background_info=background_info, 73 | focus=focus, 74 | topN=num_experts, 75 | ).experts 76 | output = output.replace("*", "").replace("[", "").replace("]", "") 77 | expert_list = [] 78 | for s in output.split("\n"): 79 | match = re.search(r"\d+\.\s*(.*)", s) 80 | if match: 81 | expert_list.append(match.group(1)) 82 | expert_list = [expert.strip() for expert in expert_list if expert.strip()] 83 | return dspy.Prediction(experts=expert_list, raw_output=output) 84 | -------------------------------------------------------------------------------- /knowledge_storm/collaborative_storm/modules/grounded_question_answering.py: -------------------------------------------------------------------------------- 1 | import dspy 2 | from typing import Union, List 3 | 4 | from .callback import BaseCallbackHandler 5 | from .collaborative_storm_utils import ( 6 | trim_output_after_hint, 7 | format_search_results, 8 | extract_cited_storm_info, 9 | separate_citations, 10 | ) 11 | from ...logging_wrapper import LoggingWrapper 12 | from ...utils import ArticleTextProcessing 13 | from ...interface import Information 14 | 15 | 16 | class QuestionToQuery(dspy.Signature): 17 | """You want to answer the question or support a claim using Google search. What do you type in the search box? 18 | The question is raised in a round table discussion on a topic. The question may or may not focus on the topic itself. 19 | Write the queries you will use in the following format: 20 | - query 1 21 | - query 2 22 | ... 23 | - query n""" 24 | 25 | topic = dspy.InputField(prefix="Topic context:", format=str) 26 | question = dspy.InputField( 27 | prefix="I want to collect information about: ", format=str 28 | ) 29 | queries = dspy.OutputField(prefix="Queries: \n", format=str) 30 | 31 | 32 | class AnswerQuestion(dspy.Signature): 33 | """You are an expert who can use information effectively. You have gathered the related information and will now use the information to form a response. 34 | Make your response as informative as possible and make sure every sentence is supported by the gathered information. 35 | If [Gathered information] is not directly related to the [Topic] and [Question], provide the most relevant answer you can based on the available information, and explain any limitations or gaps. 36 | Use [1], [2], ..., [n] in line (for example, "The capital of the United States is Washington, D.C.[1][3]."). 37 | You DO NOT need to include a References or Sources section to list the sources at the end. The style of writing should be formal. 38 | """ 39 | 40 | topic = dspy.InputField(prefix="Topic you are discussing about:", format=str) 41 | question = dspy.InputField(prefix="You want to provide insight on: ", format=str) 42 | info = dspy.InputField(prefix="Gathered information:\n", format=str) 43 | style = dspy.InputField(prefix="Style of your response should be:", format=str) 44 | answer = dspy.OutputField( 45 | prefix="Now give your response. (Try to use as many different sources as possible and do not hallucinate.)", 46 | format=str, 47 | ) 48 | 49 | 50 | class AnswerQuestionModule(dspy.Module): 51 | def __init__( 52 | self, 53 | retriever: dspy.Retrieve, 54 | max_search_queries: int, 55 | question_answering_lm: Union[dspy.dsp.LM, dspy.dsp.HFModel], 56 | logging_wrapper: LoggingWrapper, 57 | ): 58 | super().__init__() 59 | self.question_answering_lm = question_answering_lm 60 | self.question_to_query = dspy.Predict(QuestionToQuery) 61 | self.answer_question = dspy.Predict(AnswerQuestion) 62 | self.retriever = retriever 63 | self.max_search_queries = max_search_queries 64 | self.logging_wrapper = logging_wrapper 65 | 66 | def retrieve_information(self, topic, question): 67 | # decompose question to queries 68 | with self.logging_wrapper.log_event( 69 | f"AnswerQuestionModule.question_to_query ({hash(question)})" 70 | ): 71 | with dspy.settings.context(lm=self.question_answering_lm): 72 | queries = self.question_to_query(topic=topic, question=question).queries 73 | queries = trim_output_after_hint(queries, hint="Queries:") 74 | queries = [ 75 | q.replace("-", "").strip().strip('"').strip('"').strip() 76 | for q in queries.split("\n") 77 | ] 78 | queries = queries[: self.max_search_queries] 79 | self.logging_wrapper.add_query_count(count=len(queries)) 80 | with self.logging_wrapper.log_event( 81 | f"AnswerQuestionModule.retriever.retrieve ({hash(question)})" 82 | ): 83 | # retrieve information using retriever 84 | searched_results: List[Information] = self.retriever.retrieve( 85 | list(set(queries)), exclude_urls=[] 86 | ) 87 | # update storm information meta to include the question 88 | for storm_info in searched_results: 89 | storm_info.meta["question"] = question 90 | return queries, searched_results 91 | 92 | def forward( 93 | self, 94 | topic: str, 95 | question: str, 96 | mode: str = "brief", 97 | style: str = "conversational", 98 | callback_handler: BaseCallbackHandler = None, 99 | ): 100 | """ 101 | Processes a topic and question to generate a response with relevant information and citations. 102 | 103 | Args: 104 | topic (str): The topic of interest. 105 | question (str): The specific question related to the topic. 106 | mode (str, optional): Mode of summarization. 'brief' takes only the first snippet of each Information. 107 | 'extensive' adds snippets iteratively until the word limit is reached. Defaults to 'brief'. 108 | 109 | Returns: 110 | dspy.Prediction: An object containing the following: 111 | - question (str): the question to answer 112 | - queries (List[str]): List of query strings used for information retrieval. 113 | - raw_retrieved_info (List[Information]): List of Information instances retrieved. 114 | - cited_info (Dict[int, Information]): Dictionary of cited Information instances, indexed by their citation number. 115 | - response (str): The generated response string with inline citations. 116 | """ 117 | # retrieve information 118 | if callback_handler is not None: 119 | callback_handler.on_expert_information_collection_start() 120 | queries, searched_results = self.retrieve_information( 121 | topic=topic, question=question 122 | ) 123 | if callback_handler is not None: 124 | callback_handler.on_expert_information_collection_end(searched_results) 125 | # format information string for answer generation 126 | info_text, index_to_information_mapping = format_search_results( 127 | searched_results, mode=mode 128 | ) 129 | answer = "Sorry, there is insufficient information to answer the question." 130 | # generate answer to the question 131 | if info_text: 132 | with self.logging_wrapper.log_event( 133 | f"AnswerQuestionModule.answer_question ({hash(question)})" 134 | ): 135 | with dspy.settings.context( 136 | lm=self.question_answering_lm, show_guidelines=False 137 | ): 138 | answer = self.answer_question( 139 | topic=topic, question=question, info=info_text, style=style 140 | ).answer 141 | answer = ArticleTextProcessing.remove_uncompleted_sentences_with_citations( 142 | answer 143 | ) 144 | answer = trim_output_after_hint( 145 | answer, 146 | hint="Now give your response. (Try to use as many different sources as possible and do not hallucinate.)", 147 | ) 148 | # enforce single citation index bracket. [1, 2] -> [1][2] 149 | answer = separate_citations(answer) 150 | if callback_handler is not None: 151 | callback_handler.on_expert_utterance_generation_end() 152 | # construct cited search result 153 | cited_searched_results = extract_cited_storm_info( 154 | response=answer, index_to_storm_info=index_to_information_mapping 155 | ) 156 | 157 | return dspy.Prediction( 158 | question=question, 159 | queries=queries, 160 | raw_retrieved_info=searched_results, 161 | cited_info=cited_searched_results, 162 | response=answer, 163 | ) 164 | -------------------------------------------------------------------------------- /knowledge_storm/collaborative_storm/modules/grounded_question_generation.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module handles question generation within the Co-STORM framework, specifically designed to support the Moderator role. 3 | 4 | The Moderator generates insightful, thought-provoking questions that introduce new directions into the conversation. 5 | By leveraging uncited or unused snippets of information retrieved during the discussion, the Moderator ensures the conversation remains dynamic and avoids repetitive or overly niche topics. 6 | 7 | For more detailed information, refer to Section 3.5 of the Co-STORM paper: https://www.arxiv.org/pdf/2408.15232. 8 | """ 9 | 10 | import dspy 11 | from typing import List, Union 12 | 13 | from .collaborative_storm_utils import ( 14 | format_search_results, 15 | extract_and_remove_citations, 16 | keep_first_and_last_paragraph, 17 | extract_cited_storm_info, 18 | ) 19 | from ...dataclass import ConversationTurn, KnowledgeBase 20 | from ...interface import Information 21 | 22 | 23 | class KnowledgeBaseSummmary(dspy.Signature): 24 | """Your job is to give brief summary of what's been discussed in a roundtable conversation. Contents are themantically organized into hierarchical sections. 25 | You will be presented with these sections where "#" denotes level of section. 26 | """ 27 | 28 | topic = dspy.InputField(prefix="topic: ", format=str) 29 | structure = dspy.InputField(prefix="Tree structure: \n", format=str) 30 | output = dspy.OutputField(prefix="Now give brief summary:\n", format=str) 31 | 32 | 33 | class ConvertUtteranceStyle(dspy.Signature): 34 | """ 35 | You are an invited speaker in the round table conversation. 36 | Your task is to make the question or the response more conversational and engaging to facilicate the flow of conversation. 37 | Note that this is ongoing conversation so no need to have welcoming and concluding words. Previous speaker utterance is provided only for making the conversation more natural. 38 | Note that do not hallucinate and keep the citation index like [1] as it is. Also, 39 | """ 40 | 41 | expert = dspy.InputField(prefix="You are inivited as: ", format=str) 42 | action = dspy.InputField( 43 | prefix="You want to contribute to conversation by: ", format=str 44 | ) 45 | prev = dspy.InputField(prefix="Previous speaker said: ", format=str) 46 | content = dspy.InputField( 47 | prefix="Question or response you want to say: ", format=str 48 | ) 49 | utterance = dspy.OutputField( 50 | prefix="Your utterance (keep the information as much as you can with citations, prefer shorter answers without loss of information): ", 51 | format=str, 52 | ) 53 | 54 | 55 | class GroundedQuestionGeneration(dspy.Signature): 56 | """Your job is to find next discussion focus in a roundtable conversation. You will be given previous conversation summary and some information that might assist you discover new discussion focus. 57 | Note that the new discussion focus should bring new angle and perspective to the discussion and avoid repetition. The new discussion focus should be grounded on the available information and push the boundaries of the current discussion for broader exploration. 58 | The new discussion focus should have natural flow from last utterance in the conversation. 59 | Use [1][2] in line to ground your question. 60 | """ 61 | 62 | topic = dspy.InputField(prefix="topic: ", format=str) 63 | summary = dspy.InputField(prefix="Discussion history: \n", format=str) 64 | information = dspy.InputField(prefix="Available information: \n", format=str) 65 | last_utterance = dspy.InputField( 66 | prefix="Last utterance in the conversation: \n", format=str 67 | ) 68 | output = dspy.OutputField( 69 | prefix="Now give next discussion focus in the format of one sentence question:\n", 70 | format=str, 71 | ) 72 | 73 | 74 | class GroundedQuestionGenerationModule(dspy.Module): 75 | def __init__(self, engine: Union[dspy.dsp.LM, dspy.dsp.HFModel]): 76 | self.engine = engine 77 | self.gen_focus = dspy.Predict(GroundedQuestionGeneration) 78 | self.polish_style = dspy.Predict(ConvertUtteranceStyle) 79 | self.gen_summary = dspy.Predict(KnowledgeBaseSummmary) 80 | 81 | def forward( 82 | self, 83 | topic: str, 84 | knowledge_base: KnowledgeBase, 85 | last_conv_turn: ConversationTurn, 86 | unused_snippets: List[Information], 87 | ): 88 | information, index_to_information_mapping = format_search_results( 89 | unused_snippets, info_max_num_words=1000 90 | ) 91 | summary = knowledge_base.get_knowledge_base_summary() 92 | last_utterance, _ = extract_and_remove_citations(last_conv_turn.utterance) 93 | with dspy.settings.context(lm=self.engine, show_guidelines=False): 94 | raw_utterance = self.gen_focus( 95 | topic=topic, 96 | summary=summary, 97 | information=information, 98 | last_utterance=last_utterance, 99 | ).output 100 | utterance = self.polish_style( 101 | expert="Roundtable conversation moderator", 102 | action="Raising a new question by natural transit from previous utterance.", 103 | prev=keep_first_and_last_paragraph(last_utterance), 104 | content=raw_utterance, 105 | ).utterance 106 | cited_searched_results = extract_cited_storm_info( 107 | response=utterance, index_to_storm_info=index_to_information_mapping 108 | ) 109 | return dspy.Prediction( 110 | raw_utterance=raw_utterance, 111 | utterance=utterance, 112 | cited_info=cited_searched_results, 113 | ) 114 | -------------------------------------------------------------------------------- /knowledge_storm/collaborative_storm/modules/knowledge_base_summary.py: -------------------------------------------------------------------------------- 1 | import dspy 2 | from typing import Union 3 | from ...dataclass import KnowledgeBase 4 | 5 | 6 | class KnowledgeBaseSummmary(dspy.Signature): 7 | """Your job is to give brief summary of what's been discussed in a roundtable conversation. Contents are themantically organized into hierarchical sections. 8 | You will be presented with these sections where "#" denotes level of section. 9 | """ 10 | 11 | topic = dspy.InputField(prefix="topic: ", format=str) 12 | structure = dspy.InputField(prefix="Tree structure: \n", format=str) 13 | output = dspy.OutputField(prefix="Now give brief summary:\n", format=str) 14 | 15 | 16 | class KnowledgeBaseSummaryModule(dspy.Module): 17 | def __init__(self, engine: Union[dspy.dsp.LM, dspy.dsp.HFModel]): 18 | self.engine = engine 19 | self.gen_summary = dspy.Predict(KnowledgeBaseSummmary) 20 | 21 | def forward(self, knowledge_base: KnowledgeBase): 22 | structure = knowledge_base.get_node_hierarchy_string( 23 | include_indent=False, 24 | include_full_path=False, 25 | include_hash_tag=True, 26 | include_node_content_count=False, 27 | ) 28 | with dspy.settings.context(lm=self.engine, show_guidelines=False): 29 | summary = self.gen_summary( 30 | topic=knowledge_base.topic, structure=structure 31 | ).output 32 | return summary 33 | -------------------------------------------------------------------------------- /knowledge_storm/collaborative_storm/modules/simulate_user.py: -------------------------------------------------------------------------------- 1 | import dspy 2 | from typing import List, Union 3 | 4 | from .collaborative_storm_utils import extract_and_remove_citations 5 | from ...dataclass import ConversationTurn 6 | from ...storm_wiki.modules.knowledge_curation import AskQuestionWithPersona 7 | 8 | 9 | class GenSimulatedUserUtterance(dspy.Module): 10 | def __init__(self, engine: Union[dspy.dsp.LM, dspy.dsp.HFModel]): 11 | self.engine = engine 12 | self.ask_qeustion = dspy.Predict(AskQuestionWithPersona) 13 | 14 | def gen_conv_history_string(self, conversation_turns: List[ConversationTurn]): 15 | conv_history = [] 16 | total_turns = len(conversation_turns) 17 | 18 | for i, turn in enumerate(conversation_turns): 19 | utterance, _ = extract_and_remove_citations(turn.utterance) 20 | if i >= total_turns - 4: 21 | conv_history.append(f"{turn.role}: {utterance}") 22 | else: 23 | if turn.claim_to_make: 24 | conv_history.append(f"{turn.role}: {turn.claim_to_make}") 25 | else: 26 | conv_history.append(f"{turn.role}: {utterance}") 27 | 28 | return "\n".join(conv_history) 29 | 30 | def forward(self, topic: str, intent: str, conv_history: List[ConversationTurn]): 31 | conv_history_string = self.gen_conv_history_string(conv_history) 32 | with dspy.settings.context(lm=self.engine, show_guidelines=False): 33 | return self.ask_qeustion( 34 | topic=topic, 35 | persona=f"researcher with interest in {intent}", 36 | conv=conv_history_string, 37 | ).question 38 | -------------------------------------------------------------------------------- /knowledge_storm/encoder.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | 4 | from concurrent.futures import ThreadPoolExecutor, as_completed 5 | from typing import List, Tuple, Union, Optional, Dict, Literal 6 | from pathlib import Path 7 | 8 | try: 9 | import warnings 10 | 11 | with warnings.catch_warnings(): 12 | warnings.filterwarnings("ignore", category=UserWarning) 13 | if "LITELLM_LOCAL_MODEL_COST_MAP" not in os.environ: 14 | os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" 15 | import litellm 16 | 17 | litellm.drop_params = True 18 | litellm.telemetry = False 19 | 20 | from litellm.caching.caching import Cache 21 | 22 | disk_cache_dir = os.path.join(Path.home(), ".storm_local_cache") 23 | litellm.cache = Cache(disk_cache_dir=disk_cache_dir, type="disk") 24 | 25 | except ImportError: 26 | 27 | class LitellmPlaceholder: 28 | def __getattr__(self, _): 29 | raise ImportError( 30 | "The LiteLLM package is not installed. Run `pip install litellm`." 31 | ) 32 | 33 | litellm = LitellmPlaceholder() 34 | 35 | 36 | class Encoder: 37 | """ 38 | A wrapper class for the LiteLLM embedding model, designed to handle embedding 39 | generation tasks efficiently. It supports parallel processing and local caching of 40 | embedding results for improved performance. 41 | 42 | The Encoder utilizes the LiteLLM library to interact with various embedding models, 43 | such as OpenAI and Azure embeddings. Users can specify the desired encoder type and 44 | provide relevant API credentials during initialization. 45 | 46 | Features: 47 | - Support for multiple embedding models (e.g., OpenAI, Azure). 48 | - Parallel processing for faster embedding generation. 49 | - Local disk caching to store and reuse embedding results. 50 | - Total token usage tracking for cost monitoring. 51 | 52 | Note: 53 | Refer to the LiteLLM documentation for details on supported embedding models: 54 | https://docs.litellm.ai/docs/embedding/supported_embedding 55 | """ 56 | 57 | def __init__( 58 | self, 59 | encoder_type: Optional[str] = None, 60 | api_key: Optional[str] = None, 61 | api_base: Optional[str] = None, 62 | api_version: Optional[str] = None, 63 | ): 64 | """ 65 | Initializes the Encoder with the appropriate embedding model. 66 | 67 | Args: 68 | encoder_type (Optional[str]): Type of encoder ('openai', 'azure', etc.). 69 | api_key (Optional[str]): API key for the encoder service. 70 | api_base (Optional[str]): API base URL for the encoder service. 71 | api_version (Optional[str]): API version for the encoder service. 72 | """ 73 | self.embedding_model_name = None 74 | self.kargs = {} 75 | self.total_token_usage = 0 76 | 77 | # Initialize the appropriate embedding model 78 | encoder_type = encoder_type or os.getenv("ENCODER_API_TYPE") 79 | if not encoder_type: 80 | raise ValueError("ENCODER_API_TYPE environment variable is not set.") 81 | 82 | if encoder_type.lower() == "openai": 83 | self.embedding_model_name = "text-embedding-3-small" 84 | self.kargs = {"api_key": api_key or os.getenv("OPENAI_API_KEY")} 85 | elif encoder_type.lower() == "azure": 86 | self.embedding_model_name = "azure/text-embedding-3-small" 87 | self.kargs = { 88 | "api_key": api_key or os.getenv("AZURE_API_KEY"), 89 | "api_base": api_base or os.getenv("AZURE_API_BASE"), 90 | "api_version": api_version or os.getenv("AZURE_API_VERSION"), 91 | } 92 | else: 93 | raise ValueError( 94 | f"Unsupported ENCODER_API_TYPE '{encoder_type}'. Supported types are 'openai', 'azure', 'together'." 95 | ) 96 | 97 | def get_total_token_usage(self, reset: bool = False) -> int: 98 | """ 99 | Retrieves the total token usage. 100 | 101 | Args: 102 | reset (bool): If True, resets the total token usage counter after retrieval. 103 | 104 | Returns: 105 | int: The total number of tokens used. 106 | """ 107 | token_usage = self.total_token_usage 108 | if reset: 109 | self.total_token_usage = 0 110 | return token_usage 111 | 112 | def encode(self, texts: Union[str, List[str]], max_workers: int = 5) -> np.ndarray: 113 | """ 114 | Public method to get embeddings for the given texts. 115 | 116 | Args: 117 | texts (Union[str, List[str]]): A single text string or a list of text strings to embed. 118 | 119 | Returns: 120 | np.ndarray: The array of embeddings. 121 | """ 122 | return self._get_text_embeddings(texts, max_workers=max_workers) 123 | 124 | def _get_single_text_embedding(self, text): 125 | response = litellm.embedding( 126 | model=self.embedding_model_name, input=text, caching=True, **self.kargs 127 | ) 128 | embedding = response.data[0]["embedding"] 129 | token_usage = response.get("usage", {}).get("total_tokens", 0) 130 | return text, embedding, token_usage 131 | 132 | def _get_text_embeddings( 133 | self, 134 | texts: Union[str, List[str]], 135 | max_workers: int = 5, 136 | ) -> Tuple[np.ndarray, int]: 137 | """ 138 | Get text embeddings using OpenAI's text-embedding-3-small model. 139 | 140 | Args: 141 | texts (Union[str, List[str]]): A single text string or a list of text strings to embed. 142 | max_workers (int): The maximum number of workers for parallel processing. 143 | api_key (str): The API key for accessing OpenAI's services. 144 | embedding_cache (Optional[Dict[str, np.ndarray]]): A cache to store previously computed embeddings. 145 | 146 | Returns: 147 | Tuple[np.ndarray, int]: The 2D array of embeddings and the total token usage. 148 | """ 149 | 150 | if isinstance(texts, str): 151 | _, embedding, tokens = self._get_single_text_embedding(texts) 152 | self.total_token_usage += tokens 153 | return np.array(embedding) 154 | 155 | embeddings = [] 156 | total_tokens = 0 157 | 158 | with ThreadPoolExecutor(max_workers=max_workers) as executor: 159 | futures = { 160 | executor.submit(self._get_single_text_embedding, text): text 161 | for text in texts 162 | } 163 | 164 | for future in as_completed(futures): 165 | try: 166 | text, embedding, tokens = future.result() 167 | embeddings.append((text, embedding, tokens)) 168 | total_tokens += tokens 169 | except Exception as e: 170 | print(f"An error occurred for text: {futures[future]}") 171 | print(e) 172 | 173 | # Sort results to match the order of the input texts 174 | embeddings.sort(key=lambda x: texts.index(x[0])) 175 | embeddings = [result[1] for result in embeddings] 176 | self.total_token_usage += total_tokens 177 | 178 | return np.array(embeddings) 179 | -------------------------------------------------------------------------------- /knowledge_storm/logging_wrapper.py: -------------------------------------------------------------------------------- 1 | from contextlib import contextmanager 2 | import time 3 | import pytz 4 | from datetime import datetime 5 | 6 | # Define California timezone 7 | CALIFORNIA_TZ = pytz.timezone("America/Los_Angeles") 8 | 9 | 10 | class EventLog: 11 | def __init__(self, event_name): 12 | self.event_name = event_name 13 | self.start_time = None 14 | self.end_time = None 15 | self.child_events = {} 16 | 17 | def record_start_time(self): 18 | self.start_time = datetime.now( 19 | pytz.utc 20 | ) # Store in UTC for consistent timezone conversion 21 | 22 | def record_end_time(self): 23 | self.end_time = datetime.now( 24 | pytz.utc 25 | ) # Store in UTC for consistent timezone conversion 26 | 27 | def get_total_time(self): 28 | if self.start_time and self.end_time: 29 | return (self.end_time - self.start_time).total_seconds() 30 | return 0 31 | 32 | def get_start_time(self): 33 | if self.start_time: 34 | # Format to milliseconds 35 | return self.start_time.astimezone(CALIFORNIA_TZ).strftime( 36 | "%Y-%m-%d %H:%M:%S.%f" 37 | )[:-3] 38 | return None 39 | 40 | def get_end_time(self): 41 | if self.end_time: 42 | # Format to milliseconds 43 | return self.end_time.astimezone(CALIFORNIA_TZ).strftime( 44 | "%Y-%m-%d %H:%M:%S.%f" 45 | )[:-3] 46 | return None 47 | 48 | def add_child_event(self, child_event): 49 | self.child_events[child_event.event_name] = child_event 50 | 51 | def get_child_events(self): 52 | return self.child_events 53 | 54 | 55 | class LoggingWrapper: 56 | def __init__(self, lm_config): 57 | self.logging_dict = {} 58 | self.lm_config = lm_config 59 | self.current_pipeline_stage = None 60 | self.event_stack = [] 61 | self.pipeline_stage_active = False 62 | 63 | def _pipeline_stage_start(self, pipeline_stage: str): 64 | if self.pipeline_stage_active: 65 | raise RuntimeError( 66 | "A pipeline stage is already active. End the current stage before starting a new one." 67 | ) 68 | 69 | self.current_pipeline_stage = pipeline_stage 70 | self.logging_dict[pipeline_stage] = { 71 | "time_usage": {}, 72 | "lm_usage": {}, 73 | "lm_history": [], 74 | "query_count": 0, 75 | } 76 | self.pipeline_stage_active = True 77 | 78 | def _event_start(self, event_name: str): 79 | if not self.pipeline_stage_active: 80 | raise RuntimeError("No pipeline stage is currently active.") 81 | 82 | if not self.event_stack and self.current_pipeline_stage: 83 | # Top-level event (directly under the pipeline stage) 84 | if ( 85 | event_name 86 | not in self.logging_dict[self.current_pipeline_stage]["time_usage"] 87 | ): 88 | event = EventLog(event_name=event_name) 89 | event.record_start_time() 90 | self.logging_dict[self.current_pipeline_stage]["time_usage"][ 91 | event_name 92 | ] = event 93 | self.event_stack.append(event) 94 | else: 95 | self.logging_dict[self.current_pipeline_stage]["time_usage"][ 96 | event_name 97 | ].record_start_time() 98 | elif self.event_stack: 99 | # Nested event (under another event) 100 | parent_event = self.event_stack[-1] 101 | if event_name not in parent_event.get_child_events(): 102 | event = EventLog(event_name=event_name) 103 | event.record_start_time() 104 | parent_event.add_child_event(event) 105 | self.logging_dict[self.current_pipeline_stage]["time_usage"][ 106 | event_name 107 | ] = event 108 | self.event_stack.append(event) 109 | else: 110 | parent_event.get_child_events()[event_name].record_start_time() 111 | else: 112 | raise RuntimeError( 113 | "Cannot start an event without an active pipeline stage or parent event." 114 | ) 115 | 116 | def _event_end(self, event_name: str): 117 | if not self.pipeline_stage_active: 118 | raise RuntimeError("No pipeline stage is currently active.") 119 | 120 | if not self.event_stack: 121 | raise RuntimeError("No parent event is currently active.") 122 | 123 | if self.event_stack: 124 | current_event_log = self.event_stack[-1] 125 | if event_name in current_event_log.get_child_events(): 126 | current_event_log.get_child_events()[event_name].record_end_time() 127 | elif ( 128 | event_name 129 | in self.logging_dict[self.current_pipeline_stage]["time_usage"] 130 | ): 131 | self.logging_dict[self.current_pipeline_stage]["time_usage"][ 132 | event_name 133 | ].record_end_time() 134 | else: 135 | raise AssertionError( 136 | f"Failure to record end time for event {event_name}. Start time is not recorded." 137 | ) 138 | if current_event_log.event_name == event_name: 139 | self.event_stack.pop() 140 | else: 141 | raise RuntimeError("Cannot end an event without an active parent event.") 142 | 143 | def _pipeline_stage_end(self): 144 | if not self.pipeline_stage_active: 145 | raise RuntimeError("No pipeline stage is currently active to end.") 146 | 147 | self.logging_dict[self.current_pipeline_stage][ 148 | "lm_usage" 149 | ] = self.lm_config.collect_and_reset_lm_usage() 150 | self.logging_dict[self.current_pipeline_stage][ 151 | "lm_history" 152 | ] = self.lm_config.collect_and_reset_lm_history() 153 | self.pipeline_stage_active = False 154 | 155 | def add_query_count(self, count): 156 | if not self.pipeline_stage_active: 157 | raise RuntimeError( 158 | "No pipeline stage is currently active to add query count." 159 | ) 160 | 161 | self.logging_dict[self.current_pipeline_stage]["query_count"] += count 162 | 163 | @contextmanager 164 | def log_event(self, event_name): 165 | if not self.pipeline_stage_active: 166 | raise RuntimeError("No pipeline stage is currently active.") 167 | 168 | self._event_start(event_name) 169 | yield 170 | self._event_end(event_name) 171 | 172 | @contextmanager 173 | def log_pipeline_stage(self, pipeline_stage): 174 | if self.pipeline_stage_active: 175 | print( 176 | "A pipeline stage is already active, ending the current stage safely." 177 | ) 178 | self._pipeline_stage_end() 179 | 180 | start_time = time.time() 181 | try: 182 | self._pipeline_stage_start(pipeline_stage) 183 | yield 184 | except Exception as e: 185 | print(f"Error occurred during pipeline stage '{pipeline_stage}': {e}") 186 | finally: 187 | self.logging_dict[self.current_pipeline_stage]["total_wall_time"] = ( 188 | time.time() - start_time 189 | ) 190 | self._pipeline_stage_end() 191 | 192 | def dump_logging_and_reset(self, reset_logging=True): 193 | log_dump = {} 194 | for pipeline_stage, pipeline_log in self.logging_dict.items(): 195 | time_stamp_log = { 196 | event_name: { 197 | "total_time_seconds": event.get_total_time(), 198 | "start_time": event.get_start_time(), 199 | "end_time": event.get_end_time(), 200 | } 201 | for event_name, event in pipeline_log["time_usage"].items() 202 | } 203 | log_dump[pipeline_stage] = { 204 | "time_usage": time_stamp_log, 205 | "lm_usage": pipeline_log["lm_usage"], 206 | "lm_history": pipeline_log["lm_history"], 207 | "query_count": pipeline_log["query_count"], 208 | "total_wall_time": pipeline_log["total_wall_time"], 209 | } 210 | if reset_logging: 211 | self.logging_dict.clear() 212 | return log_dump 213 | -------------------------------------------------------------------------------- /knowledge_storm/storm_wiki/__init__.py: -------------------------------------------------------------------------------- 1 | from .engine import * 2 | from .modules import * 3 | -------------------------------------------------------------------------------- /knowledge_storm/storm_wiki/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .knowledge_curation import * 2 | from .persona_generator import * 3 | from .retriever import * 4 | from .storm_dataclass import * 5 | -------------------------------------------------------------------------------- /knowledge_storm/storm_wiki/modules/article_generation.py: -------------------------------------------------------------------------------- 1 | import concurrent.futures 2 | import copy 3 | import logging 4 | from concurrent.futures import as_completed 5 | from typing import List, Union 6 | 7 | import dspy 8 | 9 | from .callback import BaseCallbackHandler 10 | from .storm_dataclass import StormInformationTable, StormArticle 11 | from ...interface import ArticleGenerationModule, Information 12 | from ...utils import ArticleTextProcessing 13 | 14 | 15 | class StormArticleGenerationModule(ArticleGenerationModule): 16 | """ 17 | The interface for article generation stage. Given topic, collected information from 18 | knowledge curation stage, generated outline from outline generation stage, 19 | """ 20 | 21 | def __init__( 22 | self, 23 | article_gen_lm=Union[dspy.dsp.LM, dspy.dsp.HFModel], 24 | retrieve_top_k: int = 5, 25 | max_thread_num: int = 10, 26 | ): 27 | super().__init__() 28 | self.retrieve_top_k = retrieve_top_k 29 | self.article_gen_lm = article_gen_lm 30 | self.max_thread_num = max_thread_num 31 | self.section_gen = ConvToSection(engine=self.article_gen_lm) 32 | 33 | def generate_section( 34 | self, topic, section_name, information_table, section_outline, section_query 35 | ): 36 | collected_info: List[Information] = [] 37 | if information_table is not None: 38 | collected_info = information_table.retrieve_information( 39 | queries=section_query, search_top_k=self.retrieve_top_k 40 | ) 41 | output = self.section_gen( 42 | topic=topic, 43 | outline=section_outline, 44 | section=section_name, 45 | collected_info=collected_info, 46 | ) 47 | return { 48 | "section_name": section_name, 49 | "section_content": output.section, 50 | "collected_info": collected_info, 51 | } 52 | 53 | def generate_article( 54 | self, 55 | topic: str, 56 | information_table: StormInformationTable, 57 | article_with_outline: StormArticle, 58 | callback_handler: BaseCallbackHandler = None, 59 | ) -> StormArticle: 60 | """ 61 | Generate article for the topic based on the information table and article outline. 62 | 63 | Args: 64 | topic (str): The topic of the article. 65 | information_table (StormInformationTable): The information table containing the collected information. 66 | article_with_outline (StormArticle): The article with specified outline. 67 | callback_handler (BaseCallbackHandler): An optional callback handler that can be used to trigger 68 | custom callbacks at various stages of the article generation process. Defaults to None. 69 | """ 70 | information_table.prepare_table_for_retrieval() 71 | 72 | if article_with_outline is None: 73 | article_with_outline = StormArticle(topic_name=topic) 74 | 75 | sections_to_write = article_with_outline.get_first_level_section_names() 76 | 77 | section_output_dict_collection = [] 78 | if len(sections_to_write) == 0: 79 | logging.error( 80 | f"No outline for {topic}. Will directly search with the topic." 81 | ) 82 | section_output_dict = self.generate_section( 83 | topic=topic, 84 | section_name=topic, 85 | information_table=information_table, 86 | section_outline="", 87 | section_query=[topic], 88 | ) 89 | section_output_dict_collection = [section_output_dict] 90 | else: 91 | with concurrent.futures.ThreadPoolExecutor( 92 | max_workers=self.max_thread_num 93 | ) as executor: 94 | future_to_sec_title = {} 95 | for section_title in sections_to_write: 96 | # We don't want to write a separate introduction section. 97 | if section_title.lower().strip() == "introduction": 98 | continue 99 | # We don't want to write a separate conclusion section. 100 | if section_title.lower().strip().startswith( 101 | "conclusion" 102 | ) or section_title.lower().strip().startswith("summary"): 103 | continue 104 | section_query = article_with_outline.get_outline_as_list( 105 | root_section_name=section_title, add_hashtags=False 106 | ) 107 | queries_with_hashtags = article_with_outline.get_outline_as_list( 108 | root_section_name=section_title, add_hashtags=True 109 | ) 110 | section_outline = "\n".join(queries_with_hashtags) 111 | future_to_sec_title[ 112 | executor.submit( 113 | self.generate_section, 114 | topic, 115 | section_title, 116 | information_table, 117 | section_outline, 118 | section_query, 119 | ) 120 | ] = section_title 121 | 122 | for future in as_completed(future_to_sec_title): 123 | section_output_dict_collection.append(future.result()) 124 | 125 | article = copy.deepcopy(article_with_outline) 126 | for section_output_dict in section_output_dict_collection: 127 | article.update_section( 128 | parent_section_name=topic, 129 | current_section_content=section_output_dict["section_content"], 130 | current_section_info_list=section_output_dict["collected_info"], 131 | ) 132 | article.post_processing() 133 | return article 134 | 135 | 136 | class ConvToSection(dspy.Module): 137 | """Use the information collected from the information-seeking conversation to write a section.""" 138 | 139 | def __init__(self, engine: Union[dspy.dsp.LM, dspy.dsp.HFModel]): 140 | super().__init__() 141 | self.write_section = dspy.Predict(WriteSection) 142 | self.engine = engine 143 | 144 | def forward( 145 | self, topic: str, outline: str, section: str, collected_info: List[Information] 146 | ): 147 | info = "" 148 | for idx, storm_info in enumerate(collected_info): 149 | info += f"[{idx + 1}]\n" + "\n".join(storm_info.snippets) 150 | info += "\n\n" 151 | 152 | info = ArticleTextProcessing.limit_word_count_preserve_newline(info, 1500) 153 | 154 | with dspy.settings.context(lm=self.engine): 155 | section = ArticleTextProcessing.clean_up_section( 156 | self.write_section(topic=topic, info=info, section=section).output 157 | ) 158 | 159 | return dspy.Prediction(section=section) 160 | 161 | 162 | class WriteSection(dspy.Signature): 163 | """Write a Wikipedia section based on the collected information. 164 | 165 | Here is the format of your writing: 166 | 1. Use "#" Title" to indicate section title, "##" Title" to indicate subsection title, "###" Title" to indicate subsubsection title, and so on. 167 | 2. Use [1], [2], ..., [n] in line (for example, "The capital of the United States is Washington, D.C.[1][3]."). You DO NOT need to include a References or Sources section to list the sources at the end. 168 | """ 169 | 170 | info = dspy.InputField(prefix="The collected information:\n", format=str) 171 | topic = dspy.InputField(prefix="The topic of the page: ", format=str) 172 | section = dspy.InputField(prefix="The section you need to write: ", format=str) 173 | output = dspy.OutputField( 174 | prefix="Write the section with proper inline citations (Start your writing with # section title. Don't include the page title or try to write other sections):\n", 175 | format=str, 176 | ) 177 | -------------------------------------------------------------------------------- /knowledge_storm/storm_wiki/modules/article_polish.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from typing import Union 3 | 4 | import dspy 5 | 6 | from .storm_dataclass import StormArticle 7 | from ...interface import ArticlePolishingModule 8 | from ...utils import ArticleTextProcessing 9 | 10 | 11 | class StormArticlePolishingModule(ArticlePolishingModule): 12 | """ 13 | The interface for article generation stage. Given topic, collected information from 14 | knowledge curation stage, generated outline from outline generation stage. 15 | """ 16 | 17 | def __init__( 18 | self, 19 | article_gen_lm: Union[dspy.dsp.LM, dspy.dsp.HFModel], 20 | article_polish_lm: Union[dspy.dsp.LM, dspy.dsp.HFModel], 21 | ): 22 | self.article_gen_lm = article_gen_lm 23 | self.article_polish_lm = article_polish_lm 24 | 25 | self.polish_page = PolishPageModule( 26 | write_lead_engine=self.article_gen_lm, polish_engine=self.article_polish_lm 27 | ) 28 | 29 | def polish_article( 30 | self, topic: str, draft_article: StormArticle, remove_duplicate: bool = False 31 | ) -> StormArticle: 32 | """ 33 | Polish article. 34 | 35 | Args: 36 | topic (str): The topic of the article. 37 | draft_article (StormArticle): The draft article. 38 | remove_duplicate (bool): Whether to use one additional LM call to remove duplicates from the article. 39 | """ 40 | 41 | article_text = draft_article.to_string() 42 | polish_result = self.polish_page( 43 | topic=topic, draft_page=article_text, polish_whole_page=remove_duplicate 44 | ) 45 | lead_section = f"# summary\n{polish_result.lead_section}" 46 | polished_article = "\n\n".join([lead_section, polish_result.page]) 47 | polished_article_dict = ArticleTextProcessing.parse_article_into_dict( 48 | polished_article 49 | ) 50 | polished_article = copy.deepcopy(draft_article) 51 | polished_article.insert_or_create_section(article_dict=polished_article_dict) 52 | polished_article.post_processing() 53 | return polished_article 54 | 55 | 56 | class WriteLeadSection(dspy.Signature): 57 | """Write a lead section for the given Wikipedia page with the following guidelines: 58 | 1. The lead should stand on its own as a concise overview of the article's topic. It should identify the topic, establish context, explain why the topic is notable, and summarize the most important points, including any prominent controversies. 59 | 2. The lead section should be concise and contain no more than four well-composed paragraphs. 60 | 3. The lead section should be carefully sourced as appropriate. Add inline citations (e.g., "Washington, D.C., is the capital of the United States.[1][3].") where necessary. 61 | """ 62 | 63 | topic = dspy.InputField(prefix="The topic of the page: ", format=str) 64 | draft_page = dspy.InputField(prefix="The draft page:\n", format=str) 65 | lead_section = dspy.OutputField(prefix="Write the lead section:\n", format=str) 66 | 67 | 68 | class PolishPage(dspy.Signature): 69 | """You are a faithful text editor that is good at finding repeated information in the article and deleting them to make sure there is no repetition in the article. You won't delete any non-repeated part in the article. You will keep the inline citations and article structure (indicated by "#", "##", etc.) appropriately. Do your job for the following article.""" 70 | 71 | draft_page = dspy.InputField(prefix="The draft article:\n", format=str) 72 | page = dspy.OutputField(prefix="Your revised article:\n", format=str) 73 | 74 | 75 | class PolishPageModule(dspy.Module): 76 | def __init__( 77 | self, 78 | write_lead_engine: Union[dspy.dsp.LM, dspy.dsp.HFModel], 79 | polish_engine: Union[dspy.dsp.LM, dspy.dsp.HFModel], 80 | ): 81 | super().__init__() 82 | self.write_lead_engine = write_lead_engine 83 | self.polish_engine = polish_engine 84 | self.write_lead = dspy.Predict(WriteLeadSection) 85 | self.polish_page = dspy.Predict(PolishPage) 86 | 87 | def forward(self, topic: str, draft_page: str, polish_whole_page: bool = True): 88 | # NOTE: Change show_guidelines to false to make the generation more robust to different LM families. 89 | with dspy.settings.context(lm=self.write_lead_engine, show_guidelines=False): 90 | lead_section = self.write_lead( 91 | topic=topic, draft_page=draft_page 92 | ).lead_section 93 | if "The lead section:" in lead_section: 94 | lead_section = lead_section.split("The lead section:")[1].strip() 95 | if polish_whole_page: 96 | # NOTE: Change show_guidelines to false to make the generation more robust to different LM families. 97 | with dspy.settings.context(lm=self.polish_engine, show_guidelines=False): 98 | page = self.polish_page(draft_page=draft_page).page 99 | else: 100 | page = draft_page 101 | 102 | return dspy.Prediction(lead_section=lead_section, page=page) 103 | -------------------------------------------------------------------------------- /knowledge_storm/storm_wiki/modules/callback.py: -------------------------------------------------------------------------------- 1 | class BaseCallbackHandler: 2 | """Base callback handler that can be used to handle callbacks from the STORM pipeline.""" 3 | 4 | def on_identify_perspective_start(self, **kwargs): 5 | """Run when the perspective identification starts.""" 6 | pass 7 | 8 | def on_identify_perspective_end(self, perspectives: list[str], **kwargs): 9 | """Run when the perspective identification finishes.""" 10 | pass 11 | 12 | def on_information_gathering_start(self, **kwargs): 13 | """Run when the information gathering starts.""" 14 | pass 15 | 16 | def on_dialogue_turn_end(self, dlg_turn, **kwargs): 17 | """Run when a question asking and answering turn finishes.""" 18 | pass 19 | 20 | def on_information_gathering_end(self, **kwargs): 21 | """Run when the information gathering finishes.""" 22 | pass 23 | 24 | def on_information_organization_start(self, **kwargs): 25 | """Run when the information organization starts.""" 26 | pass 27 | 28 | def on_direct_outline_generation_end(self, outline: str, **kwargs): 29 | """Run when the direct outline generation finishes.""" 30 | pass 31 | 32 | def on_outline_refinement_end(self, outline: str, **kwargs): 33 | """Run when the outline refinement finishes.""" 34 | pass 35 | -------------------------------------------------------------------------------- /knowledge_storm/storm_wiki/modules/outline_generation.py: -------------------------------------------------------------------------------- 1 | from typing import Union, Optional, Tuple 2 | 3 | import dspy 4 | 5 | from .callback import BaseCallbackHandler 6 | from .storm_dataclass import StormInformationTable, StormArticle 7 | from ...interface import OutlineGenerationModule 8 | from ...utils import ArticleTextProcessing 9 | 10 | 11 | class StormOutlineGenerationModule(OutlineGenerationModule): 12 | """ 13 | The interface for outline generation stage. Given topic, collected information from knowledge 14 | curation stage, generate outline for the article. 15 | """ 16 | 17 | def __init__(self, outline_gen_lm: Union[dspy.dsp.LM, dspy.dsp.HFModel]): 18 | super().__init__() 19 | self.outline_gen_lm = outline_gen_lm 20 | self.write_outline = WriteOutline(engine=self.outline_gen_lm) 21 | 22 | def generate_outline( 23 | self, 24 | topic: str, 25 | information_table: StormInformationTable, 26 | old_outline: Optional[StormArticle] = None, 27 | callback_handler: BaseCallbackHandler = None, 28 | return_draft_outline=False, 29 | ) -> Union[StormArticle, Tuple[StormArticle, StormArticle]]: 30 | """ 31 | Generates an outline for an article based on the specified topic and the information 32 | gathered during the knowledge curation stage. This method can optionally return both the 33 | final article outline and a draft outline if required. 34 | 35 | Args: 36 | topic (str): The topic of the article. 37 | information_table (StormInformationTable): The information table containing the collected information. 38 | old_outline (Optional[StormArticle]): An optional previous version of the article outline that can 39 | be used for reference or comparison. Defaults to None. 40 | callback_handler (BaseCallbackHandler): An optional callback handler that can be used to trigger 41 | custom callbacks at various stages of the outline generation process, such as when the information 42 | organization starts. Defaults to None. 43 | return_draft_outline (bool): A flag indicating whether the method should return both the final article 44 | outline and a draft version of the outline. If False, only the final article outline is returned. 45 | Defaults to False. 46 | 47 | Returns: 48 | Union[StormArticle, Tuple[StormArticle, StormArticle]]: Depending on the value of `return_draft_outline`, 49 | this method returns either a single `StormArticle` object containing the final outline or a tuple of 50 | two `StormArticle` objects, the first containing the final outline and the second containing the 51 | draft outline. 52 | """ 53 | if callback_handler is not None: 54 | callback_handler.on_information_organization_start() 55 | 56 | concatenated_dialogue_turns = sum( 57 | [conv for (_, conv) in information_table.conversations], [] 58 | ) 59 | result = self.write_outline( 60 | topic=topic, 61 | dlg_history=concatenated_dialogue_turns, 62 | callback_handler=callback_handler, 63 | ) 64 | article_with_outline_only = StormArticle.from_outline_str( 65 | topic=topic, outline_str=result.outline 66 | ) 67 | article_with_draft_outline_only = StormArticle.from_outline_str( 68 | topic=topic, outline_str=result.old_outline 69 | ) 70 | if not return_draft_outline: 71 | return article_with_outline_only 72 | return article_with_outline_only, article_with_draft_outline_only 73 | 74 | 75 | class WriteOutline(dspy.Module): 76 | """Generate the outline for the Wikipedia page.""" 77 | 78 | def __init__(self, engine: Union[dspy.dsp.LM, dspy.dsp.HFModel]): 79 | super().__init__() 80 | self.draft_page_outline = dspy.Predict(WritePageOutline) 81 | self.write_page_outline = dspy.Predict(WritePageOutlineFromConv) 82 | self.engine = engine 83 | 84 | def forward( 85 | self, 86 | topic: str, 87 | dlg_history, 88 | old_outline: Optional[str] = None, 89 | callback_handler: BaseCallbackHandler = None, 90 | ): 91 | trimmed_dlg_history = [] 92 | for turn in dlg_history: 93 | if ( 94 | "topic you" in turn.agent_utterance.lower() 95 | or "topic you" in turn.user_utterance.lower() 96 | ): 97 | continue 98 | trimmed_dlg_history.append(turn) 99 | conv = "\n".join( 100 | [ 101 | f"Wikipedia Writer: {turn.user_utterance}\nExpert: {turn.agent_utterance}" 102 | for turn in trimmed_dlg_history 103 | ] 104 | ) 105 | conv = ArticleTextProcessing.remove_citations(conv) 106 | conv = ArticleTextProcessing.limit_word_count_preserve_newline(conv, 5000) 107 | 108 | with dspy.settings.context(lm=self.engine): 109 | if old_outline is None: 110 | old_outline = ArticleTextProcessing.clean_up_outline( 111 | self.draft_page_outline(topic=topic).outline 112 | ) 113 | if callback_handler: 114 | callback_handler.on_direct_outline_generation_end( 115 | outline=old_outline 116 | ) 117 | outline = ArticleTextProcessing.clean_up_outline( 118 | self.write_page_outline( 119 | topic=topic, old_outline=old_outline, conv=conv 120 | ).outline 121 | ) 122 | if callback_handler: 123 | callback_handler.on_outline_refinement_end(outline=outline) 124 | 125 | return dspy.Prediction(outline=outline, old_outline=old_outline) 126 | 127 | 128 | class WritePageOutline(dspy.Signature): 129 | """Write an outline for a Wikipedia page. 130 | Here is the format of your writing: 131 | 1. Use "#" Title" to indicate section title, "##" Title" to indicate subsection title, "###" Title" to indicate subsubsection title, and so on. 132 | 2. Do not include other information. 133 | 3. Do not include topic name itself in the outline. 134 | """ 135 | 136 | topic = dspy.InputField(prefix="The topic you want to write: ", format=str) 137 | outline = dspy.OutputField(prefix="Write the Wikipedia page outline:\n", format=str) 138 | 139 | 140 | class NaiveOutlineGen(dspy.Module): 141 | """Generate the outline with LLM's parametric knowledge directly.""" 142 | 143 | def __init__(self): 144 | super().__init__() 145 | self.write_outline = dspy.Predict(WritePageOutline) 146 | 147 | def forward(self, topic: str): 148 | outline = self.write_outline(topic=topic).outline 149 | 150 | return dspy.Prediction(outline=outline) 151 | 152 | 153 | class WritePageOutlineFromConv(dspy.Signature): 154 | """Improve an outline for a Wikipedia page. You already have a draft outline that covers the general information. Now you want to improve it based on the information learned from an information-seeking conversation to make it more informative. 155 | Here is the format of your writing: 156 | 1. Use "#" Title" to indicate section title, "##" Title" to indicate subsection title, "###" Title" to indicate subsubsection title, and so on. 157 | 2. Do not include other information. 158 | 3. Do not include topic name itself in the outline. 159 | """ 160 | 161 | topic = dspy.InputField(prefix="The topic you want to write: ", format=str) 162 | conv = dspy.InputField(prefix="Conversation history:\n", format=str) 163 | old_outline = dspy.OutputField(prefix="Current outline:\n", format=str) 164 | outline = dspy.OutputField( 165 | prefix='Write the Wikipedia page outline (Use "#" Title" to indicate section title, "##" Title" to indicate subsection title, ...):\n', 166 | format=str, 167 | ) 168 | -------------------------------------------------------------------------------- /knowledge_storm/storm_wiki/modules/persona_generator.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import re 3 | from typing import Union, List 4 | 5 | import dspy 6 | import requests 7 | from bs4 import BeautifulSoup 8 | 9 | 10 | def get_wiki_page_title_and_toc(url): 11 | """Get the main title and table of contents from an url of a Wikipedia page.""" 12 | 13 | response = requests.get(url) 14 | soup = BeautifulSoup(response.content, "html.parser") 15 | 16 | # Get the main title from the first h1 tag 17 | main_title = soup.find("h1").text.replace("[edit]", "").strip().replace("\xa0", " ") 18 | 19 | toc = "" 20 | levels = [] 21 | excluded_sections = { 22 | "Contents", 23 | "See also", 24 | "Notes", 25 | "References", 26 | "External links", 27 | } 28 | 29 | # Start processing from h2 to exclude the main title from TOC 30 | for header in soup.find_all(["h2", "h3", "h4", "h5", "h6"]): 31 | level = int( 32 | header.name[1] 33 | ) # Extract the numeric part of the header tag (e.g., '2' from 'h2') 34 | section_title = header.text.replace("[edit]", "").strip().replace("\xa0", " ") 35 | if section_title in excluded_sections: 36 | continue 37 | 38 | while levels and level <= levels[-1]: 39 | levels.pop() 40 | levels.append(level) 41 | 42 | indentation = " " * (len(levels) - 1) 43 | toc += f"{indentation}{section_title}\n" 44 | 45 | return main_title, toc.strip() 46 | 47 | 48 | class FindRelatedTopic(dspy.Signature): 49 | """I'm writing a Wikipedia page for a topic mentioned below. Please identify and recommend some Wikipedia pages on closely related subjects. I'm looking for examples that provide insights into interesting aspects commonly associated with this topic, or examples that help me understand the typical content and structure included in Wikipedia pages for similar topics. 50 | Please list the urls in separate lines.""" 51 | 52 | topic = dspy.InputField(prefix="Topic of interest:", format=str) 53 | related_topics = dspy.OutputField(format=str) 54 | 55 | 56 | class GenPersona(dspy.Signature): 57 | """You need to select a group of Wikipedia editors who will work together to create a comprehensive article on the topic. Each of them represents a different perspective, role, or affiliation related to this topic. You can use other Wikipedia pages of related topics for inspiration. For each editor, add a description of what they will focus on. 58 | Give your answer in the following format: 1. short summary of editor 1: description\n2. short summary of editor 2: description\n... 59 | """ 60 | 61 | topic = dspy.InputField(prefix="Topic of interest:", format=str) 62 | examples = dspy.InputField( 63 | prefix="Wiki page outlines of related topics for inspiration:\n", format=str 64 | ) 65 | personas = dspy.OutputField(format=str) 66 | 67 | 68 | class CreateWriterWithPersona(dspy.Module): 69 | """Discover different perspectives of researching the topic by reading Wikipedia pages of related topics.""" 70 | 71 | def __init__(self, engine: Union[dspy.dsp.LM, dspy.dsp.HFModel]): 72 | super().__init__() 73 | self.find_related_topic = dspy.ChainOfThought(FindRelatedTopic) 74 | self.gen_persona = dspy.ChainOfThought(GenPersona) 75 | self.engine = engine 76 | 77 | def forward(self, topic: str, draft=None): 78 | with dspy.settings.context(lm=self.engine): 79 | # Get section names from wiki pages of relevant topics for inspiration. 80 | related_topics = self.find_related_topic(topic=topic).related_topics 81 | urls = [] 82 | for s in related_topics.split("\n"): 83 | if "http" in s: 84 | urls.append(s[s.find("http") :]) 85 | examples = [] 86 | for url in urls: 87 | try: 88 | title, toc = get_wiki_page_title_and_toc(url) 89 | examples.append(f"Title: {title}\nTable of Contents: {toc}") 90 | except Exception as e: 91 | logging.error(f"Error occurs when processing {url}: {e}") 92 | continue 93 | if len(examples) == 0: 94 | examples.append("N/A") 95 | gen_persona_output = self.gen_persona( 96 | topic=topic, examples="\n----------\n".join(examples) 97 | ).personas 98 | 99 | personas = [] 100 | for s in gen_persona_output.split("\n"): 101 | match = re.search(r"\d+\.\s*(.*)", s) 102 | if match: 103 | personas.append(match.group(1)) 104 | 105 | sorted_personas = personas 106 | 107 | return dspy.Prediction( 108 | personas=personas, 109 | raw_personas_output=sorted_personas, 110 | related_topics=related_topics, 111 | ) 112 | 113 | 114 | class StormPersonaGenerator: 115 | """ 116 | A generator class for creating personas based on a given topic. 117 | 118 | This class uses an underlying engine to generate personas tailored to the specified topic. 119 | The generator integrates with a `CreateWriterWithPersona` instance to create diverse personas, 120 | including a default 'Basic fact writer' persona. 121 | 122 | Attributes: 123 | create_writer_with_persona (CreateWriterWithPersona): An instance responsible for 124 | generating personas based on the provided engine and topic. 125 | 126 | Args: 127 | engine (Union[dspy.dsp.LM, dspy.dsp.HFModel]): The underlying engine used for generating 128 | personas. It must be an instance of either `dspy.dsp.LM` or `dspy.dsp.HFModel`. 129 | """ 130 | 131 | def __init__(self, engine: Union[dspy.dsp.LM, dspy.dsp.HFModel]): 132 | self.create_writer_with_persona = CreateWriterWithPersona(engine=engine) 133 | 134 | def generate_persona(self, topic: str, max_num_persona: int = 3) -> List[str]: 135 | """ 136 | Generates a list of personas based on the provided topic, up to a maximum number specified. 137 | 138 | This method first creates personas using the underlying `create_writer_with_persona` instance 139 | and then prepends a default 'Basic fact writer' persona to the list before returning it. 140 | The number of personas returned is limited to `max_num_persona`, excluding the default persona. 141 | 142 | Args: 143 | topic (str): The topic for which personas are to be generated. 144 | max_num_persona (int): The maximum number of personas to generate, excluding the 145 | default 'Basic fact writer' persona. 146 | 147 | Returns: 148 | List[str]: A list of persona descriptions, including the default 'Basic fact writer' persona 149 | and up to `max_num_persona` additional personas generated based on the topic. 150 | """ 151 | personas = self.create_writer_with_persona(topic=topic) 152 | default_persona = "Basic fact writer: Basic fact writer focusing on broadly covering the basic facts about the topic." 153 | considered_personas = [default_persona] + personas.personas[:max_num_persona] 154 | return considered_personas 155 | -------------------------------------------------------------------------------- /knowledge_storm/storm_wiki/modules/retriever.py: -------------------------------------------------------------------------------- 1 | from typing import Union, List 2 | from urllib.parse import urlparse 3 | 4 | import dspy 5 | 6 | from ...interface import Retriever, Information 7 | from ...utils import ArticleTextProcessing 8 | 9 | # Internet source restrictions according to Wikipedia standard: 10 | # https://en.wikipedia.org/wiki/Wikipedia:Reliable_sources/Perennial_sources 11 | GENERALLY_UNRELIABLE = { 12 | "112_Ukraine", 13 | "Ad_Fontes_Media", 14 | "AlterNet", 15 | "Amazon", 16 | "Anadolu_Agency_(controversial_topics)", 17 | "Ancestry.com", 18 | "Answers.com", 19 | "Antiwar.com", 20 | "Anti-Defamation_League", 21 | "arXiv", 22 | "Atlas_Obscura_places", 23 | "Bild", 24 | "Blaze_Media", 25 | "Blogger", 26 | "BroadwayWorld", 27 | "California_Globe", 28 | "The_Canary", 29 | "CelebrityNetWorth", 30 | "CESNUR", 31 | "ChatGPT", 32 | "CNET_(November_2022\u2013present)", 33 | "CoinDesk", 34 | "Consortium_News", 35 | "CounterPunch", 36 | "Correo_del_Orinoco", 37 | "Cracked.com", 38 | "Daily_Express", 39 | "Daily_Kos", 40 | "Daily_Sabah", 41 | "The_Daily_Wire", 42 | "Discogs", 43 | "Distractify", 44 | "The_Electronic_Intifada", 45 | "Encyclopaedia_Metallum", 46 | "Ethnicity_of_Celebs", 47 | "Facebook", 48 | "FamilySearch", 49 | "Fandom", 50 | "The_Federalist", 51 | "Find_a_Grave", 52 | "Findmypast", 53 | "Flags_of_the_World", 54 | "Flickr", 55 | "Forbes.com_contributors", 56 | "Fox_News_(politics_and_science)", 57 | "Fox_News_(talk_shows)", 58 | "Gawker", 59 | "GB_News", 60 | "Geni.com", 61 | "gnis-class", 62 | "gns-class", 63 | "GlobalSecurity.org", 64 | "Goodreads", 65 | "Guido_Fawkes", 66 | "Heat_Street", 67 | "History", 68 | "HuffPost_contributors", 69 | "IMDb", 70 | "Independent_Media_Center", 71 | "Inquisitr", 72 | "International_Business_Times", 73 | "Investopedia", 74 | "Jewish_Virtual_Library", 75 | "Joshua_Project", 76 | "Know_Your_Meme", 77 | "Land_Transport_Guru", 78 | "LinkedIn", 79 | "LiveJournal", 80 | "Marquis_Who's_Who", 81 | "Mashable_sponsored_content", 82 | "MEAWW", 83 | "Media_Bias/Fact_Check", 84 | "Media_Research_Center", 85 | "Medium", 86 | "metal-experience", 87 | "Metro", 88 | "The_New_American", 89 | "New_York_Post", 90 | "NGO_Monitor", 91 | "The_Onion", 92 | "Our_Campaigns", 93 | "PanAm_Post", 94 | "Patheos", 95 | "An_Phoblacht", 96 | "The_Post_Millennial", 97 | "arXiv", 98 | "bioRxiv", 99 | "medRxiv", 100 | "PeerJ Preprints", 101 | "Preprints.org", 102 | "SSRN", 103 | "PR_Newswire", 104 | "Quadrant", 105 | "Quillette", 106 | "Quora", 107 | "Raw_Story", 108 | "Reddit", 109 | "RedState", 110 | "ResearchGate", 111 | "Rolling_Stone_(politics_and_society,_2011\u2013present)", 112 | "Rolling_Stone_(Culture_Council)", 113 | "Scribd", 114 | "Scriptural_texts", 115 | "Simple_Flying", 116 | "Sixth_Tone_(politics)", 117 | "The_Skwawkbox", 118 | "SourceWatch", 119 | "Spirit_of_Metal", 120 | "Sportskeeda", 121 | "Stack_Exchange", 122 | "Stack_Overflow", 123 | "MathOverflow", 124 | "Ask_Ubuntu", 125 | "starsunfolded.com", 126 | "Statista", 127 | "TASS", 128 | "The_Truth_About_Guns", 129 | "TV.com", 130 | "TV_Tropes", 131 | "Twitter", 132 | "X.com", 133 | "Urban_Dictionary", 134 | "Venezuelanalysis", 135 | "VGChartz", 136 | "VoC", 137 | "Washington_Free_Beacon", 138 | "Weather2Travel", 139 | "The_Western_Journal", 140 | "We_Got_This_Covered", 141 | "WhatCulture", 142 | "Who's_Who_(UK)", 143 | "WhoSampled", 144 | "Wikidata", 145 | "WikiLeaks", 146 | "Wikinews", 147 | "Wikipedia", 148 | "WordPress.com", 149 | "Worldometer", 150 | "YouTube", 151 | "ZDNet", 152 | } 153 | DEPRECATED = { 154 | "Al_Mayadeen", 155 | "ANNA_News", 156 | "Baidu_Baike", 157 | "China_Global_Television_Network", 158 | "The_Cradle", 159 | "Crunchbase", 160 | "The_Daily_Caller", 161 | "Daily_Mail", 162 | "Daily_Star", 163 | "The_Epoch_Times", 164 | "FrontPage_Magazine", 165 | "The_Gateway_Pundit", 166 | "Global_Times", 167 | "The_Grayzone", 168 | "HispanTV", 169 | "Jihad_Watch", 170 | "Last.fm", 171 | "LifeSiteNews", 172 | "The_Mail_on_Sunday", 173 | "MintPress_News", 174 | "National_Enquirer", 175 | "New_Eastern_Outlook", 176 | "News_Break", 177 | "NewsBlaze", 178 | "News_of_the_World", 179 | "Newsmax", 180 | "NNDB", 181 | "Occupy_Democrats", 182 | "Office_of_Cuba_Broadcasting", 183 | "One_America_News_Network", 184 | "Peerage_websites", 185 | "Press_TV", 186 | "Project_Veritas", 187 | "Rate_Your_Music", 188 | "Republic_TV", 189 | "Royal_Central", 190 | "RT", 191 | "Sputnik", 192 | "The_Sun", 193 | "Taki's_Magazine", 194 | "Tasnim_News_Agency", 195 | "Telesur", 196 | "The_Unz_Review", 197 | "VDARE", 198 | "Voltaire_Network", 199 | "WorldNetDaily", 200 | "Zero_Hedge", 201 | } 202 | BLACKLISTED = { 203 | "Advameg", 204 | "bestgore.com", 205 | "Breitbart_News", 206 | "Centre_for_Research_on_Globalization", 207 | "Examiner.com", 208 | "Famous_Birthdays", 209 | "Healthline", 210 | "InfoWars", 211 | "Lenta.ru", 212 | "LiveLeak", 213 | "Lulu.com", 214 | "MyLife", 215 | "Natural_News", 216 | "OpIndia", 217 | "The_Points_Guy", 218 | "The_Points_Guy_(sponsored_content)", 219 | "Swarajya", 220 | "Veterans_Today", 221 | "ZoomInfo", 222 | } 223 | 224 | 225 | def is_valid_wikipedia_source(url): 226 | parsed_url = urlparse(url) 227 | # Check if the URL is from a reliable domain 228 | combined_set = GENERALLY_UNRELIABLE | DEPRECATED | BLACKLISTED 229 | for domain in combined_set: 230 | if domain in parsed_url.netloc: 231 | return False 232 | 233 | return True 234 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | dspy_ai==2.4.9 2 | wikipedia==1.4.0 3 | sentence-transformers 4 | toml 5 | langchain-text-splitters 6 | trafilatura 7 | langchain-huggingface 8 | qdrant-client 9 | langchain-qdrant 10 | numpy==1.26.4 11 | litellm==1.59.3 12 | diskcache -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | from setuptools import setup, find_packages 4 | 5 | # Read the content of the README file 6 | with open("README.md", encoding="utf-8") as f: 7 | long_description = f.read() 8 | # Remove p tags. 9 | pattern = re.compile(r".*?

", re.DOTALL) 10 | long_description = re.sub(pattern, "", long_description) 11 | 12 | # Read the content of the requirements.txt file 13 | with open("requirements.txt", encoding="utf-8") as f: 14 | requirements = f.read().splitlines() 15 | 16 | 17 | setup( 18 | name="knowledge-storm", 19 | version="1.1.0", 20 | author="Yijia Shao, Yucheng Jiang", 21 | author_email="shaoyj@stanford.edu, yuchengj@stanford.edu", 22 | description="STORM: A language model-powered knowledge curation engine.", 23 | long_description=long_description, 24 | long_description_content_type="text/markdown", 25 | url="https://github.com/stanford-oval/storm", 26 | license="MIT License", 27 | packages=find_packages(), 28 | classifiers=[ 29 | "Development Status :: 3 - Alpha", 30 | "License :: OSI Approved :: MIT License", 31 | "Operating System :: OS Independent", 32 | "Programming Language :: Python :: 3", 33 | "Programming Language :: Python :: 3.10", 34 | "Programming Language :: Python :: 3.11", 35 | ], 36 | python_requires=">=3.10", 37 | install_requires=requirements, 38 | ) 39 | --------------------------------------------------------------------------------