├── .gitignore ├── LICENSE ├── README.md ├── call_model.py ├── config.yaml ├── pyproject.toml └── src └── langfuzz ├── __init__.py └── redteam.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Harrison Chase 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # LangFuzz: Red Teaming for Language Models 3 | LangFuzz is a command line tool designed to perform red teaming on language model applications and add any points of interest to a [LangSmith Dataset](https://docs.smith.langchain.com/). It generates pairs of similar questions and compares the responses to identify potential failure modes in chatbots or other language model-based systems. For those coming from a software engineering background: this similar to a particular type of [fuzz testing](https://www.blackduck.com/glossary/what-is-fuzz-testing.html#:~:text=Definition,as%20crashes%20or%20information%20leakage.) called [metamorphic testing](https://arxiv.org/abs/2002.12543). 4 | 5 | ## Installation 6 | To install LangFuzz, use pip: 7 | 8 | ``` 9 | pip install langfuzz 10 | ``` 11 | 12 | ## Usage 13 | 14 | ### Step 1: define a model file 15 | First, define a model file that calls your model. This file should expose a sync OR async function called `call_model` that takes in a string and returns a string. An example file is found in [call_model.py](call_model.py). Example model file: 16 | 17 | ```python 18 | import random 19 | from openai import OpenAI 20 | 21 | client = OpenAI() 22 | 23 | 24 | def call_model(question: str) -> str: 25 | # This is to add some randomness in and get bad answers. 26 | if random.uniform(0, 1) > 0.5: 27 | system_message = "LangChain is an LLM framework - answer all questions with things about LLMs." 28 | else: 29 | system_message = "LangChain is blockchain technology - answer all questions with things about crypto" 30 | 31 | completion = client.chat.completions.create( 32 | model="gpt-4o-mini", 33 | messages=[ 34 | {"role": "system", "content": system_message}, 35 | {"role": "user", "content": question}, 36 | ], 37 | ) 38 | return completion.choices[0].message.content 39 | 40 | ``` 41 | 42 | ### Step 2: define a configuration file 43 | 44 | Next, you need to define a configuration file. 45 | 46 | The `config.yaml` file should contain the following keys: 47 | - `chatbot_description`: A description of the chatbot being tested. 48 | - `model_file`: Path to the Python file containing the call_model function 49 | 50 | Example config.yaml: 51 | ``` 52 | chatbot_description: "Chat over LangChain Docs" 53 | model_file: "call_model.py" 54 | ``` 55 | 56 | It may optionally contain other keys (see [Options](#options) and [Additional Configuration](#additional-configuration) below). 57 | 58 | ### Step 3: set any environment variables 59 | 60 | This requires several environment variables to run: 61 | 62 | - `OPENAI_API_KEY`: required for accessing OpenAI model 63 | - [OPTIONAL] `LANGSMITH_API_KEY`: required for sending generated data to a LangSmith dataset. 64 | 65 | ``` 66 | export OPENAI_API_KEY=... 67 | export LANGSMITH_API_KEY=... 68 | ``` 69 | 70 | ### Step 4: run the red teaming 71 | 72 | To run the red teaming process, use the following command: 73 | ``` 74 | langfuzz config.yaml [options] 75 | ``` 76 | 77 | ### Step 5: curate datapoints 78 | 79 | As the redteaming is run, pairs of datapoints will be shown to you in the command line. From there, you can choose to add both, one, or neither to a LangSmith dataset. 80 | 81 | - **Enter**: To add both inputs to the dataset, just press enter 82 | - **`1`**: If you want to add only the first input to the dataset, enter `1` 83 | - **`2`**: If you want to add only the second input to the dataset, enter `2` 84 | - **`3`**: If you don't want to add either input to the dataset, enter `3` 85 | - **`q`**: To quit, enter `q` 86 | 87 | If you add a datapoint to a LangSmith dataset, it will be added with a single input key `question` and no output key. 88 | 89 | ## Options 90 | 91 | - `--dataset_id`: ID of the dataset to use (optional) 92 | - `--n`: Number of questions to generate (default: 10) 93 | - `--max_concurrency`: Maximum number of concurrent requests to the model (default: 10) 94 | - `--n_prefill_questions`: Number of questions to prefill the dataset with (default: 10) 95 | - `--max_similarity`: Maximum similarity score to accept (default: 10) 96 | - `-p`, `--persistence-path`: Path to the persistence file (optional) 97 | 98 | These options can additionally be provided as part of the configuration file. 99 | 100 | ## Additional Configuration 101 | 102 | You can also configure more aspects of the redteaming agent. 103 | 104 | - `judge_model`: the model to use to judge whether two pairs are similar 105 | - `question_gen_model`: the model to use to generate pairs of questions 106 | - `judge_prompt`: the prompt to use to judge whether two pairs are similar 107 | - `question_gent_prompt`: the prompt to use to generate pairs of questions 108 | 109 | ## How It Works 110 | 111 | The tool generates pairs of similar questions based on the provided chatbot description. 112 | It sends these questions to the chatbot and compares the responses. 113 | A judge model evaluates the similarity of the responses on a scale of 1-10. 114 | Results with similarity scores below the max_similarity threshold are presented for review. 115 | Users can choose to add the questions to a dataset for further analysis or training. 116 | Persistence 117 | If a persistence path is provided, the tool will save generated questions and dataset information between runs. This allows for continuous red teaming sessions without duplicating questions. 118 | 119 | ## Use without LangSmith 120 | 121 | You can also use this redteaming agent without LangSmith and dump all the results to a local file. To do this, use the following command: 122 | 123 | ``` 124 | langfuzz-dump config.yaml results.json [options] 125 | ``` 126 | -------------------------------------------------------------------------------- /call_model.py: -------------------------------------------------------------------------------- 1 | import random 2 | from openai import OpenAI 3 | 4 | client = OpenAI() 5 | 6 | 7 | def call_model(question: str) -> str: 8 | # This is to add some randomness in and get bad answers. 9 | if random.uniform(0, 1) > 0.5: 10 | system_message = "LangChain is an LLM framework - answer all questions with things about LLMs." 11 | else: 12 | system_message = "LangChain is blockchain technology - answer all questions with things about crypto" 13 | 14 | completion = client.chat.completions.create( 15 | model="gpt-4o-mini", 16 | messages=[ 17 | {"role": "system", "content": system_message}, 18 | {"role": "user", "content": question}, 19 | ], 20 | ) 21 | return completion.choices[0].message.content 22 | -------------------------------------------------------------------------------- /config.yaml: -------------------------------------------------------------------------------- 1 | chatbot_description: "Chat over LangChain Docs" 2 | model_file: "call_model.py" 3 | configurable: 4 | judge_model: "gpt-4o" 5 | question_gen_model: "gpt-4o-mini" 6 | judge_prompt: | 7 | You are benevolent white hat hacker. Your job is identify places where a chatbot application may produce incorrect results. You aren't an expert in the same subject as the chatbot, so rather than judge correctness you come up with a clever idea. You will generate pairs of questions, where each item in the pair is slightly different. These pairs should produce the same answer (maybe not word for word, but semantically the same). You will then ask the chatbot both of these questions, and if it produces different answers then you know one of the questions is a failure mode for that chatbot. 8 | 9 | 10 | {question_1} 11 | 12 | 13 | 14 | {answer_1} 15 | 16 | 17 | 18 | {question_2} 19 | 20 | 21 | 22 | {answer_2} 23 | 24 | 25 | How similar are these answers on a scale of 1-10? 1 is not similar at all, 10 is absolutely similar 26 | If one answer contains information that the other one doesn't, then the similarity should be lower. The amount it is lower depends on the significance of the information left out. 27 | If the answers contradict eachother, then the similarity should be very low. 28 | question_gen_prompt: | 29 | You are benevolent white hat hacker. Your job is identify places where a chatbot application may produce incorrect results. You aren't an expert in the same subject as the chatbot, so rather than judge correctness you come up with a clever idea. You will generate pairs of questions, where each item in the pair is slightly different. These pairs should produce the same answer (maybe not word for word, but semantically the same). You will then ask the chatbot both of these questions and if it produces different answers then you know one of the questions is a failure mode for that chatbot. 30 | 31 | The chatbot you are redteaming is: 32 | 33 | {chatbot_description} 34 | 35 | Please a list of pairs of questions that are relevant to this chatbot. Make sure to generate a diverse set of questions. Only generate {n} questions. 36 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-syste] 2 | requires = ["setuptools>=45", "setuptools-scm>=6.2"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "langfuzz" 7 | version = "0.0.1" 8 | description = "Project to fuzz language model applications" 9 | authors = [{name = "Harrison Chase"}] 10 | dependencies = [ 11 | "langgraph", 12 | "langchain", 13 | "langchain_openai", 14 | "langsmith", 15 | "pyyaml", 16 | ] 17 | 18 | [project.scripts] 19 | langfuzz = "langfuzz.redteam:main" 20 | langfuzz-dump = "langfuzz.redteam:dump" 21 | 22 | [tool.setuptools.packages.find] 23 | where = ["src"] 24 | -------------------------------------------------------------------------------- /src/langfuzz/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwchase17/langfuzz/7116127c2429c382ff0e0701b7ab2271bc252eb9/src/langfuzz/__init__.py -------------------------------------------------------------------------------- /src/langfuzz/redteam.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import asyncio 3 | import contextlib 4 | import dataclasses 5 | import importlib.util 6 | import json 7 | import operator 8 | import os 9 | import threading 10 | import time 11 | import warnings 12 | import queue 13 | import inspect 14 | import yaml 15 | from typing import Annotated, Any, Callable, TypedDict 16 | 17 | from langchain.chat_models import init_chat_model 18 | 19 | from langsmith import Client 20 | from langgraph.graph import END, START, StateGraph 21 | from langgraph.types import Send 22 | 23 | # Suppress all warnings 24 | warnings.filterwarnings("ignore") 25 | 26 | 27 | @dataclasses.dataclass(order=True) 28 | class Result: 29 | prio: int 30 | answer: dict = dataclasses.field(hash=False, compare=False) 31 | 32 | 33 | class Response(TypedDict): 34 | input_1: str 35 | input_2: str 36 | 37 | 38 | class ListResponse(TypedDict): 39 | pairs: list[Response] 40 | 41 | 42 | class JudgeState(TypedDict): 43 | input_1: Any 44 | input_2: Any 45 | output_1: Any 46 | output_2: Any 47 | judge: dict 48 | 49 | 50 | class AgentState(TypedDict): 51 | chatbot_description: str 52 | generated_questions: list[str] 53 | n: int 54 | pairs: list 55 | answers: Annotated[list, operator.add] 56 | 57 | 58 | answer_prompt = """You are benevolent white hat hacker. Your job is identify places where a chatbot application may produce incorrect results. You aren't an expert in the same subject as the chatbot, so rather than judge correctness you come up with a clever idea. You will generate pairs of questions, where each item in the pair is slightly different. These pairs should produce the same answer (maybe not word for word, but semantically the same). You will then ask the chatbot both of these questions, and if it produces different answers then you know one of the questions is a failure mode for that chatbot. 59 | 60 | 61 | {question_1} 62 | 63 | 64 | 65 | {answer_1} 66 | 67 | 68 | 69 | {question_2} 70 | 71 | 72 | 73 | {answer_2} 74 | 75 | 76 | How similar are these answers on a scale of 1-10? 1 is not similar at all, 10 is absolutely similar 77 | If one answer contains information that the other one doesn't, then the similarity should be lower. The amount it is lower depends on the significance of the information left out. 78 | If the answers contradict eachother, then the similarity should be very low.""" 79 | 80 | 81 | class Judge(TypedDict): 82 | logic: str 83 | similarity: int 84 | 85 | 86 | question_prompt = """You are benevolent white hat hacker. Your job is identify places where a chatbot application may produce \ 87 | incorrect results. You aren't an expert in the same subject as the chatbot, so rather than judge correctness you come up \ 88 | with a clever idea. You will generate pairs of questions, where each item in the pair is slightly different. These pairs should \ 89 | produce the same answer (maybe not word for word, but semantically the same). You will then ask the chatbot both of these questions, \ 90 | and if it produces different answers then you know one of the questions is a failure mode for that chatbot. 91 | 92 | The chatbot you are redteaming is: 93 | 94 | {chatbot_description} 95 | 96 | Please a list of pairs of questions that are relevant to this chatbot. Make sure to generate a diverse set of questions. Only generate {n} questions.""" 97 | 98 | 99 | def generate_questions(state, config): 100 | model = init_chat_model( 101 | model=config["configurable"].get("question_gen_model", "gpt-4o-mini") 102 | ) 103 | _question_prompt = config["configurable"].get( 104 | "question_gen_prompt", question_prompt 105 | ) 106 | prompt = _question_prompt.format( 107 | chatbot_description=state["chatbot_description"], n=state["n"] 108 | ) 109 | if state.get("generated_questions", []): 110 | prompt += ( 111 | "\n\nHere are some questions that have already been generated, don't duplicate them: " 112 | + "\n".join(state["generated_questions"]) 113 | ) 114 | questions = model.with_structured_output(ListResponse).invoke(prompt) 115 | return questions 116 | 117 | 118 | def judge_node(state, config): 119 | model = init_chat_model(model=config["configurable"].get("judge_model", "gpt-4o")) 120 | _judge_prompt = config["configurable"].get("judge_prompt", answer_prompt) 121 | judge = model.with_structured_output(Judge).invoke( 122 | _judge_prompt.format( 123 | question_1=state["input_1"], 124 | question_2=state["input_2"], 125 | answer_1=state["output_1"], 126 | answer_2=state["output_2"], 127 | ) 128 | ) 129 | return {"judge": judge} 130 | 131 | 132 | async def _show_results(r): 133 | def clear_terminal(): 134 | os.system("cls" if os.name == "nt" else "clear") 135 | 136 | clear_terminal() 137 | print(f"## Question 1: {r['input_1']}\n") 138 | print(r["output_1"]) 139 | print("\n\n") 140 | print(f"## Question 2: {r['input_2']}\n") 141 | print(r["output_2"]) 142 | 143 | print("\n\n") 144 | 145 | print(f"## Score: {r['judge']['similarity']}") 146 | print(f"Reasoning: {r['judge']['logic']}") 147 | 148 | print("\n\n") 149 | 150 | print("## Curate") 151 | print("**Enter**: To add both inputs to the dataset, just press enter") 152 | print("**`1`**: If you want to add only the first input to the dataset, enter `1`") 153 | print("**`2`**: If you want to add only the second input to the dataset, enter `2`") 154 | print("**`3`**: If you don't want to add either input to the dataset, enter `3`") 155 | print("**`q`**: To quit, enter `q`") 156 | 157 | 158 | def create_judge_graph(call_model: Callable): 159 | if inspect.iscoroutinefunction(call_model): 160 | 161 | async def answer_1(state: JudgeState): 162 | answer = await call_model(state["input_1"]) 163 | return {"output_1": answer} 164 | 165 | async def answer_2(state: JudgeState): 166 | answer = await call_model(state["input_2"]) 167 | return {"output_2": answer} 168 | else: 169 | 170 | def answer_1(state: JudgeState): 171 | answer = call_model(state["input_1"]) 172 | return {"output_1": answer} 173 | 174 | def answer_2(state: JudgeState): 175 | answer = call_model(state["input_2"]) 176 | return {"output_2": answer} 177 | 178 | judge_graph = StateGraph(JudgeState) 179 | judge_graph.add_node(answer_1) 180 | judge_graph.add_node(answer_2) 181 | judge_graph.add_node(judge_node) 182 | judge_graph.add_edge(START, "answer_1") 183 | judge_graph.add_edge(START, "answer_2") 184 | judge_graph.add_edge("answer_1", "judge_node") 185 | judge_graph.add_edge("answer_2", "judge_node") 186 | judge_graph.add_edge("judge_node", END) 187 | judge_graph = judge_graph.compile() 188 | return judge_graph 189 | 190 | 191 | def create_redteam_graph(call_model: Callable): 192 | judge_graph = create_judge_graph(call_model) 193 | 194 | async def judge_graph_node(state): 195 | result = await judge_graph.ainvoke(state) 196 | return {"answers": [result]} 197 | 198 | def generate_answers(state): 199 | return [ 200 | Send("judge_graph_node", {"input_1": e["input_1"], "input_2": e["input_2"]}) 201 | for e in state["pairs"] 202 | ] 203 | 204 | graph = StateGraph(AgentState) 205 | graph.add_node(generate_questions) 206 | graph.add_node(judge_graph_node) 207 | graph.add_conditional_edges("generate_questions", generate_answers) 208 | graph.set_entry_point("generate_questions") 209 | graph = graph.compile() 210 | return graph 211 | 212 | 213 | async def run_redteam( 214 | config, 215 | call_model: Callable, 216 | dataset_id, 217 | n, 218 | max_concurrency, 219 | n_prefill_questions, 220 | max_similarity, 221 | persistence_path, 222 | ): 223 | os.system("cls" if os.name == "nt" else "clear") 224 | print("Running Redteam...") 225 | if persistence_path: 226 | try: 227 | with open(persistence_path, "r") as config_file: 228 | persistence = json.load(config_file) 229 | except FileNotFoundError: 230 | persistence = {} 231 | else: 232 | persistence = {} 233 | dataset_id = ( 234 | dataset_id 235 | or config.get("dataset_id", None) 236 | or persistence.get("dataset_id", None) 237 | ) 238 | n = n or config.get("n", 10) 239 | max_concurrency = max_concurrency or config.get("max_concurrency", 10) 240 | n_prefill_questions = n_prefill_questions or config.get("n_prefill_questions", 10) 241 | max_similarity = max_similarity or config.get("max_similarity", 10) 242 | if persistence: 243 | generated_questions = persistence.get("generated_questions", []) 244 | else: 245 | generated_questions = [] 246 | chatbot_description = config["chatbot_description"] 247 | client = Client() 248 | if dataset_id is None: 249 | name = f"Redteaming results {time.strftime('%Y-%m-%d %H:%M:%S')}" 250 | dataset_id = client.create_dataset(dataset_name=name).id 251 | print(f"Created dataset: {name}") 252 | if persistence: 253 | persistence["dataset_id"] = str(dataset_id) 254 | with open(persistence_path, "w") as config_file: 255 | json.dump(persistence, config_file, indent=4) 256 | graph = create_redteam_graph(call_model) 257 | 258 | # Add these to Results below 259 | done = asyncio.Event() 260 | results = queue.PriorityQueue() 261 | got_results = False 262 | 263 | async def collect_results(): 264 | inputs = { 265 | "chatbot_description": chatbot_description, 266 | "n": n, 267 | "generated_questions": generated_questions, 268 | } 269 | async with contextlib.aclosing( 270 | graph.astream( 271 | inputs, {"max_concurrency": max_concurrency}, stream_mode="updates" 272 | ) 273 | ) as stream: 274 | done_fut = asyncio.create_task(done.wait()) 275 | while True: 276 | try: 277 | fin, other = await asyncio.wait( 278 | {done_fut, asyncio.create_task(anext(stream))}, 279 | return_when=asyncio.FIRST_COMPLETED, 280 | ) 281 | for fut in fin: 282 | if fut is done_fut: 283 | for f in other: 284 | f.cancel() 285 | try: 286 | await f 287 | except asyncio.CancelledError: 288 | pass 289 | break 290 | else: 291 | event = fut.result() 292 | except StopAsyncIteration: 293 | break 294 | if "generate_questions" in event: 295 | print( 296 | f"Generated {len(event['generate_questions']['pairs'])} pairs" 297 | ) 298 | if "judge_graph_node" in event: 299 | for answer in event["judge_graph_node"]["answers"]: 300 | if answer["judge"]["similarity"] <= max_similarity: 301 | results.put(Result(answer["judge"]["similarity"], answer)) 302 | else: 303 | if persistence_path: 304 | generated_questions.extend( 305 | [answer["input_1"], answer["input_2"]] 306 | ) 307 | persistence["generated_questions"] = generated_questions 308 | with open(persistence_path, "w") as config_file: 309 | json.dump(persistence, config_file, indent=4) 310 | while results.qsize() >= n_prefill_questions: 311 | await asyncio.sleep(1) 312 | results.put(Result(11, {})) 313 | 314 | def run_async_collection(): 315 | asyncio.run(collect_results()) 316 | 317 | thread = threading.Thread(target=run_async_collection) 318 | thread.start() 319 | 320 | # Show results 321 | while True: 322 | if results: 323 | got_results = True 324 | r1 = results.get() 325 | if r1.prio == 11: 326 | break 327 | r = r1.answer 328 | if persistence_path: 329 | generated_questions.extend([r["input_1"], r["input_2"]]) 330 | persistence["generated_questions"] = generated_questions 331 | with open(persistence_path, "w") as config_file: 332 | json.dump(persistence, config_file, indent=4) 333 | await _show_results(r) 334 | i = input() 335 | if i == "1": 336 | client.create_examples( 337 | inputs=[{"question": r["input_1"]}], 338 | dataset_id=dataset_id, 339 | ) 340 | elif i == "2": 341 | client.create_examples( 342 | inputs=[{"question": r["input_2"]}], 343 | dataset_id=dataset_id, 344 | ) 345 | elif i == "3": 346 | pass 347 | elif i == "q": 348 | done.set() 349 | thread.join() 350 | break 351 | else: 352 | client.create_examples( 353 | inputs=[{"question": r["input_1"]}, {"question": r["input_2"]}], 354 | dataset_id=dataset_id, 355 | ) 356 | elif not thread.is_alive(): 357 | break 358 | elif not got_results: 359 | time.sleep(0.1) 360 | else: 361 | os.system("cls" if os.name == "nt" else "clear") 362 | print("Waiting for results...") 363 | time.sleep(0.1) 364 | 365 | 366 | async def run_redteam_dump( 367 | config, call_model: Callable, n, max_concurrency, results_path 368 | ): 369 | print("Running Redteam...") 370 | n = n or config.get("n", 10) 371 | max_concurrency = max_concurrency or config.get("max_concurrency", 10) 372 | chatbot_description = config["chatbot_description"] 373 | graph = create_redteam_graph(call_model) 374 | 375 | # Add these to Results below 376 | inputs = {"chatbot_description": chatbot_description, "n": n} 377 | final_result = {} 378 | async for event_type, event in graph.astream( 379 | inputs, {"max_concurrency": max_concurrency}, stream_mode=["values", "updates"] 380 | ): 381 | if event_type == "values": 382 | if "answers" in event and len(event["answers"]) > 0: 383 | print(f"Finished finding {len(event['answers'])} pairs") 384 | elif "pairs" in event: 385 | print(f"Generated {len(event['pairs'])} pairs") 386 | final_result = event 387 | else: 388 | if "judge_graph_node" in event: 389 | print("Found a pair") 390 | 391 | with open(results_path, "w") as f: 392 | json.dump(final_result["answers"], f, indent=2) 393 | 394 | 395 | def main(): 396 | def load_config(config_path): 397 | with open(config_path, "r") as file: 398 | return yaml.safe_load(file) 399 | 400 | def load_call_model(file_path): 401 | spec = importlib.util.spec_from_file_location("call_model_module", file_path) 402 | module = importlib.util.module_from_spec(spec) 403 | spec.loader.exec_module(module) 404 | return module.call_model 405 | 406 | parser = argparse.ArgumentParser( 407 | description="Run RedteamingAgent with configuration" 408 | ) 409 | parser.add_argument("config_path", type=str, help="Path to the configuration file") 410 | parser.add_argument("--dataset_id", type=str, help="ID of the dataset to use") 411 | parser.add_argument("--n", type=int, help="Number of questions to generate") 412 | parser.add_argument( 413 | "--max_concurrency", 414 | type=int, 415 | help="Maximum number of concurrent requests to the model", 416 | ) 417 | parser.add_argument( 418 | "--n_prefill_questions", 419 | type=int, 420 | help="Number of questions to prefill the dataset with", 421 | ) 422 | parser.add_argument( 423 | "--max_similarity", type=int, help="Maximum similarity score to accept" 424 | ) 425 | parser.add_argument( 426 | "-p", "--persistence-path", type=str, help="Path to the persistence file" 427 | ) 428 | args = parser.parse_args() 429 | 430 | config = load_config(args.config_path) 431 | 432 | call_model = load_call_model(config["model_file"]) 433 | asyncio.run( 434 | run_redteam( 435 | config, 436 | call_model, 437 | args.dataset_id, 438 | args.n, 439 | args.max_concurrency, 440 | args.n_prefill_questions, 441 | args.max_similarity, 442 | args.persistence_path, 443 | ) 444 | ) 445 | 446 | 447 | def dump(): 448 | def load_config(config_path): 449 | with open(config_path, "r") as file: 450 | return yaml.safe_load(file) 451 | 452 | def load_call_model(file_path): 453 | spec = importlib.util.spec_from_file_location("call_model_module", file_path) 454 | module = importlib.util.module_from_spec(spec) 455 | spec.loader.exec_module(module) 456 | return module.call_model 457 | 458 | parser = argparse.ArgumentParser( 459 | description="Run RedteamingAgent with configuration" 460 | ) 461 | parser.add_argument("config_path", type=str, help="Path to the configuration file") 462 | parser.add_argument("results_path", type=str, help="Path to store results in") 463 | parser.add_argument("--n", type=int, help="Number of questions to generate") 464 | parser.add_argument( 465 | "--max_concurrency", 466 | type=int, 467 | help="Maximum number of concurrent requests to the model", 468 | ) 469 | args = parser.parse_args() 470 | 471 | config = load_config(args.config_path) 472 | 473 | call_model = load_call_model(config["model_file"]) 474 | asyncio.run( 475 | run_redteam_dump( 476 | config, call_model, args.n, args.max_concurrency, args.results_path 477 | ) 478 | ) 479 | --------------------------------------------------------------------------------