├── .github
├── ISSUE_TEMPLATE
│ └── bug_report.md
└── workflows
│ ├── format-check.yml
│ └── python-package.yml
├── .gitignore
├── .pre-commit-config.yaml
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── assets
├── co-storm-workflow.jpg
├── logo.svg
├── overview.svg
├── storm_naacl2024_slides.pdf
└── two_stages.jpg
├── examples
├── costorm_examples
│ └── run_costorm_gpt.py
└── storm_examples
│ ├── README.md
│ ├── helper
│ └── process_kaggle_arxiv_abstract_dataset.py
│ ├── run_storm_wiki_claude.py
│ ├── run_storm_wiki_deepseek.py
│ ├── run_storm_wiki_gemini.py
│ ├── run_storm_wiki_gpt.py
│ ├── run_storm_wiki_gpt_with_VectorRM.py
│ ├── run_storm_wiki_groq.py
│ ├── run_storm_wiki_mistral.py
│ ├── run_storm_wiki_ollama.py
│ ├── run_storm_wiki_ollama_with_searxng.py
│ └── run_storm_wiki_serper.py
├── frontend
└── demo_light
│ ├── .streamlit
│ └── config.toml
│ ├── README.md
│ ├── assets
│ ├── article_display.jpg
│ ├── create_article.jpg
│ └── void.jpg
│ ├── demo_util.py
│ ├── pages_util
│ ├── CreateNewArticle.py
│ └── MyArticles.py
│ ├── requirements.txt
│ ├── stoc.py
│ └── storm.py
├── knowledge_storm
├── __init__.py
├── collaborative_storm
│ ├── __init__.py
│ ├── engine.py
│ └── modules
│ │ ├── __init__.py
│ │ ├── article_generation.py
│ │ ├── callback.py
│ │ ├── co_storm_agents.py
│ │ ├── collaborative_storm_utils.py
│ │ ├── costorm_expert_utterance_generator.py
│ │ ├── expert_generation.py
│ │ ├── grounded_question_answering.py
│ │ ├── grounded_question_generation.py
│ │ ├── information_insertion_module.py
│ │ ├── knowledge_base_summary.py
│ │ ├── simulate_user.py
│ │ └── warmstart_hierarchical_chat.py
├── dataclass.py
├── encoder.py
├── interface.py
├── lm.py
├── logging_wrapper.py
├── rm.py
├── storm_wiki
│ ├── __init__.py
│ ├── engine.py
│ └── modules
│ │ ├── __init__.py
│ │ ├── article_generation.py
│ │ ├── article_polish.py
│ │ ├── callback.py
│ │ ├── knowledge_curation.py
│ │ ├── outline_generation.py
│ │ ├── persona_generator.py
│ │ ├── retriever.py
│ │ └── storm_dataclass.py
└── utils.py
├── requirements.txt
└── setup.py
/.github/ISSUE_TEMPLATE/bug_report.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Bug report
3 | about: Create a report to help us improve
4 | title: "[BUG]"
5 | labels: ''
6 | assignees: ''
7 |
8 | ---
9 |
10 | **Describe the bug**
11 | A clear and concise description of what the bug is.
12 |
13 | **To Reproduce**
14 | Report following things
15 | 1. Input topic name
16 | 2. All output files generated for this topic as a zip file.
17 |
18 | **Screenshots**
19 | If applicable, add screenshots to help explain your problem.
20 |
21 | **Environment:**
22 | - OS: [e.g. iOS, Windows]
23 | - Browser [e.g. chrome, safari] if the bug report is UI problem
24 |
--------------------------------------------------------------------------------
/.github/workflows/format-check.yml:
--------------------------------------------------------------------------------
1 | name: Check Python formatting with Black
2 |
3 | on:
4 | pull_request:
5 | branches:
6 | - main
7 |
8 | jobs:
9 | lint:
10 | runs-on: ubuntu-latest
11 | steps:
12 | - uses: actions/checkout@v2
13 | - uses: actions/setup-python@v2
14 | - uses: psf/black@stable
15 | with:
16 | black_args: "knowledge_storm --check"
17 |
--------------------------------------------------------------------------------
/.github/workflows/python-package.yml:
--------------------------------------------------------------------------------
1 | name: Build and upload Python package
2 |
3 | on:
4 | workflow_dispatch: # Allows manual triggering of the workflow
5 |
6 | jobs:
7 | build:
8 |
9 | runs-on: ubuntu-latest
10 |
11 | steps:
12 | - uses: actions/checkout@master
13 | - name: Set up Python 3.11
14 | uses: actions/setup-python@v3
15 | with:
16 | python-version: "3.11"
17 | - name: Compare versions in setup.py and knowledge_storm/__init__.py
18 | run: |
19 | VERSION_SETUP=$(grep -oP '(?<=version=\").*(?=\")' setup.py)
20 | VERSION_INIT=$(grep -oP '(?<=__version__ = \").*(?=\")' knowledge_storm/__init__.py)
21 | echo "Version in setup.py: $VERSION_SETUP"
22 | echo "Version in __init__.py: $VERSION_INIT"
23 | if [ "$VERSION_SETUP" != "$VERSION_INIT" ]; then
24 | echo "Error: Version mismatch between setup.py ($VERSION_SETUP) and knowledge_storm/__init__.py ($VERSION_INIT)"
25 | exit 1
26 | fi
27 | shell: bash
28 | - name: Install dependencies
29 | run: python3 -m pip install setuptools wheel twine
30 | - name: Install dependencies
31 | run: |
32 | python3 -m pip install --upgrade pip setuptools wheel
33 | if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
34 | - name: Build a binary wheel
35 | run: python3 setup.py sdist bdist_wheel
36 | - name: Publish package to PyPI
37 | uses: pypa/gh-action-pypi-publish@release/v1
38 | with:
39 | user: __token__
40 | password: ${{ secrets.PYPI_API_TOKEN }}
41 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # mac
7 | .DS_Store
8 |
9 | # Other
10 | .vscode
11 | *.tsv
12 | *.pt
13 | gpt*.txt
14 | *.env
15 | local/
16 | local_*
17 | build/
18 | *.egg-info/
19 | .idea
20 | .venv
21 |
22 | # Project-specific
23 | secrets.toml
24 | *.log
25 | */assertion.log
26 | *results/
27 | .venv/
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | - repo: https://github.com/psf/black
3 | rev: 24.8.0
4 | hooks:
5 | - id: black
6 | name: Format Python code with black
7 | entry: black
8 | args: ["knowledge_storm/"]
9 | language: python
10 | pass_filenames: true
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contributing
2 |
3 | Thank you for your interest in contributing to STORM!
4 |
5 | Contributions aren't just about code. Currently (last edit: 7/22/2024), we are accepting the following forms of contribution:
6 | - Pull requests for additional language model support to `knowledge_storm/lm.py`.
7 | - Pull requests for additional retrieval model/search engine support to `knowledge_storm/rm.py`.
8 | - Pull requests for new features to `frontend/demo_light` to assist other developers.
9 | - Identification and reporting of issues or bugs.
10 | - Helping each other by responding to issues.
11 |
12 | Please note that we are not accepting code refactoring PRs at this time to avoid conflicts with our team's efforts.
13 |
14 | ## Development
15 | This section contains technical instructions & hints for contributors.
16 |
17 | ### Setting up
18 | 1. Fork this repository and clone your forked repository.
19 | 2. Install the required packages:
20 | ```
21 | conda create -n storm python=3.11
22 | conda activate storm
23 | pip install -r requirements.txt
24 | ```
25 | 3. If you want to contribute to `frontend/demo_light`, follow its [Setup guide](https://github.com/stanford-oval/storm/tree/main/frontend/demo_light#setup) to install additional packages.
26 |
27 | ### PR suggestions
28 |
29 | Following the suggested format can lead to a faster review process.
30 |
31 | **Title:**
32 |
33 | [New LM/New RM/Demo Enhancement] xxx
34 |
35 | **Description:**
36 | - For new language model support, (1) describe how to use the new LM class, (2) create an example script following the style of existing example scripts under `examples/`, (3) attach an input-output example of the example script.
37 | - For new retrieval model/search engine support, (1) describe how to use the new RM class and (2) attach input-output examples of the RM class.
38 | - For demo light enhancements, (1) describe what's new and (2) attach screenshots to demonstrate the UI change.
39 | - Please clearly describe the required API keys and provide instructions on how to get them (if applicable). This project manages API key with `secrets.toml`.
40 |
41 | **Code Format:**
42 |
43 | We adopt [`black`](https://github.com/psf/black) for arranging and formatting Python code. To streamline the contribution process, we set up a [pre-commit hook](https://pre-commit.com/) to format the code under `knowledge_storm/` before committing. To install the pre-commit hook, run:
44 | ```
45 | pip install pre-commit
46 | pre-commit install
47 | ```
48 | The hook will automatically format the code before each commit.
49 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Stanford Open Virtual Assistant Lab
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
--------------------------------------------------------------------------------
/assets/co-storm-workflow.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stanford-oval/storm/45ee413100f0287da9ca5250290a56ac4fa73c48/assets/co-storm-workflow.jpg
--------------------------------------------------------------------------------
/assets/storm_naacl2024_slides.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stanford-oval/storm/45ee413100f0287da9ca5250290a56ac4fa73c48/assets/storm_naacl2024_slides.pdf
--------------------------------------------------------------------------------
/assets/two_stages.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stanford-oval/storm/45ee413100f0287da9ca5250290a56ac4fa73c48/assets/two_stages.jpg
--------------------------------------------------------------------------------
/examples/storm_examples/README.md:
--------------------------------------------------------------------------------
1 | # Examples
2 |
3 | We host a number of example scripts for various customization of STORM (e.g., use your favorite language models, use your own corpus, etc.). These examples can be starting points for your own customizations and you are welcome to contribute your own examples by submitting a pull request to this directory.
4 |
5 | ## Run STORM with your own language model
6 | [run_storm_wiki_gpt.py](run_storm_wiki_gpt.py) provides an example of running STORM with GPT models, and [run_storm_wiki_claude.py](run_storm_wiki_claude.py) provides an example of running STORM with Claude models. Besides using close-source models, you can also run STORM with models with open weights.
7 |
8 | `run_storm_wiki_mistral.py` provides an example of running STORM with `Mistral-7B-Instruct-v0.2` using [VLLM](https://docs.vllm.ai/en/stable/) server:
9 |
10 | 1. Set up a VLLM server with the `Mistral-7B-Instruct-v0.2` model running.
11 | 2. Run the following command under the root directory of the repository:
12 |
13 | ```
14 | python examples/storm_examples/run_storm_wiki_mistral.py \
15 | --url $URL \
16 | --port $PORT \
17 | --output-dir $OUTPUT_DIR \
18 | --retriever you \
19 | --do-research \
20 | --do-generate-outline \
21 | --do-generate-article \
22 | --do-polish-article
23 | ```
24 | - `--url` URL of the VLLM server.
25 | - `--port` Port of the VLLM server.
26 |
27 | Besides VLLM server, STORM is also compatible with [TGI](https://huggingface.co/docs/text-generation-inference/en/index) server or [Together.ai](https://www.together.ai/products#inference) endpoint.
28 |
29 |
30 | ## Run STORM with your own corpus
31 |
32 | By default, STORM is grounded on the Internet using the search engine, but it can also be grounded on your own corpus using `VectorRM`. [run_storm_wiki_with_gpt_with_VectorRM.py](run_storm_wiki_gpt_with_VectorRM.py) provides an example of running STORM grounding on your provided data.
33 |
34 | 1. Set up API keys.
35 | - Make sure you have set up the OpenAI API key.
36 | - `VectorRM` use [Qdrant](https://github.com/qdrant/qdrant-client) to create a vector store. If you want to set up this vector store online on a [Qdrant cloud server](https://cloud.qdrant.io/login), you need to set up `QDRANT_API_KEY` in `secrets.toml` as well; if you want to save the vector store locally, make sure you provide a location for the vector store.
37 | 2. Prepare your corpus. The documents should be provided as a single CSV file with the following format:
38 |
39 | | content | title | url | description |
40 | |------------------------|------------|------------|------------------------------------|
41 | | I am a document. | Document 1 | docu-n-112 | A self-explanatory document. |
42 | | I am another document. | Document 2 | docu-l-13 | Another self-explanatory document. |
43 | | ... | ... | ... | ... |
44 |
45 | - `url` will be used as a unique identifier of the document in STORM engine, so ensure different documents have different urls.
46 | - The contents for `title` and `description` columns are optional. If not provided, the script will use default empty values.
47 | - The content column is crucial and should be provided for each document.
48 |
49 | 3. Run the command under the root directory of the repository:
50 | To create the vector store offline, run
51 |
52 | ```
53 | python examples/storm_examples/run_storm_wiki_gpt_with_VectorRM.py \
54 | --output-dir $OUTPUT_DIR \
55 | --vector-db-mode offline \
56 | --offline-vector-db-dir $OFFLINE_VECTOR_DB_DIR \
57 | --csv-file-path $CSV_FILE_PATH \
58 | --device $DEVICE_FOR_EMBEDDING(mps, cuda, cpu) \
59 | --do-research \
60 | --do-generate-outline \
61 | --do-generate-article \
62 | --do-polish-article
63 | ```
64 |
65 | To create the vector store online on a Qdrant server, run
66 |
67 | ```
68 | python examples/storm_examples/run_storm_wiki_gpt_with_VectorRM.py \
69 | --output-dir $OUTPUT_DIR \
70 | --vector-db-mode online \
71 | --online-vector-db-url $ONLINE_VECTOR_DB_URL \
72 | --csv-file-path $CSV_FILE_PATH \
73 | --device $DEVICE_FOR_EMBEDDING(mps, cuda, cpu) \
74 | --do-research \
75 | --do-generate-outline \
76 | --do-generate-article \
77 | --do-polish-article
78 | ```
79 |
80 | 4. **Quick test with Kaggle arXiv Paper Abstracts dataset**:
81 |
82 | - Download `arxiv_data_210930-054931.csv` from [here](https://www.kaggle.com/datasets/spsayakpaul/arxiv-paper-abstracts).
83 | - Run the following command under the root directory to downsample the dataset by filtering papers with terms `[cs.CV]` and get a csv file that match the format mentioned above.
84 |
85 | ```
86 | python examples/storm_examples/helper/process_kaggle_arxiv_abstract_dataset.py --input-path $PATH_TO_THE_DOWNLOADED_FILE --output-path $PATH_TO_THE_PROCESSED_CSV
87 | ```
88 | - Run the following command to run STORM grounding on the processed dataset. You can input a topic related to computer vision (e.g., "The progress of multimodal models in computer vision") to see the generated article. (Note that the generated article may not include enough details since the quick test only use the abstracts of arxiv papers.)
89 |
90 | ```
91 | python examples/storm_examples/run_storm_wiki_gpt_with_VectorRM.py \
92 | --output-dir $OUTPUT_DIR \
93 | --vector-db-mode offline \
94 | --offline-vector-db-dir $OFFLINE_VECTOR_DB_DIR \
95 | --csv-file-path $PATH_TO_THE_PROCESSED_CSV \
96 | --device $DEVICE_FOR_EMBEDDING(mps, cuda, cpu) \
97 | --do-research \
98 | --do-generate-outline \
99 | --do-generate-article \
100 | --do-polish-article
101 | ```
102 | - For a quicker run, you can also download the pre-embedded vector store directly from [here](https://drive.google.com/file/d/1bijFkw5BKU7bqcmXMhO-5hg2fdKAL9bf/view?usp=share_link).
103 |
104 | ```
105 | python examples/storm_examples/run_storm_wiki_gpt_with_VectorRM.py \
106 | --output-dir $OUTPUT_DIR \
107 | --vector-db-mode offline \
108 | --offline-vector-db-dir $DOWNLOADED_VECTOR_DB_DR \
109 | --do-research \
110 | --do-generate-outline \
111 | --do-generate-article \
112 | --do-polish-article
113 | ```
--------------------------------------------------------------------------------
/examples/storm_examples/helper/process_kaggle_arxiv_abstract_dataset.py:
--------------------------------------------------------------------------------
1 | """Process `arxiv_data_210930-054931.csv`
2 | from https://www.kaggle.com/datasets/spsayakpaul/arxiv-paper-abstracts
3 | to a csv file that is compatible with VectorRM.
4 | """
5 |
6 | from argparse import ArgumentParser
7 |
8 | import pandas as pd
9 |
10 | if __name__ == "__main__":
11 | parser = ArgumentParser()
12 | parser.add_argument(
13 | "--input-path", type=str, help="Path to arxiv_data_210930-054931.csv."
14 | )
15 | parser.add_argument(
16 | "--output-path",
17 | type=str,
18 | help="Path to store the csv file that is compatible with VectorRM.",
19 | )
20 | args = parser.parse_args()
21 |
22 | df = pd.read_csv(args.input_path)
23 | print(f"The original dataset has {len(df)} samples.")
24 |
25 | # Downsample the dataset.
26 | df = df[df["terms"] == "['cs.CV']"]
27 |
28 | # Reformat the dataset to match the VectorRM input format.
29 | df.rename(columns={"abstracts": "content", "titles": "title"}, inplace=True)
30 | df["url"] = [
31 | "uid_" + str(idx) for idx in range(len(df))
32 | ] # Ensure the url is unique.
33 | df["description"] = ""
34 |
35 | print(f"The downsampled dataset has {len(df)} samples.")
36 | df.to_csv(args.output_path, index=False)
37 |
--------------------------------------------------------------------------------
/examples/storm_examples/run_storm_wiki_claude.py:
--------------------------------------------------------------------------------
1 | """
2 | STORM Wiki pipeline powered by Claude family models and You.com search engine.
3 | You need to set up the following environment variables to run this script:
4 | - ANTHROPIC_API_KEY: Anthropic API key
5 | - YDC_API_KEY: You.com API key; BING_SEARCH_API_KEY: Bing Search API key, SERPER_API_KEY: Serper API key, BRAVE_API_KEY: Brave API key, or TAVILY_API_KEY: Tavily API key
6 |
7 | Output will be structured as below
8 | args.output_dir/
9 | topic_name/ # topic_name will follow convention of underscore-connected topic name w/o space and slash
10 | conversation_log.json # Log of information-seeking conversation
11 | raw_search_results.json # Raw search results from search engine
12 | direct_gen_outline.txt # Outline directly generated with LLM's parametric knowledge
13 | storm_gen_outline.txt # Outline refined with collected information
14 | url_to_info.json # Sources that are used in the final article
15 | storm_gen_article.txt # Final article generated
16 | storm_gen_article_polished.txt # Polished final article (if args.do_polish_article is True)
17 | """
18 |
19 | import os
20 | from argparse import ArgumentParser
21 |
22 | from knowledge_storm import (
23 | STORMWikiRunnerArguments,
24 | STORMWikiRunner,
25 | STORMWikiLMConfigs,
26 | )
27 | from knowledge_storm.lm import ClaudeModel
28 | from knowledge_storm.rm import (
29 | YouRM,
30 | BingSearch,
31 | BraveRM,
32 | SerperRM,
33 | DuckDuckGoSearchRM,
34 | TavilySearchRM,
35 | SearXNG,
36 | )
37 | from knowledge_storm.utils import load_api_key
38 |
39 |
40 | def main(args):
41 | load_api_key(toml_file_path="secrets.toml")
42 | lm_configs = STORMWikiLMConfigs()
43 | claude_kwargs = {
44 | "api_key": os.getenv("ANTHROPIC_API_KEY"),
45 | "temperature": 1.0,
46 | "top_p": 0.9,
47 | }
48 |
49 | # STORM is a LM system so different components can be powered by different models.
50 | # For a good balance between cost and quality, you can choose a cheaper/faster model for conv_simulator_lm
51 | # which is used to split queries, synthesize answers in the conversation. We recommend using stronger models
52 | # for outline_gen_lm which is responsible for organizing the collected information, and article_gen_lm
53 | # which is responsible for generating sections with citations.
54 | conv_simulator_lm = ClaudeModel(
55 | model="claude-3-haiku-20240307", max_tokens=500, **claude_kwargs
56 | )
57 | question_asker_lm = ClaudeModel(
58 | model="claude-3-sonnet-20240229", max_tokens=500, **claude_kwargs
59 | )
60 | outline_gen_lm = ClaudeModel(
61 | model="claude-3-opus-20240229", max_tokens=400, **claude_kwargs
62 | )
63 | article_gen_lm = ClaudeModel(
64 | model="claude-3-opus-20240229", max_tokens=700, **claude_kwargs
65 | )
66 | article_polish_lm = ClaudeModel(
67 | model="claude-3-opus-20240229", max_tokens=4000, **claude_kwargs
68 | )
69 |
70 | lm_configs.set_conv_simulator_lm(conv_simulator_lm)
71 | lm_configs.set_question_asker_lm(question_asker_lm)
72 | lm_configs.set_outline_gen_lm(outline_gen_lm)
73 | lm_configs.set_article_gen_lm(article_gen_lm)
74 | lm_configs.set_article_polish_lm(article_polish_lm)
75 |
76 | engine_args = STORMWikiRunnerArguments(
77 | output_dir=args.output_dir,
78 | max_conv_turn=args.max_conv_turn,
79 | max_perspective=args.max_perspective,
80 | search_top_k=args.search_top_k,
81 | max_thread_num=args.max_thread_num,
82 | )
83 |
84 | # STORM is a knowledge curation system which consumes information from the retrieval module.
85 | # Currently, the information source is the Internet and we use search engine API as the retrieval module.
86 | match args.retriever:
87 | case "bing":
88 | rm = BingSearch(
89 | bing_search_api=os.getenv("BING_SEARCH_API_KEY"),
90 | k=engine_args.search_top_k,
91 | )
92 | case "you":
93 | rm = YouRM(ydc_api_key=os.getenv("YDC_API_KEY"), k=engine_args.search_top_k)
94 | case "brave":
95 | rm = BraveRM(
96 | brave_search_api_key=os.getenv("BRAVE_API_KEY"),
97 | k=engine_args.search_top_k,
98 | )
99 | case "duckduckgo":
100 | rm = DuckDuckGoSearchRM(
101 | k=engine_args.search_top_k, safe_search="On", region="us-en"
102 | )
103 | case "serper":
104 | rm = SerperRM(
105 | serper_search_api_key=os.getenv("SERPER_API_KEY"),
106 | query_params={"autocorrect": True, "num": 10, "page": 1},
107 | )
108 | case "tavily":
109 | rm = TavilySearchRM(
110 | tavily_search_api_key=os.getenv("TAVILY_API_KEY"),
111 | k=engine_args.search_top_k,
112 | include_raw_content=True,
113 | )
114 | case "searxng":
115 | rm = SearXNG(
116 | searxng_api_key=os.getenv("SEARXNG_API_KEY"), k=engine_args.search_top_k
117 | )
118 | case _:
119 | raise ValueError(
120 | f'Invalid retriever: {args.retriever}. Choose either "bing", "you", "brave", "duckduckgo", "serper", "tavily", or "searxng"'
121 | )
122 |
123 | runner = STORMWikiRunner(engine_args, lm_configs, rm)
124 |
125 | topic = input("Topic: ")
126 | runner.run(
127 | topic=topic,
128 | do_research=args.do_research,
129 | do_generate_outline=args.do_generate_outline,
130 | do_generate_article=args.do_generate_article,
131 | do_polish_article=args.do_polish_article,
132 | )
133 | runner.post_run()
134 | runner.summary()
135 |
136 |
137 | if __name__ == "__main__":
138 | parser = ArgumentParser()
139 | # global arguments
140 | parser.add_argument(
141 | "--output-dir",
142 | type=str,
143 | default="./results/claude",
144 | help="Directory to store the outputs.",
145 | )
146 | parser.add_argument(
147 | "--max-thread-num",
148 | type=int,
149 | default=3,
150 | help="Maximum number of threads to use. The information seeking part and the article generation"
151 | "part can speed up by using multiple threads. Consider reducing it if keep getting "
152 | '"Exceed rate limit" error when calling LM API.',
153 | )
154 | parser.add_argument(
155 | "--retriever",
156 | type=str,
157 | choices=["bing", "you", "brave", "serper", "duckduckgo", "tavily", "searxng"],
158 | help="The search engine API to use for retrieving information.",
159 | )
160 | # stage of the pipeline
161 | parser.add_argument(
162 | "--do-research",
163 | action="store_true",
164 | help="If True, simulate conversation to research the topic; otherwise, load the results.",
165 | )
166 | parser.add_argument(
167 | "--do-generate-outline",
168 | action="store_true",
169 | help="If True, generate an outline for the topic; otherwise, load the results.",
170 | )
171 | parser.add_argument(
172 | "--do-generate-article",
173 | action="store_true",
174 | help="If True, generate an article for the topic; otherwise, load the results.",
175 | )
176 | parser.add_argument(
177 | "--do-polish-article",
178 | action="store_true",
179 | help="If True, polish the article by adding a summarization section and (optionally) removing "
180 | "duplicate content.",
181 | )
182 | # hyperparameters for the pre-writing stage
183 | parser.add_argument(
184 | "--max-conv-turn",
185 | type=int,
186 | default=3,
187 | help="Maximum number of questions in conversational question asking.",
188 | )
189 | parser.add_argument(
190 | "--max-perspective",
191 | type=int,
192 | default=3,
193 | help="Maximum number of perspectives to consider in perspective-guided question asking.",
194 | )
195 | parser.add_argument(
196 | "--search-top-k",
197 | type=int,
198 | default=3,
199 | help="Top k search results to consider for each search query.",
200 | )
201 | # hyperparameters for the writing stage
202 | parser.add_argument(
203 | "--retrieve-top-k",
204 | type=int,
205 | default=3,
206 | help="Top k collected references for each section title.",
207 | )
208 | parser.add_argument(
209 | "--remove-duplicate",
210 | action="store_true",
211 | help="If True, remove duplicate content from the article.",
212 | )
213 |
214 | main(parser.parse_args())
215 |
--------------------------------------------------------------------------------
/examples/storm_examples/run_storm_wiki_deepseek.py:
--------------------------------------------------------------------------------
1 | """
2 | STORM Wiki pipeline powered by DeepSeek models and You.com or Bing search engine.
3 | You need to set up the following environment variables to run this script:
4 | - DEEPSEEK_API_KEY: DeepSeek API key
5 | - DEEPSEEK_API_BASE: DeepSeek API base URL (default is https://api.deepseek.com)
6 | - YDC_API_KEY: You.com API key; BING_SEARCH_API_KEY: Bing Search API key, SERPER_API_KEY: Serper API key, BRAVE_API_KEY: Brave API key, or TAVILY_API_KEY: Tavily API key
7 |
8 | Output will be structured as below
9 | args.output_dir/
10 | topic_name/ # topic_name will follow convention of underscore-connected topic name w/o space and slash
11 | conversation_log.json # Log of information-seeking conversation
12 | raw_search_results.json # Raw search results from search engine
13 | direct_gen_outline.txt # Outline directly generated with LLM's parametric knowledge
14 | storm_gen_outline.txt # Outline refined with collected information
15 | url_to_info.json # Sources that are used in the final article
16 | storm_gen_article.txt # Final article generated
17 | storm_gen_article_polished.txt # Polished final article (if args.do_polish_article is True)
18 | """
19 |
20 | import os
21 | import re
22 | import logging
23 | from argparse import ArgumentParser
24 |
25 | from knowledge_storm import (
26 | STORMWikiRunnerArguments,
27 | STORMWikiRunner,
28 | STORMWikiLMConfigs,
29 | )
30 | from knowledge_storm.lm import DeepSeekModel
31 | from knowledge_storm.rm import (
32 | YouRM,
33 | BingSearch,
34 | BraveRM,
35 | SerperRM,
36 | DuckDuckGoSearchRM,
37 | TavilySearchRM,
38 | SearXNG,
39 | )
40 | from knowledge_storm.utils import load_api_key
41 |
42 |
43 | def sanitize_topic(topic):
44 | """
45 | Sanitize the topic name for use in file names.
46 | Remove or replace characters that are not allowed in file names.
47 | """
48 | # Replace spaces with underscores
49 | topic = topic.replace(" ", "_")
50 |
51 | # Remove any character that isn't alphanumeric, underscore, or hyphen
52 | topic = re.sub(r"[^a-zA-Z0-9_-]", "", topic)
53 |
54 | # Ensure the topic isn't empty after sanitization
55 | if not topic:
56 | topic = "unnamed_topic"
57 |
58 | return topic
59 |
60 |
61 | def main(args):
62 | load_api_key(toml_file_path="secrets.toml")
63 | lm_configs = STORMWikiLMConfigs()
64 |
65 | logger = logging.getLogger(__name__)
66 |
67 | # Ensure DEEPSEEK_API_KEY is set
68 | if not os.getenv("DEEPSEEK_API_KEY"):
69 | raise ValueError(
70 | "DEEPSEEK_API_KEY environment variable is not set. Please set it in your secrets.toml file."
71 | )
72 |
73 | deepseek_kwargs = {
74 | "api_key": os.getenv("DEEPSEEK_API_KEY"),
75 | "api_base": os.getenv("DEEPSEEK_API_BASE", "https://api.deepseek.com"),
76 | "temperature": args.temperature,
77 | "top_p": args.top_p,
78 | }
79 |
80 | # DeepSeek offers two main models: 'deepseek-chat' for general tasks and 'deepseek-coder' for coding tasks
81 | # Users can choose the appropriate model based on their needs
82 | conv_simulator_lm = DeepSeekModel(
83 | model=args.model, max_tokens=500, **deepseek_kwargs
84 | )
85 | question_asker_lm = DeepSeekModel(
86 | model=args.model, max_tokens=500, **deepseek_kwargs
87 | )
88 | outline_gen_lm = DeepSeekModel(model=args.model, max_tokens=400, **deepseek_kwargs)
89 | article_gen_lm = DeepSeekModel(model=args.model, max_tokens=700, **deepseek_kwargs)
90 | article_polish_lm = DeepSeekModel(
91 | model=args.model, max_tokens=4000, **deepseek_kwargs
92 | )
93 |
94 | lm_configs.set_conv_simulator_lm(conv_simulator_lm)
95 | lm_configs.set_question_asker_lm(question_asker_lm)
96 | lm_configs.set_outline_gen_lm(outline_gen_lm)
97 | lm_configs.set_article_gen_lm(article_gen_lm)
98 | lm_configs.set_article_polish_lm(article_polish_lm)
99 |
100 | engine_args = STORMWikiRunnerArguments(
101 | output_dir=args.output_dir,
102 | max_conv_turn=args.max_conv_turn,
103 | max_perspective=args.max_perspective,
104 | search_top_k=args.search_top_k,
105 | max_thread_num=args.max_thread_num,
106 | )
107 |
108 | # STORM is a knowledge curation system which consumes information from the retrieval module.
109 | # Currently, the information source is the Internet and we use search engine API as the retrieval module.
110 | match args.retriever:
111 | case "bing":
112 | rm = BingSearch(
113 | bing_search_api=os.getenv("BING_SEARCH_API_KEY"),
114 | k=engine_args.search_top_k,
115 | )
116 | case "you":
117 | rm = YouRM(ydc_api_key=os.getenv("YDC_API_KEY"), k=engine_args.search_top_k)
118 | case "brave":
119 | rm = BraveRM(
120 | brave_search_api_key=os.getenv("BRAVE_API_KEY"),
121 | k=engine_args.search_top_k,
122 | )
123 | case "duckduckgo":
124 | rm = DuckDuckGoSearchRM(
125 | k=engine_args.search_top_k, safe_search="On", region="us-en"
126 | )
127 | case "serper":
128 | rm = SerperRM(
129 | serper_search_api_key=os.getenv("SERPER_API_KEY"),
130 | query_params={"autocorrect": True, "num": 10, "page": 1},
131 | )
132 | case "tavily":
133 | rm = TavilySearchRM(
134 | tavily_search_api_key=os.getenv("TAVILY_API_KEY"),
135 | k=engine_args.search_top_k,
136 | include_raw_content=True,
137 | )
138 | case "searxng":
139 | rm = SearXNG(
140 | searxng_api_key=os.getenv("SEARXNG_API_KEY"), k=engine_args.search_top_k
141 | )
142 | case _:
143 | raise ValueError(
144 | f'Invalid retriever: {args.retriever}. Choose either "bing", "you", "brave", "duckduckgo", "serper", "tavily", or "searxng"'
145 | )
146 |
147 | runner = STORMWikiRunner(engine_args, lm_configs, rm)
148 |
149 | topic = input("Topic: ")
150 | sanitized_topic = sanitize_topic(topic)
151 |
152 | try:
153 | runner.run(
154 | topic=sanitized_topic,
155 | do_research=args.do_research,
156 | do_generate_outline=args.do_generate_outline,
157 | do_generate_article=args.do_generate_article,
158 | do_polish_article=args.do_polish_article,
159 | remove_duplicate=args.remove_duplicate,
160 | )
161 | runner.post_run()
162 | runner.summary()
163 | except Exception as e:
164 | logger.exception(f"An error occurred: {str(e)}")
165 | raise
166 |
167 |
168 | if __name__ == "__main__":
169 | parser = ArgumentParser()
170 | # global arguments
171 | parser.add_argument(
172 | "--output-dir",
173 | type=str,
174 | default="./results/deepseek",
175 | help="Directory to store the outputs.",
176 | )
177 | parser.add_argument(
178 | "--max-thread-num",
179 | type=int,
180 | default=3,
181 | help="Maximum number of threads to use. The information seeking part and the article generation"
182 | "part can speed up by using multiple threads. Consider reducing it if keep getting "
183 | '"Exceed rate limit" error when calling LM API.',
184 | )
185 | parser.add_argument(
186 | "--retriever",
187 | type=str,
188 | choices=["bing", "you", "brave", "serper", "duckduckgo", "tavily", "searxng"],
189 | help="The search engine API to use for retrieving information.",
190 | )
191 | parser.add_argument(
192 | "--model",
193 | type=str,
194 | choices=["deepseek-chat", "deepseek-coder"],
195 | default="deepseek-chat",
196 | help='DeepSeek model to use. "deepseek-chat" for general tasks, "deepseek-coder" for coding tasks.',
197 | )
198 | parser.add_argument(
199 | "--temperature", type=float, default=1.0, help="Sampling temperature to use."
200 | )
201 | parser.add_argument(
202 | "--top_p", type=float, default=0.9, help="Top-p sampling parameter."
203 | )
204 | # stage of the pipeline
205 | parser.add_argument(
206 | "--do-research",
207 | action="store_true",
208 | help="If True, simulate conversation to research the topic; otherwise, load the results.",
209 | )
210 | parser.add_argument(
211 | "--do-generate-outline",
212 | action="store_true",
213 | help="If True, generate an outline for the topic; otherwise, load the results.",
214 | )
215 | parser.add_argument(
216 | "--do-generate-article",
217 | action="store_true",
218 | help="If True, generate an article for the topic; otherwise, load the results.",
219 | )
220 | parser.add_argument(
221 | "--do-polish-article",
222 | action="store_true",
223 | help="If True, polish the article by adding a summarization section and (optionally) removing "
224 | "duplicate content.",
225 | )
226 | # hyperparameters for the pre-writing stage
227 | parser.add_argument(
228 | "--max-conv-turn",
229 | type=int,
230 | default=3,
231 | help="Maximum number of questions in conversational question asking.",
232 | )
233 | parser.add_argument(
234 | "--max-perspective",
235 | type=int,
236 | default=3,
237 | help="Maximum number of perspectives to consider in perspective-guided question asking.",
238 | )
239 | parser.add_argument(
240 | "--search-top-k",
241 | type=int,
242 | default=3,
243 | help="Top k search results to consider for each search query.",
244 | )
245 | # hyperparameters for the writing stage
246 | parser.add_argument(
247 | "--retrieve-top-k",
248 | type=int,
249 | default=3,
250 | help="Top k collected references for each section title.",
251 | )
252 | parser.add_argument(
253 | "--remove-duplicate",
254 | action="store_true",
255 | help="If True, remove duplicate content from the article.",
256 | )
257 |
258 | main(parser.parse_args())
259 |
--------------------------------------------------------------------------------
/examples/storm_examples/run_storm_wiki_gemini.py:
--------------------------------------------------------------------------------
1 | """
2 | STORM Wiki pipeline powered by Google Gemini models and search engine.
3 | You need to set up the following environment variables to run this script:
4 | - GOOGLE_API_KEY: Google API key (Can be obtained from https://ai.google.dev/gemini-api/docs/api-key)
5 | - YDC_API_KEY: You.com API key; BING_SEARCH_API_KEY: Bing Search API key, SERPER_API_KEY: Serper API key, BRAVE_API_KEY: Brave API key, or TAVILY_API_KEY: Tavily API key
6 |
7 | Output will be structured as below
8 | args.output_dir/
9 | topic_name/ # topic_name will follow convention of underscore-connected topic name w/o space and slash
10 | conversation_log.json # Log of information-seeking conversation
11 | raw_search_results.json # Raw search results from search engine
12 | direct_gen_outline.txt # Outline directly generated with LLM's parametric knowledge
13 | storm_gen_outline.txt # Outline refined with collected information
14 | url_to_info.json # Sources that are used in the final article
15 | storm_gen_article.txt # Final article generated
16 | storm_gen_article_polished.txt # Polished final article (if args.do_polish_article is True)
17 | """
18 |
19 | import os
20 | from argparse import ArgumentParser
21 |
22 | from knowledge_storm import (
23 | STORMWikiRunnerArguments,
24 | STORMWikiRunner,
25 | STORMWikiLMConfigs,
26 | )
27 | from knowledge_storm.lm import GoogleModel
28 | from knowledge_storm.rm import (
29 | YouRM,
30 | BingSearch,
31 | BraveRM,
32 | SerperRM,
33 | DuckDuckGoSearchRM,
34 | TavilySearchRM,
35 | SearXNG,
36 | )
37 | from knowledge_storm.utils import load_api_key
38 |
39 |
40 | def main(args):
41 | load_api_key(toml_file_path="secrets.toml")
42 | lm_configs = STORMWikiLMConfigs()
43 | gemini_kwargs = {
44 | "api_key": os.getenv("GOOGLE_API_KEY"),
45 | "temperature": 1.0,
46 | "top_p": 0.9,
47 | }
48 |
49 | # STORM is a LM system so different components can be powered by different models.
50 | # For a good balance between cost and quality, you can choose a cheaper/faster model for conv_simulator_lm
51 | # which is used to split queries, synthesize answers in the conversation. We recommend using stronger models
52 | # for outline_gen_lm which is responsible for organizing the collected information, and article_gen_lm
53 | # which is responsible for generating sections with citations.
54 | # To check out available Google models, see:
55 | # https://ai.google.dev/gemini-api/docs/get-started/tutorial?lang=python#list_models
56 | conv_simulator_lm = GoogleModel(
57 | model="models/gemini-1.5-flash", max_tokens=500, **gemini_kwargs
58 | )
59 | question_asker_lm = GoogleModel(
60 | model="models/gemini-1.5-flash", max_tokens=500, **gemini_kwargs
61 | )
62 | outline_gen_lm = GoogleModel(
63 | model="models/gemini-1.5-pro-exp-0801", max_tokens=400, **gemini_kwargs
64 | )
65 | article_gen_lm = GoogleModel(
66 | model="models/gemini-1.5-pro-exp-0801", max_tokens=700, **gemini_kwargs
67 | )
68 | article_polish_lm = GoogleModel(
69 | model="models/gemini-1.5-pro-exp-0801", max_tokens=4000, **gemini_kwargs
70 | )
71 |
72 | lm_configs.set_conv_simulator_lm(conv_simulator_lm)
73 | lm_configs.set_question_asker_lm(question_asker_lm)
74 | lm_configs.set_outline_gen_lm(outline_gen_lm)
75 | lm_configs.set_article_gen_lm(article_gen_lm)
76 | lm_configs.set_article_polish_lm(article_polish_lm)
77 |
78 | engine_args = STORMWikiRunnerArguments(
79 | output_dir=args.output_dir,
80 | max_conv_turn=args.max_conv_turn,
81 | max_perspective=args.max_perspective,
82 | search_top_k=args.search_top_k,
83 | max_thread_num=args.max_thread_num,
84 | )
85 |
86 | # STORM is a knowledge curation system which consumes information from the retrieval module.
87 | # Currently, the information source is the Internet and we use search engine API as the retrieval module.
88 | match args.retriever:
89 | case "bing":
90 | rm = BingSearch(
91 | bing_search_api=os.getenv("BING_SEARCH_API_KEY"),
92 | k=engine_args.search_top_k,
93 | )
94 | case "you":
95 | rm = YouRM(ydc_api_key=os.getenv("YDC_API_KEY"), k=engine_args.search_top_k)
96 | case "brave":
97 | rm = BraveRM(
98 | brave_search_api_key=os.getenv("BRAVE_API_KEY"),
99 | k=engine_args.search_top_k,
100 | )
101 | case "duckduckgo":
102 | rm = DuckDuckGoSearchRM(
103 | k=engine_args.search_top_k, safe_search="On", region="us-en"
104 | )
105 | case "serper":
106 | rm = SerperRM(
107 | serper_search_api_key=os.getenv("SERPER_API_KEY"),
108 | query_params={"autocorrect": True, "num": 10, "page": 1},
109 | )
110 | case "tavily":
111 | rm = TavilySearchRM(
112 | tavily_search_api_key=os.getenv("TAVILY_API_KEY"),
113 | k=engine_args.search_top_k,
114 | include_raw_content=True,
115 | )
116 | case "searxng":
117 | rm = SearXNG(
118 | searxng_api_key=os.getenv("SEARXNG_API_KEY"), k=engine_args.search_top_k
119 | )
120 | case _:
121 | raise ValueError(
122 | f'Invalid retriever: {args.retriever}. Choose either "bing", "you", "brave", "duckduckgo", "serper", "tavily", or "searxng"'
123 | )
124 |
125 | runner = STORMWikiRunner(engine_args, lm_configs, rm)
126 |
127 | topic = input("Topic: ")
128 | runner.run(
129 | topic=topic,
130 | do_research=args.do_research,
131 | do_generate_outline=args.do_generate_outline,
132 | do_generate_article=args.do_generate_article,
133 | do_polish_article=args.do_polish_article,
134 | )
135 | runner.post_run()
136 | runner.summary()
137 |
138 |
139 | if __name__ == "__main__":
140 | parser = ArgumentParser()
141 | # global arguments
142 | parser.add_argument(
143 | "--output-dir",
144 | type=str,
145 | default="./results/gemini",
146 | help="Directory to store the outputs.",
147 | )
148 | parser.add_argument(
149 | "--max-thread-num",
150 | type=int,
151 | default=3,
152 | help="Maximum number of threads to use. The information seeking part and the article generation"
153 | "part can speed up by using multiple threads. Consider reducing it if keep getting "
154 | '"Exceed rate limit" error when calling LM API.',
155 | )
156 | parser.add_argument(
157 | "--retriever",
158 | type=str,
159 | choices=["bing", "you", "brave", "serper", "duckduckgo", "tavily", "searxng"],
160 | help="The search engine API to use for retrieving information.",
161 | )
162 | # stage of the pipeline
163 | parser.add_argument(
164 | "--do-research",
165 | action="store_true",
166 | help="If True, simulate conversation to research the topic; otherwise, load the results.",
167 | )
168 | parser.add_argument(
169 | "--do-generate-outline",
170 | action="store_true",
171 | help="If True, generate an outline for the topic; otherwise, load the results.",
172 | )
173 | parser.add_argument(
174 | "--do-generate-article",
175 | action="store_true",
176 | help="If True, generate an article for the topic; otherwise, load the results.",
177 | )
178 | parser.add_argument(
179 | "--do-polish-article",
180 | action="store_true",
181 | help="If True, polish the article by adding a summarization section and (optionally) removing "
182 | "duplicate content.",
183 | )
184 | # hyperparameters for the pre-writing stage
185 | parser.add_argument(
186 | "--max-conv-turn",
187 | type=int,
188 | default=3,
189 | help="Maximum number of questions in conversational question asking.",
190 | )
191 | parser.add_argument(
192 | "--max-perspective",
193 | type=int,
194 | default=3,
195 | help="Maximum number of perspectives to consider in perspective-guided question asking.",
196 | )
197 | parser.add_argument(
198 | "--search-top-k",
199 | type=int,
200 | default=3,
201 | help="Top k search results to consider for each search query.",
202 | )
203 | # hyperparameters for the writing stage
204 | parser.add_argument(
205 | "--retrieve-top-k",
206 | type=int,
207 | default=3,
208 | help="Top k collected references for each section title.",
209 | )
210 | parser.add_argument(
211 | "--remove-duplicate",
212 | action="store_true",
213 | help="If True, remove duplicate content from the article.",
214 | )
215 |
216 | main(parser.parse_args())
217 |
--------------------------------------------------------------------------------
/examples/storm_examples/run_storm_wiki_gpt.py:
--------------------------------------------------------------------------------
1 | """
2 | STORM Wiki pipeline powered by GPT-3.5/4 and You.com search engine.
3 | You need to set up the following environment variables to run this script:
4 | - OPENAI_API_KEY: OpenAI API key
5 | - OPENAI_API_TYPE: OpenAI API type (e.g., 'openai' or 'azure')
6 | - AZURE_API_BASE: Azure API base URL if using Azure API
7 | - AZURE_API_VERSION: Azure API version if using Azure API
8 | - YDC_API_KEY: You.com API key; BING_SEARCH_API_KEY: Bing Search API key, SERPER_API_KEY: Serper API key, BRAVE_API_KEY: Brave API key, or TAVILY_API_KEY: Tavily API key
9 |
10 | Output will be structured as below
11 | args.output_dir/
12 | topic_name/ # topic_name will follow convention of underscore-connected topic name w/o space and slash
13 | conversation_log.json # Log of information-seeking conversation
14 | raw_search_results.json # Raw search results from search engine
15 | direct_gen_outline.txt # Outline directly generated with LLM's parametric knowledge
16 | storm_gen_outline.txt # Outline refined with collected information
17 | url_to_info.json # Sources that are used in the final article
18 | storm_gen_article.txt # Final article generated
19 | storm_gen_article_polished.txt # Polished final article (if args.do_polish_article is True)
20 | """
21 |
22 | import os
23 |
24 | from argparse import ArgumentParser
25 | from knowledge_storm import (
26 | STORMWikiRunnerArguments,
27 | STORMWikiRunner,
28 | STORMWikiLMConfigs,
29 | )
30 | from knowledge_storm.lm import OpenAIModel, AzureOpenAIModel
31 | from knowledge_storm.rm import (
32 | YouRM,
33 | BingSearch,
34 | BraveRM,
35 | SerperRM,
36 | DuckDuckGoSearchRM,
37 | TavilySearchRM,
38 | SearXNG,
39 | AzureAISearch,
40 | )
41 | from knowledge_storm.utils import load_api_key
42 |
43 |
44 | def main(args):
45 | load_api_key(toml_file_path="secrets.toml")
46 | lm_configs = STORMWikiLMConfigs()
47 | openai_kwargs = {
48 | "api_key": os.getenv("OPENAI_API_KEY"),
49 | "temperature": 1.0,
50 | "top_p": 0.9,
51 | }
52 |
53 | ModelClass = (
54 | OpenAIModel if os.getenv("OPENAI_API_TYPE") == "openai" else AzureOpenAIModel
55 | )
56 | # If you are using Azure service, make sure the model name matches your own deployed model name.
57 | # The default name here is only used for demonstration and may not match your case.
58 | gpt_35_model_name = (
59 | "gpt-3.5-turbo" if os.getenv("OPENAI_API_TYPE") == "openai" else "gpt-35-turbo"
60 | )
61 | gpt_4_model_name = "gpt-4o"
62 | if os.getenv("OPENAI_API_TYPE") == "azure":
63 | openai_kwargs["api_base"] = os.getenv("AZURE_API_BASE")
64 | openai_kwargs["api_version"] = os.getenv("AZURE_API_VERSION")
65 |
66 | # STORM is a LM system so different components can be powered by different models.
67 | # For a good balance between cost and quality, you can choose a cheaper/faster model for conv_simulator_lm
68 | # which is used to split queries, synthesize answers in the conversation. We recommend using stronger models
69 | # for outline_gen_lm which is responsible for organizing the collected information, and article_gen_lm
70 | # which is responsible for generating sections with citations.
71 | conv_simulator_lm = ModelClass(
72 | model=gpt_35_model_name, max_tokens=500, **openai_kwargs
73 | )
74 | question_asker_lm = ModelClass(
75 | model=gpt_35_model_name, max_tokens=500, **openai_kwargs
76 | )
77 | outline_gen_lm = ModelClass(model=gpt_4_model_name, max_tokens=400, **openai_kwargs)
78 | article_gen_lm = ModelClass(model=gpt_4_model_name, max_tokens=700, **openai_kwargs)
79 | article_polish_lm = ModelClass(
80 | model=gpt_4_model_name, max_tokens=4000, **openai_kwargs
81 | )
82 |
83 | lm_configs.set_conv_simulator_lm(conv_simulator_lm)
84 | lm_configs.set_question_asker_lm(question_asker_lm)
85 | lm_configs.set_outline_gen_lm(outline_gen_lm)
86 | lm_configs.set_article_gen_lm(article_gen_lm)
87 | lm_configs.set_article_polish_lm(article_polish_lm)
88 |
89 | engine_args = STORMWikiRunnerArguments(
90 | output_dir=args.output_dir,
91 | max_conv_turn=args.max_conv_turn,
92 | max_perspective=args.max_perspective,
93 | search_top_k=args.search_top_k,
94 | max_thread_num=args.max_thread_num,
95 | )
96 |
97 | # STORM is a knowledge curation system which consumes information from the retrieval module.
98 | # Currently, the information source is the Internet and we use search engine API as the retrieval module.
99 |
100 | match args.retriever:
101 | case "bing":
102 | rm = BingSearch(
103 | bing_search_api=os.getenv("BING_SEARCH_API_KEY"),
104 | k=engine_args.search_top_k,
105 | )
106 | case "you":
107 | rm = YouRM(ydc_api_key=os.getenv("YDC_API_KEY"), k=engine_args.search_top_k)
108 | case "brave":
109 | rm = BraveRM(
110 | brave_search_api_key=os.getenv("BRAVE_API_KEY"),
111 | k=engine_args.search_top_k,
112 | )
113 | case "duckduckgo":
114 | rm = DuckDuckGoSearchRM(
115 | k=engine_args.search_top_k, safe_search="On", region="us-en"
116 | )
117 | case "serper":
118 | rm = SerperRM(
119 | serper_search_api_key=os.getenv("SERPER_API_KEY"),
120 | query_params={"autocorrect": True, "num": 10, "page": 1},
121 | )
122 | case "tavily":
123 | rm = TavilySearchRM(
124 | tavily_search_api_key=os.getenv("TAVILY_API_KEY"),
125 | k=engine_args.search_top_k,
126 | include_raw_content=True,
127 | )
128 | case "searxng":
129 | rm = SearXNG(
130 | searxng_api_key=os.getenv("SEARXNG_API_KEY"), k=engine_args.search_top_k
131 | )
132 | case "azure_ai_search":
133 | rm = AzureAISearch(
134 | azure_ai_search_api_key=os.getenv("AZURE_AI_SEARCH_API_KEY"),
135 | k=engine_args.search_top_k,
136 | )
137 | case _:
138 | raise ValueError(
139 | f'Invalid retriever: {args.retriever}. Choose either "bing", "you", "brave", "duckduckgo", "serper", "tavily", "searxng", or "azure_ai_search"'
140 | )
141 |
142 | runner = STORMWikiRunner(engine_args, lm_configs, rm)
143 |
144 | topic = input("Topic: ")
145 | runner.run(
146 | topic=topic,
147 | do_research=args.do_research,
148 | do_generate_outline=args.do_generate_outline,
149 | do_generate_article=args.do_generate_article,
150 | do_polish_article=args.do_polish_article,
151 | )
152 | runner.post_run()
153 | runner.summary()
154 |
155 |
156 | if __name__ == "__main__":
157 | parser = ArgumentParser()
158 | # global arguments
159 | parser.add_argument(
160 | "--output-dir",
161 | type=str,
162 | default="./results/gpt",
163 | help="Directory to store the outputs.",
164 | )
165 | parser.add_argument(
166 | "--max-thread-num",
167 | type=int,
168 | default=3,
169 | help="Maximum number of threads to use. The information seeking part and the article generation"
170 | "part can speed up by using multiple threads. Consider reducing it if keep getting "
171 | '"Exceed rate limit" error when calling LM API.',
172 | )
173 | parser.add_argument(
174 | "--retriever",
175 | type=str,
176 | choices=[
177 | "bing",
178 | "you",
179 | "brave",
180 | "serper",
181 | "duckduckgo",
182 | "tavily",
183 | "searxng",
184 | "azure_ai_search",
185 | ],
186 | help="The search engine API to use for retrieving information.",
187 | )
188 | # stage of the pipeline
189 | parser.add_argument(
190 | "--do-research",
191 | action="store_true",
192 | help="If True, simulate conversation to research the topic; otherwise, load the results.",
193 | )
194 | parser.add_argument(
195 | "--do-generate-outline",
196 | action="store_true",
197 | help="If True, generate an outline for the topic; otherwise, load the results.",
198 | )
199 | parser.add_argument(
200 | "--do-generate-article",
201 | action="store_true",
202 | help="If True, generate an article for the topic; otherwise, load the results.",
203 | )
204 | parser.add_argument(
205 | "--do-polish-article",
206 | action="store_true",
207 | help="If True, polish the article by adding a summarization section and (optionally) removing "
208 | "duplicate content.",
209 | )
210 | # hyperparameters for the pre-writing stage
211 | parser.add_argument(
212 | "--max-conv-turn",
213 | type=int,
214 | default=3,
215 | help="Maximum number of questions in conversational question asking.",
216 | )
217 | parser.add_argument(
218 | "--max-perspective",
219 | type=int,
220 | default=3,
221 | help="Maximum number of perspectives to consider in perspective-guided question asking.",
222 | )
223 | parser.add_argument(
224 | "--search-top-k",
225 | type=int,
226 | default=3,
227 | help="Top k search results to consider for each search query.",
228 | )
229 | # hyperparameters for the writing stage
230 | parser.add_argument(
231 | "--retrieve-top-k",
232 | type=int,
233 | default=3,
234 | help="Top k collected references for each section title.",
235 | )
236 | parser.add_argument(
237 | "--remove-duplicate",
238 | action="store_true",
239 | help="If True, remove duplicate content from the article.",
240 | )
241 |
242 | main(parser.parse_args())
243 |
--------------------------------------------------------------------------------
/examples/storm_examples/run_storm_wiki_gpt_with_VectorRM.py:
--------------------------------------------------------------------------------
1 | """
2 | This STORM Wiki pipeline powered by GPT-3.5/4 and local retrieval model that uses Qdrant.
3 | You need to set up the following environment variables to run this script:
4 | - OPENAI_API_KEY: OpenAI API key
5 | - OPENAI_API_TYPE: OpenAI API type (e.g., 'openai' or 'azure')
6 | - QDRANT_API_KEY: Qdrant API key (needed ONLY if online vector store was used)
7 |
8 | You will also need an existing Qdrant vector store either saved in a folder locally offline or in a server online.
9 | If not, then you would need a CSV file with documents, and the script is going to create the vector store for you.
10 | The CSV should be in the following format:
11 | content | title | url | description
12 | I am a document. | Document 1 | docu-n-112 | A self-explanatory document.
13 | I am another document. | Document 2 | docu-l-13 | Another self-explanatory document.
14 |
15 | Notice that the URL will be a unique identifier for the document so ensure different documents have different urls.
16 |
17 | Output will be structured as below
18 | args.output_dir/
19 | topic_name/ # topic_name will follow convention of underscore-connected topic name w/o space and slash
20 | conversation_log.json # Log of information-seeking conversation
21 | raw_search_results.json # Raw search results from search engine
22 | direct_gen_outline.txt # Outline directly generated with LLM's parametric knowledge
23 | storm_gen_outline.txt # Outline refined with collected information
24 | url_to_info.json # Sources that are used in the final article
25 | storm_gen_article.txt # Final article generated
26 | storm_gen_article_polished.txt # Polished final article (if args.do_polish_article is True)
27 | """
28 |
29 | import os
30 | from argparse import ArgumentParser
31 |
32 | from knowledge_storm import (
33 | STORMWikiRunnerArguments,
34 | STORMWikiRunner,
35 | STORMWikiLMConfigs,
36 | )
37 | from knowledge_storm.rm import VectorRM
38 | from knowledge_storm.lm import OpenAIModel, AzureOpenAIModel
39 | from knowledge_storm.utils import load_api_key, QdrantVectorStoreManager
40 |
41 |
42 | def main(args):
43 | # Load API key from the specified toml file path
44 | load_api_key(toml_file_path="secrets.toml")
45 |
46 | # Initialize the language model configurations
47 | engine_lm_configs = STORMWikiLMConfigs()
48 | openai_kwargs = {
49 | "api_key": os.getenv("OPENAI_API_KEY"),
50 | "temperature": 1.0,
51 | "top_p": 0.9,
52 | }
53 |
54 | ModelClass = (
55 | OpenAIModel if os.getenv("OPENAI_API_TYPE") == "openai" else AzureOpenAIModel
56 | )
57 | # If you are using Azure service, make sure the model name matches your own deployed model name.
58 | # The default name here is only used for demonstration and may not match your case.
59 | gpt_35_model_name = (
60 | "gpt-3.5-turbo" if os.getenv("OPENAI_API_TYPE") == "openai" else "gpt-35-turbo"
61 | )
62 | gpt_4_model_name = "gpt-4o"
63 | if os.getenv("OPENAI_API_TYPE") == "azure":
64 | openai_kwargs["api_base"] = os.getenv("AZURE_API_BASE")
65 | openai_kwargs["api_version"] = os.getenv("AZURE_API_VERSION")
66 |
67 | # STORM is a LM system so different components can be powered by different models.
68 | # For a good balance between cost and quality, you can choose a cheaper/faster model for conv_simulator_lm
69 | # which is used to split queries, synthesize answers in the conversation. We recommend using stronger models
70 | # for outline_gen_lm which is responsible for organizing the collected information, and article_gen_lm
71 | # which is responsible for generating sections with citations.
72 | conv_simulator_lm = ModelClass(
73 | model=gpt_35_model_name, max_tokens=500, **openai_kwargs
74 | )
75 | question_asker_lm = ModelClass(
76 | model=gpt_35_model_name, max_tokens=500, **openai_kwargs
77 | )
78 | outline_gen_lm = ModelClass(model=gpt_4_model_name, max_tokens=400, **openai_kwargs)
79 | article_gen_lm = ModelClass(model=gpt_4_model_name, max_tokens=700, **openai_kwargs)
80 | article_polish_lm = ModelClass(
81 | model=gpt_4_model_name, max_tokens=4000, **openai_kwargs
82 | )
83 |
84 | engine_lm_configs.set_conv_simulator_lm(conv_simulator_lm)
85 | engine_lm_configs.set_question_asker_lm(question_asker_lm)
86 | engine_lm_configs.set_outline_gen_lm(outline_gen_lm)
87 | engine_lm_configs.set_article_gen_lm(article_gen_lm)
88 | engine_lm_configs.set_article_polish_lm(article_polish_lm)
89 |
90 | # Initialize the engine arguments
91 | engine_args = STORMWikiRunnerArguments(
92 | output_dir=args.output_dir,
93 | max_conv_turn=args.max_conv_turn,
94 | max_perspective=args.max_perspective,
95 | search_top_k=args.search_top_k,
96 | max_thread_num=args.max_thread_num,
97 | )
98 |
99 | # Create / update the vector store with the documents in the csv file
100 | if args.csv_file_path:
101 | kwargs = {
102 | "file_path": args.csv_file_path,
103 | "content_column": "content",
104 | "title_column": "title",
105 | "url_column": "url",
106 | "desc_column": "description",
107 | "batch_size": args.embed_batch_size,
108 | "vector_db_mode": args.vector_db_mode,
109 | "collection_name": args.collection_name,
110 | "embedding_model": args.embedding_model,
111 | "device": args.device,
112 | }
113 | if args.vector_db_mode == "offline":
114 | QdrantVectorStoreManager.create_or_update_vector_store(
115 | vector_store_path=args.offline_vector_db_dir, **kwargs
116 | )
117 | elif args.vector_db_mode == "online":
118 | QdrantVectorStoreManager.create_or_update_vector_store(
119 | url=args.online_vector_db_url,
120 | api_key=os.getenv("QDRANT_API_KEY"),
121 | **kwargs
122 | )
123 |
124 | # Setup VectorRM to retrieve information from your own data
125 | rm = VectorRM(
126 | collection_name=args.collection_name,
127 | embedding_model=args.embedding_model,
128 | device=args.device,
129 | k=engine_args.search_top_k,
130 | )
131 |
132 | # initialize the vector store, either online (store the db on Qdrant server) or offline (store the db locally):
133 | if args.vector_db_mode == "offline":
134 | rm.init_offline_vector_db(vector_store_path=args.offline_vector_db_dir)
135 | elif args.vector_db_mode == "online":
136 | rm.init_online_vector_db(
137 | url=args.online_vector_db_url, api_key=os.getenv("QDRANT_API_KEY")
138 | )
139 |
140 | # Initialize the STORM Wiki Runner
141 | runner = STORMWikiRunner(engine_args, engine_lm_configs, rm)
142 |
143 | # run the pipeline
144 | topic = input("Topic: ")
145 | runner.run(
146 | topic=topic,
147 | do_research=args.do_research,
148 | do_generate_outline=args.do_generate_outline,
149 | do_generate_article=args.do_generate_article,
150 | do_polish_article=args.do_polish_article,
151 | )
152 | runner.post_run()
153 | runner.summary()
154 |
155 |
156 | if __name__ == "__main__":
157 | parser = ArgumentParser()
158 | # global arguments
159 | parser.add_argument(
160 | "--output-dir",
161 | type=str,
162 | default="./results/gpt_retrieval",
163 | help="Directory to store the outputs.",
164 | )
165 | parser.add_argument(
166 | "--max-thread-num",
167 | type=int,
168 | default=3,
169 | help="Maximum number of threads to use. The information seeking part and the article generation"
170 | "part can speed up by using multiple threads. Consider reducing it if keep getting "
171 | '"Exceed rate limit" error when calling LM API.',
172 | )
173 | # provide local corpus and set up vector db
174 | parser.add_argument(
175 | "--collection-name",
176 | type=str,
177 | default="my_documents",
178 | help="The collection name for vector store.",
179 | )
180 | parser.add_argument(
181 | "--embedding_model",
182 | type=str,
183 | default="BAAI/bge-m3",
184 | help="The collection name for vector store.",
185 | )
186 | parser.add_argument(
187 | "--device",
188 | type=str,
189 | default="mps",
190 | help="The device used to run the retrieval model (mps, cuda, cpu, etc).",
191 | )
192 | parser.add_argument(
193 | "--vector-db-mode",
194 | type=str,
195 | choices=["offline", "online"],
196 | help="The mode of the Qdrant vector store (offline or online).",
197 | )
198 | parser.add_argument(
199 | "--offline-vector-db-dir",
200 | type=str,
201 | default="./vector_store",
202 | help="If use offline mode, please provide the directory to store the vector store.",
203 | )
204 | parser.add_argument(
205 | "--online-vector-db-url",
206 | type=str,
207 | help="If use online mode, please provide the url of the Qdrant server.",
208 | )
209 | parser.add_argument(
210 | "--csv-file-path",
211 | type=str,
212 | default=None,
213 | help="The path of the custom document corpus in CSV format. The CSV file should include "
214 | "content, title, url, and description columns.",
215 | )
216 | parser.add_argument(
217 | "--embed-batch-size",
218 | type=int,
219 | default=64,
220 | help="Batch size for embedding the documents in the csv file.",
221 | )
222 | # stage of the pipeline
223 | parser.add_argument(
224 | "--do-research",
225 | action="store_true",
226 | help="If True, simulate conversation to research the topic; otherwise, load the results.",
227 | )
228 | parser.add_argument(
229 | "--do-generate-outline",
230 | action="store_true",
231 | help="If True, generate an outline for the topic; otherwise, load the results.",
232 | )
233 | parser.add_argument(
234 | "--do-generate-article",
235 | action="store_true",
236 | help="If True, generate an article for the topic; otherwise, load the results.",
237 | )
238 | parser.add_argument(
239 | "--do-polish-article",
240 | action="store_true",
241 | help="If True, polish the article by adding a summarization section and (optionally) removing "
242 | "duplicate content.",
243 | )
244 | # hyperparameters for the pre-writing stage
245 | parser.add_argument(
246 | "--max-conv-turn",
247 | type=int,
248 | default=3,
249 | help="Maximum number of questions in conversational question asking.",
250 | )
251 | parser.add_argument(
252 | "--max-perspective",
253 | type=int,
254 | default=3,
255 | help="Maximum number of perspectives to consider in perspective-guided question asking.",
256 | )
257 | parser.add_argument(
258 | "--search-top-k",
259 | type=int,
260 | default=3,
261 | help="Top k search results to consider for each search query.",
262 | )
263 | # hyperparameters for the writing stage
264 | parser.add_argument(
265 | "--retrieve-top-k",
266 | type=int,
267 | default=3,
268 | help="Top k collected references for each section title.",
269 | )
270 | parser.add_argument(
271 | "--remove-duplicate",
272 | action="store_true",
273 | help="If True, remove duplicate content from the article.",
274 | )
275 | main(parser.parse_args())
276 |
--------------------------------------------------------------------------------
/examples/storm_examples/run_storm_wiki_groq.py:
--------------------------------------------------------------------------------
1 | """
2 | STORM Wiki pipeline powered by llama3-70b-8192 hosted by Groq server and You.com search engine.
3 | You need to set up the following environment variables to run this script:
4 | - GROQ_API_KEY: You can get your Groq API Key at https://console.groq.com/keys
5 | - YDC_API_KEY: You.com API key; BING_SEARCH_API_KEY: Bing Search API key, SERPER_API_KEY: Serper API key, BRAVE_API_KEY: Brave API key, or TAVILY_API_KEY: Tavily API key
6 | You also need to have a VLLM server running with the Mistral-7B-Instruct-v0.2 model. Specify `--url` and `--port` accordingly.
7 |
8 | Output will be structured as below
9 | args.output_dir/
10 | topic_name/ # topic_name will follow convention of underscore-connected topic name w/o space and slash
11 | conversation_log.json # Log of information-seeking conversation
12 | raw_search_results.json # Raw search results from search engine
13 | direct_gen_outline.txt # Outline directly generated with LLM's parametric knowledge
14 | storm_gen_outline.txt # Outline refined with collected information
15 | url_to_info.json # Sources that are used in the final article
16 | storm_gen_article.txt # Final article generated
17 | storm_gen_article_polished.txt # Polished final article (if args.do_polish_article is True)
18 | """
19 |
20 | import os
21 | import re
22 | from argparse import ArgumentParser
23 |
24 | from knowledge_storm import (
25 | STORMWikiRunnerArguments,
26 | STORMWikiRunner,
27 | STORMWikiLMConfigs,
28 | )
29 |
30 | # Now import lm directly
31 | import lm
32 | from lm import GroqModel
33 | from knowledge_storm.rm import (
34 | YouRM,
35 | BingSearch,
36 | BraveRM,
37 | SerperRM,
38 | DuckDuckGoSearchRM,
39 | TavilySearchRM,
40 | SearXNG,
41 | )
42 | from knowledge_storm.utils import load_api_key
43 |
44 |
45 | def sanitize_topic(topic):
46 | """
47 | Sanitize the topic name for use in file names.
48 | Remove or replace characters that are not allowed in file names.
49 | """
50 | # Replace spaces with underscores
51 | topic = topic.replace(" ", "_")
52 |
53 | # Remove any character that isn't alphanumeric, underscore, or hyphen
54 | topic = re.sub(r"[^a-zA-Z0-9_-]", "", topic)
55 |
56 | # Ensure the topic isn't empty after sanitization
57 | if not topic:
58 | topic = "unnamed_topic"
59 |
60 | return topic
61 |
62 |
63 | def main(args):
64 | load_api_key(toml_file_path="secrets.toml")
65 | lm_configs = STORMWikiLMConfigs()
66 |
67 | # Ensure GROQ_API_KEY is set
68 | if not os.getenv("GROQ_API_KEY"):
69 | raise ValueError(
70 | "GROQ_API_KEY environment variable is not set. Please set it in your secrets.toml file."
71 | )
72 |
73 | groq_kwargs = {
74 | "api_key": os.getenv("GROQ_API_KEY"),
75 | "api_base": "https://api.groq.com/openai/v1",
76 | "temperature": args.temperature,
77 | "top_p": args.top_p,
78 | }
79 |
80 | # Groq currently offers the "llama3-70b-8192" model with generous free API credits and the llama3.1 family of models as a preview for paying customers
81 | conv_simulator_lm = GroqModel(
82 | model="llama3-70b-8192", max_tokens=500, **groq_kwargs
83 | )
84 | question_asker_lm = GroqModel(
85 | model="llama3-70b-8192", max_tokens=500, **groq_kwargs
86 | )
87 | outline_gen_lm = GroqModel(model="llama3-70b-8192", max_tokens=400, **groq_kwargs)
88 | article_gen_lm = GroqModel(model="llama3-70b-8192", max_tokens=700, **groq_kwargs)
89 | article_polish_lm = GroqModel(
90 | model="llama3-70b-8192", max_tokens=4000, **groq_kwargs
91 | )
92 |
93 | lm_configs.set_conv_simulator_lm(conv_simulator_lm)
94 | lm_configs.set_question_asker_lm(question_asker_lm)
95 | lm_configs.set_outline_gen_lm(outline_gen_lm)
96 | lm_configs.set_article_gen_lm(article_gen_lm)
97 | lm_configs.set_article_polish_lm(article_polish_lm)
98 |
99 | engine_args = STORMWikiRunnerArguments(
100 | output_dir=args.output_dir,
101 | max_conv_turn=args.max_conv_turn,
102 | max_perspective=args.max_perspective,
103 | search_top_k=args.search_top_k,
104 | max_thread_num=args.max_thread_num,
105 | )
106 |
107 | # STORM is a knowledge curation system which consumes information from the retrieval module.
108 | # Currently, the information source is the Internet and we use search engine API as the retrieval module.
109 | match args.retriever:
110 | case "bing":
111 | rm = BingSearch(
112 | bing_search_api=os.getenv("BING_SEARCH_API_KEY"),
113 | k=engine_args.search_top_k,
114 | )
115 | case "you":
116 | rm = YouRM(ydc_api_key=os.getenv("YDC_API_KEY"), k=engine_args.search_top_k)
117 | case "brave":
118 | rm = BraveRM(
119 | brave_search_api_key=os.getenv("BRAVE_API_KEY"),
120 | k=engine_args.search_top_k,
121 | )
122 | case "duckduckgo":
123 | rm = DuckDuckGoSearchRM(
124 | k=engine_args.search_top_k, safe_search="On", region="us-en"
125 | )
126 | case "serper":
127 | rm = SerperRM(
128 | serper_search_api_key=os.getenv("SERPER_API_KEY"),
129 | query_params={"autocorrect": True, "num": 10, "page": 1},
130 | )
131 | case "tavily":
132 | rm = TavilySearchRM(
133 | tavily_search_api_key=os.getenv("TAVILY_API_KEY"),
134 | k=engine_args.search_top_k,
135 | include_raw_content=True,
136 | )
137 | case "searxng":
138 | rm = SearXNG(
139 | searxng_api_key=os.getenv("SEARXNG_API_KEY"), k=engine_args.search_top_k
140 | )
141 | case _:
142 | raise ValueError(
143 | f'Invalid retriever: {args.retriever}. Choose either "bing", "you", "brave", "duckduckgo", "serper", "tavily", or "searxng"'
144 | )
145 |
146 | runner = STORMWikiRunner(engine_args, lm_configs, rm)
147 |
148 | topic = input("Topic: ")
149 | sanitized_topic = sanitize_topic(topic)
150 |
151 | try:
152 | runner.run(
153 | topic=sanitized_topic,
154 | do_research=args.do_research,
155 | do_generate_outline=args.do_generate_outline,
156 | do_generate_article=args.do_generate_article,
157 | do_polish_article=args.do_polish_article,
158 | remove_duplicate=args.remove_duplicate,
159 | )
160 | runner.post_run()
161 | runner.summary()
162 | except Exception as e:
163 | logger.exception(f"An error occurred: {str(e)}")
164 | raise
165 |
166 |
167 | if __name__ == "__main__":
168 | parser = ArgumentParser()
169 | # global arguments
170 | parser.add_argument(
171 | "--output-dir",
172 | type=str,
173 | default="./results/groq",
174 | help="Directory to store the outputs.",
175 | )
176 | parser.add_argument(
177 | "--max-thread-num",
178 | type=int,
179 | default=3,
180 | help="Maximum number of threads to use. The information seeking part and the article generation"
181 | "part can speed up by using multiple threads. Consider reducing it if keep getting "
182 | '"Exceed rate limit" error when calling LM API.',
183 | )
184 | parser.add_argument(
185 | "--retriever",
186 | type=str,
187 | choices=["bing", "you", "brave", "serper", "duckduckgo", "tavily", "searxng"],
188 | help="The search engine API to use for retrieving information.",
189 | )
190 | parser.add_argument(
191 | "--temperature", type=float, default=1.0, help="Sampling temperature to use."
192 | )
193 | parser.add_argument(
194 | "--top_p", type=float, default=0.9, help="Top-p sampling parameter."
195 | )
196 | # stage of the pipeline
197 | parser.add_argument(
198 | "--do-research",
199 | action="store_true",
200 | help="If True, simulate conversation to research the topic; otherwise, load the results.",
201 | )
202 | parser.add_argument(
203 | "--do-generate-outline",
204 | action="store_true",
205 | help="If True, generate an outline for the topic; otherwise, load the results.",
206 | )
207 | parser.add_argument(
208 | "--do-generate-article",
209 | action="store_true",
210 | help="If True, generate an article for the topic; otherwise, load the results.",
211 | )
212 | parser.add_argument(
213 | "--do-polish-article",
214 | action="store_true",
215 | help="If True, polish the article by adding a summarization section and (optionally) removing "
216 | "duplicate content.",
217 | )
218 | # hyperparameters for the pre-writing stage
219 | parser.add_argument(
220 | "--max-conv-turn",
221 | type=int,
222 | default=3,
223 | help="Maximum number of questions in conversational question asking.",
224 | )
225 | parser.add_argument(
226 | "--max-perspective",
227 | type=int,
228 | default=3,
229 | help="Maximum number of perspectives to consider in perspective-guided question asking.",
230 | )
231 | parser.add_argument(
232 | "--search-top-k",
233 | type=int,
234 | default=3,
235 | help="Top k search results to consider for each search query.",
236 | )
237 | # hyperparameters for the writing stage
238 | parser.add_argument(
239 | "--retrieve-top-k",
240 | type=int,
241 | default=3,
242 | help="Top k collected references for each section title.",
243 | )
244 | parser.add_argument(
245 | "--remove-duplicate",
246 | action="store_true",
247 | help="If True, remove duplicate content from the article.",
248 | )
249 |
250 | main(parser.parse_args())
251 |
--------------------------------------------------------------------------------
/examples/storm_examples/run_storm_wiki_ollama_with_searxng.py:
--------------------------------------------------------------------------------
1 | import os
2 | from argparse import ArgumentParser
3 |
4 | from dspy import Example
5 |
6 | from knowledge_storm import (
7 | STORMWikiRunnerArguments,
8 | STORMWikiRunner,
9 | STORMWikiLMConfigs,
10 | )
11 | from knowledge_storm.lm import OllamaClient
12 | from knowledge_storm.rm import SearXNG
13 | from knowledge_storm.utils import load_api_key
14 |
15 |
16 | def main(args):
17 | load_api_key(toml_file_path="secrets.toml")
18 | lm_configs = STORMWikiLMConfigs()
19 |
20 | ollama_kwargs = {
21 | "model": args.model,
22 | "port": args.port,
23 | "url": args.url,
24 | "stop": ("\n\n---",),
25 | }
26 |
27 | conv_simulator_lm = OllamaClient(max_tokens=500, **ollama_kwargs)
28 | question_asker_lm = OllamaClient(max_tokens=500, **ollama_kwargs)
29 | outline_gen_lm = OllamaClient(max_tokens=400, **ollama_kwargs)
30 | article_gen_lm = OllamaClient(max_tokens=700, **ollama_kwargs)
31 | article_polish_lm = OllamaClient(max_tokens=4000, **ollama_kwargs)
32 |
33 | lm_configs.set_conv_simulator_lm(conv_simulator_lm)
34 | lm_configs.set_question_asker_lm(question_asker_lm)
35 | lm_configs.set_outline_gen_lm(outline_gen_lm)
36 | lm_configs.set_article_gen_lm(article_gen_lm)
37 | lm_configs.set_article_polish_lm(article_polish_lm)
38 |
39 | engine_args = STORMWikiRunnerArguments(
40 | output_dir=args.output_dir,
41 | max_conv_turn=args.max_conv_turn,
42 | max_perspective=args.max_perspective,
43 | search_top_k=args.search_top_k,
44 | max_thread_num=args.max_thread_num,
45 | )
46 |
47 | rm = SearXNG(
48 | searxng_api_url=args.searxng_api_url,
49 | searxng_api_key=os.getenv("SEARXNG_API_KEY"),
50 | k=engine_args.search_top_k,
51 | )
52 |
53 | runner = STORMWikiRunner(engine_args, lm_configs, rm)
54 |
55 | find_related_topic_example = Example(
56 | topic="Knowledge Curation",
57 | related_topics="https://en.wikipedia.org/wiki/Knowledge_management\n"
58 | "https://en.wikipedia.org/wiki/Information_science\n"
59 | "https://en.wikipedia.org/wiki/Library_science\n",
60 | )
61 | gen_persona_example = Example(
62 | topic="Knowledge Curation",
63 | examples="Title: Knowledge management\n"
64 | "Table of Contents: History\nResearch\n Dimensions\n Strategies\n Motivations\nKM technologies"
65 | "\nKnowledge barriers\nKnowledge retention\nKnowledge audit\nKnowledge protection\n"
66 | " Knowledge protection methods\n Formal methods\n Informal methods\n"
67 | " Balancing knowledge protection and knowledge sharing\n Knowledge protection risks",
68 | personas=(
69 | "1. Historian of Knowledge Systems: This editor will focus on the history and evolution of knowledge "
70 | "curation. They will provide context on how knowledge curation has changed over time and its impact on "
71 | "modern practices.\n"
72 | "2. Information Science Professional: With insights from 'Information science', this editor will "
73 | "explore the foundational theories, definitions, and philosophy that underpin knowledge curation\n"
74 | "3. Digital Librarian: This editor will delve into the specifics of how digital libraries operate, "
75 | "including software, metadata, digital preservation.\n"
76 | "4. Technical expert: This editor will focus on the technical aspects of knowledge curation, "
77 | "such as common features of content management systems.\n"
78 | "5. Museum Curator: The museum curator will contribute expertise on the curation of physical items and "
79 | "the transition of these practices into the digital realm."
80 | ),
81 | )
82 | runner.storm_knowledge_curation_module.persona_generator.create_writer_with_persona.find_related_topic.demos = [
83 | find_related_topic_example
84 | ]
85 | runner.storm_knowledge_curation_module.persona_generator.create_writer_with_persona.gen_persona.demos = [
86 | gen_persona_example
87 | ]
88 |
89 | write_page_outline_example = Example(
90 | topic="Example Topic",
91 | conv="Wikipedia Writer: ...\nExpert: ...\nWikipedia Writer: ...\nExpert: ...",
92 | old_outline="# Section 1\n## Subsection 1\n## Subsection 2\n"
93 | "# Section 2\n## Subsection 1\n## Subsection 2\n"
94 | "# Section 3",
95 | outline="# New Section 1\n## New Subsection 1\n## New Subsection 2\n"
96 | "# New Section 2\n"
97 | "# New Section 3\n## New Subsection 1\n## New Subsection 2\n## New Subsection 3",
98 | )
99 | runner.storm_outline_generation_module.write_outline.write_page_outline.demos = [
100 | write_page_outline_example
101 | ]
102 | write_section_example = Example(
103 | info="[1]\nInformation in document 1\n[2]\nInformation in document 2\n[3]\nInformation in document 3",
104 | topic="Example Topic",
105 | section="Example Section",
106 | output="# Example Topic\n## Subsection 1\n"
107 | "This is an example sentence [1]. This is another example sentence [2][3].\n"
108 | "## Subsection 2\nThis is one more example sentence [1].",
109 | )
110 | runner.storm_article_generation.section_gen.write_section.demos = [
111 | write_section_example
112 | ]
113 |
114 | topic = input("Topic: ")
115 | runner.run(
116 | topic=topic,
117 | do_research=args.do_research,
118 | do_generate_outline=args.do_generate_outline,
119 | do_generate_article=args.do_generate_article,
120 | do_polish_article=args.do_polish_article,
121 | )
122 | runner.post_run()
123 | runner.summary()
124 |
125 |
126 | if __name__ == "__main__":
127 | parser = ArgumentParser()
128 | # global arguments
129 | parser.add_argument(
130 | "--url", type=str, default="http://localhost", help="URL of the Ollama server."
131 | )
132 | parser.add_argument(
133 | "--port", type=int, default=11434, help="Port of the Ollama server."
134 | )
135 | parser.add_argument(
136 | "--model", type=str, default="llama3:latest", help="Model of the Ollama server."
137 | )
138 | parser.add_argument(
139 | "--output-dir",
140 | type=str,
141 | default="./results/ollama",
142 | help="Directory to store the outputs.",
143 | )
144 | parser.add_argument(
145 | "--max-thread-num",
146 | type=int,
147 | default=3,
148 | help="Maximum number of threads to use. The information seeking part and the article generation"
149 | "part can speed up by using multiple threads. Consider reducing it if keep getting "
150 | '"Exceed rate limit" error when calling LM API.',
151 | )
152 | parser.add_argument(
153 | "--retriever",
154 | type=str,
155 | choices=["searxng"],
156 | help="The search engine API to use for retrieving information.",
157 | )
158 | parser.add_argument(
159 | "--searxng-api-url", type=str, required=True, help="URL of the SearXNG API."
160 | )
161 | # stage of the pipeline
162 | parser.add_argument(
163 | "--do-research",
164 | action="store_true",
165 | help="If True, simulate conversation to research the topic; otherwise, load the results.",
166 | )
167 | parser.add_argument(
168 | "--do-generate-outline",
169 | action="store_true",
170 | help="If True, generate an outline for the topic; otherwise, load the results.",
171 | )
172 | parser.add_argument(
173 | "--do-generate-article",
174 | action="store_true",
175 | help="If True, generate an article for the topic; otherwise, load the results.",
176 | )
177 | parser.add_argument(
178 | "--do-polish-article",
179 | action="store_true",
180 | help="If True, polish the article by adding a summarization section and (optionally) removing "
181 | "duplicate content.",
182 | )
183 | # hyperparameters for the pre-writing stage
184 | parser.add_argument(
185 | "--max-conv-turn",
186 | type=int,
187 | default=3,
188 | help="Maximum number of questions in conversational question asking.",
189 | )
190 | parser.add_argument(
191 | "--max-perspective",
192 | type=int,
193 | default=3,
194 | help="Maximum number of perspectives to consider in perspective-guided question asking.",
195 | )
196 | parser.add_argument(
197 | "--search-top-k",
198 | type=int,
199 | default=3,
200 | help="Top k search results to consider for each search query.",
201 | )
202 | # hyperparameters for the writing stage
203 | parser.add_argument(
204 | "--retrieve-top-k",
205 | type=int,
206 | default=3,
207 | help="Top k collected references for each section title.",
208 | )
209 | parser.add_argument(
210 | "--remove-duplicate",
211 | action="store_true",
212 | help="If True, remove duplicate content from the article.",
213 | )
214 |
215 | main(parser.parse_args())
216 |
--------------------------------------------------------------------------------
/examples/storm_examples/run_storm_wiki_serper.py:
--------------------------------------------------------------------------------
1 | """
2 | STORM Wiki pipeline powered by Claude family models and serper search engine.
3 | You need to set up the following environment variables to run this script:
4 | - ANTHROPIC_API_KEY: Anthropic API key
5 | - SERPER_API_KEY: Serper.dev api key
6 |
7 | Output will be structured as below
8 | args.output_dir/
9 | topic_name/ # topic_name will follow convention of underscore-connected topic name w/o space and slash
10 | conversation_log.json # Log of information-seeking conversation
11 | raw_search_results.json # Raw search results from search engine
12 | direct_gen_outline.txt # Outline directly generated with LLM's parametric knowledge
13 | storm_gen_outline.txt # Outline refined with collected information
14 | url_to_info.json # Sources that are used in the final article
15 | storm_gen_article.txt # Final article generated
16 | storm_gen_article_polished.txt # Polished final article (if args.do_polish_article is True)
17 | """
18 |
19 | import os
20 | from argparse import ArgumentParser
21 |
22 | from knowledge_storm import (
23 | STORMWikiRunnerArguments,
24 | STORMWikiRunner,
25 | STORMWikiLMConfigs,
26 | )
27 | from knowledge_storm.lm import ClaudeModel
28 | from knowledge_storm.rm import SerperRM
29 | from knowledge_storm.utils import load_api_key
30 |
31 |
32 | def main(args):
33 | load_api_key(toml_file_path="secrets.toml")
34 | lm_configs = STORMWikiLMConfigs()
35 | claude_kwargs = {
36 | "api_key": os.getenv("ANTHROPIC_API_KEY"),
37 | "temperature": 1.0,
38 | "top_p": 0.9,
39 | }
40 |
41 | # STORM is a LM system so different components can be powered by different models.
42 | # For a good balance between cost and quality, you can choose a cheaper/faster model for conv_simulator_lm
43 | # which is used to split queries, synthesize answers in the conversation. We recommend using stronger models
44 | # for outline_gen_lm which is responsible for organizing the collected information, and article_gen_lm
45 | # which is responsible for generating sections with citations.
46 | conv_simulator_lm = ClaudeModel(
47 | model="claude-3-haiku-20240307", max_tokens=500, **claude_kwargs
48 | )
49 | question_asker_lm = ClaudeModel(
50 | model="claude-3-sonnet-20240229", max_tokens=500, **claude_kwargs
51 | )
52 | outline_gen_lm = ClaudeModel(
53 | model="claude-3-opus-20240229", max_tokens=400, **claude_kwargs
54 | )
55 | article_gen_lm = ClaudeModel(
56 | model="claude-3-opus-20240229", max_tokens=700, **claude_kwargs
57 | )
58 | article_polish_lm = ClaudeModel(
59 | model="claude-3-opus-20240229", max_tokens=4000, **claude_kwargs
60 | )
61 |
62 | lm_configs.set_conv_simulator_lm(conv_simulator_lm)
63 | lm_configs.set_question_asker_lm(question_asker_lm)
64 | lm_configs.set_outline_gen_lm(outline_gen_lm)
65 | lm_configs.set_article_gen_lm(article_gen_lm)
66 | lm_configs.set_article_polish_lm(article_polish_lm)
67 |
68 | engine_args = STORMWikiRunnerArguments(
69 | output_dir=args.output_dir,
70 | max_conv_turn=args.max_conv_turn,
71 | max_perspective=args.max_perspective,
72 | search_top_k=args.search_top_k,
73 | max_thread_num=args.max_thread_num,
74 | )
75 | # Documentation to generate the data is available here:
76 | # https://serper.dev/playground
77 | # Important to note that tbs(date range is hardcoded values).
78 | # num is results per pages and is recommended to use in increments of 10(10, 20, etc).
79 | # page is how many pages will be searched.
80 | # h1 is where the google search will orginate from.
81 | topic = input("topic: ")
82 | data = {"autocorrect": True, "num": 10, "page": 1}
83 | rm = SerperRM(serper_search_api_key=os.getenv("SERPER_API_KEY"), query_params=data)
84 |
85 | runner = STORMWikiRunner(engine_args, lm_configs, rm)
86 |
87 | runner.run(
88 | topic=topic,
89 | do_research=args.do_research,
90 | do_generate_outline=args.do_generate_outline,
91 | do_generate_article=args.do_generate_article,
92 | do_polish_article=args.do_polish_article,
93 | )
94 | runner.post_run()
95 | runner.summary()
96 |
97 |
98 | if __name__ == "__main__":
99 | parser = ArgumentParser()
100 | # global arguments
101 | parser.add_argument(
102 | "--output-dir",
103 | type=str,
104 | default="./results/serper",
105 | help="Directory to store the outputs.",
106 | )
107 | parser.add_argument(
108 | "--max-thread-num",
109 | type=int,
110 | default=3,
111 | help="Maximum number of threads to use. The information seeking part and the article generation"
112 | "part can speed up by using multiple threads. Consider reducing it if keep getting "
113 | '"Exceed rate limit" error when calling LM API.',
114 | )
115 | parser.add_argument(
116 | "--retriever",
117 | type=str,
118 | choices=["bing", "you", "serper"],
119 | help="The search engine API to use for retrieving information.",
120 | )
121 | # stage of the pipeline
122 | parser.add_argument(
123 | "--do-research",
124 | action="store_true",
125 | help="If True, simulate conversation to research the topic; otherwise, load the results.",
126 | )
127 | parser.add_argument(
128 | "--do-generate-outline",
129 | action="store_true",
130 | help="If True, generate an outline for the topic; otherwise, load the results.",
131 | )
132 | parser.add_argument(
133 | "--do-generate-article",
134 | action="store_true",
135 | help="If True, generate an article for the topic; otherwise, load the results.",
136 | )
137 | parser.add_argument(
138 | "--do-polish-article",
139 | action="store_true",
140 | help="If True, polish the article by adding a summarization section and (optionally) removing "
141 | "duplicate content.",
142 | )
143 | # hyperparameters for the pre-writing stage
144 | parser.add_argument(
145 | "--max-conv-turn",
146 | type=int,
147 | default=3,
148 | help="Maximum number of questions in conversational question asking.",
149 | )
150 | parser.add_argument(
151 | "--max-perspective",
152 | type=int,
153 | default=3,
154 | help="Maximum number of perspectives to consider in perspective-guided question asking.",
155 | )
156 | parser.add_argument(
157 | "--search-top-k",
158 | type=int,
159 | default=3,
160 | help="Top k search results to consider for each search query.",
161 | )
162 | # hyperparameters for the writing stage
163 | parser.add_argument(
164 | "--retrieve-top-k",
165 | type=int,
166 | default=3,
167 | help="Top k collected references for each section title.",
168 | )
169 | parser.add_argument(
170 | "--remove-duplicate",
171 | action="store_true",
172 | help="If True, remove duplicate content from the article.",
173 | )
174 |
175 | main(parser.parse_args())
176 |
--------------------------------------------------------------------------------
/frontend/demo_light/.streamlit/config.toml:
--------------------------------------------------------------------------------
1 | [client]
2 | showErrorDetails = false
3 | toolbarMode = "minimal"
4 |
5 | [theme]
6 | primaryColor = "#F63366"
7 | backgroundColor = "#FFFFFF"
8 | secondaryBackgroundColor = "#F0F2F6"
9 | textColor = "#262730"
10 | font = "sans serif"
--------------------------------------------------------------------------------
/frontend/demo_light/README.md:
--------------------------------------------------------------------------------
1 | # STORM Minimal User Interface
2 |
3 | This is a minimal user interface for `STORMWikiRunner` which includes the following features:
4 | 1. Allowing user to create a new article through the "Create New Article" page.
5 | 2. Showing the intermediate steps of STORMWikiRunner in real-time when creating an article.
6 | 3. Displaying the written article and references side by side.
7 | 4. Allowing user to view previously created articles through the "My Articles" page.
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 | ## Setup
18 | 1. Make sure you have installed `knowledge-storm` or set up the source code correctly.
19 | 2. Install additional packages required by the user interface:
20 | ```bash
21 | pip install -r requirements.txt
22 | ```
23 | 2. Make sure you set up the API keys following the instructions in the main README file. Create a copy of `secrets.toml` and place it under `.streamlit/`.
24 | 3. Run the following command to start the user interface:
25 | ```bash
26 | streamlit run storm.py
27 | ```
28 | The user interface will create a `DEMO_WORKING_DIR` directory in the current directory to store the outputs.
29 |
30 | ## Customization
31 |
32 | You can customize the `STORMWikiRunner` powering the user interface according to [the guidelines](https://github.com/stanford-oval/storm?tab=readme-ov-file#customize-storm) in the main README file.
33 |
34 | The `STORMWikiRunner` is initialized in `set_storm_runner()` in [demo_util.py](demo_util.py). You can change `STORMWikiRunnerArguments`, `STORMWikiLMConfigs`, or use a different retrieval model according to your need.
35 |
--------------------------------------------------------------------------------
/frontend/demo_light/assets/article_display.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stanford-oval/storm/45ee413100f0287da9ca5250290a56ac4fa73c48/frontend/demo_light/assets/article_display.jpg
--------------------------------------------------------------------------------
/frontend/demo_light/assets/create_article.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stanford-oval/storm/45ee413100f0287da9ca5250290a56ac4fa73c48/frontend/demo_light/assets/create_article.jpg
--------------------------------------------------------------------------------
/frontend/demo_light/assets/void.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stanford-oval/storm/45ee413100f0287da9ca5250290a56ac4fa73c48/frontend/demo_light/assets/void.jpg
--------------------------------------------------------------------------------
/frontend/demo_light/pages_util/CreateNewArticle.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 |
4 | import demo_util
5 | import streamlit as st
6 | from demo_util import (
7 | DemoFileIOHelper,
8 | DemoTextProcessingHelper,
9 | DemoUIHelper,
10 | truncate_filename,
11 | )
12 |
13 |
14 | def handle_not_started():
15 | if st.session_state["page3_write_article_state"] == "not started":
16 | _, search_form_column, _ = st.columns([2, 5, 2])
17 | with search_form_column:
18 | with st.form(key="search_form"):
19 | # Text input for the search topic
20 | DemoUIHelper.st_markdown_adjust_size(
21 | content="Enter the topic you want to learn in depth:", font_size=18
22 | )
23 | st.session_state["page3_topic"] = st.text_input(
24 | label="page3_topic", label_visibility="collapsed"
25 | )
26 | pass_appropriateness_check = True
27 |
28 | # Submit button for the form
29 | submit_button = st.form_submit_button(label="Research")
30 | # only start new search when button is clicked, not started, or already finished previous one
31 | if submit_button and st.session_state["page3_write_article_state"] in [
32 | "not started",
33 | "show results",
34 | ]:
35 | if not st.session_state["page3_topic"].strip():
36 | pass_appropriateness_check = False
37 | st.session_state["page3_warning_message"] = (
38 | "topic could not be empty"
39 | )
40 |
41 | st.session_state["page3_topic_name_cleaned"] = (
42 | st.session_state["page3_topic"]
43 | .replace(" ", "_")
44 | .replace("/", "_")
45 | )
46 | st.session_state["page3_topic_name_truncated"] = truncate_filename(
47 | st.session_state["page3_topic_name_cleaned"]
48 | )
49 | if not pass_appropriateness_check:
50 | st.session_state["page3_write_article_state"] = "not started"
51 | alert = st.warning(
52 | st.session_state["page3_warning_message"], icon="⚠️"
53 | )
54 | time.sleep(5)
55 | alert.empty()
56 | else:
57 | st.session_state["page3_write_article_state"] = "initiated"
58 |
59 |
60 | def handle_initiated():
61 | if st.session_state["page3_write_article_state"] == "initiated":
62 | current_working_dir = os.path.join(demo_util.get_demo_dir(), "DEMO_WORKING_DIR")
63 | if not os.path.exists(current_working_dir):
64 | os.makedirs(current_working_dir)
65 |
66 | if "runner" not in st.session_state:
67 | demo_util.set_storm_runner()
68 | st.session_state["page3_current_working_dir"] = current_working_dir
69 | st.session_state["page3_write_article_state"] = "pre_writing"
70 |
71 |
72 | def handle_pre_writing():
73 | if st.session_state["page3_write_article_state"] == "pre_writing":
74 | status = st.status(
75 | "I am brain**STORM**ing now to research the topic. (This may take 2-3 minutes.)"
76 | )
77 | st_callback_handler = demo_util.StreamlitCallbackHandler(status)
78 | with status:
79 | # STORM main gen outline
80 | st.session_state["runner"].run(
81 | topic=st.session_state["page3_topic"],
82 | do_research=True,
83 | do_generate_outline=True,
84 | do_generate_article=False,
85 | do_polish_article=False,
86 | callback_handler=st_callback_handler,
87 | )
88 | conversation_log_path = os.path.join(
89 | st.session_state["page3_current_working_dir"],
90 | st.session_state["page3_topic_name_truncated"],
91 | "conversation_log.json",
92 | )
93 | demo_util._display_persona_conversations(
94 | DemoFileIOHelper.read_json_file(conversation_log_path)
95 | )
96 | st.session_state["page3_write_article_state"] = "final_writing"
97 | status.update(label="brain**STORM**ing complete!", state="complete")
98 |
99 |
100 | def handle_final_writing():
101 | if st.session_state["page3_write_article_state"] == "final_writing":
102 | # polish final article
103 | with st.status(
104 | "Now I will connect the information I found for your reference. (This may take 4-5 minutes.)"
105 | ) as status:
106 | st.info(
107 | "Now I will connect the information I found for your reference. (This may take 4-5 minutes.)"
108 | )
109 | st.session_state["runner"].run(
110 | topic=st.session_state["page3_topic"],
111 | do_research=False,
112 | do_generate_outline=False,
113 | do_generate_article=True,
114 | do_polish_article=True,
115 | remove_duplicate=False,
116 | )
117 | # finish the session
118 | st.session_state["runner"].post_run()
119 |
120 | # update status bar
121 | st.session_state["page3_write_article_state"] = "prepare_to_show_result"
122 | status.update(label="information snythesis complete!", state="complete")
123 |
124 |
125 | def handle_prepare_to_show_result():
126 | if st.session_state["page3_write_article_state"] == "prepare_to_show_result":
127 | _, show_result_col, _ = st.columns([4, 3, 4])
128 | with show_result_col:
129 | if st.button("show final article"):
130 | st.session_state["page3_write_article_state"] = "completed"
131 | st.rerun()
132 |
133 |
134 | def handle_completed():
135 | if st.session_state["page3_write_article_state"] == "completed":
136 | # display polished article
137 | current_working_dir_paths = DemoFileIOHelper.read_structure_to_dict(
138 | st.session_state["page3_current_working_dir"]
139 | )
140 | current_article_file_path_dict = current_working_dir_paths[
141 | st.session_state["page3_topic_name_truncated"]
142 | ]
143 | demo_util.display_article_page(
144 | selected_article_name=st.session_state["page3_topic_name_cleaned"],
145 | selected_article_file_path_dict=current_article_file_path_dict,
146 | show_title=True,
147 | show_main_article=True,
148 | )
149 |
150 |
151 | def create_new_article_page():
152 | demo_util.clear_other_page_session_state(page_index=3)
153 |
154 | if "page3_write_article_state" not in st.session_state:
155 | st.session_state["page3_write_article_state"] = "not started"
156 |
157 | handle_not_started()
158 |
159 | handle_initiated()
160 |
161 | handle_pre_writing()
162 |
163 | handle_final_writing()
164 |
165 | handle_prepare_to_show_result()
166 |
167 | handle_completed()
168 |
--------------------------------------------------------------------------------
/frontend/demo_light/pages_util/MyArticles.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import demo_util
4 | import streamlit as st
5 | from demo_util import DemoFileIOHelper, DemoUIHelper
6 | from streamlit_card import card
7 |
8 |
9 | # set page config and display title
10 | def my_articles_page():
11 | with st.sidebar:
12 | _, return_button_col = st.columns([2, 5])
13 | with return_button_col:
14 | if st.button(
15 | "Select another article",
16 | disabled="page2_selected_my_article" not in st.session_state,
17 | ):
18 | if "page2_selected_my_article" in st.session_state:
19 | del st.session_state["page2_selected_my_article"]
20 | st.rerun()
21 |
22 | # sync my articles
23 | if "page2_user_articles_file_path_dict" not in st.session_state:
24 | local_dir = os.path.join(demo_util.get_demo_dir(), "DEMO_WORKING_DIR")
25 | os.makedirs(local_dir, exist_ok=True)
26 | st.session_state["page2_user_articles_file_path_dict"] = (
27 | DemoFileIOHelper.read_structure_to_dict(local_dir)
28 | )
29 |
30 | # if no feature demo selected, display all featured articles as info cards
31 | def article_card_setup(column_to_add, card_title, article_name):
32 | with column_to_add:
33 | cleaned_article_title = article_name.replace("_", " ")
34 | hasClicked = card(
35 | title=" / ".join(card_title),
36 | text=article_name.replace("_", " "),
37 | image=DemoFileIOHelper.read_image_as_base64(
38 | os.path.join(demo_util.get_demo_dir(), "assets", "void.jpg")
39 | ),
40 | styles=DemoUIHelper.get_article_card_UI_style(boarder_color="#9AD8E1"),
41 | )
42 | if hasClicked:
43 | st.session_state["page2_selected_my_article"] = article_name
44 | st.rerun()
45 |
46 | if "page2_selected_my_article" not in st.session_state:
47 | # display article cards
48 | my_article_columns = st.columns(3)
49 | if len(st.session_state["page2_user_articles_file_path_dict"]) > 0:
50 | # get article names
51 | article_names = sorted(
52 | list(st.session_state["page2_user_articles_file_path_dict"].keys())
53 | )
54 | # configure pagination
55 | pagination = st.container()
56 | bottom_menu = st.columns((1, 4, 1, 1, 1))[1:-1]
57 | with bottom_menu[2]:
58 | batch_size = st.selectbox("Page Size", options=[24, 48, 72])
59 | with bottom_menu[1]:
60 | total_pages = (
61 | int(len(article_names) / batch_size)
62 | if int(len(article_names) / batch_size) > 0
63 | else 1
64 | )
65 | current_page = st.number_input(
66 | "Page", min_value=1, max_value=total_pages, step=1
67 | )
68 | with bottom_menu[0]:
69 | st.markdown(f"Page **{current_page}** of **{total_pages}** ")
70 | # show article cards
71 | with pagination:
72 | my_article_count = 0
73 | start_index = (current_page - 1) * batch_size
74 | end_index = min(current_page * batch_size, len(article_names))
75 | for article_name in article_names[start_index:end_index]:
76 | column_to_add = my_article_columns[my_article_count % 3]
77 | my_article_count += 1
78 | article_card_setup(
79 | column_to_add=column_to_add,
80 | card_title=["My Article"],
81 | article_name=article_name,
82 | )
83 | else:
84 | with my_article_columns[0]:
85 | hasClicked = card(
86 | title="Get started",
87 | text="Start your first research!",
88 | image=DemoFileIOHelper.read_image_as_base64(
89 | os.path.join(demo_util.get_demo_dir(), "assets", "void.jpg")
90 | ),
91 | styles=DemoUIHelper.get_article_card_UI_style(),
92 | )
93 | if hasClicked:
94 | st.session_state.selected_page = 1
95 | st.session_state["manual_selection_override"] = True
96 | st.session_state["rerun_requested"] = True
97 | st.rerun()
98 | else:
99 | selected_article_name = st.session_state["page2_selected_my_article"]
100 | selected_article_file_path_dict = st.session_state[
101 | "page2_user_articles_file_path_dict"
102 | ][selected_article_name]
103 |
104 | demo_util.display_article_page(
105 | selected_article_name=selected_article_name,
106 | selected_article_file_path_dict=selected_article_file_path_dict,
107 | show_title=True,
108 | show_main_article=True,
109 | )
110 |
--------------------------------------------------------------------------------
/frontend/demo_light/requirements.txt:
--------------------------------------------------------------------------------
1 | streamlit==1.31.1
2 | streamlit-card
3 | markdown
4 | unidecode
5 | extra-streamlit-components==0.1.60
6 | streamlit_extras
7 | deprecation==2.1.0
8 | st-pages==0.4.5
9 | streamlit-float
10 | streamlit-option-menu
--------------------------------------------------------------------------------
/frontend/demo_light/stoc.py:
--------------------------------------------------------------------------------
1 | """https://github.com/arnaudmiribel/stoc"""
2 |
3 | import re
4 |
5 | import streamlit as st
6 | import unidecode
7 |
8 | DISABLE_LINK_CSS = """
9 | """
15 |
16 |
17 | class stoc:
18 | def __init__(self):
19 | self.toc_items = list()
20 |
21 | def h1(self, text: str, write: bool = True):
22 | if write:
23 | st.write(f"# {text}")
24 | self.toc_items.append(("h1", text))
25 |
26 | def h2(self, text: str, write: bool = True):
27 | if write:
28 | st.write(f"## {text}")
29 | self.toc_items.append(("h2", text))
30 |
31 | def h3(self, text: str, write: bool = True):
32 | if write:
33 | st.write(f"### {text}")
34 | self.toc_items.append(("h3", text))
35 |
36 | def toc(self, expander):
37 | st.write(DISABLE_LINK_CSS, unsafe_allow_html=True)
38 | # st.sidebar.caption("Table of contents")
39 | if expander is None:
40 | expander = st.sidebar.expander("**Table of contents**", expanded=True)
41 | with expander:
42 | with st.container(height=600, border=False):
43 | markdown_toc = ""
44 | for title_size, title in self.toc_items:
45 | h = int(title_size.replace("h", ""))
46 | markdown_toc += (
47 | " " * 2 * h
48 | + "- "
49 | + f' {title} \n'
50 | )
51 | # st.sidebar.write(markdown_toc, unsafe_allow_html=True)
52 | st.write(markdown_toc, unsafe_allow_html=True)
53 |
54 | @classmethod
55 | def get_toc(cls, markdown_text: str, topic=""):
56 | def increase_heading_depth_and_add_top_heading(markdown_text, new_top_heading):
57 | lines = markdown_text.splitlines()
58 | # Increase the depth of each heading by adding an extra '#'
59 | increased_depth_lines = [
60 | "#" + line if line.startswith("#") else line for line in lines
61 | ]
62 | # Add the new top-level heading at the beginning
63 | increased_depth_lines.insert(0, f"# {new_top_heading}")
64 | # Re-join the modified lines back into a single string
65 | modified_text = "\n".join(increased_depth_lines)
66 | return modified_text
67 |
68 | if topic:
69 | markdown_text = increase_heading_depth_and_add_top_heading(
70 | markdown_text, topic
71 | )
72 | toc = []
73 | for line in markdown_text.splitlines():
74 | if line.startswith("#"):
75 | # Remove the '#' characters and strip leading/trailing spaces
76 | heading_text = line.lstrip("#").strip()
77 | # Create slug (lowercase, spaces to hyphens, remove non-alphanumeric characters)
78 | slug = (
79 | re.sub(r"[^a-zA-Z0-9\s-]", "", heading_text)
80 | .lower()
81 | .replace(" ", "-")
82 | )
83 | # Determine heading level for indentation
84 | level = line.count("#") - 1
85 | # Add to the table of contents
86 | toc.append(" " * level + f"- [{heading_text}](#{slug})")
87 | return "\n".join(toc)
88 |
89 | @classmethod
90 | def from_markdown(cls, text: str, expander=None):
91 | self = cls()
92 | for line in text.splitlines():
93 | if line.startswith("###"):
94 | self.h3(line[3:], write=False)
95 | elif line.startswith("##"):
96 | self.h2(line[2:], write=False)
97 | elif line.startswith("#"):
98 | self.h1(line[1:], write=False)
99 | # customize markdown font size
100 | custom_css = """
101 |
111 | """
112 | st.markdown(custom_css, unsafe_allow_html=True)
113 |
114 | st.write(text)
115 | self.toc(expander=expander)
116 |
117 |
118 | def normalize(s):
119 | """
120 | Normalize titles as valid HTML ids for anchors
121 | >>> normalize("it's a test to spot how Things happ3n héhé")
122 | "it-s-a-test-to-spot-how-things-happ3n-h-h"
123 | """
124 |
125 | # Replace accents with "-"
126 | s_wo_accents = unidecode.unidecode(s)
127 | accents = [s for s in s if s not in s_wo_accents]
128 | for accent in accents:
129 | s = s.replace(accent, "-")
130 |
131 | # Lowercase
132 | s = s.lower()
133 |
134 | # Keep only alphanum and remove "-" suffix if existing
135 | normalized = (
136 | "".join([char if char.isalnum() else "-" for char in s]).strip("-").lower()
137 | )
138 |
139 | return normalized
140 |
--------------------------------------------------------------------------------
/frontend/demo_light/storm.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | script_dir = os.path.dirname(os.path.abspath(__file__))
4 | wiki_root_dir = os.path.dirname(os.path.dirname(script_dir))
5 |
6 | import demo_util
7 | from pages_util import MyArticles, CreateNewArticle
8 | from streamlit_float import *
9 | from streamlit_option_menu import option_menu
10 |
11 |
12 | def main():
13 | global database
14 | st.set_page_config(layout="wide")
15 |
16 | if "first_run" not in st.session_state:
17 | st.session_state["first_run"] = True
18 |
19 | # set api keys from secrets
20 | if st.session_state["first_run"]:
21 | for key, value in st.secrets.items():
22 | if type(value) == str:
23 | os.environ[key] = value
24 |
25 | # initialize session_state
26 | if "selected_article_index" not in st.session_state:
27 | st.session_state["selected_article_index"] = 0
28 | if "selected_page" not in st.session_state:
29 | st.session_state["selected_page"] = 0
30 | if st.session_state.get("rerun_requested", False):
31 | st.session_state["rerun_requested"] = False
32 | st.rerun()
33 |
34 | st.write(
35 | "", unsafe_allow_html=True
36 | )
37 | menu_container = st.container()
38 | with menu_container:
39 | pages = ["My Articles", "Create New Article"]
40 | styles = {
41 | "container": {"padding": "0.2rem 0", "background-color": "#22222200"},
42 | }
43 | menu_selection = option_menu(
44 | None,
45 | pages,
46 | icons=["house", "search"],
47 | menu_icon="cast",
48 | default_index=0,
49 | orientation="horizontal",
50 | manual_select=st.session_state.selected_page,
51 | styles=styles,
52 | key="menu_selection",
53 | )
54 | if st.session_state.get("manual_selection_override", False):
55 | menu_selection = pages[st.session_state["selected_page"]]
56 | st.session_state["manual_selection_override"] = False
57 | st.session_state["selected_page"] = None
58 |
59 | if menu_selection == "My Articles":
60 | demo_util.clear_other_page_session_state(page_index=2)
61 | MyArticles.my_articles_page()
62 | elif menu_selection == "Create New Article":
63 | demo_util.clear_other_page_session_state(page_index=3)
64 | CreateNewArticle.create_new_article_page()
65 |
66 |
67 | if __name__ == "__main__":
68 | main()
69 |
--------------------------------------------------------------------------------
/knowledge_storm/__init__.py:
--------------------------------------------------------------------------------
1 | from .storm_wiki import *
2 | from .collaborative_storm import *
3 | from .encoder import *
4 | from .interface import *
5 | from .lm import *
6 | from .rm import *
7 | from .utils import *
8 | from .dataclass import *
9 |
10 | __version__ = "1.1.0"
11 |
--------------------------------------------------------------------------------
/knowledge_storm/collaborative_storm/__init__.py:
--------------------------------------------------------------------------------
1 | from .modules import *
2 | from .engine import *
3 |
--------------------------------------------------------------------------------
/knowledge_storm/collaborative_storm/modules/__init__.py:
--------------------------------------------------------------------------------
1 | from .article_generation import *
2 | from .grounded_question_answering import *
3 | from .grounded_question_generation import *
4 | from .information_insertion_module import *
5 | from .simulate_user import *
6 | from .warmstart_hierarchical_chat import *
7 | from .knowledge_base_summary import *
8 | from .costorm_expert_utterance_generator import *
9 |
--------------------------------------------------------------------------------
/knowledge_storm/collaborative_storm/modules/article_generation.py:
--------------------------------------------------------------------------------
1 | import dspy
2 | from concurrent.futures import ThreadPoolExecutor, as_completed
3 | from typing import Set, Union
4 |
5 | from .collaborative_storm_utils import clean_up_section
6 | from ...dataclass import KnowledgeBase, KnowledgeNode
7 |
8 |
9 | class ArticleGenerationModule(dspy.Module):
10 | """Use the information collected from the information-seeking conversation to write a section."""
11 |
12 | def __init__(
13 | self,
14 | engine: Union[dspy.dsp.LM, dspy.dsp.HFModel],
15 | ):
16 | super().__init__()
17 | self.write_section = dspy.Predict(WriteSection)
18 | self.engine = engine
19 |
20 | def _get_cited_information_string(
21 | self,
22 | all_citation_index: Set[int],
23 | knowledge_base: KnowledgeBase,
24 | max_words: int = 4000,
25 | ):
26 | information = []
27 | cur_word_count = 0
28 | for index in sorted(list(all_citation_index)):
29 | info = knowledge_base.info_uuid_to_info_dict[index]
30 | snippet = info.snippets[0]
31 | info_text = f"[{index}]: {snippet} (Question: {info.meta['question']}. Query: {info.meta['query']})"
32 | cur_snippet_length = len(info_text.split())
33 | if cur_snippet_length + cur_word_count > max_words:
34 | break
35 | cur_word_count += cur_snippet_length
36 | information.append(info_text)
37 | return "\n".join(information)
38 |
39 | def gen_section(
40 | self, topic: str, node: KnowledgeNode, knowledge_base: KnowledgeBase
41 | ):
42 | if node is None or len(node.content) == 0:
43 | return ""
44 | if (
45 | node.synthesize_output is not None
46 | and node.synthesize_output
47 | and not node.need_regenerate_synthesize_output
48 | ):
49 | return node.synthesize_output
50 | all_citation_index = node.collect_all_content()
51 | information = self._get_cited_information_string(
52 | all_citation_index=all_citation_index, knowledge_base=knowledge_base
53 | )
54 | with dspy.settings.context(lm=self.engine):
55 | synthesize_output = clean_up_section(
56 | self.write_section(
57 | topic=topic, info=information, section=node.name
58 | ).output
59 | )
60 | node.synthesize_output = synthesize_output
61 | node.need_regenerate_synthesize_output = False
62 | return node.synthesize_output
63 |
64 | def forward(self, knowledge_base: KnowledgeBase):
65 | all_nodes = knowledge_base.collect_all_nodes()
66 | node_to_paragraph = {}
67 |
68 | # Define a function to generate paragraphs for nodes
69 | def _node_generate_paragraph(node):
70 | node_gen_paragraph = self.gen_section(
71 | topic=knowledge_base.topic, node=node, knowledge_base=knowledge_base
72 | )
73 | lines = node_gen_paragraph.split("\n")
74 | if lines[0].strip().replace("*", "").replace("#", "") == node.name:
75 | lines = lines[1:]
76 | node_gen_paragraph = "\n".join(lines)
77 | path = " -> ".join(node.get_path_from_root())
78 | return path, node_gen_paragraph
79 |
80 | with ThreadPoolExecutor(max_workers=5) as executor:
81 | # Submit all tasks
82 | future_to_node = {
83 | executor.submit(_node_generate_paragraph, node): node
84 | for node in all_nodes
85 | }
86 |
87 | # Collect the results as they complete
88 | for future in as_completed(future_to_node):
89 | path, node_gen_paragraph = future.result()
90 | node_to_paragraph[path] = node_gen_paragraph
91 |
92 | def helper(cur_root, level):
93 | to_return = []
94 | if cur_root is not None:
95 | hash_tag = "#" * level + " "
96 | cur_path = " -> ".join(cur_root.get_path_from_root())
97 | node_gen_paragraph = node_to_paragraph[cur_path]
98 | to_return.append(f"{hash_tag}{cur_root.name}\n{node_gen_paragraph}")
99 | for child in cur_root.children:
100 | to_return.extend(helper(child, level + 1))
101 | return to_return
102 |
103 | to_return = []
104 | for child in knowledge_base.root.children:
105 | to_return.extend(helper(child, level=1))
106 |
107 | return "\n".join(to_return)
108 |
109 |
110 | class WriteSection(dspy.Signature):
111 | """Write a Wikipedia section based on the collected information. You will be given the topic, the section you are writing and relevant information.
112 | Each information will be provided with the raw content along with question and query lead to that information.
113 | Here is the format of your writing:
114 | Use [1], [2], ..., [n] in line (for example, "The capital of the United States is Washington, D.C.[1][3]."). You DO NOT need to include a References or Sources section to list the sources at the end.
115 | """
116 |
117 | info = dspy.InputField(prefix="The collected information:\n", format=str)
118 | topic = dspy.InputField(prefix="The topic of the page: ", format=str)
119 | section = dspy.InputField(prefix="The section you need to write: ", format=str)
120 | output = dspy.OutputField(
121 | prefix="Write the section with proper inline citations (Start your writing. Don't include the page title, section name, or try to write other sections. Do not start the section with topic name.):\n",
122 | format=str,
123 | )
124 |
--------------------------------------------------------------------------------
/knowledge_storm/collaborative_storm/modules/callback.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 | from ...interface import Information
3 |
4 |
5 | class BaseCallbackHandler:
6 | """Base callback handler to manage callbacks from the Co-STORM pipeline."""
7 |
8 | def on_turn_policy_planning_start(self, **kwargs):
9 | """Run when the turn policy planning begins, before deciding the direction or goal for the next conversation turn."""
10 | pass
11 |
12 | def on_expert_action_planning_start(self, **kwargs):
13 | """Run when the expert action planning begins, preparing to determine the actions that each expert should take."""
14 | pass
15 |
16 | def on_expert_action_planning_end(self, **kwargs):
17 | """Run when the expert action planning ends, after deciding the actions that each expert should take."""
18 | pass
19 |
20 | def on_expert_information_collection_start(self, **kwargs):
21 | """Run when the expert information collection starts, start gathering all necessary data from selected sources."""
22 | pass
23 |
24 | def on_expert_information_collection_end(self, info: List[Information], **kwargs):
25 | """Run when the expert information collection ends, after gathering all necessary data from selected sources."""
26 | pass
27 |
28 | def on_expert_utterance_generation_end(self, **kwargs):
29 | """Run when the expert utterance generation ends, before creating responses or statements from each expert."""
30 | pass
31 |
32 | def on_expert_utterance_polishing_start(self, **kwargs):
33 | """Run when the expert utterance polishing begins, to refine and improve the clarity and coherence of generated content."""
34 | pass
35 |
36 | def on_mindmap_insert_start(self, **kwargs):
37 | """Run when the process of inserting new information into the mindmap starts."""
38 | pass
39 |
40 | def on_mindmap_insert_end(self, **kwargs):
41 | """Run when the process of inserting new information into the mindmap ends."""
42 | pass
43 |
44 | def on_mindmap_reorg_start(self, **kwargs):
45 | """Run when the reorganization of the mindmap begins, to restructure and optimize the flow of information."""
46 | pass
47 |
48 | def on_expert_list_update_start(self, **kwargs):
49 | """Run when the expert list update starts, to modify or refresh the list of active experts."""
50 | pass
51 |
52 | def on_article_generation_start(self, **kwargs):
53 | """Run when the article generation process begins, to compile and format the final article content."""
54 | pass
55 |
56 | def on_warmstart_update(self, message, **kwargs):
57 | """Run when the warm start process has update."""
58 | pass
59 |
60 |
61 | class LocalConsolePrintCallBackHandler(BaseCallbackHandler):
62 | def __init__(self):
63 | pass
64 |
65 | def on_turn_policy_planning_start(self, **kwargs):
66 | """Run when the turn policy planning begins, before deciding the direction or goal for the next conversation turn."""
67 | print("Start planning next expert; inspect mind map; inspect system state.")
68 |
69 | def on_expert_action_planning_start(self, **kwargs):
70 | """Run when the expert action planning begins, preparing to determine the actions that each expert should take."""
71 | print("Reviewing discourse history; Deciding utterance intent.")
72 |
73 | def on_expert_information_collection_start(self, **kwargs):
74 | """Run when the expert information collection ends, after gathering all necessary data from selected sources."""
75 | print("Start searching with the search engine; browsing collected information.")
76 |
77 | def on_expert_information_collection_end(self, info: List[Information], **kwargs):
78 | """Run when the expert information collection ends, after gathering all necessary data from selected sources."""
79 | if info:
80 | urls = [i.url for i in info]
81 | information_string = "\n".join([f"Finish browsing {url}" for url in urls])
82 | print(information_string)
83 |
84 | def on_expert_utterance_generation_end(self, **kwargs):
85 | """Run when the expert utterance generation ends, before creating responses or statements from each expert."""
86 | print("Finish generating utterance from collected information.")
87 |
88 | def on_expert_utterance_polishing_start(self, **kwargs):
89 | """Run when the expert utterance polishing begins, to refine and improve the clarity and coherence of generated content."""
90 | print("Start polishing utterance.")
91 |
92 | def on_mindmap_insert_start(self, **kwargs):
93 | """Run when the process of inserting new information into the mindmap starts."""
94 | print("Start inserting information into mind map.")
95 |
96 | def on_mindmap_insert_end(self, **kwargs):
97 | """Run when the process of inserting new information into the mindmap ends."""
98 | print("Finish inserting information into mind map.")
99 |
100 | def on_mindmap_reorg_start(self, **kwargs):
101 | """Run when the reorganization of the mindmap begins, to restructure and optimize the flow of information."""
102 | print("Start re-organizing mind map.")
103 |
104 | def on_expert_list_update_start(self, **kwargs):
105 | """Run when the expert list update starts, to modify or refresh the list of active experts."""
106 | print("Start updating expert candidates.")
107 |
108 | def on_warmstart_update(self, message, **kwargs):
109 | """Run when the warm start process has update."""
110 | print(f"Warm start update: {message}")
111 |
--------------------------------------------------------------------------------
/knowledge_storm/collaborative_storm/modules/collaborative_storm_utils.py:
--------------------------------------------------------------------------------
1 | import dspy
2 | import os
3 | import re
4 | import sys
5 | import toml
6 | from typing import List, Tuple, Dict, Optional, TYPE_CHECKING
7 |
8 | if TYPE_CHECKING:
9 | from ..engine import RunnerArgument
10 | from ...interface import Information, Retriever, LMConfigs
11 | from ...logging_wrapper import LoggingWrapper
12 | from ...rm import BingSearch
13 |
14 |
15 | def extract_storm_info_snippet(info: Information, snippet_index: int) -> Information:
16 | """
17 | Constructs a new Information instance with only the specified snippet index.
18 |
19 | Args:
20 | storm_info (Information): The original Information instance.
21 | snippet_index (int): The index of the snippet to retain.
22 |
23 | Returns:
24 | Information: A new Information instance with only the specified snippet.
25 | """
26 | if snippet_index < 0 or snippet_index >= len(info.snippets):
27 | raise ValueError("Snippet index out of range")
28 |
29 | new_snippets = [info.snippets[snippet_index]]
30 | new_storm_info = Information(
31 | info.url, info.description, new_snippets, info.title, info.meta
32 | )
33 | return new_storm_info
34 |
35 |
36 | def format_search_results(
37 | searched_results: List[Information],
38 | info_max_num_words: int = 1000,
39 | mode: str = "brief",
40 | ) -> Tuple[str, Dict[int, Information]]:
41 | """
42 | Constructs a string from a list of search results with a specified word limit and returns a mapping of indices to Information.
43 |
44 | Args:
45 | searched_results (List[Information]): List of Information objects to process.
46 | info_max_num_words (int, optional): Maximum number of words allowed in the output string. Defaults to 1000.
47 | mode (str, optional): Mode of summarization. 'brief' takes only the first snippet of each Information.
48 | 'extensive' adds snippets iteratively until the word limit is reached. Defaults to 'brief'.
49 |
50 | Returns:
51 | Tuple[str, Dict[int, Information]]:
52 | - Formatted string with search results, constrained by the word limit.
53 | - Dictionary mapping indices to the corresponding Information objects.
54 | """
55 | total_length = 0
56 |
57 | extracted_snippet_queue = []
58 | max_snippets = (
59 | max(len(info.snippets) for info in searched_results) if searched_results else 0
60 | )
61 | max_snippets = 1 if mode == "brief" else max_snippets
62 | abort = False
63 | included_snippets = set()
64 | for i in range(max_snippets):
65 | for info in searched_results:
66 | if i < len(info.snippets) and not abort:
67 | cur_snippet = info.snippets[i]
68 | cur_snippet_len = len(info.snippets[i].split())
69 | if total_length + cur_snippet_len > info_max_num_words:
70 | abort = True
71 | break
72 | if cur_snippet not in included_snippets:
73 | included_snippets.add(cur_snippet)
74 | info = extract_storm_info_snippet(info, snippet_index=i)
75 | extracted_snippet_queue.append(info)
76 | total_length += cur_snippet_len
77 | output = []
78 | index_mapping = {}
79 | for idx, info in enumerate(extracted_snippet_queue):
80 | output.append(f"[{idx + 1}]: {info.snippets[0]}")
81 | index_mapping[idx + 1] = info
82 | assert -1 not in index_mapping
83 | return "\n".join(output), index_mapping
84 |
85 |
86 | def extract_cited_storm_info(
87 | response: str, index_to_storm_info: Dict[int, Information]
88 | ) -> Dict[int, Information]:
89 | """
90 | Extracts a sub-dictionary of Information instances that are cited in the response.
91 |
92 | Args:
93 | response (str): The response string containing inline citations like [1], [2], etc.
94 | index_to_storm_info (Dict[int, Information]): A dictionary mapping indices to Information instances.
95 |
96 | Returns:
97 | Dict[int, Information]: A sub-dictionary with only the indices that appear in the response.
98 | """
99 | cited_indices = set(map(int, re.findall(r"\[(\d+)\]", response)))
100 | cited_storm_info = {
101 | index: info
102 | for index, info in index_to_storm_info.items()
103 | if index in cited_indices
104 | }
105 | return cited_storm_info
106 |
107 |
108 | def trim_output_after_hint(response: str, hint: str) -> str:
109 | """
110 | Trims the output string to only keep the substring after the given hint (not including the hint).
111 |
112 | Args:
113 | response (str): The original output string.
114 | hint (str): The hint string after which the substring should be kept.
115 |
116 | Returns:
117 | str: The trimmed output string, or the original string if the hint is not found.
118 | """
119 | if hint in response:
120 | start_index = response.find(hint) + len(hint)
121 | return response[start_index:].strip()
122 | return response.strip("\n")
123 |
124 |
125 | def separate_citations(text: str) -> str:
126 | """
127 | Separates multiple citations within square brackets into individual citations.
128 |
129 | Args:
130 | text (str): The input string containing citations.
131 |
132 | Returns:
133 | str: The string with separated citations.
134 | """
135 |
136 | # Define a function to process each match
137 | def replace_citations(match):
138 | citations = match.group(1).split(",")
139 | return "".join(f"[{citation.strip()}]" for citation in citations)
140 |
141 | # Use regular expressions to find and replace citations
142 | pattern = re.compile(r"\[(\d+(?:,\s*\d+)*)\]")
143 | return pattern.sub(replace_citations, text)
144 |
145 |
146 | def extract_and_remove_citations(text: str) -> Tuple[str, List[int]]:
147 | """
148 | Removes single inline citations from the input string and returns the modified string and a list of citation integers.
149 |
150 | Args:
151 | text (str): The input string containing citations.
152 |
153 | Returns:
154 | Tuple[str, List[int]]: The string after removal of citations and a list of citation integers.
155 | """
156 | citations = []
157 |
158 | # Define a function to process each match
159 | def extract_citation(match):
160 | citation = int(match.group(1))
161 | citations.append(citation)
162 | return ""
163 |
164 | # Use regular expressions to find and replace citations
165 | pattern = re.compile(r"\[(\d+)\]")
166 | modified_text = pattern.sub(extract_citation, text)
167 |
168 | return modified_text, citations
169 |
170 |
171 | def keep_first_and_last_paragraph(text: str) -> str:
172 | """
173 | Processes the input text to keep the first and last paragraphs and replace
174 | the middle paragraphs with '[content omitted due to space limit]'.
175 |
176 | Args:
177 | text (str): The input text containing paragraphs separated by '\n\n'.
178 |
179 | Returns:
180 | str: The processed text.
181 | """
182 | paragraphs = text.split("\n\n")
183 |
184 | if len(paragraphs) <= 3:
185 | return text
186 |
187 | first_paragraph = paragraphs[0]
188 | last_paragraph = "\n\n".join(paragraphs[-2:])
189 | return (
190 | f"{first_paragraph}\n\n[content omitted due to space limit]\n\n{last_paragraph}"
191 | )
192 |
193 |
194 | def clean_up_section(text):
195 | """Clean up a section:
196 | 1. Remove uncompleted sentences (usually due to output token limitation).
197 | 2. Deduplicate individual groups of citations.
198 | 3. Remove unnecessary summary."""
199 |
200 | paragraphs = text.split("\n")
201 | output_paragraphs = []
202 | summary_sec_flag = False
203 | for p in paragraphs:
204 | p = p.strip()
205 | if len(p) == 0:
206 | continue
207 | if not p.startswith("#"):
208 | p = separate_citations(p)
209 | if summary_sec_flag:
210 | if p.startswith("#"):
211 | summary_sec_flag = False
212 | else:
213 | continue
214 | if (
215 | p.startswith("Overall")
216 | or p.startswith("In summary")
217 | or p.startswith("In conclusion")
218 | ):
219 | continue
220 | if "# Summary" in p or "# Conclusion" in p:
221 | summary_sec_flag = True
222 | continue
223 | output_paragraphs.append(p)
224 |
225 | return "\n\n".join(output_paragraphs) # Join with '\n\n' for markdown format.
226 |
227 |
228 | def load_api_key(toml_file_path):
229 | try:
230 | with open(toml_file_path, "r") as file:
231 | data = toml.load(file)
232 | except FileNotFoundError:
233 | print(f"File not found: {toml_file_path}", file=sys.stderr)
234 | return
235 | except toml.TomlDecodeError:
236 | print(f"Error decoding TOML file: {toml_file_path}", file=sys.stderr)
237 | return
238 | # Set environment variables
239 | for key, value in data.items():
240 | os.environ[key] = str(value)
241 |
242 |
243 | def _get_answer_question_module_instance(
244 | lm_config: LMConfigs,
245 | runner_argument: "RunnerArgument",
246 | logging_wrapper: LoggingWrapper,
247 | rm: Optional[dspy.Retrieve] = None,
248 | ):
249 | from .grounded_question_answering import AnswerQuestionModule
250 |
251 | # configure retriever
252 | if rm is None:
253 | rm = BingSearch(k=runner_argument.retrieve_top_k)
254 | retriever = Retriever(rm=rm, max_thread=runner_argument.max_search_thread)
255 | # return AnswerQuestionModule instance
256 | return AnswerQuestionModule(
257 | retriever=retriever,
258 | max_search_queries=runner_argument.max_search_queries,
259 | question_answering_lm=lm_config.question_answering_lm,
260 | logging_wrapper=logging_wrapper,
261 | )
262 |
--------------------------------------------------------------------------------
/knowledge_storm/collaborative_storm/modules/costorm_expert_utterance_generator.py:
--------------------------------------------------------------------------------
1 | import dspy
2 | from typing import Union
3 |
4 | from .callback import BaseCallbackHandler
5 | from .collaborative_storm_utils import (
6 | trim_output_after_hint,
7 | extract_and_remove_citations,
8 | keep_first_and_last_paragraph,
9 | )
10 |
11 | from .grounded_question_answering import AnswerQuestionModule
12 | from .grounded_question_generation import ConvertUtteranceStyle
13 | from ...dataclass import ConversationTurn
14 | from ...logging_wrapper import LoggingWrapper
15 |
16 |
17 | class GenExpertActionPlanning(dspy.Signature):
18 | """
19 | You are an invited speaker in the round table conversation. Your task is to make a very short note to your assistant to help you prepare for your turn in the conversation.
20 | You will be given the topic we are discussing, your expertise, and the conversation history.
21 | Take a look at conversation history, especially last few turns, then let your assistant prepare the material for you with one of following ways.
22 | 1. Original Question: Initiates a new question to other speakers.
23 | 2. Further Details: Provides additional information.
24 | 3. Information Request: Requests information from other speakers.
25 | 4. Potential Answer: Offers a possible solution or answer.
26 |
27 | Strictly follow this format: [type of contribution]: [one sentence description]. For example, Original Question: [description]
28 | """
29 |
30 | topic = dspy.InputField(prefix="topic of discussion: ", format=str)
31 | expert = dspy.InputField(prefix="You are inivited as: ", format=str)
32 | summary = dspy.InputField(prefix="Discussion history: \n", format=str)
33 | last_utterance = dspy.InputField(
34 | prefix="Last utterance in the conversation: \n", format=str
35 | )
36 | resposne = dspy.OutputField(
37 | prefix="Now give your note. Start with one of [Original Question, Further Details, Information Request, Potential Answer] with one sentence description\n",
38 | format=str,
39 | )
40 |
41 |
42 | class CoStormExpertUtteranceGenerationModule(dspy.Module):
43 | def __init__(
44 | self,
45 | action_planning_lm: Union[dspy.dsp.LM, dspy.dsp.HFModel],
46 | utterance_polishing_lm: Union[dspy.dsp.LM, dspy.dsp.HFModel],
47 | answer_question_module: AnswerQuestionModule,
48 | logging_wrapper: LoggingWrapper,
49 | callback_handler: BaseCallbackHandler = None,
50 | ):
51 | self.action_planning_lm = action_planning_lm
52 | self.utterance_polishing_lm = utterance_polishing_lm
53 | self.expert_action = dspy.Predict(GenExpertActionPlanning)
54 | self.change_style = dspy.Predict(ConvertUtteranceStyle)
55 | self.answer_question_module = answer_question_module
56 | self.logging_wrapper = logging_wrapper
57 | self.callback_handler = callback_handler
58 |
59 | def parse_action(self, action):
60 | action_types = [
61 | "Original Question",
62 | "Further Details",
63 | "Information Request",
64 | "Potential Answer",
65 | ]
66 | for action_type in action_types:
67 | if f"{action_type}:" in action:
68 | return action_type, trim_output_after_hint(action, f"{action_type}:")
69 | elif f"[{action_type}]:" in action:
70 | return action_type, trim_output_after_hint(action, f"[{action_type}]:")
71 | return "Undefined", ""
72 |
73 | def polish_utterance(
74 | self, conversation_turn: ConversationTurn, last_conv_turn: ConversationTurn
75 | ):
76 | # change utterance style
77 | action_type = conversation_turn.utterance_type
78 | with self.logging_wrapper.log_event(
79 | "RoundTableConversationModule.ConvertUtteranceStyle"
80 | ):
81 | with dspy.settings.context(
82 | lm=self.utterance_polishing_lm, show_guidelines=False
83 | ):
84 | action_string = (
85 | f"{action_type} about: {conversation_turn.claim_to_make}"
86 | )
87 | if action_type in ["Original Question", "Information Request"]:
88 | action_string = f"{action_type}"
89 | last_expert_utterance_wo_citation, _ = extract_and_remove_citations(
90 | last_conv_turn.utterance
91 | )
92 | trimmed_last_expert_utterance = keep_first_and_last_paragraph(
93 | last_expert_utterance_wo_citation
94 | )
95 | utterance = self.change_style(
96 | expert=conversation_turn.role,
97 | action=action_string,
98 | prev=trimmed_last_expert_utterance,
99 | content=conversation_turn.raw_utterance,
100 | ).utterance
101 | conversation_turn.utterance = utterance
102 |
103 | def forward(
104 | self,
105 | topic: str,
106 | current_expert: str,
107 | conversation_summary: str,
108 | last_conv_turn: ConversationTurn,
109 | ):
110 | last_utterance, _ = extract_and_remove_citations(last_conv_turn.utterance)
111 | if last_conv_turn.utterance_type in [
112 | "Original Question",
113 | "Information Request",
114 | ]:
115 | action_type = "Potential Answer"
116 | action_content = last_utterance
117 | else:
118 | with self.logging_wrapper.log_event(
119 | "CoStormExpertUtteranceGenerationModule: GenExpertActionPlanning"
120 | ):
121 | with dspy.settings.context(
122 | lm=self.action_planning_lm, show_guidelines=False
123 | ):
124 | action = self.expert_action(
125 | topic=topic,
126 | expert=current_expert,
127 | summary=conversation_summary,
128 | last_utterance=last_utterance,
129 | ).resposne
130 | action_type, action_content = self.parse_action(action)
131 |
132 | if self.callback_handler is not None:
133 | self.callback_handler.on_expert_action_planning_end()
134 | # get response
135 | conversation_turn = ConversationTurn(
136 | role=current_expert, raw_utterance="", utterance_type=action_type
137 | )
138 |
139 | if action_type == "Undefined":
140 | raise Exception(f"unexpected output: {action}")
141 | elif action_type in ["Further Details", "Potential Answer"]:
142 | with self.logging_wrapper.log_event(
143 | "RoundTableConversationModule: QuestionAnswering"
144 | ):
145 | grounded_answer = self.answer_question_module(
146 | topic=topic,
147 | question=action_content,
148 | mode="brief",
149 | style="conversational and concise",
150 | callback_handler=self.callback_handler,
151 | )
152 | conversation_turn.claim_to_make = action_content
153 | conversation_turn.raw_utterance = grounded_answer.response
154 | conversation_turn.queries = grounded_answer.queries
155 | conversation_turn.raw_retrieved_info = grounded_answer.raw_retrieved_info
156 | conversation_turn.cited_info = grounded_answer.cited_info
157 | elif action_type in ["Original Question", "Information Request"]:
158 | conversation_turn.raw_utterance = action_content
159 |
160 | return dspy.Prediction(conversation_turn=conversation_turn)
161 |
--------------------------------------------------------------------------------
/knowledge_storm/collaborative_storm/modules/expert_generation.py:
--------------------------------------------------------------------------------
1 | import dspy
2 | import re
3 | from typing import Union
4 |
5 |
6 | class GenerateExpertGeneral(dspy.Signature):
7 | """You need to select a group of diverse experts who will be suitable to be invited to a roundtable discussion on the given topic.
8 | Each expert should represent a different perspective, role, or affiliation related to this topic.
9 | You can use the background information provided about the topic for inspiration. For each expert, add a description of their expertise and what they will focus on during the discussion.
10 | No need to include speakers name in the output.
11 | Strictly follow format below:
12 | 1. [speaker 1 role]: [speaker 1 short description]
13 | 2. [speaker 2 role]: [speaker 2 short description]
14 | """
15 |
16 | topic = dspy.InputField(prefix="Topic of interest:", format=str)
17 | background_info = dspy.InputField(
18 | prefix="Background information about the topic:\n", format=str
19 | )
20 | topN = dspy.InputField(prefix="Number of speakers needed: ", format=str)
21 | experts = dspy.OutputField(format=str)
22 |
23 |
24 | class GenerateExpertWithFocus(dspy.Signature):
25 | """
26 | You need to select a group of speakers who will be suitable to have roundtable discussion on the [topic] of specific [focus].
27 | You may consider inviting speakers having opposite stands on the topic; speakers representing different interest parties; Ensure that the selected speakers are directly connected to the specific context and scenario provided.
28 | For example, if the discussion focus is about a recent event at a specific university, consider inviting students, faculty members, journalists covering the event, university officials, and local community members.
29 | Use the background information provided about the topic for inspiration. For each speaker, add a description of their interests and what they will focus on during the discussion.
30 | No need to include speakers name in the output.
31 | Strictly follow format below:
32 | 1. [speaker 1 role]: [speaker 1 short description]
33 | 2. [speaker 2 role]: [speaker 2 short description]
34 | """
35 |
36 | topic = dspy.InputField(prefix="Topic of interest:", format=str)
37 | background_info = dspy.InputField(prefix="Background information:\n", format=str)
38 | focus = dspy.InputField(prefix="Discussion focus: ", format=str)
39 | topN = dspy.InputField(prefix="Number of speakers needed: ", format=str)
40 | experts = dspy.OutputField(format=str)
41 |
42 |
43 | class GenerateExpertModule(dspy.Module):
44 | def __init__(self, engine: Union[dspy.dsp.LM, dspy.dsp.HFModel]):
45 | self.engine = engine
46 | self.generate_expert_general = dspy.Predict(GenerateExpertGeneral)
47 | self.generate_expert_w_focus = dspy.ChainOfThought(GenerateExpertWithFocus)
48 |
49 | def trim_background(self, background: str, max_words: int = 100):
50 | words = background.split()
51 | cur_len = len(words)
52 | if cur_len <= max_words:
53 | return background
54 | trimmed_words = words[: min(cur_len, max_words)]
55 | trimmed_background = " ".join(trimmed_words)
56 | return f"{trimmed_background} [rest content omitted]."
57 |
58 | def forward(
59 | self, topic: str, num_experts: int, background_info: str = "", focus: str = ""
60 | ):
61 | with dspy.settings.context(lm=self.engine, show_guidelines=False):
62 | if not focus:
63 | output = self.generate_expert_general(
64 | topic=topic, background_info=background_info, topN=num_experts
65 | ).experts
66 | else:
67 | background_info = self.trim_background(
68 | background=background_info, max_words=100
69 | )
70 | output = self.generate_expert_w_focus(
71 | topic=topic,
72 | background_info=background_info,
73 | focus=focus,
74 | topN=num_experts,
75 | ).experts
76 | output = output.replace("*", "").replace("[", "").replace("]", "")
77 | expert_list = []
78 | for s in output.split("\n"):
79 | match = re.search(r"\d+\.\s*(.*)", s)
80 | if match:
81 | expert_list.append(match.group(1))
82 | expert_list = [expert.strip() for expert in expert_list if expert.strip()]
83 | return dspy.Prediction(experts=expert_list, raw_output=output)
84 |
--------------------------------------------------------------------------------
/knowledge_storm/collaborative_storm/modules/grounded_question_answering.py:
--------------------------------------------------------------------------------
1 | import dspy
2 | from typing import Union, List
3 |
4 | from .callback import BaseCallbackHandler
5 | from .collaborative_storm_utils import (
6 | trim_output_after_hint,
7 | format_search_results,
8 | extract_cited_storm_info,
9 | separate_citations,
10 | )
11 | from ...logging_wrapper import LoggingWrapper
12 | from ...utils import ArticleTextProcessing
13 | from ...interface import Information
14 |
15 |
16 | class QuestionToQuery(dspy.Signature):
17 | """You want to answer the question or support a claim using Google search. What do you type in the search box?
18 | The question is raised in a round table discussion on a topic. The question may or may not focus on the topic itself.
19 | Write the queries you will use in the following format:
20 | - query 1
21 | - query 2
22 | ...
23 | - query n"""
24 |
25 | topic = dspy.InputField(prefix="Topic context:", format=str)
26 | question = dspy.InputField(
27 | prefix="I want to collect information about: ", format=str
28 | )
29 | queries = dspy.OutputField(prefix="Queries: \n", format=str)
30 |
31 |
32 | class AnswerQuestion(dspy.Signature):
33 | """You are an expert who can use information effectively. You have gathered the related information and will now use the information to form a response.
34 | Make your response as informative as possible and make sure every sentence is supported by the gathered information.
35 | If [Gathered information] is not directly related to the [Topic] and [Question], provide the most relevant answer you can based on the available information, and explain any limitations or gaps.
36 | Use [1], [2], ..., [n] in line (for example, "The capital of the United States is Washington, D.C.[1][3].").
37 | You DO NOT need to include a References or Sources section to list the sources at the end. The style of writing should be formal.
38 | """
39 |
40 | topic = dspy.InputField(prefix="Topic you are discussing about:", format=str)
41 | question = dspy.InputField(prefix="You want to provide insight on: ", format=str)
42 | info = dspy.InputField(prefix="Gathered information:\n", format=str)
43 | style = dspy.InputField(prefix="Style of your response should be:", format=str)
44 | answer = dspy.OutputField(
45 | prefix="Now give your response. (Try to use as many different sources as possible and do not hallucinate.)",
46 | format=str,
47 | )
48 |
49 |
50 | class AnswerQuestionModule(dspy.Module):
51 | def __init__(
52 | self,
53 | retriever: dspy.Retrieve,
54 | max_search_queries: int,
55 | question_answering_lm: Union[dspy.dsp.LM, dspy.dsp.HFModel],
56 | logging_wrapper: LoggingWrapper,
57 | ):
58 | super().__init__()
59 | self.question_answering_lm = question_answering_lm
60 | self.question_to_query = dspy.Predict(QuestionToQuery)
61 | self.answer_question = dspy.Predict(AnswerQuestion)
62 | self.retriever = retriever
63 | self.max_search_queries = max_search_queries
64 | self.logging_wrapper = logging_wrapper
65 |
66 | def retrieve_information(self, topic, question):
67 | # decompose question to queries
68 | with self.logging_wrapper.log_event(
69 | f"AnswerQuestionModule.question_to_query ({hash(question)})"
70 | ):
71 | with dspy.settings.context(lm=self.question_answering_lm):
72 | queries = self.question_to_query(topic=topic, question=question).queries
73 | queries = trim_output_after_hint(queries, hint="Queries:")
74 | queries = [
75 | q.replace("-", "").strip().strip('"').strip('"').strip()
76 | for q in queries.split("\n")
77 | ]
78 | queries = queries[: self.max_search_queries]
79 | self.logging_wrapper.add_query_count(count=len(queries))
80 | with self.logging_wrapper.log_event(
81 | f"AnswerQuestionModule.retriever.retrieve ({hash(question)})"
82 | ):
83 | # retrieve information using retriever
84 | searched_results: List[Information] = self.retriever.retrieve(
85 | list(set(queries)), exclude_urls=[]
86 | )
87 | # update storm information meta to include the question
88 | for storm_info in searched_results:
89 | storm_info.meta["question"] = question
90 | return queries, searched_results
91 |
92 | def forward(
93 | self,
94 | topic: str,
95 | question: str,
96 | mode: str = "brief",
97 | style: str = "conversational",
98 | callback_handler: BaseCallbackHandler = None,
99 | ):
100 | """
101 | Processes a topic and question to generate a response with relevant information and citations.
102 |
103 | Args:
104 | topic (str): The topic of interest.
105 | question (str): The specific question related to the topic.
106 | mode (str, optional): Mode of summarization. 'brief' takes only the first snippet of each Information.
107 | 'extensive' adds snippets iteratively until the word limit is reached. Defaults to 'brief'.
108 |
109 | Returns:
110 | dspy.Prediction: An object containing the following:
111 | - question (str): the question to answer
112 | - queries (List[str]): List of query strings used for information retrieval.
113 | - raw_retrieved_info (List[Information]): List of Information instances retrieved.
114 | - cited_info (Dict[int, Information]): Dictionary of cited Information instances, indexed by their citation number.
115 | - response (str): The generated response string with inline citations.
116 | """
117 | # retrieve information
118 | if callback_handler is not None:
119 | callback_handler.on_expert_information_collection_start()
120 | queries, searched_results = self.retrieve_information(
121 | topic=topic, question=question
122 | )
123 | if callback_handler is not None:
124 | callback_handler.on_expert_information_collection_end(searched_results)
125 | # format information string for answer generation
126 | info_text, index_to_information_mapping = format_search_results(
127 | searched_results, mode=mode
128 | )
129 | answer = "Sorry, there is insufficient information to answer the question."
130 | # generate answer to the question
131 | if info_text:
132 | with self.logging_wrapper.log_event(
133 | f"AnswerQuestionModule.answer_question ({hash(question)})"
134 | ):
135 | with dspy.settings.context(
136 | lm=self.question_answering_lm, show_guidelines=False
137 | ):
138 | answer = self.answer_question(
139 | topic=topic, question=question, info=info_text, style=style
140 | ).answer
141 | answer = ArticleTextProcessing.remove_uncompleted_sentences_with_citations(
142 | answer
143 | )
144 | answer = trim_output_after_hint(
145 | answer,
146 | hint="Now give your response. (Try to use as many different sources as possible and do not hallucinate.)",
147 | )
148 | # enforce single citation index bracket. [1, 2] -> [1][2]
149 | answer = separate_citations(answer)
150 | if callback_handler is not None:
151 | callback_handler.on_expert_utterance_generation_end()
152 | # construct cited search result
153 | cited_searched_results = extract_cited_storm_info(
154 | response=answer, index_to_storm_info=index_to_information_mapping
155 | )
156 |
157 | return dspy.Prediction(
158 | question=question,
159 | queries=queries,
160 | raw_retrieved_info=searched_results,
161 | cited_info=cited_searched_results,
162 | response=answer,
163 | )
164 |
--------------------------------------------------------------------------------
/knowledge_storm/collaborative_storm/modules/grounded_question_generation.py:
--------------------------------------------------------------------------------
1 | """
2 | This module handles question generation within the Co-STORM framework, specifically designed to support the Moderator role.
3 |
4 | The Moderator generates insightful, thought-provoking questions that introduce new directions into the conversation.
5 | By leveraging uncited or unused snippets of information retrieved during the discussion, the Moderator ensures the conversation remains dynamic and avoids repetitive or overly niche topics.
6 |
7 | For more detailed information, refer to Section 3.5 of the Co-STORM paper: https://www.arxiv.org/pdf/2408.15232.
8 | """
9 |
10 | import dspy
11 | from typing import List, Union
12 |
13 | from .collaborative_storm_utils import (
14 | format_search_results,
15 | extract_and_remove_citations,
16 | keep_first_and_last_paragraph,
17 | extract_cited_storm_info,
18 | )
19 | from ...dataclass import ConversationTurn, KnowledgeBase
20 | from ...interface import Information
21 |
22 |
23 | class KnowledgeBaseSummmary(dspy.Signature):
24 | """Your job is to give brief summary of what's been discussed in a roundtable conversation. Contents are themantically organized into hierarchical sections.
25 | You will be presented with these sections where "#" denotes level of section.
26 | """
27 |
28 | topic = dspy.InputField(prefix="topic: ", format=str)
29 | structure = dspy.InputField(prefix="Tree structure: \n", format=str)
30 | output = dspy.OutputField(prefix="Now give brief summary:\n", format=str)
31 |
32 |
33 | class ConvertUtteranceStyle(dspy.Signature):
34 | """
35 | You are an invited speaker in the round table conversation.
36 | Your task is to make the question or the response more conversational and engaging to facilicate the flow of conversation.
37 | Note that this is ongoing conversation so no need to have welcoming and concluding words. Previous speaker utterance is provided only for making the conversation more natural.
38 | Note that do not hallucinate and keep the citation index like [1] as it is. Also,
39 | """
40 |
41 | expert = dspy.InputField(prefix="You are inivited as: ", format=str)
42 | action = dspy.InputField(
43 | prefix="You want to contribute to conversation by: ", format=str
44 | )
45 | prev = dspy.InputField(prefix="Previous speaker said: ", format=str)
46 | content = dspy.InputField(
47 | prefix="Question or response you want to say: ", format=str
48 | )
49 | utterance = dspy.OutputField(
50 | prefix="Your utterance (keep the information as much as you can with citations, prefer shorter answers without loss of information): ",
51 | format=str,
52 | )
53 |
54 |
55 | class GroundedQuestionGeneration(dspy.Signature):
56 | """Your job is to find next discussion focus in a roundtable conversation. You will be given previous conversation summary and some information that might assist you discover new discussion focus.
57 | Note that the new discussion focus should bring new angle and perspective to the discussion and avoid repetition. The new discussion focus should be grounded on the available information and push the boundaries of the current discussion for broader exploration.
58 | The new discussion focus should have natural flow from last utterance in the conversation.
59 | Use [1][2] in line to ground your question.
60 | """
61 |
62 | topic = dspy.InputField(prefix="topic: ", format=str)
63 | summary = dspy.InputField(prefix="Discussion history: \n", format=str)
64 | information = dspy.InputField(prefix="Available information: \n", format=str)
65 | last_utterance = dspy.InputField(
66 | prefix="Last utterance in the conversation: \n", format=str
67 | )
68 | output = dspy.OutputField(
69 | prefix="Now give next discussion focus in the format of one sentence question:\n",
70 | format=str,
71 | )
72 |
73 |
74 | class GroundedQuestionGenerationModule(dspy.Module):
75 | def __init__(self, engine: Union[dspy.dsp.LM, dspy.dsp.HFModel]):
76 | self.engine = engine
77 | self.gen_focus = dspy.Predict(GroundedQuestionGeneration)
78 | self.polish_style = dspy.Predict(ConvertUtteranceStyle)
79 | self.gen_summary = dspy.Predict(KnowledgeBaseSummmary)
80 |
81 | def forward(
82 | self,
83 | topic: str,
84 | knowledge_base: KnowledgeBase,
85 | last_conv_turn: ConversationTurn,
86 | unused_snippets: List[Information],
87 | ):
88 | information, index_to_information_mapping = format_search_results(
89 | unused_snippets, info_max_num_words=1000
90 | )
91 | summary = knowledge_base.get_knowledge_base_summary()
92 | last_utterance, _ = extract_and_remove_citations(last_conv_turn.utterance)
93 | with dspy.settings.context(lm=self.engine, show_guidelines=False):
94 | raw_utterance = self.gen_focus(
95 | topic=topic,
96 | summary=summary,
97 | information=information,
98 | last_utterance=last_utterance,
99 | ).output
100 | utterance = self.polish_style(
101 | expert="Roundtable conversation moderator",
102 | action="Raising a new question by natural transit from previous utterance.",
103 | prev=keep_first_and_last_paragraph(last_utterance),
104 | content=raw_utterance,
105 | ).utterance
106 | cited_searched_results = extract_cited_storm_info(
107 | response=utterance, index_to_storm_info=index_to_information_mapping
108 | )
109 | return dspy.Prediction(
110 | raw_utterance=raw_utterance,
111 | utterance=utterance,
112 | cited_info=cited_searched_results,
113 | )
114 |
--------------------------------------------------------------------------------
/knowledge_storm/collaborative_storm/modules/knowledge_base_summary.py:
--------------------------------------------------------------------------------
1 | import dspy
2 | from typing import Union
3 | from ...dataclass import KnowledgeBase
4 |
5 |
6 | class KnowledgeBaseSummmary(dspy.Signature):
7 | """Your job is to give brief summary of what's been discussed in a roundtable conversation. Contents are themantically organized into hierarchical sections.
8 | You will be presented with these sections where "#" denotes level of section.
9 | """
10 |
11 | topic = dspy.InputField(prefix="topic: ", format=str)
12 | structure = dspy.InputField(prefix="Tree structure: \n", format=str)
13 | output = dspy.OutputField(prefix="Now give brief summary:\n", format=str)
14 |
15 |
16 | class KnowledgeBaseSummaryModule(dspy.Module):
17 | def __init__(self, engine: Union[dspy.dsp.LM, dspy.dsp.HFModel]):
18 | self.engine = engine
19 | self.gen_summary = dspy.Predict(KnowledgeBaseSummmary)
20 |
21 | def forward(self, knowledge_base: KnowledgeBase):
22 | structure = knowledge_base.get_node_hierarchy_string(
23 | include_indent=False,
24 | include_full_path=False,
25 | include_hash_tag=True,
26 | include_node_content_count=False,
27 | )
28 | with dspy.settings.context(lm=self.engine, show_guidelines=False):
29 | summary = self.gen_summary(
30 | topic=knowledge_base.topic, structure=structure
31 | ).output
32 | return summary
33 |
--------------------------------------------------------------------------------
/knowledge_storm/collaborative_storm/modules/simulate_user.py:
--------------------------------------------------------------------------------
1 | import dspy
2 | from typing import List, Union
3 |
4 | from .collaborative_storm_utils import extract_and_remove_citations
5 | from ...dataclass import ConversationTurn
6 | from ...storm_wiki.modules.knowledge_curation import AskQuestionWithPersona
7 |
8 |
9 | class GenSimulatedUserUtterance(dspy.Module):
10 | def __init__(self, engine: Union[dspy.dsp.LM, dspy.dsp.HFModel]):
11 | self.engine = engine
12 | self.ask_qeustion = dspy.Predict(AskQuestionWithPersona)
13 |
14 | def gen_conv_history_string(self, conversation_turns: List[ConversationTurn]):
15 | conv_history = []
16 | total_turns = len(conversation_turns)
17 |
18 | for i, turn in enumerate(conversation_turns):
19 | utterance, _ = extract_and_remove_citations(turn.utterance)
20 | if i >= total_turns - 4:
21 | conv_history.append(f"{turn.role}: {utterance}")
22 | else:
23 | if turn.claim_to_make:
24 | conv_history.append(f"{turn.role}: {turn.claim_to_make}")
25 | else:
26 | conv_history.append(f"{turn.role}: {utterance}")
27 |
28 | return "\n".join(conv_history)
29 |
30 | def forward(self, topic: str, intent: str, conv_history: List[ConversationTurn]):
31 | conv_history_string = self.gen_conv_history_string(conv_history)
32 | with dspy.settings.context(lm=self.engine, show_guidelines=False):
33 | return self.ask_qeustion(
34 | topic=topic,
35 | persona=f"researcher with interest in {intent}",
36 | conv=conv_history_string,
37 | ).question
38 |
--------------------------------------------------------------------------------
/knowledge_storm/encoder.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 |
4 | from concurrent.futures import ThreadPoolExecutor, as_completed
5 | from typing import List, Tuple, Union, Optional, Dict, Literal
6 | from pathlib import Path
7 |
8 | try:
9 | import warnings
10 |
11 | with warnings.catch_warnings():
12 | warnings.filterwarnings("ignore", category=UserWarning)
13 | if "LITELLM_LOCAL_MODEL_COST_MAP" not in os.environ:
14 | os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
15 | import litellm
16 |
17 | litellm.drop_params = True
18 | litellm.telemetry = False
19 |
20 | from litellm.caching.caching import Cache
21 |
22 | disk_cache_dir = os.path.join(Path.home(), ".storm_local_cache")
23 | litellm.cache = Cache(disk_cache_dir=disk_cache_dir, type="disk")
24 |
25 | except ImportError:
26 |
27 | class LitellmPlaceholder:
28 | def __getattr__(self, _):
29 | raise ImportError(
30 | "The LiteLLM package is not installed. Run `pip install litellm`."
31 | )
32 |
33 | litellm = LitellmPlaceholder()
34 |
35 |
36 | class Encoder:
37 | """
38 | A wrapper class for the LiteLLM embedding model, designed to handle embedding
39 | generation tasks efficiently. It supports parallel processing and local caching of
40 | embedding results for improved performance.
41 |
42 | The Encoder utilizes the LiteLLM library to interact with various embedding models,
43 | such as OpenAI and Azure embeddings. Users can specify the desired encoder type and
44 | provide relevant API credentials during initialization.
45 |
46 | Features:
47 | - Support for multiple embedding models (e.g., OpenAI, Azure).
48 | - Parallel processing for faster embedding generation.
49 | - Local disk caching to store and reuse embedding results.
50 | - Total token usage tracking for cost monitoring.
51 |
52 | Note:
53 | Refer to the LiteLLM documentation for details on supported embedding models:
54 | https://docs.litellm.ai/docs/embedding/supported_embedding
55 | """
56 |
57 | def __init__(
58 | self,
59 | encoder_type: Optional[str] = None,
60 | api_key: Optional[str] = None,
61 | api_base: Optional[str] = None,
62 | api_version: Optional[str] = None,
63 | ):
64 | """
65 | Initializes the Encoder with the appropriate embedding model.
66 |
67 | Args:
68 | encoder_type (Optional[str]): Type of encoder ('openai', 'azure', etc.).
69 | api_key (Optional[str]): API key for the encoder service.
70 | api_base (Optional[str]): API base URL for the encoder service.
71 | api_version (Optional[str]): API version for the encoder service.
72 | """
73 | self.embedding_model_name = None
74 | self.kargs = {}
75 | self.total_token_usage = 0
76 |
77 | # Initialize the appropriate embedding model
78 | encoder_type = encoder_type or os.getenv("ENCODER_API_TYPE")
79 | if not encoder_type:
80 | raise ValueError("ENCODER_API_TYPE environment variable is not set.")
81 |
82 | if encoder_type.lower() == "openai":
83 | self.embedding_model_name = "text-embedding-3-small"
84 | self.kargs = {"api_key": api_key or os.getenv("OPENAI_API_KEY")}
85 | elif encoder_type.lower() == "azure":
86 | self.embedding_model_name = "azure/text-embedding-3-small"
87 | self.kargs = {
88 | "api_key": api_key or os.getenv("AZURE_API_KEY"),
89 | "api_base": api_base or os.getenv("AZURE_API_BASE"),
90 | "api_version": api_version or os.getenv("AZURE_API_VERSION"),
91 | }
92 | else:
93 | raise ValueError(
94 | f"Unsupported ENCODER_API_TYPE '{encoder_type}'. Supported types are 'openai', 'azure', 'together'."
95 | )
96 |
97 | def get_total_token_usage(self, reset: bool = False) -> int:
98 | """
99 | Retrieves the total token usage.
100 |
101 | Args:
102 | reset (bool): If True, resets the total token usage counter after retrieval.
103 |
104 | Returns:
105 | int: The total number of tokens used.
106 | """
107 | token_usage = self.total_token_usage
108 | if reset:
109 | self.total_token_usage = 0
110 | return token_usage
111 |
112 | def encode(self, texts: Union[str, List[str]], max_workers: int = 5) -> np.ndarray:
113 | """
114 | Public method to get embeddings for the given texts.
115 |
116 | Args:
117 | texts (Union[str, List[str]]): A single text string or a list of text strings to embed.
118 |
119 | Returns:
120 | np.ndarray: The array of embeddings.
121 | """
122 | return self._get_text_embeddings(texts, max_workers=max_workers)
123 |
124 | def _get_single_text_embedding(self, text):
125 | response = litellm.embedding(
126 | model=self.embedding_model_name, input=text, caching=True, **self.kargs
127 | )
128 | embedding = response.data[0]["embedding"]
129 | token_usage = response.get("usage", {}).get("total_tokens", 0)
130 | return text, embedding, token_usage
131 |
132 | def _get_text_embeddings(
133 | self,
134 | texts: Union[str, List[str]],
135 | max_workers: int = 5,
136 | ) -> Tuple[np.ndarray, int]:
137 | """
138 | Get text embeddings using OpenAI's text-embedding-3-small model.
139 |
140 | Args:
141 | texts (Union[str, List[str]]): A single text string or a list of text strings to embed.
142 | max_workers (int): The maximum number of workers for parallel processing.
143 | api_key (str): The API key for accessing OpenAI's services.
144 | embedding_cache (Optional[Dict[str, np.ndarray]]): A cache to store previously computed embeddings.
145 |
146 | Returns:
147 | Tuple[np.ndarray, int]: The 2D array of embeddings and the total token usage.
148 | """
149 |
150 | if isinstance(texts, str):
151 | _, embedding, tokens = self._get_single_text_embedding(texts)
152 | self.total_token_usage += tokens
153 | return np.array(embedding)
154 |
155 | embeddings = []
156 | total_tokens = 0
157 |
158 | with ThreadPoolExecutor(max_workers=max_workers) as executor:
159 | futures = {
160 | executor.submit(self._get_single_text_embedding, text): text
161 | for text in texts
162 | }
163 |
164 | for future in as_completed(futures):
165 | try:
166 | text, embedding, tokens = future.result()
167 | embeddings.append((text, embedding, tokens))
168 | total_tokens += tokens
169 | except Exception as e:
170 | print(f"An error occurred for text: {futures[future]}")
171 | print(e)
172 |
173 | # Sort results to match the order of the input texts
174 | embeddings.sort(key=lambda x: texts.index(x[0]))
175 | embeddings = [result[1] for result in embeddings]
176 | self.total_token_usage += total_tokens
177 |
178 | return np.array(embeddings)
179 |
--------------------------------------------------------------------------------
/knowledge_storm/logging_wrapper.py:
--------------------------------------------------------------------------------
1 | from contextlib import contextmanager
2 | import time
3 | import pytz
4 | from datetime import datetime
5 |
6 | # Define California timezone
7 | CALIFORNIA_TZ = pytz.timezone("America/Los_Angeles")
8 |
9 |
10 | class EventLog:
11 | def __init__(self, event_name):
12 | self.event_name = event_name
13 | self.start_time = None
14 | self.end_time = None
15 | self.child_events = {}
16 |
17 | def record_start_time(self):
18 | self.start_time = datetime.now(
19 | pytz.utc
20 | ) # Store in UTC for consistent timezone conversion
21 |
22 | def record_end_time(self):
23 | self.end_time = datetime.now(
24 | pytz.utc
25 | ) # Store in UTC for consistent timezone conversion
26 |
27 | def get_total_time(self):
28 | if self.start_time and self.end_time:
29 | return (self.end_time - self.start_time).total_seconds()
30 | return 0
31 |
32 | def get_start_time(self):
33 | if self.start_time:
34 | # Format to milliseconds
35 | return self.start_time.astimezone(CALIFORNIA_TZ).strftime(
36 | "%Y-%m-%d %H:%M:%S.%f"
37 | )[:-3]
38 | return None
39 |
40 | def get_end_time(self):
41 | if self.end_time:
42 | # Format to milliseconds
43 | return self.end_time.astimezone(CALIFORNIA_TZ).strftime(
44 | "%Y-%m-%d %H:%M:%S.%f"
45 | )[:-3]
46 | return None
47 |
48 | def add_child_event(self, child_event):
49 | self.child_events[child_event.event_name] = child_event
50 |
51 | def get_child_events(self):
52 | return self.child_events
53 |
54 |
55 | class LoggingWrapper:
56 | def __init__(self, lm_config):
57 | self.logging_dict = {}
58 | self.lm_config = lm_config
59 | self.current_pipeline_stage = None
60 | self.event_stack = []
61 | self.pipeline_stage_active = False
62 |
63 | def _pipeline_stage_start(self, pipeline_stage: str):
64 | if self.pipeline_stage_active:
65 | raise RuntimeError(
66 | "A pipeline stage is already active. End the current stage before starting a new one."
67 | )
68 |
69 | self.current_pipeline_stage = pipeline_stage
70 | self.logging_dict[pipeline_stage] = {
71 | "time_usage": {},
72 | "lm_usage": {},
73 | "lm_history": [],
74 | "query_count": 0,
75 | }
76 | self.pipeline_stage_active = True
77 |
78 | def _event_start(self, event_name: str):
79 | if not self.pipeline_stage_active:
80 | raise RuntimeError("No pipeline stage is currently active.")
81 |
82 | if not self.event_stack and self.current_pipeline_stage:
83 | # Top-level event (directly under the pipeline stage)
84 | if (
85 | event_name
86 | not in self.logging_dict[self.current_pipeline_stage]["time_usage"]
87 | ):
88 | event = EventLog(event_name=event_name)
89 | event.record_start_time()
90 | self.logging_dict[self.current_pipeline_stage]["time_usage"][
91 | event_name
92 | ] = event
93 | self.event_stack.append(event)
94 | else:
95 | self.logging_dict[self.current_pipeline_stage]["time_usage"][
96 | event_name
97 | ].record_start_time()
98 | elif self.event_stack:
99 | # Nested event (under another event)
100 | parent_event = self.event_stack[-1]
101 | if event_name not in parent_event.get_child_events():
102 | event = EventLog(event_name=event_name)
103 | event.record_start_time()
104 | parent_event.add_child_event(event)
105 | self.logging_dict[self.current_pipeline_stage]["time_usage"][
106 | event_name
107 | ] = event
108 | self.event_stack.append(event)
109 | else:
110 | parent_event.get_child_events()[event_name].record_start_time()
111 | else:
112 | raise RuntimeError(
113 | "Cannot start an event without an active pipeline stage or parent event."
114 | )
115 |
116 | def _event_end(self, event_name: str):
117 | if not self.pipeline_stage_active:
118 | raise RuntimeError("No pipeline stage is currently active.")
119 |
120 | if not self.event_stack:
121 | raise RuntimeError("No parent event is currently active.")
122 |
123 | if self.event_stack:
124 | current_event_log = self.event_stack[-1]
125 | if event_name in current_event_log.get_child_events():
126 | current_event_log.get_child_events()[event_name].record_end_time()
127 | elif (
128 | event_name
129 | in self.logging_dict[self.current_pipeline_stage]["time_usage"]
130 | ):
131 | self.logging_dict[self.current_pipeline_stage]["time_usage"][
132 | event_name
133 | ].record_end_time()
134 | else:
135 | raise AssertionError(
136 | f"Failure to record end time for event {event_name}. Start time is not recorded."
137 | )
138 | if current_event_log.event_name == event_name:
139 | self.event_stack.pop()
140 | else:
141 | raise RuntimeError("Cannot end an event without an active parent event.")
142 |
143 | def _pipeline_stage_end(self):
144 | if not self.pipeline_stage_active:
145 | raise RuntimeError("No pipeline stage is currently active to end.")
146 |
147 | self.logging_dict[self.current_pipeline_stage][
148 | "lm_usage"
149 | ] = self.lm_config.collect_and_reset_lm_usage()
150 | self.logging_dict[self.current_pipeline_stage][
151 | "lm_history"
152 | ] = self.lm_config.collect_and_reset_lm_history()
153 | self.pipeline_stage_active = False
154 |
155 | def add_query_count(self, count):
156 | if not self.pipeline_stage_active:
157 | raise RuntimeError(
158 | "No pipeline stage is currently active to add query count."
159 | )
160 |
161 | self.logging_dict[self.current_pipeline_stage]["query_count"] += count
162 |
163 | @contextmanager
164 | def log_event(self, event_name):
165 | if not self.pipeline_stage_active:
166 | raise RuntimeError("No pipeline stage is currently active.")
167 |
168 | self._event_start(event_name)
169 | yield
170 | self._event_end(event_name)
171 |
172 | @contextmanager
173 | def log_pipeline_stage(self, pipeline_stage):
174 | if self.pipeline_stage_active:
175 | print(
176 | "A pipeline stage is already active, ending the current stage safely."
177 | )
178 | self._pipeline_stage_end()
179 |
180 | start_time = time.time()
181 | try:
182 | self._pipeline_stage_start(pipeline_stage)
183 | yield
184 | except Exception as e:
185 | print(f"Error occurred during pipeline stage '{pipeline_stage}': {e}")
186 | finally:
187 | self.logging_dict[self.current_pipeline_stage]["total_wall_time"] = (
188 | time.time() - start_time
189 | )
190 | self._pipeline_stage_end()
191 |
192 | def dump_logging_and_reset(self, reset_logging=True):
193 | log_dump = {}
194 | for pipeline_stage, pipeline_log in self.logging_dict.items():
195 | time_stamp_log = {
196 | event_name: {
197 | "total_time_seconds": event.get_total_time(),
198 | "start_time": event.get_start_time(),
199 | "end_time": event.get_end_time(),
200 | }
201 | for event_name, event in pipeline_log["time_usage"].items()
202 | }
203 | log_dump[pipeline_stage] = {
204 | "time_usage": time_stamp_log,
205 | "lm_usage": pipeline_log["lm_usage"],
206 | "lm_history": pipeline_log["lm_history"],
207 | "query_count": pipeline_log["query_count"],
208 | "total_wall_time": pipeline_log["total_wall_time"],
209 | }
210 | if reset_logging:
211 | self.logging_dict.clear()
212 | return log_dump
213 |
--------------------------------------------------------------------------------
/knowledge_storm/storm_wiki/__init__.py:
--------------------------------------------------------------------------------
1 | from .engine import *
2 | from .modules import *
3 |
--------------------------------------------------------------------------------
/knowledge_storm/storm_wiki/modules/__init__.py:
--------------------------------------------------------------------------------
1 | from .knowledge_curation import *
2 | from .persona_generator import *
3 | from .retriever import *
4 | from .storm_dataclass import *
5 |
--------------------------------------------------------------------------------
/knowledge_storm/storm_wiki/modules/article_generation.py:
--------------------------------------------------------------------------------
1 | import concurrent.futures
2 | import copy
3 | import logging
4 | from concurrent.futures import as_completed
5 | from typing import List, Union
6 |
7 | import dspy
8 |
9 | from .callback import BaseCallbackHandler
10 | from .storm_dataclass import StormInformationTable, StormArticle
11 | from ...interface import ArticleGenerationModule, Information
12 | from ...utils import ArticleTextProcessing
13 |
14 |
15 | class StormArticleGenerationModule(ArticleGenerationModule):
16 | """
17 | The interface for article generation stage. Given topic, collected information from
18 | knowledge curation stage, generated outline from outline generation stage,
19 | """
20 |
21 | def __init__(
22 | self,
23 | article_gen_lm=Union[dspy.dsp.LM, dspy.dsp.HFModel],
24 | retrieve_top_k: int = 5,
25 | max_thread_num: int = 10,
26 | ):
27 | super().__init__()
28 | self.retrieve_top_k = retrieve_top_k
29 | self.article_gen_lm = article_gen_lm
30 | self.max_thread_num = max_thread_num
31 | self.section_gen = ConvToSection(engine=self.article_gen_lm)
32 |
33 | def generate_section(
34 | self, topic, section_name, information_table, section_outline, section_query
35 | ):
36 | collected_info: List[Information] = []
37 | if information_table is not None:
38 | collected_info = information_table.retrieve_information(
39 | queries=section_query, search_top_k=self.retrieve_top_k
40 | )
41 | output = self.section_gen(
42 | topic=topic,
43 | outline=section_outline,
44 | section=section_name,
45 | collected_info=collected_info,
46 | )
47 | return {
48 | "section_name": section_name,
49 | "section_content": output.section,
50 | "collected_info": collected_info,
51 | }
52 |
53 | def generate_article(
54 | self,
55 | topic: str,
56 | information_table: StormInformationTable,
57 | article_with_outline: StormArticle,
58 | callback_handler: BaseCallbackHandler = None,
59 | ) -> StormArticle:
60 | """
61 | Generate article for the topic based on the information table and article outline.
62 |
63 | Args:
64 | topic (str): The topic of the article.
65 | information_table (StormInformationTable): The information table containing the collected information.
66 | article_with_outline (StormArticle): The article with specified outline.
67 | callback_handler (BaseCallbackHandler): An optional callback handler that can be used to trigger
68 | custom callbacks at various stages of the article generation process. Defaults to None.
69 | """
70 | information_table.prepare_table_for_retrieval()
71 |
72 | if article_with_outline is None:
73 | article_with_outline = StormArticle(topic_name=topic)
74 |
75 | sections_to_write = article_with_outline.get_first_level_section_names()
76 |
77 | section_output_dict_collection = []
78 | if len(sections_to_write) == 0:
79 | logging.error(
80 | f"No outline for {topic}. Will directly search with the topic."
81 | )
82 | section_output_dict = self.generate_section(
83 | topic=topic,
84 | section_name=topic,
85 | information_table=information_table,
86 | section_outline="",
87 | section_query=[topic],
88 | )
89 | section_output_dict_collection = [section_output_dict]
90 | else:
91 | with concurrent.futures.ThreadPoolExecutor(
92 | max_workers=self.max_thread_num
93 | ) as executor:
94 | future_to_sec_title = {}
95 | for section_title in sections_to_write:
96 | # We don't want to write a separate introduction section.
97 | if section_title.lower().strip() == "introduction":
98 | continue
99 | # We don't want to write a separate conclusion section.
100 | if section_title.lower().strip().startswith(
101 | "conclusion"
102 | ) or section_title.lower().strip().startswith("summary"):
103 | continue
104 | section_query = article_with_outline.get_outline_as_list(
105 | root_section_name=section_title, add_hashtags=False
106 | )
107 | queries_with_hashtags = article_with_outline.get_outline_as_list(
108 | root_section_name=section_title, add_hashtags=True
109 | )
110 | section_outline = "\n".join(queries_with_hashtags)
111 | future_to_sec_title[
112 | executor.submit(
113 | self.generate_section,
114 | topic,
115 | section_title,
116 | information_table,
117 | section_outline,
118 | section_query,
119 | )
120 | ] = section_title
121 |
122 | for future in as_completed(future_to_sec_title):
123 | section_output_dict_collection.append(future.result())
124 |
125 | article = copy.deepcopy(article_with_outline)
126 | for section_output_dict in section_output_dict_collection:
127 | article.update_section(
128 | parent_section_name=topic,
129 | current_section_content=section_output_dict["section_content"],
130 | current_section_info_list=section_output_dict["collected_info"],
131 | )
132 | article.post_processing()
133 | return article
134 |
135 |
136 | class ConvToSection(dspy.Module):
137 | """Use the information collected from the information-seeking conversation to write a section."""
138 |
139 | def __init__(self, engine: Union[dspy.dsp.LM, dspy.dsp.HFModel]):
140 | super().__init__()
141 | self.write_section = dspy.Predict(WriteSection)
142 | self.engine = engine
143 |
144 | def forward(
145 | self, topic: str, outline: str, section: str, collected_info: List[Information]
146 | ):
147 | info = ""
148 | for idx, storm_info in enumerate(collected_info):
149 | info += f"[{idx + 1}]\n" + "\n".join(storm_info.snippets)
150 | info += "\n\n"
151 |
152 | info = ArticleTextProcessing.limit_word_count_preserve_newline(info, 1500)
153 |
154 | with dspy.settings.context(lm=self.engine):
155 | section = ArticleTextProcessing.clean_up_section(
156 | self.write_section(topic=topic, info=info, section=section).output
157 | )
158 |
159 | return dspy.Prediction(section=section)
160 |
161 |
162 | class WriteSection(dspy.Signature):
163 | """Write a Wikipedia section based on the collected information.
164 |
165 | Here is the format of your writing:
166 | 1. Use "#" Title" to indicate section title, "##" Title" to indicate subsection title, "###" Title" to indicate subsubsection title, and so on.
167 | 2. Use [1], [2], ..., [n] in line (for example, "The capital of the United States is Washington, D.C.[1][3]."). You DO NOT need to include a References or Sources section to list the sources at the end.
168 | """
169 |
170 | info = dspy.InputField(prefix="The collected information:\n", format=str)
171 | topic = dspy.InputField(prefix="The topic of the page: ", format=str)
172 | section = dspy.InputField(prefix="The section you need to write: ", format=str)
173 | output = dspy.OutputField(
174 | prefix="Write the section with proper inline citations (Start your writing with # section title. Don't include the page title or try to write other sections):\n",
175 | format=str,
176 | )
177 |
--------------------------------------------------------------------------------
/knowledge_storm/storm_wiki/modules/article_polish.py:
--------------------------------------------------------------------------------
1 | import copy
2 | from typing import Union
3 |
4 | import dspy
5 |
6 | from .storm_dataclass import StormArticle
7 | from ...interface import ArticlePolishingModule
8 | from ...utils import ArticleTextProcessing
9 |
10 |
11 | class StormArticlePolishingModule(ArticlePolishingModule):
12 | """
13 | The interface for article generation stage. Given topic, collected information from
14 | knowledge curation stage, generated outline from outline generation stage.
15 | """
16 |
17 | def __init__(
18 | self,
19 | article_gen_lm: Union[dspy.dsp.LM, dspy.dsp.HFModel],
20 | article_polish_lm: Union[dspy.dsp.LM, dspy.dsp.HFModel],
21 | ):
22 | self.article_gen_lm = article_gen_lm
23 | self.article_polish_lm = article_polish_lm
24 |
25 | self.polish_page = PolishPageModule(
26 | write_lead_engine=self.article_gen_lm, polish_engine=self.article_polish_lm
27 | )
28 |
29 | def polish_article(
30 | self, topic: str, draft_article: StormArticle, remove_duplicate: bool = False
31 | ) -> StormArticle:
32 | """
33 | Polish article.
34 |
35 | Args:
36 | topic (str): The topic of the article.
37 | draft_article (StormArticle): The draft article.
38 | remove_duplicate (bool): Whether to use one additional LM call to remove duplicates from the article.
39 | """
40 |
41 | article_text = draft_article.to_string()
42 | polish_result = self.polish_page(
43 | topic=topic, draft_page=article_text, polish_whole_page=remove_duplicate
44 | )
45 | lead_section = f"# summary\n{polish_result.lead_section}"
46 | polished_article = "\n\n".join([lead_section, polish_result.page])
47 | polished_article_dict = ArticleTextProcessing.parse_article_into_dict(
48 | polished_article
49 | )
50 | polished_article = copy.deepcopy(draft_article)
51 | polished_article.insert_or_create_section(article_dict=polished_article_dict)
52 | polished_article.post_processing()
53 | return polished_article
54 |
55 |
56 | class WriteLeadSection(dspy.Signature):
57 | """Write a lead section for the given Wikipedia page with the following guidelines:
58 | 1. The lead should stand on its own as a concise overview of the article's topic. It should identify the topic, establish context, explain why the topic is notable, and summarize the most important points, including any prominent controversies.
59 | 2. The lead section should be concise and contain no more than four well-composed paragraphs.
60 | 3. The lead section should be carefully sourced as appropriate. Add inline citations (e.g., "Washington, D.C., is the capital of the United States.[1][3].") where necessary.
61 | """
62 |
63 | topic = dspy.InputField(prefix="The topic of the page: ", format=str)
64 | draft_page = dspy.InputField(prefix="The draft page:\n", format=str)
65 | lead_section = dspy.OutputField(prefix="Write the lead section:\n", format=str)
66 |
67 |
68 | class PolishPage(dspy.Signature):
69 | """You are a faithful text editor that is good at finding repeated information in the article and deleting them to make sure there is no repetition in the article. You won't delete any non-repeated part in the article. You will keep the inline citations and article structure (indicated by "#", "##", etc.) appropriately. Do your job for the following article."""
70 |
71 | draft_page = dspy.InputField(prefix="The draft article:\n", format=str)
72 | page = dspy.OutputField(prefix="Your revised article:\n", format=str)
73 |
74 |
75 | class PolishPageModule(dspy.Module):
76 | def __init__(
77 | self,
78 | write_lead_engine: Union[dspy.dsp.LM, dspy.dsp.HFModel],
79 | polish_engine: Union[dspy.dsp.LM, dspy.dsp.HFModel],
80 | ):
81 | super().__init__()
82 | self.write_lead_engine = write_lead_engine
83 | self.polish_engine = polish_engine
84 | self.write_lead = dspy.Predict(WriteLeadSection)
85 | self.polish_page = dspy.Predict(PolishPage)
86 |
87 | def forward(self, topic: str, draft_page: str, polish_whole_page: bool = True):
88 | # NOTE: Change show_guidelines to false to make the generation more robust to different LM families.
89 | with dspy.settings.context(lm=self.write_lead_engine, show_guidelines=False):
90 | lead_section = self.write_lead(
91 | topic=topic, draft_page=draft_page
92 | ).lead_section
93 | if "The lead section:" in lead_section:
94 | lead_section = lead_section.split("The lead section:")[1].strip()
95 | if polish_whole_page:
96 | # NOTE: Change show_guidelines to false to make the generation more robust to different LM families.
97 | with dspy.settings.context(lm=self.polish_engine, show_guidelines=False):
98 | page = self.polish_page(draft_page=draft_page).page
99 | else:
100 | page = draft_page
101 |
102 | return dspy.Prediction(lead_section=lead_section, page=page)
103 |
--------------------------------------------------------------------------------
/knowledge_storm/storm_wiki/modules/callback.py:
--------------------------------------------------------------------------------
1 | class BaseCallbackHandler:
2 | """Base callback handler that can be used to handle callbacks from the STORM pipeline."""
3 |
4 | def on_identify_perspective_start(self, **kwargs):
5 | """Run when the perspective identification starts."""
6 | pass
7 |
8 | def on_identify_perspective_end(self, perspectives: list[str], **kwargs):
9 | """Run when the perspective identification finishes."""
10 | pass
11 |
12 | def on_information_gathering_start(self, **kwargs):
13 | """Run when the information gathering starts."""
14 | pass
15 |
16 | def on_dialogue_turn_end(self, dlg_turn, **kwargs):
17 | """Run when a question asking and answering turn finishes."""
18 | pass
19 |
20 | def on_information_gathering_end(self, **kwargs):
21 | """Run when the information gathering finishes."""
22 | pass
23 |
24 | def on_information_organization_start(self, **kwargs):
25 | """Run when the information organization starts."""
26 | pass
27 |
28 | def on_direct_outline_generation_end(self, outline: str, **kwargs):
29 | """Run when the direct outline generation finishes."""
30 | pass
31 |
32 | def on_outline_refinement_end(self, outline: str, **kwargs):
33 | """Run when the outline refinement finishes."""
34 | pass
35 |
--------------------------------------------------------------------------------
/knowledge_storm/storm_wiki/modules/outline_generation.py:
--------------------------------------------------------------------------------
1 | from typing import Union, Optional, Tuple
2 |
3 | import dspy
4 |
5 | from .callback import BaseCallbackHandler
6 | from .storm_dataclass import StormInformationTable, StormArticle
7 | from ...interface import OutlineGenerationModule
8 | from ...utils import ArticleTextProcessing
9 |
10 |
11 | class StormOutlineGenerationModule(OutlineGenerationModule):
12 | """
13 | The interface for outline generation stage. Given topic, collected information from knowledge
14 | curation stage, generate outline for the article.
15 | """
16 |
17 | def __init__(self, outline_gen_lm: Union[dspy.dsp.LM, dspy.dsp.HFModel]):
18 | super().__init__()
19 | self.outline_gen_lm = outline_gen_lm
20 | self.write_outline = WriteOutline(engine=self.outline_gen_lm)
21 |
22 | def generate_outline(
23 | self,
24 | topic: str,
25 | information_table: StormInformationTable,
26 | old_outline: Optional[StormArticle] = None,
27 | callback_handler: BaseCallbackHandler = None,
28 | return_draft_outline=False,
29 | ) -> Union[StormArticle, Tuple[StormArticle, StormArticle]]:
30 | """
31 | Generates an outline for an article based on the specified topic and the information
32 | gathered during the knowledge curation stage. This method can optionally return both the
33 | final article outline and a draft outline if required.
34 |
35 | Args:
36 | topic (str): The topic of the article.
37 | information_table (StormInformationTable): The information table containing the collected information.
38 | old_outline (Optional[StormArticle]): An optional previous version of the article outline that can
39 | be used for reference or comparison. Defaults to None.
40 | callback_handler (BaseCallbackHandler): An optional callback handler that can be used to trigger
41 | custom callbacks at various stages of the outline generation process, such as when the information
42 | organization starts. Defaults to None.
43 | return_draft_outline (bool): A flag indicating whether the method should return both the final article
44 | outline and a draft version of the outline. If False, only the final article outline is returned.
45 | Defaults to False.
46 |
47 | Returns:
48 | Union[StormArticle, Tuple[StormArticle, StormArticle]]: Depending on the value of `return_draft_outline`,
49 | this method returns either a single `StormArticle` object containing the final outline or a tuple of
50 | two `StormArticle` objects, the first containing the final outline and the second containing the
51 | draft outline.
52 | """
53 | if callback_handler is not None:
54 | callback_handler.on_information_organization_start()
55 |
56 | concatenated_dialogue_turns = sum(
57 | [conv for (_, conv) in information_table.conversations], []
58 | )
59 | result = self.write_outline(
60 | topic=topic,
61 | dlg_history=concatenated_dialogue_turns,
62 | callback_handler=callback_handler,
63 | )
64 | article_with_outline_only = StormArticle.from_outline_str(
65 | topic=topic, outline_str=result.outline
66 | )
67 | article_with_draft_outline_only = StormArticle.from_outline_str(
68 | topic=topic, outline_str=result.old_outline
69 | )
70 | if not return_draft_outline:
71 | return article_with_outline_only
72 | return article_with_outline_only, article_with_draft_outline_only
73 |
74 |
75 | class WriteOutline(dspy.Module):
76 | """Generate the outline for the Wikipedia page."""
77 |
78 | def __init__(self, engine: Union[dspy.dsp.LM, dspy.dsp.HFModel]):
79 | super().__init__()
80 | self.draft_page_outline = dspy.Predict(WritePageOutline)
81 | self.write_page_outline = dspy.Predict(WritePageOutlineFromConv)
82 | self.engine = engine
83 |
84 | def forward(
85 | self,
86 | topic: str,
87 | dlg_history,
88 | old_outline: Optional[str] = None,
89 | callback_handler: BaseCallbackHandler = None,
90 | ):
91 | trimmed_dlg_history = []
92 | for turn in dlg_history:
93 | if (
94 | "topic you" in turn.agent_utterance.lower()
95 | or "topic you" in turn.user_utterance.lower()
96 | ):
97 | continue
98 | trimmed_dlg_history.append(turn)
99 | conv = "\n".join(
100 | [
101 | f"Wikipedia Writer: {turn.user_utterance}\nExpert: {turn.agent_utterance}"
102 | for turn in trimmed_dlg_history
103 | ]
104 | )
105 | conv = ArticleTextProcessing.remove_citations(conv)
106 | conv = ArticleTextProcessing.limit_word_count_preserve_newline(conv, 5000)
107 |
108 | with dspy.settings.context(lm=self.engine):
109 | if old_outline is None:
110 | old_outline = ArticleTextProcessing.clean_up_outline(
111 | self.draft_page_outline(topic=topic).outline
112 | )
113 | if callback_handler:
114 | callback_handler.on_direct_outline_generation_end(
115 | outline=old_outline
116 | )
117 | outline = ArticleTextProcessing.clean_up_outline(
118 | self.write_page_outline(
119 | topic=topic, old_outline=old_outline, conv=conv
120 | ).outline
121 | )
122 | if callback_handler:
123 | callback_handler.on_outline_refinement_end(outline=outline)
124 |
125 | return dspy.Prediction(outline=outline, old_outline=old_outline)
126 |
127 |
128 | class WritePageOutline(dspy.Signature):
129 | """Write an outline for a Wikipedia page.
130 | Here is the format of your writing:
131 | 1. Use "#" Title" to indicate section title, "##" Title" to indicate subsection title, "###" Title" to indicate subsubsection title, and so on.
132 | 2. Do not include other information.
133 | 3. Do not include topic name itself in the outline.
134 | """
135 |
136 | topic = dspy.InputField(prefix="The topic you want to write: ", format=str)
137 | outline = dspy.OutputField(prefix="Write the Wikipedia page outline:\n", format=str)
138 |
139 |
140 | class NaiveOutlineGen(dspy.Module):
141 | """Generate the outline with LLM's parametric knowledge directly."""
142 |
143 | def __init__(self):
144 | super().__init__()
145 | self.write_outline = dspy.Predict(WritePageOutline)
146 |
147 | def forward(self, topic: str):
148 | outline = self.write_outline(topic=topic).outline
149 |
150 | return dspy.Prediction(outline=outline)
151 |
152 |
153 | class WritePageOutlineFromConv(dspy.Signature):
154 | """Improve an outline for a Wikipedia page. You already have a draft outline that covers the general information. Now you want to improve it based on the information learned from an information-seeking conversation to make it more informative.
155 | Here is the format of your writing:
156 | 1. Use "#" Title" to indicate section title, "##" Title" to indicate subsection title, "###" Title" to indicate subsubsection title, and so on.
157 | 2. Do not include other information.
158 | 3. Do not include topic name itself in the outline.
159 | """
160 |
161 | topic = dspy.InputField(prefix="The topic you want to write: ", format=str)
162 | conv = dspy.InputField(prefix="Conversation history:\n", format=str)
163 | old_outline = dspy.OutputField(prefix="Current outline:\n", format=str)
164 | outline = dspy.OutputField(
165 | prefix='Write the Wikipedia page outline (Use "#" Title" to indicate section title, "##" Title" to indicate subsection title, ...):\n',
166 | format=str,
167 | )
168 |
--------------------------------------------------------------------------------
/knowledge_storm/storm_wiki/modules/persona_generator.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import re
3 | from typing import Union, List
4 |
5 | import dspy
6 | import requests
7 | from bs4 import BeautifulSoup
8 |
9 |
10 | def get_wiki_page_title_and_toc(url):
11 | """Get the main title and table of contents from an url of a Wikipedia page."""
12 |
13 | response = requests.get(url)
14 | soup = BeautifulSoup(response.content, "html.parser")
15 |
16 | # Get the main title from the first h1 tag
17 | main_title = soup.find("h1").text.replace("[edit]", "").strip().replace("\xa0", " ")
18 |
19 | toc = ""
20 | levels = []
21 | excluded_sections = {
22 | "Contents",
23 | "See also",
24 | "Notes",
25 | "References",
26 | "External links",
27 | }
28 |
29 | # Start processing from h2 to exclude the main title from TOC
30 | for header in soup.find_all(["h2", "h3", "h4", "h5", "h6"]):
31 | level = int(
32 | header.name[1]
33 | ) # Extract the numeric part of the header tag (e.g., '2' from 'h2')
34 | section_title = header.text.replace("[edit]", "").strip().replace("\xa0", " ")
35 | if section_title in excluded_sections:
36 | continue
37 |
38 | while levels and level <= levels[-1]:
39 | levels.pop()
40 | levels.append(level)
41 |
42 | indentation = " " * (len(levels) - 1)
43 | toc += f"{indentation}{section_title}\n"
44 |
45 | return main_title, toc.strip()
46 |
47 |
48 | class FindRelatedTopic(dspy.Signature):
49 | """I'm writing a Wikipedia page for a topic mentioned below. Please identify and recommend some Wikipedia pages on closely related subjects. I'm looking for examples that provide insights into interesting aspects commonly associated with this topic, or examples that help me understand the typical content and structure included in Wikipedia pages for similar topics.
50 | Please list the urls in separate lines."""
51 |
52 | topic = dspy.InputField(prefix="Topic of interest:", format=str)
53 | related_topics = dspy.OutputField(format=str)
54 |
55 |
56 | class GenPersona(dspy.Signature):
57 | """You need to select a group of Wikipedia editors who will work together to create a comprehensive article on the topic. Each of them represents a different perspective, role, or affiliation related to this topic. You can use other Wikipedia pages of related topics for inspiration. For each editor, add a description of what they will focus on.
58 | Give your answer in the following format: 1. short summary of editor 1: description\n2. short summary of editor 2: description\n...
59 | """
60 |
61 | topic = dspy.InputField(prefix="Topic of interest:", format=str)
62 | examples = dspy.InputField(
63 | prefix="Wiki page outlines of related topics for inspiration:\n", format=str
64 | )
65 | personas = dspy.OutputField(format=str)
66 |
67 |
68 | class CreateWriterWithPersona(dspy.Module):
69 | """Discover different perspectives of researching the topic by reading Wikipedia pages of related topics."""
70 |
71 | def __init__(self, engine: Union[dspy.dsp.LM, dspy.dsp.HFModel]):
72 | super().__init__()
73 | self.find_related_topic = dspy.ChainOfThought(FindRelatedTopic)
74 | self.gen_persona = dspy.ChainOfThought(GenPersona)
75 | self.engine = engine
76 |
77 | def forward(self, topic: str, draft=None):
78 | with dspy.settings.context(lm=self.engine):
79 | # Get section names from wiki pages of relevant topics for inspiration.
80 | related_topics = self.find_related_topic(topic=topic).related_topics
81 | urls = []
82 | for s in related_topics.split("\n"):
83 | if "http" in s:
84 | urls.append(s[s.find("http") :])
85 | examples = []
86 | for url in urls:
87 | try:
88 | title, toc = get_wiki_page_title_and_toc(url)
89 | examples.append(f"Title: {title}\nTable of Contents: {toc}")
90 | except Exception as e:
91 | logging.error(f"Error occurs when processing {url}: {e}")
92 | continue
93 | if len(examples) == 0:
94 | examples.append("N/A")
95 | gen_persona_output = self.gen_persona(
96 | topic=topic, examples="\n----------\n".join(examples)
97 | ).personas
98 |
99 | personas = []
100 | for s in gen_persona_output.split("\n"):
101 | match = re.search(r"\d+\.\s*(.*)", s)
102 | if match:
103 | personas.append(match.group(1))
104 |
105 | sorted_personas = personas
106 |
107 | return dspy.Prediction(
108 | personas=personas,
109 | raw_personas_output=sorted_personas,
110 | related_topics=related_topics,
111 | )
112 |
113 |
114 | class StormPersonaGenerator:
115 | """
116 | A generator class for creating personas based on a given topic.
117 |
118 | This class uses an underlying engine to generate personas tailored to the specified topic.
119 | The generator integrates with a `CreateWriterWithPersona` instance to create diverse personas,
120 | including a default 'Basic fact writer' persona.
121 |
122 | Attributes:
123 | create_writer_with_persona (CreateWriterWithPersona): An instance responsible for
124 | generating personas based on the provided engine and topic.
125 |
126 | Args:
127 | engine (Union[dspy.dsp.LM, dspy.dsp.HFModel]): The underlying engine used for generating
128 | personas. It must be an instance of either `dspy.dsp.LM` or `dspy.dsp.HFModel`.
129 | """
130 |
131 | def __init__(self, engine: Union[dspy.dsp.LM, dspy.dsp.HFModel]):
132 | self.create_writer_with_persona = CreateWriterWithPersona(engine=engine)
133 |
134 | def generate_persona(self, topic: str, max_num_persona: int = 3) -> List[str]:
135 | """
136 | Generates a list of personas based on the provided topic, up to a maximum number specified.
137 |
138 | This method first creates personas using the underlying `create_writer_with_persona` instance
139 | and then prepends a default 'Basic fact writer' persona to the list before returning it.
140 | The number of personas returned is limited to `max_num_persona`, excluding the default persona.
141 |
142 | Args:
143 | topic (str): The topic for which personas are to be generated.
144 | max_num_persona (int): The maximum number of personas to generate, excluding the
145 | default 'Basic fact writer' persona.
146 |
147 | Returns:
148 | List[str]: A list of persona descriptions, including the default 'Basic fact writer' persona
149 | and up to `max_num_persona` additional personas generated based on the topic.
150 | """
151 | personas = self.create_writer_with_persona(topic=topic)
152 | default_persona = "Basic fact writer: Basic fact writer focusing on broadly covering the basic facts about the topic."
153 | considered_personas = [default_persona] + personas.personas[:max_num_persona]
154 | return considered_personas
155 |
--------------------------------------------------------------------------------
/knowledge_storm/storm_wiki/modules/retriever.py:
--------------------------------------------------------------------------------
1 | from typing import Union, List
2 | from urllib.parse import urlparse
3 |
4 | import dspy
5 |
6 | from ...interface import Retriever, Information
7 | from ...utils import ArticleTextProcessing
8 |
9 | # Internet source restrictions according to Wikipedia standard:
10 | # https://en.wikipedia.org/wiki/Wikipedia:Reliable_sources/Perennial_sources
11 | GENERALLY_UNRELIABLE = {
12 | "112_Ukraine",
13 | "Ad_Fontes_Media",
14 | "AlterNet",
15 | "Amazon",
16 | "Anadolu_Agency_(controversial_topics)",
17 | "Ancestry.com",
18 | "Answers.com",
19 | "Antiwar.com",
20 | "Anti-Defamation_League",
21 | "arXiv",
22 | "Atlas_Obscura_places",
23 | "Bild",
24 | "Blaze_Media",
25 | "Blogger",
26 | "BroadwayWorld",
27 | "California_Globe",
28 | "The_Canary",
29 | "CelebrityNetWorth",
30 | "CESNUR",
31 | "ChatGPT",
32 | "CNET_(November_2022\u2013present)",
33 | "CoinDesk",
34 | "Consortium_News",
35 | "CounterPunch",
36 | "Correo_del_Orinoco",
37 | "Cracked.com",
38 | "Daily_Express",
39 | "Daily_Kos",
40 | "Daily_Sabah",
41 | "The_Daily_Wire",
42 | "Discogs",
43 | "Distractify",
44 | "The_Electronic_Intifada",
45 | "Encyclopaedia_Metallum",
46 | "Ethnicity_of_Celebs",
47 | "Facebook",
48 | "FamilySearch",
49 | "Fandom",
50 | "The_Federalist",
51 | "Find_a_Grave",
52 | "Findmypast",
53 | "Flags_of_the_World",
54 | "Flickr",
55 | "Forbes.com_contributors",
56 | "Fox_News_(politics_and_science)",
57 | "Fox_News_(talk_shows)",
58 | "Gawker",
59 | "GB_News",
60 | "Geni.com",
61 | "gnis-class",
62 | "gns-class",
63 | "GlobalSecurity.org",
64 | "Goodreads",
65 | "Guido_Fawkes",
66 | "Heat_Street",
67 | "History",
68 | "HuffPost_contributors",
69 | "IMDb",
70 | "Independent_Media_Center",
71 | "Inquisitr",
72 | "International_Business_Times",
73 | "Investopedia",
74 | "Jewish_Virtual_Library",
75 | "Joshua_Project",
76 | "Know_Your_Meme",
77 | "Land_Transport_Guru",
78 | "LinkedIn",
79 | "LiveJournal",
80 | "Marquis_Who's_Who",
81 | "Mashable_sponsored_content",
82 | "MEAWW",
83 | "Media_Bias/Fact_Check",
84 | "Media_Research_Center",
85 | "Medium",
86 | "metal-experience",
87 | "Metro",
88 | "The_New_American",
89 | "New_York_Post",
90 | "NGO_Monitor",
91 | "The_Onion",
92 | "Our_Campaigns",
93 | "PanAm_Post",
94 | "Patheos",
95 | "An_Phoblacht",
96 | "The_Post_Millennial",
97 | "arXiv",
98 | "bioRxiv",
99 | "medRxiv",
100 | "PeerJ Preprints",
101 | "Preprints.org",
102 | "SSRN",
103 | "PR_Newswire",
104 | "Quadrant",
105 | "Quillette",
106 | "Quora",
107 | "Raw_Story",
108 | "Reddit",
109 | "RedState",
110 | "ResearchGate",
111 | "Rolling_Stone_(politics_and_society,_2011\u2013present)",
112 | "Rolling_Stone_(Culture_Council)",
113 | "Scribd",
114 | "Scriptural_texts",
115 | "Simple_Flying",
116 | "Sixth_Tone_(politics)",
117 | "The_Skwawkbox",
118 | "SourceWatch",
119 | "Spirit_of_Metal",
120 | "Sportskeeda",
121 | "Stack_Exchange",
122 | "Stack_Overflow",
123 | "MathOverflow",
124 | "Ask_Ubuntu",
125 | "starsunfolded.com",
126 | "Statista",
127 | "TASS",
128 | "The_Truth_About_Guns",
129 | "TV.com",
130 | "TV_Tropes",
131 | "Twitter",
132 | "X.com",
133 | "Urban_Dictionary",
134 | "Venezuelanalysis",
135 | "VGChartz",
136 | "VoC",
137 | "Washington_Free_Beacon",
138 | "Weather2Travel",
139 | "The_Western_Journal",
140 | "We_Got_This_Covered",
141 | "WhatCulture",
142 | "Who's_Who_(UK)",
143 | "WhoSampled",
144 | "Wikidata",
145 | "WikiLeaks",
146 | "Wikinews",
147 | "Wikipedia",
148 | "WordPress.com",
149 | "Worldometer",
150 | "YouTube",
151 | "ZDNet",
152 | }
153 | DEPRECATED = {
154 | "Al_Mayadeen",
155 | "ANNA_News",
156 | "Baidu_Baike",
157 | "China_Global_Television_Network",
158 | "The_Cradle",
159 | "Crunchbase",
160 | "The_Daily_Caller",
161 | "Daily_Mail",
162 | "Daily_Star",
163 | "The_Epoch_Times",
164 | "FrontPage_Magazine",
165 | "The_Gateway_Pundit",
166 | "Global_Times",
167 | "The_Grayzone",
168 | "HispanTV",
169 | "Jihad_Watch",
170 | "Last.fm",
171 | "LifeSiteNews",
172 | "The_Mail_on_Sunday",
173 | "MintPress_News",
174 | "National_Enquirer",
175 | "New_Eastern_Outlook",
176 | "News_Break",
177 | "NewsBlaze",
178 | "News_of_the_World",
179 | "Newsmax",
180 | "NNDB",
181 | "Occupy_Democrats",
182 | "Office_of_Cuba_Broadcasting",
183 | "One_America_News_Network",
184 | "Peerage_websites",
185 | "Press_TV",
186 | "Project_Veritas",
187 | "Rate_Your_Music",
188 | "Republic_TV",
189 | "Royal_Central",
190 | "RT",
191 | "Sputnik",
192 | "The_Sun",
193 | "Taki's_Magazine",
194 | "Tasnim_News_Agency",
195 | "Telesur",
196 | "The_Unz_Review",
197 | "VDARE",
198 | "Voltaire_Network",
199 | "WorldNetDaily",
200 | "Zero_Hedge",
201 | }
202 | BLACKLISTED = {
203 | "Advameg",
204 | "bestgore.com",
205 | "Breitbart_News",
206 | "Centre_for_Research_on_Globalization",
207 | "Examiner.com",
208 | "Famous_Birthdays",
209 | "Healthline",
210 | "InfoWars",
211 | "Lenta.ru",
212 | "LiveLeak",
213 | "Lulu.com",
214 | "MyLife",
215 | "Natural_News",
216 | "OpIndia",
217 | "The_Points_Guy",
218 | "The_Points_Guy_(sponsored_content)",
219 | "Swarajya",
220 | "Veterans_Today",
221 | "ZoomInfo",
222 | }
223 |
224 |
225 | def is_valid_wikipedia_source(url):
226 | parsed_url = urlparse(url)
227 | # Check if the URL is from a reliable domain
228 | combined_set = GENERALLY_UNRELIABLE | DEPRECATED | BLACKLISTED
229 | for domain in combined_set:
230 | if domain in parsed_url.netloc:
231 | return False
232 |
233 | return True
234 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | dspy_ai==2.4.9
2 | wikipedia==1.4.0
3 | sentence-transformers
4 | toml
5 | langchain-text-splitters
6 | trafilatura
7 | langchain-huggingface
8 | qdrant-client
9 | langchain-qdrant
10 | numpy==1.26.4
11 | litellm==1.59.3
12 | diskcache
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | import re
2 |
3 | from setuptools import setup, find_packages
4 |
5 | # Read the content of the README file
6 | with open("README.md", encoding="utf-8") as f:
7 | long_description = f.read()
8 | # Remove p tags.
9 | pattern = re.compile(r".*?", re.DOTALL)
10 | long_description = re.sub(pattern, "", long_description)
11 |
12 | # Read the content of the requirements.txt file
13 | with open("requirements.txt", encoding="utf-8") as f:
14 | requirements = f.read().splitlines()
15 |
16 |
17 | setup(
18 | name="knowledge-storm",
19 | version="1.1.0",
20 | author="Yijia Shao, Yucheng Jiang",
21 | author_email="shaoyj@stanford.edu, yuchengj@stanford.edu",
22 | description="STORM: A language model-powered knowledge curation engine.",
23 | long_description=long_description,
24 | long_description_content_type="text/markdown",
25 | url="https://github.com/stanford-oval/storm",
26 | license="MIT License",
27 | packages=find_packages(),
28 | classifiers=[
29 | "Development Status :: 3 - Alpha",
30 | "License :: OSI Approved :: MIT License",
31 | "Operating System :: OS Independent",
32 | "Programming Language :: Python :: 3",
33 | "Programming Language :: Python :: 3.10",
34 | "Programming Language :: Python :: 3.11",
35 | ],
36 | python_requires=">=3.10",
37 | install_requires=requirements,
38 | )
39 |
--------------------------------------------------------------------------------