├── .gitignore
├── LICENSE
├── README.md
├── call_model.py
├── config.yaml
├── pyproject.toml
└── src
└── langfuzz
├── __init__.py
└── redteam.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
110 | .pdm.toml
111 | .pdm-python
112 | .pdm-build/
113 |
114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
115 | __pypackages__/
116 |
117 | # Celery stuff
118 | celerybeat-schedule
119 | celerybeat.pid
120 |
121 | # SageMath parsed files
122 | *.sage.py
123 |
124 | # Environments
125 | .env
126 | .venv
127 | env/
128 | venv/
129 | ENV/
130 | env.bak/
131 | venv.bak/
132 |
133 | # Spyder project settings
134 | .spyderproject
135 | .spyproject
136 |
137 | # Rope project settings
138 | .ropeproject
139 |
140 | # mkdocs documentation
141 | /site
142 |
143 | # mypy
144 | .mypy_cache/
145 | .dmypy.json
146 | dmypy.json
147 |
148 | # Pyre type checker
149 | .pyre/
150 |
151 | # pytype static type analyzer
152 | .pytype/
153 |
154 | # Cython debug symbols
155 | cython_debug/
156 |
157 | # PyCharm
158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
160 | # and can be added to the global gitignore or merged into this file. For a more nuclear
161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
162 | #.idea/
163 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Harrison Chase
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 | # LangFuzz: Red Teaming for Language Models
3 | LangFuzz is a command line tool designed to perform red teaming on language model applications and add any points of interest to a [LangSmith Dataset](https://docs.smith.langchain.com/). It generates pairs of similar questions and compares the responses to identify potential failure modes in chatbots or other language model-based systems. For those coming from a software engineering background: this similar to a particular type of [fuzz testing](https://www.blackduck.com/glossary/what-is-fuzz-testing.html#:~:text=Definition,as%20crashes%20or%20information%20leakage.) called [metamorphic testing](https://arxiv.org/abs/2002.12543).
4 |
5 | ## Installation
6 | To install LangFuzz, use pip:
7 |
8 | ```
9 | pip install langfuzz
10 | ```
11 |
12 | ## Usage
13 |
14 | ### Step 1: define a model file
15 | First, define a model file that calls your model. This file should expose a sync OR async function called `call_model` that takes in a string and returns a string. An example file is found in [call_model.py](call_model.py). Example model file:
16 |
17 | ```python
18 | import random
19 | from openai import OpenAI
20 |
21 | client = OpenAI()
22 |
23 |
24 | def call_model(question: str) -> str:
25 | # This is to add some randomness in and get bad answers.
26 | if random.uniform(0, 1) > 0.5:
27 | system_message = "LangChain is an LLM framework - answer all questions with things about LLMs."
28 | else:
29 | system_message = "LangChain is blockchain technology - answer all questions with things about crypto"
30 |
31 | completion = client.chat.completions.create(
32 | model="gpt-4o-mini",
33 | messages=[
34 | {"role": "system", "content": system_message},
35 | {"role": "user", "content": question},
36 | ],
37 | )
38 | return completion.choices[0].message.content
39 |
40 | ```
41 |
42 | ### Step 2: define a configuration file
43 |
44 | Next, you need to define a configuration file.
45 |
46 | The `config.yaml` file should contain the following keys:
47 | - `chatbot_description`: A description of the chatbot being tested.
48 | - `model_file`: Path to the Python file containing the call_model function
49 |
50 | Example config.yaml:
51 | ```
52 | chatbot_description: "Chat over LangChain Docs"
53 | model_file: "call_model.py"
54 | ```
55 |
56 | It may optionally contain other keys (see [Options](#options) and [Additional Configuration](#additional-configuration) below).
57 |
58 | ### Step 3: set any environment variables
59 |
60 | This requires several environment variables to run:
61 |
62 | - `OPENAI_API_KEY`: required for accessing OpenAI model
63 | - [OPTIONAL] `LANGSMITH_API_KEY`: required for sending generated data to a LangSmith dataset.
64 |
65 | ```
66 | export OPENAI_API_KEY=...
67 | export LANGSMITH_API_KEY=...
68 | ```
69 |
70 | ### Step 4: run the red teaming
71 |
72 | To run the red teaming process, use the following command:
73 | ```
74 | langfuzz config.yaml [options]
75 | ```
76 |
77 | ### Step 5: curate datapoints
78 |
79 | As the redteaming is run, pairs of datapoints will be shown to you in the command line. From there, you can choose to add both, one, or neither to a LangSmith dataset.
80 |
81 | - **Enter**: To add both inputs to the dataset, just press enter
82 | - **`1`**: If you want to add only the first input to the dataset, enter `1`
83 | - **`2`**: If you want to add only the second input to the dataset, enter `2`
84 | - **`3`**: If you don't want to add either input to the dataset, enter `3`
85 | - **`q`**: To quit, enter `q`
86 |
87 | If you add a datapoint to a LangSmith dataset, it will be added with a single input key `question` and no output key.
88 |
89 | ## Options
90 |
91 | - `--dataset_id`: ID of the dataset to use (optional)
92 | - `--n`: Number of questions to generate (default: 10)
93 | - `--max_concurrency`: Maximum number of concurrent requests to the model (default: 10)
94 | - `--n_prefill_questions`: Number of questions to prefill the dataset with (default: 10)
95 | - `--max_similarity`: Maximum similarity score to accept (default: 10)
96 | - `-p`, `--persistence-path`: Path to the persistence file (optional)
97 |
98 | These options can additionally be provided as part of the configuration file.
99 |
100 | ## Additional Configuration
101 |
102 | You can also configure more aspects of the redteaming agent.
103 |
104 | - `judge_model`: the model to use to judge whether two pairs are similar
105 | - `question_gen_model`: the model to use to generate pairs of questions
106 | - `judge_prompt`: the prompt to use to judge whether two pairs are similar
107 | - `question_gent_prompt`: the prompt to use to generate pairs of questions
108 |
109 | ## How It Works
110 |
111 | The tool generates pairs of similar questions based on the provided chatbot description.
112 | It sends these questions to the chatbot and compares the responses.
113 | A judge model evaluates the similarity of the responses on a scale of 1-10.
114 | Results with similarity scores below the max_similarity threshold are presented for review.
115 | Users can choose to add the questions to a dataset for further analysis or training.
116 | Persistence
117 | If a persistence path is provided, the tool will save generated questions and dataset information between runs. This allows for continuous red teaming sessions without duplicating questions.
118 |
119 | ## Use without LangSmith
120 |
121 | You can also use this redteaming agent without LangSmith and dump all the results to a local file. To do this, use the following command:
122 |
123 | ```
124 | langfuzz-dump config.yaml results.json [options]
125 | ```
126 |
--------------------------------------------------------------------------------
/call_model.py:
--------------------------------------------------------------------------------
1 | import random
2 | from openai import OpenAI
3 |
4 | client = OpenAI()
5 |
6 |
7 | def call_model(question: str) -> str:
8 | # This is to add some randomness in and get bad answers.
9 | if random.uniform(0, 1) > 0.5:
10 | system_message = "LangChain is an LLM framework - answer all questions with things about LLMs."
11 | else:
12 | system_message = "LangChain is blockchain technology - answer all questions with things about crypto"
13 |
14 | completion = client.chat.completions.create(
15 | model="gpt-4o-mini",
16 | messages=[
17 | {"role": "system", "content": system_message},
18 | {"role": "user", "content": question},
19 | ],
20 | )
21 | return completion.choices[0].message.content
22 |
--------------------------------------------------------------------------------
/config.yaml:
--------------------------------------------------------------------------------
1 | chatbot_description: "Chat over LangChain Docs"
2 | model_file: "call_model.py"
3 | configurable:
4 | judge_model: "gpt-4o"
5 | question_gen_model: "gpt-4o-mini"
6 | judge_prompt: |
7 | You are benevolent white hat hacker. Your job is identify places where a chatbot application may produce incorrect results. You aren't an expert in the same subject as the chatbot, so rather than judge correctness you come up with a clever idea. You will generate pairs of questions, where each item in the pair is slightly different. These pairs should produce the same answer (maybe not word for word, but semantically the same). You will then ask the chatbot both of these questions, and if it produces different answers then you know one of the questions is a failure mode for that chatbot.
8 |
9 |
10 | {question_1}
11 |
12 |
13 |
14 | {answer_1}
15 |
16 |
17 |
18 | {question_2}
19 |
20 |
21 |
22 | {answer_2}
23 |
24 |
25 | How similar are these answers on a scale of 1-10? 1 is not similar at all, 10 is absolutely similar
26 | If one answer contains information that the other one doesn't, then the similarity should be lower. The amount it is lower depends on the significance of the information left out.
27 | If the answers contradict eachother, then the similarity should be very low.
28 | question_gen_prompt: |
29 | You are benevolent white hat hacker. Your job is identify places where a chatbot application may produce incorrect results. You aren't an expert in the same subject as the chatbot, so rather than judge correctness you come up with a clever idea. You will generate pairs of questions, where each item in the pair is slightly different. These pairs should produce the same answer (maybe not word for word, but semantically the same). You will then ask the chatbot both of these questions and if it produces different answers then you know one of the questions is a failure mode for that chatbot.
30 |
31 | The chatbot you are redteaming is:
32 |
33 | {chatbot_description}
34 |
35 | Please a list of pairs of questions that are relevant to this chatbot. Make sure to generate a diverse set of questions. Only generate {n} questions.
36 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-syste]
2 | requires = ["setuptools>=45", "setuptools-scm>=6.2"]
3 | build-backend = "setuptools.build_meta"
4 |
5 | [project]
6 | name = "langfuzz"
7 | version = "0.0.1"
8 | description = "Project to fuzz language model applications"
9 | authors = [{name = "Harrison Chase"}]
10 | dependencies = [
11 | "langgraph",
12 | "langchain",
13 | "langchain_openai",
14 | "langsmith",
15 | "pyyaml",
16 | ]
17 |
18 | [project.scripts]
19 | langfuzz = "langfuzz.redteam:main"
20 | langfuzz-dump = "langfuzz.redteam:dump"
21 |
22 | [tool.setuptools.packages.find]
23 | where = ["src"]
24 |
--------------------------------------------------------------------------------
/src/langfuzz/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hwchase17/langfuzz/7116127c2429c382ff0e0701b7ab2271bc252eb9/src/langfuzz/__init__.py
--------------------------------------------------------------------------------
/src/langfuzz/redteam.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import asyncio
3 | import contextlib
4 | import dataclasses
5 | import importlib.util
6 | import json
7 | import operator
8 | import os
9 | import threading
10 | import time
11 | import warnings
12 | import queue
13 | import inspect
14 | import yaml
15 | from typing import Annotated, Any, Callable, TypedDict
16 |
17 | from langchain.chat_models import init_chat_model
18 |
19 | from langsmith import Client
20 | from langgraph.graph import END, START, StateGraph
21 | from langgraph.types import Send
22 |
23 | # Suppress all warnings
24 | warnings.filterwarnings("ignore")
25 |
26 |
27 | @dataclasses.dataclass(order=True)
28 | class Result:
29 | prio: int
30 | answer: dict = dataclasses.field(hash=False, compare=False)
31 |
32 |
33 | class Response(TypedDict):
34 | input_1: str
35 | input_2: str
36 |
37 |
38 | class ListResponse(TypedDict):
39 | pairs: list[Response]
40 |
41 |
42 | class JudgeState(TypedDict):
43 | input_1: Any
44 | input_2: Any
45 | output_1: Any
46 | output_2: Any
47 | judge: dict
48 |
49 |
50 | class AgentState(TypedDict):
51 | chatbot_description: str
52 | generated_questions: list[str]
53 | n: int
54 | pairs: list
55 | answers: Annotated[list, operator.add]
56 |
57 |
58 | answer_prompt = """You are benevolent white hat hacker. Your job is identify places where a chatbot application may produce incorrect results. You aren't an expert in the same subject as the chatbot, so rather than judge correctness you come up with a clever idea. You will generate pairs of questions, where each item in the pair is slightly different. These pairs should produce the same answer (maybe not word for word, but semantically the same). You will then ask the chatbot both of these questions, and if it produces different answers then you know one of the questions is a failure mode for that chatbot.
59 |
60 |
61 | {question_1}
62 |
63 |
64 |
65 | {answer_1}
66 |
67 |
68 |
69 | {question_2}
70 |
71 |
72 |
73 | {answer_2}
74 |
75 |
76 | How similar are these answers on a scale of 1-10? 1 is not similar at all, 10 is absolutely similar
77 | If one answer contains information that the other one doesn't, then the similarity should be lower. The amount it is lower depends on the significance of the information left out.
78 | If the answers contradict eachother, then the similarity should be very low."""
79 |
80 |
81 | class Judge(TypedDict):
82 | logic: str
83 | similarity: int
84 |
85 |
86 | question_prompt = """You are benevolent white hat hacker. Your job is identify places where a chatbot application may produce \
87 | incorrect results. You aren't an expert in the same subject as the chatbot, so rather than judge correctness you come up \
88 | with a clever idea. You will generate pairs of questions, where each item in the pair is slightly different. These pairs should \
89 | produce the same answer (maybe not word for word, but semantically the same). You will then ask the chatbot both of these questions, \
90 | and if it produces different answers then you know one of the questions is a failure mode for that chatbot.
91 |
92 | The chatbot you are redteaming is:
93 |
94 | {chatbot_description}
95 |
96 | Please a list of pairs of questions that are relevant to this chatbot. Make sure to generate a diverse set of questions. Only generate {n} questions."""
97 |
98 |
99 | def generate_questions(state, config):
100 | model = init_chat_model(
101 | model=config["configurable"].get("question_gen_model", "gpt-4o-mini")
102 | )
103 | _question_prompt = config["configurable"].get(
104 | "question_gen_prompt", question_prompt
105 | )
106 | prompt = _question_prompt.format(
107 | chatbot_description=state["chatbot_description"], n=state["n"]
108 | )
109 | if state.get("generated_questions", []):
110 | prompt += (
111 | "\n\nHere are some questions that have already been generated, don't duplicate them: "
112 | + "\n".join(state["generated_questions"])
113 | )
114 | questions = model.with_structured_output(ListResponse).invoke(prompt)
115 | return questions
116 |
117 |
118 | def judge_node(state, config):
119 | model = init_chat_model(model=config["configurable"].get("judge_model", "gpt-4o"))
120 | _judge_prompt = config["configurable"].get("judge_prompt", answer_prompt)
121 | judge = model.with_structured_output(Judge).invoke(
122 | _judge_prompt.format(
123 | question_1=state["input_1"],
124 | question_2=state["input_2"],
125 | answer_1=state["output_1"],
126 | answer_2=state["output_2"],
127 | )
128 | )
129 | return {"judge": judge}
130 |
131 |
132 | async def _show_results(r):
133 | def clear_terminal():
134 | os.system("cls" if os.name == "nt" else "clear")
135 |
136 | clear_terminal()
137 | print(f"## Question 1: {r['input_1']}\n")
138 | print(r["output_1"])
139 | print("\n\n")
140 | print(f"## Question 2: {r['input_2']}\n")
141 | print(r["output_2"])
142 |
143 | print("\n\n")
144 |
145 | print(f"## Score: {r['judge']['similarity']}")
146 | print(f"Reasoning: {r['judge']['logic']}")
147 |
148 | print("\n\n")
149 |
150 | print("## Curate")
151 | print("**Enter**: To add both inputs to the dataset, just press enter")
152 | print("**`1`**: If you want to add only the first input to the dataset, enter `1`")
153 | print("**`2`**: If you want to add only the second input to the dataset, enter `2`")
154 | print("**`3`**: If you don't want to add either input to the dataset, enter `3`")
155 | print("**`q`**: To quit, enter `q`")
156 |
157 |
158 | def create_judge_graph(call_model: Callable):
159 | if inspect.iscoroutinefunction(call_model):
160 |
161 | async def answer_1(state: JudgeState):
162 | answer = await call_model(state["input_1"])
163 | return {"output_1": answer}
164 |
165 | async def answer_2(state: JudgeState):
166 | answer = await call_model(state["input_2"])
167 | return {"output_2": answer}
168 | else:
169 |
170 | def answer_1(state: JudgeState):
171 | answer = call_model(state["input_1"])
172 | return {"output_1": answer}
173 |
174 | def answer_2(state: JudgeState):
175 | answer = call_model(state["input_2"])
176 | return {"output_2": answer}
177 |
178 | judge_graph = StateGraph(JudgeState)
179 | judge_graph.add_node(answer_1)
180 | judge_graph.add_node(answer_2)
181 | judge_graph.add_node(judge_node)
182 | judge_graph.add_edge(START, "answer_1")
183 | judge_graph.add_edge(START, "answer_2")
184 | judge_graph.add_edge("answer_1", "judge_node")
185 | judge_graph.add_edge("answer_2", "judge_node")
186 | judge_graph.add_edge("judge_node", END)
187 | judge_graph = judge_graph.compile()
188 | return judge_graph
189 |
190 |
191 | def create_redteam_graph(call_model: Callable):
192 | judge_graph = create_judge_graph(call_model)
193 |
194 | async def judge_graph_node(state):
195 | result = await judge_graph.ainvoke(state)
196 | return {"answers": [result]}
197 |
198 | def generate_answers(state):
199 | return [
200 | Send("judge_graph_node", {"input_1": e["input_1"], "input_2": e["input_2"]})
201 | for e in state["pairs"]
202 | ]
203 |
204 | graph = StateGraph(AgentState)
205 | graph.add_node(generate_questions)
206 | graph.add_node(judge_graph_node)
207 | graph.add_conditional_edges("generate_questions", generate_answers)
208 | graph.set_entry_point("generate_questions")
209 | graph = graph.compile()
210 | return graph
211 |
212 |
213 | async def run_redteam(
214 | config,
215 | call_model: Callable,
216 | dataset_id,
217 | n,
218 | max_concurrency,
219 | n_prefill_questions,
220 | max_similarity,
221 | persistence_path,
222 | ):
223 | os.system("cls" if os.name == "nt" else "clear")
224 | print("Running Redteam...")
225 | if persistence_path:
226 | try:
227 | with open(persistence_path, "r") as config_file:
228 | persistence = json.load(config_file)
229 | except FileNotFoundError:
230 | persistence = {}
231 | else:
232 | persistence = {}
233 | dataset_id = (
234 | dataset_id
235 | or config.get("dataset_id", None)
236 | or persistence.get("dataset_id", None)
237 | )
238 | n = n or config.get("n", 10)
239 | max_concurrency = max_concurrency or config.get("max_concurrency", 10)
240 | n_prefill_questions = n_prefill_questions or config.get("n_prefill_questions", 10)
241 | max_similarity = max_similarity or config.get("max_similarity", 10)
242 | if persistence:
243 | generated_questions = persistence.get("generated_questions", [])
244 | else:
245 | generated_questions = []
246 | chatbot_description = config["chatbot_description"]
247 | client = Client()
248 | if dataset_id is None:
249 | name = f"Redteaming results {time.strftime('%Y-%m-%d %H:%M:%S')}"
250 | dataset_id = client.create_dataset(dataset_name=name).id
251 | print(f"Created dataset: {name}")
252 | if persistence:
253 | persistence["dataset_id"] = str(dataset_id)
254 | with open(persistence_path, "w") as config_file:
255 | json.dump(persistence, config_file, indent=4)
256 | graph = create_redteam_graph(call_model)
257 |
258 | # Add these to Results below
259 | done = asyncio.Event()
260 | results = queue.PriorityQueue()
261 | got_results = False
262 |
263 | async def collect_results():
264 | inputs = {
265 | "chatbot_description": chatbot_description,
266 | "n": n,
267 | "generated_questions": generated_questions,
268 | }
269 | async with contextlib.aclosing(
270 | graph.astream(
271 | inputs, {"max_concurrency": max_concurrency}, stream_mode="updates"
272 | )
273 | ) as stream:
274 | done_fut = asyncio.create_task(done.wait())
275 | while True:
276 | try:
277 | fin, other = await asyncio.wait(
278 | {done_fut, asyncio.create_task(anext(stream))},
279 | return_when=asyncio.FIRST_COMPLETED,
280 | )
281 | for fut in fin:
282 | if fut is done_fut:
283 | for f in other:
284 | f.cancel()
285 | try:
286 | await f
287 | except asyncio.CancelledError:
288 | pass
289 | break
290 | else:
291 | event = fut.result()
292 | except StopAsyncIteration:
293 | break
294 | if "generate_questions" in event:
295 | print(
296 | f"Generated {len(event['generate_questions']['pairs'])} pairs"
297 | )
298 | if "judge_graph_node" in event:
299 | for answer in event["judge_graph_node"]["answers"]:
300 | if answer["judge"]["similarity"] <= max_similarity:
301 | results.put(Result(answer["judge"]["similarity"], answer))
302 | else:
303 | if persistence_path:
304 | generated_questions.extend(
305 | [answer["input_1"], answer["input_2"]]
306 | )
307 | persistence["generated_questions"] = generated_questions
308 | with open(persistence_path, "w") as config_file:
309 | json.dump(persistence, config_file, indent=4)
310 | while results.qsize() >= n_prefill_questions:
311 | await asyncio.sleep(1)
312 | results.put(Result(11, {}))
313 |
314 | def run_async_collection():
315 | asyncio.run(collect_results())
316 |
317 | thread = threading.Thread(target=run_async_collection)
318 | thread.start()
319 |
320 | # Show results
321 | while True:
322 | if results:
323 | got_results = True
324 | r1 = results.get()
325 | if r1.prio == 11:
326 | break
327 | r = r1.answer
328 | if persistence_path:
329 | generated_questions.extend([r["input_1"], r["input_2"]])
330 | persistence["generated_questions"] = generated_questions
331 | with open(persistence_path, "w") as config_file:
332 | json.dump(persistence, config_file, indent=4)
333 | await _show_results(r)
334 | i = input()
335 | if i == "1":
336 | client.create_examples(
337 | inputs=[{"question": r["input_1"]}],
338 | dataset_id=dataset_id,
339 | )
340 | elif i == "2":
341 | client.create_examples(
342 | inputs=[{"question": r["input_2"]}],
343 | dataset_id=dataset_id,
344 | )
345 | elif i == "3":
346 | pass
347 | elif i == "q":
348 | done.set()
349 | thread.join()
350 | break
351 | else:
352 | client.create_examples(
353 | inputs=[{"question": r["input_1"]}, {"question": r["input_2"]}],
354 | dataset_id=dataset_id,
355 | )
356 | elif not thread.is_alive():
357 | break
358 | elif not got_results:
359 | time.sleep(0.1)
360 | else:
361 | os.system("cls" if os.name == "nt" else "clear")
362 | print("Waiting for results...")
363 | time.sleep(0.1)
364 |
365 |
366 | async def run_redteam_dump(
367 | config, call_model: Callable, n, max_concurrency, results_path
368 | ):
369 | print("Running Redteam...")
370 | n = n or config.get("n", 10)
371 | max_concurrency = max_concurrency or config.get("max_concurrency", 10)
372 | chatbot_description = config["chatbot_description"]
373 | graph = create_redteam_graph(call_model)
374 |
375 | # Add these to Results below
376 | inputs = {"chatbot_description": chatbot_description, "n": n}
377 | final_result = {}
378 | async for event_type, event in graph.astream(
379 | inputs, {"max_concurrency": max_concurrency}, stream_mode=["values", "updates"]
380 | ):
381 | if event_type == "values":
382 | if "answers" in event and len(event["answers"]) > 0:
383 | print(f"Finished finding {len(event['answers'])} pairs")
384 | elif "pairs" in event:
385 | print(f"Generated {len(event['pairs'])} pairs")
386 | final_result = event
387 | else:
388 | if "judge_graph_node" in event:
389 | print("Found a pair")
390 |
391 | with open(results_path, "w") as f:
392 | json.dump(final_result["answers"], f, indent=2)
393 |
394 |
395 | def main():
396 | def load_config(config_path):
397 | with open(config_path, "r") as file:
398 | return yaml.safe_load(file)
399 |
400 | def load_call_model(file_path):
401 | spec = importlib.util.spec_from_file_location("call_model_module", file_path)
402 | module = importlib.util.module_from_spec(spec)
403 | spec.loader.exec_module(module)
404 | return module.call_model
405 |
406 | parser = argparse.ArgumentParser(
407 | description="Run RedteamingAgent with configuration"
408 | )
409 | parser.add_argument("config_path", type=str, help="Path to the configuration file")
410 | parser.add_argument("--dataset_id", type=str, help="ID of the dataset to use")
411 | parser.add_argument("--n", type=int, help="Number of questions to generate")
412 | parser.add_argument(
413 | "--max_concurrency",
414 | type=int,
415 | help="Maximum number of concurrent requests to the model",
416 | )
417 | parser.add_argument(
418 | "--n_prefill_questions",
419 | type=int,
420 | help="Number of questions to prefill the dataset with",
421 | )
422 | parser.add_argument(
423 | "--max_similarity", type=int, help="Maximum similarity score to accept"
424 | )
425 | parser.add_argument(
426 | "-p", "--persistence-path", type=str, help="Path to the persistence file"
427 | )
428 | args = parser.parse_args()
429 |
430 | config = load_config(args.config_path)
431 |
432 | call_model = load_call_model(config["model_file"])
433 | asyncio.run(
434 | run_redteam(
435 | config,
436 | call_model,
437 | args.dataset_id,
438 | args.n,
439 | args.max_concurrency,
440 | args.n_prefill_questions,
441 | args.max_similarity,
442 | args.persistence_path,
443 | )
444 | )
445 |
446 |
447 | def dump():
448 | def load_config(config_path):
449 | with open(config_path, "r") as file:
450 | return yaml.safe_load(file)
451 |
452 | def load_call_model(file_path):
453 | spec = importlib.util.spec_from_file_location("call_model_module", file_path)
454 | module = importlib.util.module_from_spec(spec)
455 | spec.loader.exec_module(module)
456 | return module.call_model
457 |
458 | parser = argparse.ArgumentParser(
459 | description="Run RedteamingAgent with configuration"
460 | )
461 | parser.add_argument("config_path", type=str, help="Path to the configuration file")
462 | parser.add_argument("results_path", type=str, help="Path to store results in")
463 | parser.add_argument("--n", type=int, help="Number of questions to generate")
464 | parser.add_argument(
465 | "--max_concurrency",
466 | type=int,
467 | help="Maximum number of concurrent requests to the model",
468 | )
469 | args = parser.parse_args()
470 |
471 | config = load_config(args.config_path)
472 |
473 | call_model = load_call_model(config["model_file"])
474 | asyncio.run(
475 | run_redteam_dump(
476 | config, call_model, args.n, args.max_concurrency, args.results_path
477 | )
478 | )
479 |
--------------------------------------------------------------------------------