├── .gitignore ├── Dockerfile ├── LICENSE ├── README.md ├── app.py ├── chill.py ├── data.py ├── data_saver_server.py ├── docker-compose.yml ├── gradio_cached_examples └── 14 │ └── log.csv ├── learn.py ├── local_score.py ├── promptObjects.py ├── runpod.dockerfile ├── runpod_handler.py ├── serverless.md ├── serverless_local_test.py ├── system_map.md └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | .aider* 2 | .cache 3 | *.pyc 4 | *.jsonl 5 | *.json -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | ARG CUDA_IMAGE="12.1.1-devel-ubuntu22.04" 2 | FROM nvidia/cuda:${CUDA_IMAGE} 3 | 4 | # We need to set the host to 0.0.0.0 to allow outside access 5 | ENV HOST 0.0.0.0 6 | 7 | RUN apt-get update && apt-get upgrade -y \ 8 | && apt-get install -y git build-essential \ 9 | python3 python3-pip gcc wget \ 10 | ocl-icd-opencl-dev opencl-headers clinfo \ 11 | libclblast-dev libopenblas-dev \ 12 | && mkdir -p /etc/OpenCL/vendors && echo "libnvidia-opencl.so.1" > /etc/OpenCL/vendors/nvidia.icd 13 | RUN apt-get install git -y 14 | COPY . . 15 | 16 | # setting build related env vars 17 | ENV CUDA_DOCKER_ARCH=all 18 | ENV LLAMA_CUBLAS=1 19 | 20 | RUN useradd -m -u 1000 user 21 | # Switch to the "user" user 22 | USER user 23 | # Set home to the user's home directory 24 | ENV HOME=/home/user \ 25 | PATH=/home/user/.local/bin:$PATH \ 26 | PYTHONPATH=$HOME/app \ 27 | PYTHONUNBUFFERED=1 \ 28 | GRADIO_ALLOW_FLAGGING=never \ 29 | GRADIO_NUM_PORTS=1 \ 30 | GRADIO_SERVER_NAME=0.0.0.0 \ 31 | GRADIO_THEME=huggingface \ 32 | SYSTEM=spaces 33 | 34 | WORKDIR $HOME/app 35 | 36 | # Copy the current directory contents into the container at $HOME/app setting the owner to the user 37 | COPY --chown=user . $HOME/app 38 | # Install dependencies 39 | RUN python3 -m pip install --upgrade pip && \ 40 | python3 -m pip install pytest cmake \ 41 | scikit-build setuptools fastapi uvicorn sse-starlette \ 42 | pydantic-settings starlette-context gradio huggingface_hub hf_transfer 43 | RUN python3 -m pip install requests pydantic uvicorn starlette fastapi sse_starlette starlette_context pydantic_settings 44 | 45 | 46 | # Install llama-cpp-python (build with cuda) 47 | #RUN CMAKE_ARGS="-DLLAMA_CUBLAS=on" pip install git+https://github.com/lukestanley/llama-cpp-python.git@expose_json_grammar_convert_function 48 | 49 | CMD ["python3", "app.py"] 50 | 51 | # Credit to Radamés Ajna for the original Dockerfile -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Luke Stanley 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | --- 2 | title: ❄️ ChillTranslator 🤬 ➡️ 😎💬 3 | emoji: ❄️ 4 | colorFrom: red 5 | colorTo: blue 6 | sdk: docker 7 | pinned: false 8 | --- 9 | # ❄️ ChillTranslator 🤬 ➡️ 😎💬 10 | 11 | 12 | This is an early experimental tool aimed at helping reduce online toxicity by automatically ➡️ transforming 🌶️ spicy or toxic comments into constructive, ❤️ kinder dialogues using AI and large language models. 13 | 14 | ![ChillTranslator demo](https://github.com/lukestanley/ChillTranslator/assets/306671/128611f4-3e8e-4c52-ba20-2ae61d727d52) 15 | 16 | 17 | You can try out the ChillTranslator via the HuggingFace Space demo at [https://huggingface.co/spaces/lukestanley/ChillTranslator](https://huggingface.co/spaces/lukestanley/ChillTranslator). 18 | 19 | It can also be used via the command line to improve a specific text of your choice, or it can be used as a module. All of the code is in this repo, for the serverles worker, the logic, and Gradio UI. 20 | 21 | ChillTranslator aims to help make online interactions more healthy. 22 | 23 | Online toxicity can undermine the quality of discourse, causing distress 😞 and driving people away from online communities. Or worse: it can create a viral toxic loop 🌀! 24 | 25 | 26 | 27 | ChillTranslator hopes to mitigate toxic comments by automatically rephrasing negative comments, while maintaining the original intent and promoting positive communication 🗣️➡️💬. These rephrased texts could be suggested to the original authors as alternatives, or users could enhance their internet experience with "rose-tinted glasses" 🌹😎, automatically translating spicy comments into versions that are easier and more calming to read. 28 | There could be all kinds of failure cases, but hey, it's a start! 29 | 30 | Could Reddit, Twitter, Hacker News, or even YouTube comments be more calm and constructive places? I think so! 31 | 32 | ## Aims to: 33 | - **Convert** text to less toxic variations 34 | - **Preserve original intent**, focusing on constructive dialogue 35 | - **Self-hostable, serverless, or APIs**: running DIY could save costs, avoid needing to sign up to APIs, and avoid the risk of toxic content causing API access to be revoked. We use llama-cpp-python with Mixtral, with a HTTP server option, a fast "serverless" backend using RunPod currently which had some reliability issues, so I'm using Mistral's own API right now until I can figure out a more reliable serverless method. 36 | 37 | ## Possible future directions 🌟 38 | 39 | **Speed:** 40 | - Generating rephrasings in parallel. 41 | - Combined some LLM tasks together, to reduce request overhead. 42 | - Show intermediate results to the user, while waiting for the final result. 43 | - Split text into sentences e.g: with “pysbd” for parallel processing of translations. 44 | 45 | **Speed and Quality:** 46 | - Use Jigsaw dataset to find spicy comments, making a dataset for training a translation transformer, maybe like Google's T5 to run faster than Mixtral could. 47 | - Try using a 'Detoxify' scoring model instead of the current "spicy" score method. 48 | - Use natural language similarity techniques to compare possible rephrasing fidelity faster. 49 | - Collecting a dataset of spicy comments and their rephrasings. 50 | - Feedback loop: users could score rephrasings, or suggest their own. 51 | 52 | **Distribution:** 53 | - Better example showing use as Python module, HTTP API, for use from other tools, browser extensions. 54 | - Enabling easy experimenting with online hosted LLM APIs 55 | - Making setup on different platforms easier 56 | 57 | 58 | ## Getting started 🚀 59 | 60 | ### Try it online 61 | 62 | You can try out ChillTranslator without any installation by visiting the HuggingFace Space demo: 63 | ``` 64 | https://huggingface.co/spaces/lukestanley/ChillTranslator 65 | ``` 66 | 67 | ### Installation 68 | 69 | 1. Clone the Project Repository: 70 | ``` 71 | git clone https://github.com/lukestanley/ChillTranslator.git 72 | cd ChillTranslator 73 | ``` 74 | 2. It will automaticaly download [Mixtral-8x7B-Instruct-v0.1-GGUF](https://huggingface.co/TheBloke/Mixtral-8x7B-Instruct-v0.1-GGUF/resolve/main/mixtral-8x7b-instruct-v0.1.Q4_K_M.gguf?download=true) by default. The model HuggingFace repo and filename can be switched by enviroment variables, or you can point to a different local path. 75 | 3. Install dependencies, including a special fork of `llama-cpp-python`, and Nvidia GPU support if needed: 76 | ``` 77 | pip install requests pydantic uvicorn starlette fastapi sse_starlette starlette_context pydantic_settings 78 | 79 | # If you have an Nvidia GPU, install the special fork of llama-cpp-python with CUBLAS support: 80 | CMAKE_ARGS="-DLLAMA_CUBLAS=on" pip install git+https://github.com/lukestanley/llama-cpp-python.git@expose_json_grammar_convert_function 81 | ``` 82 | If you don't have an Nvidia GPU, the `CMAKE_ARGS="-DLLAMA_CUBLAS=on"` is not needed before the `pip install` command. 83 | 84 | 4. Start the LLM server with your chosen configuration. Example for Nvidia with `--n_gpu_layers` set to 20; different GPUs fit more or less layers. If you have no GPU, you don't need the `--n_gpu_layers` flag: 85 | ``` 86 | python3 -m llama_cpp.server --model mixtral-8x7b-instruct-v0.1.Q4_K_M.gguf --port 5834 --n_ctx 4096 --use_mlock false --n_gpu_layers 20 & 87 | ``` 88 | These config options are likely to need tweaking. Please check out https://llama-cpp-python.readthedocs.io/en/latest/ for more info. 89 | 90 | 91 | ### Local Usage 92 | 93 | ChillTranslator can be used locally to improve specific texts. 94 | This is how to see it in action from the command line with a default text: 95 | ```python 96 | python3 chill.py 97 | ``` 98 | For improving a specific text of your choice, use the `-t` flag followed by your text enclosed in quotes: 99 | ```bash 100 | python3 chill.py -t "Your text goes here" 101 | ``` 102 | To run the Gradio web server GUI: 103 | ```python 104 | python3 app.py 105 | ``` 106 | Or chill can be imported as a module, with the improvement_loop function provided the text to improve. 107 | 108 | ## Contributing 🤝 109 | 110 | Contributions are very welcome! 111 | Especially: 112 | - pull requests, 113 | - free GPU credits 114 | - LLM API credits / access. 115 | 116 | ChillTranslator is released under the MIT License. 117 | 118 | Help make the internet a kinder place, one comment at a time. 119 | Your contribution could make a big difference! -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | from os import environ as env 2 | from os import system as run 3 | from subprocess import check_output 4 | 5 | import gradio as gr 6 | 7 | 8 | def inference_binary_check(): 9 | # Without a GPU, we need to re-install llama-cpp-python to avoid an error. 10 | # We use a shell command to detect if we have an NVIDIA GPU available: 11 | use_gpu = True 12 | try: 13 | command = "nvidia-debugdump --list|grep Device" 14 | output = str(check_output(command, shell=True).decode()) 15 | if "NVIDIA" in output and "ID" in output: 16 | print("NVIDIA GPU detected.") 17 | except Exception as e: 18 | print("No NVIDIA GPU detected, using CPU. GPU check result:", e) 19 | use_gpu = False 20 | 21 | if use_gpu: 22 | print("GPU detected, existing GPU focused llama-cpp-python should work.") 23 | else: 24 | print("Avoiding error by re-installing non-GPU llama-cpp-python build because no GPU was detected.") 25 | run('pip uninstall llama-cpp-python -y') 26 | run('pip install git+https://github.com/lukestanley/llama-cpp-python.git@expose_json_grammar_convert_function --upgrade --no-cache-dir --force-reinstall') 27 | print("llama-cpp-python re-installed, will now attempt to load.") 28 | 29 | 30 | LLM_WORKER = env.get("LLM_WORKER", "runpod") 31 | 32 | if LLM_WORKER == "http" or LLM_WORKER == "in_memory": 33 | inference_binary_check() 34 | 35 | examples = [ 36 | ["You guys are so slow, we will never ship it!"], 37 | ["Your idea of a balanced diet is a biscuit in each hand."] 38 | ] 39 | 40 | description = """This is an early experimental tool aimed at helping reduce online toxicity by automatically ➡️ transforming 🌶️ spicy or toxic comments into constructive, ❤️ kinder dialogues using AI and large language models. 41 | Input and outputs may be used to train a faster model, by using this, you must agree that the input text is owned by you and that you're okay with it being used to help make an AI model that makes kinder comments! 42 | 43 | ChillTranslator aims to help make online interactions more healthy, with a tool to **convert** text to less toxic variations, **preserve original intent**, focusing on constructive dialogue. 44 | The project is on GitHub: 45 | [https://github.com/lukestanley/ChillTranslator](https://github.com/lukestanley/ChillTranslator) 46 | The repo is the same repo for the HuggingFace Space, the serverless worker, and the logic. 47 | 48 | Contributions are very welcome! Especially pull requests, free API credits. 49 | Help make the internet a kinder place, one comment at a time. Your contribution could make a big difference! 50 | Thank you! 51 | """ 52 | 53 | from chill import improvement_loop 54 | 55 | def chill_out(text): 56 | print("Got this input:", text) 57 | result: dict = improvement_loop(text) 58 | print("Got this result:", result) 59 | 60 | formatted_output = f""" 61 |
62 |

Edited text:

63 |

{result['edit']}

64 |

Details:

65 | 74 |
75 | """ 76 | return formatted_output 77 | 78 | demo = gr.Interface( 79 | fn=chill_out, 80 | inputs=gr.Textbox(lines=2, placeholder="Enter some spicy text here..."), 81 | outputs="html", 82 | examples=examples, 83 | cache_examples=True, 84 | description=description, 85 | title="❄️ ChillTranslator 🤬 ➡️ 😎💬", 86 | allow_flagging="never", 87 | ) 88 | 89 | demo.launch(max_threads=1, share=True) -------------------------------------------------------------------------------- /chill.py: -------------------------------------------------------------------------------- 1 | # chill.py 2 | from argparse import ArgumentParser 3 | import json 4 | from time import time 5 | from uuid import uuid4 6 | from data import log_to_jsonl 7 | from datetime import datetime 8 | from utils import calculate_overall_score, query_ai_prompt 9 | from promptObjects import ( 10 | improve_prompt, 11 | critique_prompt, 12 | faith_scorer_prompt, 13 | spicy_scorer_prompt, 14 | ImprovedText, 15 | Critique, 16 | FaithfulnessScore, 17 | SpicyScore, 18 | ) 19 | 20 | # This script uses the large language model to improve a text, it depends on a llama_cpp server being setup with a model loaded. 21 | # There are several different interfaces it can use, see utils.py for more details. 22 | # Here is a bit of a local setup example: 23 | # pip install llama-cpp-python[server] --upgrade 24 | # python3 -m llama_cpp.server --model mixtral-8x7b-instruct-v0.1.Q4_K_M.gguf --port 5834 --n_ctx 4096 --use_mlock false 25 | # Run this script: 26 | # python3 chill.py 27 | # This should then try and improve the original text below. 28 | # Or you could import the improvement_loop function with a string as an argument to improve a specific text, 29 | # or use it as a command line tool with the -t flag to improve a specific text. 30 | 31 | original_text = """Stop chasing dreams instead. Life is not a Hollywood movie. Not everyone is going to get a famous billionaire. Adjust your expectations to reality, and stop thinking so highly of yourself, stop judging others. Assume the responsibility for the things that happen in your life. It is kind of annoying to read your text, it is always some external thing that "happened" to you, and it is always other people who are not up to your standards. At some moment you even declare with despair. And guess what? This is true and false at the same time, in a fundamental level most people are not remarkable, and you probably aren't too. But at the same time, nobody is the same, you have worth just by being, and other people have too. The impression I get is that you must be someone incredibly annoying to work with, and that your performance is not even nearly close to what you think it is, and that you really need to come down to earth. Stop looking outside, work on yourself instead. You'll never be satisfied just by changing jobs. Do therapy if you wish, become acquainted with stoicism, be a volunteer in some poor country, whatever, but do something to regain control of your life, to get some perspective, and to adjust your expectations to reality.""" 32 | # From elzbardico on https://news.ycombinator.com/item?id=36119858 33 | 34 | """ 35 | Outputs something like this: 36 | { 37 | "critique": "The revised text effectively conveys the same message as the original but in a more constructive and diplomatic tone, maintaining the original's intention while promoting a more positive discussion.", 38 | "faithfulness_score": 0.85, 39 | "spicy_score": 0.25, 40 | "overall_score": 0.89, 41 | "edit": "Consider shifting your focus from chasing dreams to finding fulfillment in reality. Life isn't a Hollywood movie, and becoming a famous billionaire isn't a realistic goal for everyone. It might be helpful to recalibrate your expectations to better align with what's possible. Instead of judging others, try to understand them better. Take responsibility for the events in your life, rather than attributing them to external factors or other people. I understand that it can be frustrating when things don't go as planned, but keep in mind that most people, including yourself, are not inherently exceptional or unremarkable. However, everyone has unique worth that doesn't depend on their achievements or status. It's essential to recognize that you may come across as demanding to work with and that your self-perception might not match others' opinions of your performance. To gain a fresh perspective and adjust your expectations, you could explore personal growth opportunities such as therapy, practicing stoicism, volunteering in underserved communities, or any other activity that helps you develop self-awareness and emotional intelligence." 42 | } 43 | """ 44 | 45 | 46 | 47 | class ImprovementContext: 48 | def __init__(self, original_text=None): 49 | self.suggestions = [] 50 | self.last_edit = "" 51 | self.request_count = 0 52 | self.start_time = time() 53 | self.original_text = original_text 54 | self.improvement_result = dict() 55 | 56 | def query_ai_prompt_with_count(prompt, replacements, model_class, context): 57 | context.request_count += 1 58 | return query_ai_prompt(prompt, replacements, model_class) 59 | 60 | 61 | 62 | 63 | def improve_text_attempt(context): 64 | replacements = { 65 | "original_text": json.dumps(context.original_text), 66 | "previous_suggestions": json.dumps(context.suggestions, indent=2), 67 | } 68 | return query_ai_prompt_with_count(improve_prompt, replacements, ImprovedText, context) 69 | 70 | 71 | def critique_text(context): 72 | replacements = {"original_text": context.original_text, "last_edit": context.last_edit} 73 | 74 | # Query the AI for each of the new prompts separately 75 | 76 | critique_resp = query_ai_prompt_with_count(critique_prompt, replacements, Critique, context) 77 | faithfulness_resp = query_ai_prompt_with_count( 78 | faith_scorer_prompt, replacements, FaithfulnessScore, context 79 | ) 80 | spiciness_resp = query_ai_prompt_with_count( 81 | spicy_scorer_prompt, replacements, SpicyScore, context 82 | ) 83 | 84 | # Combine the results from the three queries into a single dictionary 85 | combined_resp = { 86 | "critique": critique_resp["critique"], 87 | "faithfulness_score": faithfulness_resp["faithfulness_score"], 88 | "spicy_score": spiciness_resp["spicy_score"], 89 | } 90 | 91 | return combined_resp 92 | 93 | 94 | def update_suggestions(critique_dict, iteration, context): 95 | """ 96 | Gets weighted score for new suggestion, 97 | adds new suggestion, 98 | sorts suggestions by score, 99 | updates request_count, time_used, 100 | log progress and return highest score 101 | """ 102 | context.iteration = iteration 103 | time_used = time() - context.start_time 104 | critique_dict["overall_score"] = round( 105 | calculate_overall_score( 106 | critique_dict["faithfulness_score"], critique_dict["spicy_score"] 107 | ), 108 | 2, 109 | ) 110 | critique_dict["edit"] = context.last_edit 111 | if "worst_fix" in context.improvement_result: 112 | critique_dict["worst_fix"] = context.improvement_result["worst_fix"] 113 | if "nvc" in context.improvement_result: 114 | critique_dict["perspective"] = context.improvement_result["nvc"] 115 | if "constructive" in context.improvement_result: 116 | critique_dict["constructive"] = context.improvement_result["constructive"] 117 | context.suggestions.append(critique_dict) 118 | context.suggestions = sorted(context.suggestions, key=lambda x: x["overall_score"], reverse=True)[ 119 | :2 120 | ] 121 | critique_dict["request_count"] = context.request_count 122 | if context.verbose: 123 | print_iteration_result(context.iteration, critique_dict["overall_score"], time_used, context.suggestions) 124 | return critique_dict["overall_score"] 125 | 126 | 127 | def print_iteration_result(iteration, overall_score, time_used, suggestions): 128 | print( 129 | f"Iteration {iteration}: overall_score={overall_score:.2f}, time_used={time_used:.2f} seconds." 130 | ) 131 | print("suggestions:") 132 | print(json.dumps(suggestions, indent=2)) 133 | 134 | 135 | def done_log(context): 136 | log_entry = { 137 | "uuid": str(uuid4()), 138 | "timestamp": datetime.utcnow().isoformat(), 139 | "input": context.original_text, 140 | "output": context.suggestions[0], 141 | } 142 | log_to_jsonl("inputs_and_outputs.jsonl", log_entry) 143 | 144 | 145 | def improvement_loop( 146 | input_text, 147 | max_iterations=3, 148 | good_score=0.85, 149 | min_iterations=2, 150 | good_score_if_late=0.7, 151 | deadline_seconds=60, 152 | verbose=True, 153 | ): 154 | context = ImprovementContext() 155 | context.original_text = input_text 156 | context.verbose = verbose 157 | time_used = 0 158 | 159 | for iteration in range(1, max_iterations + 1): 160 | context.improvement_result = improve_text_attempt(context) 161 | context.last_edit = context.improvement_result["hybrid"] 162 | critique_dict = critique_text(context) 163 | overall_score = update_suggestions(critique_dict, iteration, context) 164 | good_attempt = iteration >= min_iterations and overall_score >= good_score 165 | time_used = time() - context.start_time 166 | too_long = time_used > deadline_seconds and overall_score >= good_score_if_late 167 | if good_attempt or too_long: 168 | break 169 | 170 | assert len(context.suggestions) > 0 171 | if verbose: print("Stopping\nTop suggestion:\n", json.dumps(context.suggestions[0], indent=4)) 172 | context.suggestions[0].update({ 173 | "input": context.original_text, 174 | "iteration_count": iteration, 175 | "max_allowed_iterations": max_iterations, 176 | "time_used": time_used, 177 | "worst_terms": context.improvement_result.get("worst_terms", ""), 178 | "worst_fix": context.improvement_result.get("worst_fix", ""), 179 | "perspective": context.improvement_result.get("nvc", ""), 180 | "constructive": context.improvement_result.get("constructive", "") 181 | }) 182 | done_log(context) 183 | return context.suggestions[0] 184 | 185 | 186 | if __name__ == "__main__": 187 | parser = ArgumentParser(description="Process and improve text.") 188 | parser.add_argument( 189 | "-t", "--text", type=str, help="Text to be improved", default=original_text 190 | ) 191 | args = parser.parse_args() 192 | 193 | improvement_loop(args.text) 194 | 195 | # TODO: Segment the text into sentences for parallel processing, and isolate the most problematic parts for improvement 196 | """ 197 | # import pysbd 198 | # sentences = pysbd.Segmenter(language="en", clean=False).segment(paragraph) 199 | """ 200 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import requests 2 | import json 3 | import os 4 | import threading 5 | 6 | def log_to_jsonl(file_path, data): 7 | def _log_to_jsonl(): 8 | # Read the URL of the Gradio app from an environment variable 9 | url = os.environ.get("SAVE_URL") 10 | if url is None: 11 | raise ValueError("SAVE_URL environment variable not set") 12 | 13 | # Serialize the data to a JSON string 14 | json_data = [{"file_path": file_path, "data": data}] 15 | 16 | # Create a dictionary with the JSON data as the value of a field named "data" 17 | request_body = {"data": json_data} 18 | 19 | # Convert the request body to a JSON string 20 | json_data = json.dumps(request_body) 21 | 22 | # Make the HTTP POST request 23 | try: 24 | response = requests.post(url, data=json_data, headers={"Content-Type": "application/json"}) 25 | 26 | # Check if the request was successful 27 | if response.status_code == 200: 28 | print("Data saved successfully!") 29 | else: 30 | print("Error saving data:", response.text, response.status_code) 31 | except Exception as e: 32 | print("Unexpected error saving", e, url, json_data) 33 | 34 | # Create a new thread and start it 35 | thread = threading.Thread(target=_log_to_jsonl) 36 | thread.start() 37 | 38 | # If there is an ENV var called SKIP_NETWORK, we just locally save the data 39 | 40 | def local_log_to_jsonl(file_path, data): 41 | with open("local_data.jsonl", "a") as f: 42 | f.write(json.dumps({"file_path": file_path, "data": data}) + "\n") 43 | 44 | if os.environ.get("SKIP_NETWORK") is not None: 45 | log_to_jsonl = local_log_to_jsonl -------------------------------------------------------------------------------- /data_saver_server.py: -------------------------------------------------------------------------------- 1 | import gradio as gr 2 | import json 3 | import threading 4 | 5 | # Function to save data to disk in a non-blocking manner 6 | def save_data_to_disk(data): 7 | # Append data to a JSONL file 8 | with open("data.jsonl", "a") as file: 9 | file.write(json.dumps(data) + "\n") 10 | 11 | # Wrapper function to make `save_data_to_disk` non-blocking 12 | def save_data(data): 13 | # Start a new thread to handle the saving process 14 | thread = threading.Thread(target=save_data_to_disk, args=(data,)) 15 | thread.start() 16 | # Return a simple confirmation message 17 | return "Data is being saved." 18 | 19 | # Create a Gradio interface 20 | interface = gr.Interface( 21 | fn=save_data, 22 | inputs=gr.JSON(label="Input JSON Data"), 23 | outputs="text", 24 | title="Data Saving Service", 25 | description="A simple Gradio app to save arbitrary JSON data in the background.", 26 | ) 27 | 28 | # Run the Gradio app 29 | if __name__ == "__main__": 30 | interface.launch(server_name="0.0.0.0", server_port=8435, share=True) 31 | -------------------------------------------------------------------------------- /docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: '3.8' 2 | services: 3 | runpod: 4 | build: 5 | context: . 6 | dockerfile: runpod.dockerfile 7 | volumes: 8 | - ./.cache:/runpod-volume/.cache 9 | - ./test.sh:/test.sh 10 | command: /test.sh 11 | entrypoint: /usr/bin/python3 12 | -------------------------------------------------------------------------------- /gradio_cached_examples/14/log.csv: -------------------------------------------------------------------------------- 1 | output,flag,username,timestamp 2 | " 3 |
4 |

Edited text:

5 |

It seems we're moving a bit slower than anticipated. I'm concerned we might not meet our shipping deadline.

6 |

Details:

7 | 16 |
17 | ",,,2024-02-29 17:35:32.665693 18 | " 19 |
20 |

Edited text:

21 |

It seems your understanding of a balanced diet differs from mine, as it appears to include biscuits quite often.

22 |

Details:

23 | 32 |
33 | ",,,2024-02-29 17:35:35.029163 34 | -------------------------------------------------------------------------------- /learn.py: -------------------------------------------------------------------------------- 1 | # %% 2 | import asyncio 3 | import json 4 | import time 5 | import os 6 | import hashlib 7 | from functools import wraps 8 | 9 | import pandas as pd 10 | from datasets import load_dataset 11 | from detoxify import Detoxify 12 | 13 | # TODO: Compare OpenAI's moderation API to Detoxify 14 | 15 | 16 | predict_model = Detoxify('original-small') 17 | dataset = load_dataset("tasksource/jigsaw") 18 | 19 | train_data = dataset['train'] 20 | print('length',len(train_data)) # length 159571 21 | print(train_data[0]) # {'id': '0000997932d777bf', 'comment_text': "Explanation\nWhy the edits made under my username Hardcore Metallica Fan were reverted? They weren't vandalisms, just closure on some GAs after I voted at New York Dolls FAC. And please don't remove the template from the talk page since I'm retired now.89.205.38.27", 'toxic': 0, 'severe_toxic': 0, 'obscene': 0, 'threat': 0, 'insult': 0, 'identity_hate': 0} 22 | 23 | small_subset = train_data[:2000] 24 | 25 | predict_model.predict("You suck, that is not Markdown!") # Also accepts an array of strings, returning an single dict of arrays of predictions. 26 | # Returns: 27 | {'toxicity': 0.98870254, 28 | 'severe_toxicity': 0.087154716, 29 | 'obscene': 0.93440753, 30 | 'threat': 0.0032278204, 31 | 'insult': 0.7787105, 32 | 'identity_attack': 0.007936229} 33 | 34 | 35 | 36 | _in_memory_cache = {} 37 | 38 | def handle_cache(prefix, func, *args, _result=None, **kwargs): 39 | # Generate a key based on function name and arguments 40 | key = f"{func.__name__}_{args}_{kwargs}" 41 | hashed_key = hashlib.sha1(key.encode()).hexdigest() 42 | cache_filename = f"{prefix}_{hashed_key}.json" 43 | 44 | # Check the in-memory cache first 45 | if key in _in_memory_cache: 46 | return _in_memory_cache[key] 47 | 48 | # Check if cache file exists and read data 49 | if os.path.exists(cache_filename): 50 | with open(cache_filename, 'r') as file: 51 | #print("Reading from cache file with prefix", prefix) 52 | _in_memory_cache[key] = json.load(file) 53 | return _in_memory_cache[key] 54 | 55 | # If result is not provided (for sync functions), compute it 56 | if _result is None: 57 | _result = func(*args, **kwargs) 58 | 59 | # Update the in-memory cache and write it to the file 60 | _in_memory_cache[key] = _result 61 | with open(cache_filename, 'w') as file: 62 | json.dump(_result, file) 63 | 64 | return _result 65 | 66 | 67 | 68 | 69 | def cache(prefix): 70 | def decorator(func): 71 | @wraps(func) 72 | def wrapper(*args, **kwargs): 73 | # Direct call to the shared cache handling function 74 | return handle_cache(prefix, func, *args, **kwargs) 75 | return wrapper 76 | return decorator 77 | 78 | 79 | 80 | 81 | 82 | @cache("toxicity") 83 | def cached_toxicity_prediction(comments): 84 | data = predict_model.predict(comments) 85 | return data 86 | 87 | def predict_toxicity(comments, batch_size=4): 88 | """ 89 | Predicts toxicity scores for a list of comments. 90 | 91 | Args: 92 | - comments: List of comment texts. 93 | - batch_size: Size of batches for prediction to manage memory usage. 94 | 95 | Returns: 96 | A DataFrame with the original comments and their predicted toxicity scores. 97 | """ 98 | results = {'comment_text': [], 'toxicity': [], 'severe_toxicity': [], 'obscene': [], 'threat': [], 'insult': [], 'identity_attack': []} 99 | for i in range(0, len(comments), batch_size): 100 | batch_comments = comments[i:i+batch_size] 101 | predictions = cached_toxicity_prediction(batch_comments) 102 | # We convert the JSON serializable data back to a DataFrame: 103 | results['comment_text'].extend(batch_comments) 104 | for key in predictions.keys(): 105 | results[key].extend(predictions[key]) 106 | return pd.DataFrame(results) 107 | 108 | # Predict toxicity scores for the small subset of comments: 109 | #small_subset_predictions = predict_toxicity(small_subset['comment_text'][4]) 110 | # Let's just try out 4 comments with cached_toxicity_prediction: 111 | small_subset['comment_text'][0:1] 112 | 113 | # %% 114 | small_subset_predictions=predict_toxicity(small_subset['comment_text'][0:200]) 115 | 116 | # %% 117 | small_subset_predictions 118 | 119 | # %% 120 | def filter_comments(dataframe, toxicity_threshold=0.2, severe_toxicity_threshold=0.4): 121 | """ 122 | Filters comments based on specified thresholds for toxicity, severe toxicity. 123 | 124 | Args: 125 | - dataframe: DataFrame containing comments and their toxicity scores. 126 | - toxicity_threshold: Toxicity score threshold. 127 | - severe_toxicity_threshold: Severe toxicity score threshold. 128 | - identity_attack_threshold: Identity attack score threshold. 129 | 130 | Returns: 131 | DataFrame filtered based on the specified thresholds. 132 | """ 133 | identity_attack_threshold = 0.5 134 | insult_threshold = 0.3 135 | obscene_threshold = 0.6 136 | threat_threshold = 0.3 137 | filtered_df = dataframe[ 138 | (dataframe['toxicity'] >= toxicity_threshold) & 139 | #(dataframe['toxicity'] < 1.0) & # Ensure comments are spicy but not 100% toxic 140 | (dataframe['severe_toxicity'] < severe_toxicity_threshold) & 141 | (dataframe['identity_attack'] < identity_attack_threshold) & 142 | (dataframe['insult'] < insult_threshold) & 143 | (dataframe['obscene'] < obscene_threshold) & 144 | (dataframe['threat'] < threat_threshold) 145 | 146 | ] 147 | return filtered_df 148 | 149 | spicy_comments = filter_comments(small_subset_predictions) 150 | 151 | 152 | # Lets sort spicy comments by combined toxicity score: 153 | spicy_comments.sort_values(by=['toxicity', 'severe_toxicity'], ascending=True, inplace=True) 154 | 155 | # Print the spicy comments comment_text and their toxicity scores as a formatted string: 156 | for index, row in spicy_comments.iterrows(): 157 | print(f"Comment: `{row['comment_text']}` \n Toxiciy: {(row['toxicity'] + row['severe_toxicity']) / 2 * 100:.0f}% \n") 158 | -------------------------------------------------------------------------------- /local_score.py: -------------------------------------------------------------------------------- 1 | import difflib 2 | from sentence_transformers import SentenceTransformer, util 3 | from detoxify import Detoxify 4 | import nltk 5 | from nltk.sentiment import SentimentIntensityAnalyzer 6 | 7 | # Necessary installations: 8 | # pip install sentence-transformers detoxify nltk 9 | 10 | # Load models 11 | similarity_model = SentenceTransformer('stsb-roberta-base', device="cpu") 12 | context_model = SentenceTransformer('all-MiniLM-L6-v2', device="cpu") 13 | spice_model = Detoxify('unbiased-small') 14 | 15 | # Download necessary NLTK data 16 | nltk.download('vader_lexicon') 17 | sia = SentimentIntensityAnalyzer() 18 | 19 | # Define a function to normalize the compound sentiment score 20 | def get_sentiment(text) -> float: 21 | score = sia.polarity_scores(text) 22 | normalized_score = (score['compound'] + 1) / 2 # Normalizing to range [0,1] 23 | return normalized_score 24 | 25 | 26 | # Define original text and its variations 27 | original_text = "We live in an advertising hellscape now. The biggest crock of crap is being forced to watch an advertisement at the freaking gas station as I pump my gas. I'll never voluntarily use those pumps ever again." 28 | worst_terms = [ 29 | "hellscape", 30 | "crock of crap", 31 | "freaking" 32 | ] 33 | variations = { 34 | "worst_fix": "We live in an advertising-heavy environment now. The most bothersome thing is being forced to watch an advertisement at the gas station as I pump my gas. I'll avoid using those pumps in the future.", 35 | "nvc": "I feel overwhelmed by the amount of advertising in our environment now. It bothers me to have to watch an advertisement while pumping gas at the gas station. I prefer to use pumps without advertisements going forward.", 36 | "constructive": "The prevalence of advertising in our environment can feel overwhelming at times. Having to watch ads while pumping gas is particularly bothersome to me. I would appreciate less adverts at gas stations for a less distracting, peaceful experience.", 37 | "hybrid": "We live in an advertising-heavy environment now. Having to watch an advertisement at the gas station while pumping gas is quite bothersome. I'll avoid using those pumps when possible going forward.", 38 | "beta": "We live in an advertising-heavy landscape now. The biggest frustration is being forced to watch an advertisement at the gas station as I pump my gas. I'll never voluntarily use those pumps ever again.", 39 | "identity":"We live in an advertising hellscape now. The biggest crock of crap is being forced to watch an advertisement at the freaking gas station as I pump my gas. I'll never voluntarily use those pumps ever again.", 40 | "calm":"We live in an advertising-heavy landscape now. The biggest frustration is being forced to watch an advertisement at the gas station as I pump my gas. I'll never voluntarily use those pumps. Let's enjoy calm!" 41 | } 42 | 43 | # Identify replacements using difflib for 'worst_fix' variation 44 | 45 | def identify_replacements(original_text, this_variant): 46 | matcher = difflib.SequenceMatcher(None, original_text.split(), this_variant.split()) 47 | replacements = {} 48 | for opcode in matcher.get_opcodes(): 49 | if opcode[0] == 'replace': 50 | original_phrase = ' '.join(original_text.split()[opcode[1]:opcode[2]]) 51 | new_phrase = ' '.join(this_variant.split()[opcode[3]:opcode[4]]) 52 | replacements[original_phrase] = new_phrase 53 | return replacements 54 | 55 | print("Replacements found:", identify_replacements(original_text, variations['worst_fix'])) 56 | 57 | # Function to calculate an aggregate score for each variant 58 | def calculate_aggregate_score(overall_similarity, negativity_score, sentiment_delta, edit_distance, max_length, name, sentiment): 59 | negativity_weight=3 60 | sentiment_weight=0.6 61 | edit_distance_weight=1 62 | similarity_weight=0.8 63 | 64 | normalized_edit_distance = edit_distance / max(max_length, 1) 65 | weighted_similarity = overall_similarity * similarity_weight 66 | weighted_negativity = negativity_score * negativity_weight 67 | weighted_happy_sentiment = (sentiment * sentiment_weight) 68 | weighted_edit_distance = normalized_edit_distance * edit_distance_weight 69 | result = weighted_similarity - weighted_edit_distance - weighted_negativity + weighted_happy_sentiment 70 | return result 71 | 72 | # Dictionary to hold variant names and their aggregate scores 73 | variant_scores = {} 74 | 75 | def calculate_negativity_score(predictions): 76 | score = sum(predictions.values()) / len(predictions) 77 | return score 78 | 79 | # Calculate similarity for replacements and overall similarity for all variations 80 | original_embedding = context_model.encode(original_text, convert_to_tensor=True) 81 | original_text_negativity_score = calculate_negativity_score(spice_model.predict(original_text)) 82 | for name, text in variations.items(): 83 | # Compute overall semantic similarity 84 | variation_embedding = context_model.encode(text, convert_to_tensor=True) 85 | overall_similarity = util.pytorch_cos_sim(original_embedding, variation_embedding).item() 86 | 87 | # Calculate negativity score using Detoxify 88 | negativity_score = calculate_negativity_score(spice_model.predict(text)) 89 | negativity_score_delta = original_text_negativity_score - negativity_score # unused 90 | 91 | # Calculate sentiment score delta 92 | variant_sentiment = get_sentiment(text) 93 | sentiment_delta = variant_sentiment - get_sentiment(original_text) 94 | 95 | # Calculate the maximum length between the original and variation texts for normalization 96 | max_length = max(len(original_text), len(text)) 97 | 98 | # Calculate and store the aggregate score 99 | edit_distance = nltk.edit_distance(original_text, text) 100 | 101 | replacments = identify_replacements(original_text, text) 102 | aggregate_score = calculate_aggregate_score(overall_similarity, negativity_score, sentiment_delta, edit_distance, max_length,name=name, sentiment=variant_sentiment) 103 | variant_scores[name] = { 104 | "overall_similarity": overall_similarity, 105 | "negativity_score": negativity_score, 106 | "sentiment_delta": sentiment_delta, 107 | "sentiment": variant_sentiment, 108 | "edit_distance": edit_distance, 109 | "max_length": max_length, 110 | "aggregate_score": aggregate_score, 111 | "variant_text": text, 112 | "replacements": replacments 113 | } 114 | 115 | # Sort the variants by aggregate score 116 | sorted_variants = sorted(variant_scores.items(), key=lambda x: x[1]['aggregate_score'], reverse=True) 117 | 118 | for name, score in sorted_variants: 119 | print(f"\nVariation: {name}") 120 | print(f"Aggregate score: {variant_scores[name]['aggregate_score']:.4f}") 121 | print(f"Negativity score: {variant_scores[name]['negativity_score']:.4f}") 122 | print(f"Sentiment: {variant_scores[name]['sentiment']:.4f}") 123 | print(f"Sentiment delta: {variant_scores[name]['sentiment_delta']:.4f}") 124 | print(f"Edit distance: {variant_scores[name]['edit_distance']}") 125 | print(f"Variant text: `{variant_scores[name]['variant_text']}`\n") 126 | print(f"Replacements: `{variant_scores[name]['replacements']}`\n") 127 | 128 | """ 129 | Example output: 130 | 131 | Replacements found: {'advertising hellscape': 'advertising-heavy environment', 'biggest crock of crap': 'most bothersome thing', 'never voluntarily use': 'avoid using', 'ever again.': 'in the future.'} 132 | 133 | Variation: calm 134 | Aggregate score: 0.6768 135 | Negativity score: 0.0001 136 | Sentiment: 0.3878 137 | Sentiment delta: 0.2941 138 | Edit distance: 44 139 | Variant text: `We live in an advertising-heavy landscape now. The biggest frustration is being forced to watch an advertisement at the gas station as I pump my gas. I'll never voluntarily use those pumps. Let's enjoy calm!` 140 | 141 | Replacements: `{'advertising hellscape': 'advertising-heavy landscape', 'crock of crap': 'frustration', 'pumps ever again.': "pumps. Let's enjoy calm!"}` 142 | 143 | 144 | Variation: beta 145 | Aggregate score: 0.6218 146 | Negativity score: 0.0001 147 | Sentiment: 0.1366 148 | Sentiment delta: 0.0428 149 | Edit distance: 29 150 | Variant text: `We live in an advertising-heavy landscape now. The biggest frustration is being forced to watch an advertisement at the gas station as I pump my gas. I'll never voluntarily use those pumps ever again.` 151 | 152 | Replacements: `{'advertising hellscape': 'advertising-heavy landscape', 'crock of crap': 'frustration'}` 153 | 154 | 155 | Variation: worst_fix 156 | Aggregate score: 0.4016 157 | Negativity score: 0.0002 158 | Sentiment: 0.1111 159 | Sentiment delta: 0.0174 160 | Edit distance: 72 161 | Variant text: `We live in an advertising-heavy environment now. The most bothersome thing is being forced to watch an advertisement at the gas station as I pump my gas. I'll avoid using those pumps in the future.` 162 | 163 | Replacements: `{'advertising hellscape': 'advertising-heavy environment', 'biggest crock of crap': 'most bothersome thing', 'never voluntarily use': 'avoid using', 'ever again.': 'in the future.'}` 164 | 165 | 166 | Variation: constructive 167 | Aggregate score: 0.2612 168 | Negativity score: 0.0001 169 | Sentiment: 0.6391 170 | Sentiment delta: 0.5454 171 | Edit distance: 174 172 | Variant text: `The prevalence of advertising in our environment can feel overwhelming at times. Having to watch ads while pumping gas is particularly bothersome to me. I would appreciate less adverts at gas stations for a less distracting, peaceful experience.` 173 | 174 | Replacements: `{'We live': 'The prevalence of advertising', 'an advertising hellscape now. The biggest crock of crap is being forced': 'our environment can feel overwhelming at times. Having', 'an advertisement': 'ads while pumping gas is particularly bothersome to me. I would appreciate less adverts', "station as I pump my gas. I'll never voluntarily use those pumps ever again.": 'stations for a less distracting, peaceful experience.'}` 175 | 176 | 177 | Variation: nvc 178 | Aggregate score: 0.2149 179 | Negativity score: 0.0001 180 | Sentiment: 0.4234 181 | Sentiment delta: 0.3297 182 | Edit distance: 142 183 | Variant text: `I feel overwhelmed by the amount of advertising in our environment now. It bothers me to have to watch an advertisement while pumping gas at the gas station. I prefer to use pumps without advertisements going forward.` 184 | 185 | Replacements: `{'We live': 'I feel overwhelmed by the amount of advertising', 'an advertising hellscape': 'our environment', 'The biggest crock of crap is being forced': 'It bothers me to have', 'station as': 'station.', "pump my gas. I'll never voluntarily": 'prefer to', 'ever again.': 'without advertisements going forward.'}` 186 | 187 | 188 | Variation: hybrid 189 | Aggregate score: 0.1895 190 | Negativity score: 0.0001 191 | Sentiment: 0.1902 192 | Sentiment delta: 0.0965 193 | Edit distance: 112 194 | Variant text: `We live in an advertising-heavy environment now. Having to watch an advertisement at the gas station while pumping gas is quite bothersome. I'll avoid using those pumps when possible going forward.` 195 | 196 | Replacements: `{'advertising hellscape': 'advertising-heavy environment', 'The biggest crock of crap is being forced': 'Having', 'as I pump my gas.': 'while pumping gas is quite bothersome.', 'never voluntarily use': 'avoid using', 'ever again.': 'when possible going forward.'}` 197 | 198 | 199 | Variation: identity 200 | Aggregate score: 0.0974 201 | Negativity score: 0.2530 202 | Sentiment: 0.0937 203 | Sentiment delta: 0.0000 204 | Edit distance: 0 205 | Variant text: `We live in an advertising hellscape now. The biggest crock of crap is being forced to watch an advertisement at the freaking gas station as I pump my gas. I'll never voluntarily use those pumps ever again.` 206 | 207 | Replacements: `{}`""" 208 | -------------------------------------------------------------------------------- /promptObjects.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | from pydantic import BaseModel, Field 3 | 4 | improve_prompt = """ 5 | Given some inflammatory text, make minimal changes to the text to make it less inflammatory, while keeping the original meaning as much as possible. 6 | Make the new version more calm and constructive, without changing the intended meaning, with only minimal changes to the existing text. 7 | Make sure the refined text is a good reflection of the original text, without adding new ideas. 8 | Make the changes as minimal as possible. Some optional strategies to make the text less inflammatory include: 9 | -Soften harsh tone, replace or omit judgemental or extreme words. 10 | -Rather than accusations, share perspective. 11 | -Consider focusing on specific actions rather than character. 12 | -Rephrasing exaggerated expressions like "always", "never" or "everyone" to be more moderate. 13 | -Using gentler alternatives to express similar points where needed. 14 | 15 | Avoid adding new ideas, ONLY build upon what's already there, for example, you might reframe an existing point to be more balanced but never introduce unrelated concepts. 16 | 17 | Make both parties more happy where possible: 18 | The reader should be INFORMED and not *offended*, and the original author should be *content* that their points where *honoured* by your edit, by minimally refining their text without loosing the original intent. 19 | 20 | Format: 21 | JSON object with the following properties: 22 | worst_terms: array of strings of the worst terms in the text 23 | worst_fix: Text with worst terms replaced or softened 24 | nvc: text with Non-violent Communication perspective sharing where needed 25 | constructive: text with constructive suggestions 26 | hybrid: The most minimal calm variation of the original text, learning from prior variations 27 | 28 | Example input text: "You're always annoying me. You never listen to me." 29 | Example improved text outputs: 30 | { 31 | "worst_terms": ["annoying", "always annoying", "never listen"], 32 | "worst_fix": "You're often frustrating me. You rarely listen to me.", 33 | "perspective": "I often feel annoyed by you. I rarely feel you listen to me.", 34 | "constructive": "I often feel frustrated by you. I rarely feel you listen to me. How can we improve our communication?", 35 | "hybrid": "You're often frustrating me. It feels like you often don't listen to me."} 36 | End of example. 37 | Here is the real input text to improve: 38 | `{original_text}` 39 | 40 | Previous rephrasing attempts: 41 | {previous_suggestions} 42 | 43 | Provide your improved version in the required JSON format. 44 | To get a good answer, make the original text non-inflamitory, while being as faithful to the original text as much as possible. 45 | Use valid JSON then stop, the required keys are: worst_terms, worst_fix, nvc, constructive, best. 46 | Do not add any remarks before or after the JSON! 47 | """ 48 | 49 | critique_prompt = """ 50 | Critique the text. We prefer the edit prevent inflaming discussions! 51 | We also prefer concise text, and a similar semantic intent to the original. 52 | 53 | Here is the original text: 54 | `{original_text}` 55 | 56 | Here is the text to critique: 57 | `{last_edit}` 58 | 59 | Output your response as valid JSON in this format: 60 | { 61 | "critique":"STRING", 62 | } 63 | 64 | E.g: 65 | { 66 | "critique":"This is too fluffy and different from the original intent." 67 | } 68 | Please critique the text. 69 | You must output the JSON in the required format only, with no remarks or prefacing remarks - JUST JSON!""" 70 | 71 | 72 | spicy_scorer_prompt = """ 73 | Score the text. 74 | 75 | A calm spicy_score of 0 is ideal. A spicy_score of 1 is the worst, very inflammatory text that makes the reader feel attacked. 76 | 77 | Here is the original text: 78 | `{original_text}` 79 | 80 | Here is the text to score: 81 | `{last_edit}` 82 | The float variable is scored from 0 to 1. 83 | 84 | Output your response as valid JSON in this format, then stop: 85 | { 86 | "spicy_score":FLOAT 87 | } 88 | Please score the text. 89 | You must output the JSON in the required format only, with no remarks or prefacing remarks - JUST JSON! 90 | """ 91 | 92 | 93 | faith_scorer_prompt = """ 94 | Score the text. 95 | 96 | A score of 1 would have the same semantic intent as the original text. A score of 0 would mean the text has lost all semantic similarity. 97 | 98 | Here is the original text: 99 | `{original_text}` 100 | 101 | Here is the new text to score: 102 | `{last_edit}` 103 | 104 | The float variable is scored from 0 to 1. 105 | 106 | Output your response as valid JSON in this format, then stop: 107 | { 108 | "faithfulness_score":FLOAT 109 | } 110 | Please score the text. 111 | You must output the JSON in the required format only, with no remarks or prefacing remarks - JUST JSON! 112 | """ 113 | 114 | 115 | class ImprovedText(BaseModel): 116 | worst_terms: List[str] = Field(..., description="Array of strings of the worst terms in the text.") 117 | worst_fix: str = Field(..., description="The text with worst terms replaced or softened.") 118 | nvc: str = Field(..., description="The text with NVC perspective sharing where needed.") 119 | constructive: str = Field(..., description="The text with constructive suggestions.") 120 | hybrid: str = Field(..., description="A suggestion that each try to be close to the original while combining best of the variations") 121 | 122 | 123 | class SpicyScore(BaseModel): 124 | spicy_score: float = Field(..., description="The spiciness score of the text.") 125 | 126 | 127 | class Critique(BaseModel): 128 | critique: str = Field(..., description="The critique of the text.") 129 | 130 | 131 | class FaithfulnessScore(BaseModel): 132 | faithfulness_score: float = Field( 133 | ..., description="The faithfulness score of the text." 134 | ) 135 | -------------------------------------------------------------------------------- /runpod.dockerfile: -------------------------------------------------------------------------------- 1 | # Base image -> https://github.com/runpod/containers/blob/main/official-templates/base/Dockerfile 2 | # DockerHub -> https://hub.docker.com/r/runpod/base/tags 3 | FROM runpod/base:0.4.0-cuda11.8.0 4 | 5 | # Base image sets HuggingFace cache directory to use Runpod's shared cache for efficiency: 6 | ENV HF_HOME="/runpod-volume/.cache/huggingface/" 7 | # Also pre-downloading models may speed up start times while 8 | # increasing image size, but could be worth it for some use cases. 9 | 10 | RUN python3.11 -m pip install --upgrade pip && \ 11 | python3.11 -m pip install runpod==1.6.0 12 | 13 | RUN python3.11 -m pip install pytest cmake \ 14 | scikit-build setuptools pydantic-settings \ 15 | huggingface_hub hf_transfer \ 16 | pydantic pydantic_settings \ 17 | llama-cpp-python 18 | 19 | # Install llama-cpp-python (build with cuda) 20 | ENV CMAKE_ARGS="-DLLAMA_CUBLAS=on" 21 | RUN python3.11 -m pip install git+https://github.com/lukestanley/llama-cpp-python.git@expose_json_grammar_convert_function --upgrade --no-cache-dir --force-reinstall 22 | 23 | ADD runpod_handler.py . 24 | ADD chill.py . 25 | ADD utils.py . 26 | ADD promptObjects.py . 27 | 28 | ENV REPO_ID="TheBloke/phi-2-GGUF" 29 | ENV MODEL_FILE="phi-2.Q2_K.gguf" 30 | CMD python3.11 -u /runpod_handler.py 31 | 32 | -------------------------------------------------------------------------------- /runpod_handler.py: -------------------------------------------------------------------------------- 1 | import runpod 2 | from os import environ as env 3 | import json 4 | from pydantic import BaseModel, Field 5 | class Movie(BaseModel): 6 | title: str = Field(..., title="The title of the movie") 7 | year: int = Field(..., title="The year the movie was released") 8 | director: str = Field(..., title="The director of the movie") 9 | genre: str = Field(..., title="The genre of the movie") 10 | plot: str = Field(..., title="Plot summary of the movie") 11 | 12 | def pydantic_model_to_json_schema(pydantic_model_class): 13 | schema = pydantic_model_class.model_json_schema() 14 | 15 | # Optional example field from schema, is not needed for the grammar generation 16 | if "example" in schema: 17 | del schema["example"] 18 | 19 | json_schema = json.dumps(schema) 20 | return json_schema 21 | default_schema_example = """{ "title": ..., "year": ..., "director": ..., "genre": ..., "plot":...}""" 22 | default_schema = pydantic_model_to_json_schema(Movie) 23 | default_prompt = f"Instruct: \nOutput a JSON object in this format: {default_schema_example} for the following movie: The Matrix\nOutput:\n" 24 | from utils import llm_stream_sans_network_simple 25 | def handler(job): 26 | """ Handler function that will be used to process jobs. """ 27 | job_input = job['input'] 28 | filename=env.get("MODEL_FILE", "mixtral-8x7b-instruct-v0.1.Q4_K_M.gguf") 29 | prompt = job_input.get('prompt', default_prompt) 30 | schema = job_input.get('schema', default_schema) 31 | print("got this input", str(job_input)) 32 | print("prompt", prompt ) 33 | print("schema", schema ) 34 | output = llm_stream_sans_network_simple(prompt, schema) 35 | #print("got this output", str(output)) 36 | return output 37 | 38 | runpod.serverless.start({ 39 | "handler": handler, 40 | #"return_aggregate_stream": True 41 | }) 42 | -------------------------------------------------------------------------------- /serverless.md: -------------------------------------------------------------------------------- 1 | Fast severless GPU inference with RunPod 2 | ============================== 3 | 4 | This partly GPT-4 generated document explains the integration of Runpod with Docker, including testing the Runpod Dockerfile with Docker Compose, building and pushing the image to Docker Hub, and how `app.py` makes use of it. I skimmed it and added stuff to it, as a note to myself and others. 5 | 6 | # Motivation 7 | Fast inference is useful. Usually an existing hosted provider would be good for this, but I was worried about getting blocked given that we need to translate some spicy text input, the concern is that it could get flagged, and result in accounts being blocked. 8 | Also I needed something that could infer with JSON typed output, that matches particular schemas, and fast. So I found RunPod's "serverless" GPU, service. 9 | It can be used by chill.py and app.py, as one of the worker options. 10 | 11 | 12 | ## Testing with Docker Compose 13 | 14 | To test the Runpod Dockerfile, you can use Docker Compose which simplifies the process of running multi-container Docker applications. Here's how you can test it: 15 | 16 | 1. Ensure you have Docker and Docker Compose installed on your system. 17 | 2. Navigate to the directory containing the `docker-compose.yml` file. 18 | 3. Run the following command to build and start the container: 19 | ``` 20 | docker-compose up --build 21 | ``` 22 | 4. The above command will build the image as defined in `runpod.dockerfile` and start a container with the configuration specified in `docker-compose.yml`, it will automatically run a test, that matches the format expected from the llm_stream_serverless client (in utils.py), though without the network layer in play. 23 | 24 | 25 | # Direct testing with Docker, without Docker-Compose: 26 | 27 | Something like this worked for me: 28 | 29 | ```sudo docker run --gpus all -it -v "$(pwd)/.cache:/runpod-volume/.cache/huggingface/" lukestanley/test:translate2 bash``` 30 | Note the cache mount. This saves re-downloading the LLMs! 31 | 32 | 33 | ## Building and Pushing to Docker Hub 34 | 35 | After testing and ensuring that everything works as expected, you can build the Docker image and push it to Docker Hub for deployment. Here are the steps: 36 | 37 | 1. Log in to Docker Hub from your command line using `docker login --username [yourusername]`. 38 | 2. Build the Docker image with a tag: 39 | ``` 40 | docker build -t yourusername/yourimagename:tag -f runpod.dockerfile . 41 | ``` 42 | 3. Once the image is built, push it to Docker Hub: 43 | ``` 44 | docker push yourusername/yourimagename:tag 45 | ``` 46 | 4. Replace `yourusername`, `yourimagename`, and `tag` with your Docker Hub username, the name you want to give to your image, and the tag respectively. 47 | 48 | # Runpod previsioning: 49 | You'll need an account on Runpod with credit. 50 | You'll need a serverless GPU endpoint setting up using your Docker image setup here: 51 | https://www.runpod.io/console/serverless 52 | It has a Flashboot feature that seems like Firecracker with GPU support, it might be using Cloud Hypervisor under the hood, currently Firecracker has no GPU support. Fly.io also has something similar, with Cloud Hypervisor. 53 | You'll need the secret saved somewhere securely. This will likely end up as a securely treated env var for use by app.py later. 54 | You'll also need the endpoint ID. 55 | 56 | ## Runpod Integration in `app.py` 57 | 58 | The `app.py` file is a Gradio interface that makes use of the Runpod integration to perform inference. It checks for the presence of a GPU and installs the appropriate version of `llama-cpp-python`. Depending on the environment variable `LLM_WORKER`, it uses either the Runpod serverless API, an HTTP server, or loads the model into memory for inference. 59 | 60 | The `greet` function in `app.py` calls `improvement_loop` from the `chill` module, which based on an environment variable, will use the Runpod worker, that is used to process the input text and generate improved text based on the model's output. 61 | 62 | The Gradio interface is then launched with `demo.launch()`, making the application accessible via a web interface, which can be shared publicly. 63 | 64 | Note: Ensure that the necessary environment variables such as `LLM_WORKER`, `REPO_ID`, and `MODEL_FILE` are set correctly for the integration to work properly. -------------------------------------------------------------------------------- /serverless_local_test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import os, json 3 | 4 | # Define your JSON and prompt as Python dictionaries and strings 5 | schema = { 6 | "properties": { 7 | "title": {"title": "The title of the movie", "type": "string"}, 8 | "year": {"title": "The year the movie was released", "type": "integer"}, 9 | "director": {"title": "The director of the movie", "type": "string"}, 10 | "genre": {"title": "The genre of the movie", "type": "string"}, 11 | "plot": {"title": "Plot summary of the movie", "type": "string"} 12 | }, 13 | "required": ["title", "year", "director", "genre", "plot"], 14 | "title": "Movie", 15 | "type": "object" 16 | } 17 | 18 | movie ="Toy Story" 19 | prompt = "Instruct: Output a JSON object in this format: { \"title\": ..., \"year\": ..., \"director\": ..., \"genre\": ..., \"plot\":...} for the following movie: "+movie+"\nOutput:\n" 20 | 21 | # Construct the JSON input string 22 | json_input = json.dumps({"input": {"schema": json.dumps(schema), "prompt": prompt}}) 23 | print(json_input) 24 | # Define the command to execute your Python script with the JSON string 25 | command = f'python3.11 runpod_handler.py --test_input \'{json_input}\'' 26 | 27 | # Execute the command 28 | os.system(command) -------------------------------------------------------------------------------- /system_map.md: -------------------------------------------------------------------------------- 1 | 1. The system's architecture is designed to mitigate online toxicity by transforming text inputs into less provocative forms using Large Language Models (LLMs), which are pivotal in analysing and refining text. 2 | 4. Different workers, or LLM interfaces are defined, each suited for specific operational environments. 3 | 5. The HTTP server worker is optimised for development purposes, facilitating dynamic updates without necessitating server restarts, it can work offline, with or without a GPU using the `llama-cpp-python` library, provided a downloaded model. 4 | 6. An in-memory worker is used by the serverless worker. 5 | 7. For on-demand, scalable processing, the system includes a RunPod API worker that leverages serverless GPU functions. 6 | 8. Additionally, the Mistral API worker offers a paid service alternative for text processing tasks. 7 | 9. A set of environment variables are predefined to configure the LLM workers' functionality. 8 | 10. The `LLM_WORKER` environment variable sets the active LLM worker. 9 | 11. The `N_GPU_LAYERS` environment variable allows for the specification of GPU layers utilised, defaulting to the maximum available, used when the LLM worker is ran with a GPU. 10 | 12. `CONTEXT_SIZE` is an adjustable parameter that defines the extent of text the LLM can process concurrently. 11 | 13. The `LLM_MODEL_PATH` environment variable indicates the LLM model's storage location, which can be either local or sourced from the HuggingFace Hub. 12 | 14. The system enforces some rate limiting to maintain service integrity and equitable resource distribution. 13 | 15. The `LAST_REQUEST_TIME` and `REQUEST_INTERVAL` global variables are used for Mistral rate limiting. 14 | 16. The system's worker architecture is somewhat modular, enabling easy integration or replacement of components such as LLM workers. 15 | 18. The system is capable of streaming responses in some modes, allowing for real-time interaction with the LLM. 16 | 19. The `llm_streaming` function handles communication with the LLM via HTTP streaming when the server worker is active. 17 | 20. The `llm_stream_sans_network` function provides an alternative for local LLM inference without network dependency. 18 | 21. For serverless deployment, the `llm_stream_serverless` function interfaces with the RunPod API. 19 | 22. The `llm_stream_mistral_api` function facilitates interaction with the Mistral API for text processing. 20 | 23. The system includes a utility function, `replace_text`, for template-based text replacement operations. 21 | 24. A scoring function, `calculate_overall_score`, amalgamates different metrics to evaluate the text transformation's effectiveness. 22 | 25. The `query_ai_prompt` function serves as a dispatcher, directing text processing requests to the chosen LLM worker. 23 | 27. The `inference_binary_check` function within `app.py` ensures compatibility with the available hardware, particularly GPU presence. 24 | 28. The system provides a user interface through Gradio, enabling end-users to interact with the text transformation service. 25 | 29. The `chill_out` function in `app.py` is the entry point for processing user inputs through the Gradio interface. 26 | 30. The `improvement_loop` function in `chill.py` controls the iterative process of text refinement using the LLM. 27 | 28 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import json 3 | import uuid 4 | from time import time, sleep 5 | from os import environ as env 6 | from typing import Any, Dict, Union 7 | from data import log_to_jsonl 8 | import requests 9 | from huggingface_hub import hf_hub_download 10 | 11 | 12 | # There are 4 ways to use a LLM model currently used: 13 | # 1. Use the HTTP server (USE_HTTP_SERVER=True), this is good for development 14 | # when you want to change the logic of the translator without restarting the server. 15 | # 2. Load the model into memory 16 | # When using the HTTP server, it must be ran separately. See the README for instructions. 17 | # The llama_cpp Python HTTP server communicates with the AI model, similar 18 | # to the OpenAI API but adds a unique "grammar" parameter. 19 | # The real OpenAI API has other ways to set the output format. 20 | # It's possible to switch to another LLM API by changing the llm_streaming function. 21 | # 3. Use the RunPod API, which is a paid service with severless GPU functions. 22 | # See serverless.md for more information. 23 | # 4. Use the Mistral API, which is a paid services. 24 | 25 | URL = "http://localhost:5834/v1/chat/completions" 26 | in_memory_llm = None 27 | worker_options = ["runpod", "http", "in_memory", "mistral", "anthropic"] 28 | 29 | LLM_WORKER = env.get("LLM_WORKER", "anthropic") 30 | if LLM_WORKER not in worker_options: 31 | raise ValueError(f"Invalid worker: {LLM_WORKER}") 32 | N_GPU_LAYERS = int(env.get("N_GPU_LAYERS", -1)) # Default to -1, use all layers if available 33 | CONTEXT_SIZE = int(env.get("CONTEXT_SIZE", 2048)) 34 | LLM_MODEL_PATH = env.get("LLM_MODEL_PATH", None) 35 | 36 | MAX_TOKENS = int(env.get("MAX_TOKENS", 1000)) 37 | TEMPERATURE = float(env.get("TEMPERATURE", 0.3)) 38 | 39 | performing_local_inference = (LLM_WORKER == "in_memory" or (LLM_WORKER == "http" and "localhost" in URL)) 40 | 41 | if LLM_MODEL_PATH and len(LLM_MODEL_PATH) > 0: 42 | print(f"Using local model from {LLM_MODEL_PATH}") 43 | if performing_local_inference and not LLM_MODEL_PATH: 44 | print("No local LLM_MODEL_PATH environment variable set. We need a model, downloading model from HuggingFace Hub") 45 | LLM_MODEL_PATH =hf_hub_download( 46 | repo_id=env.get("REPO_ID", "TheBloke/Mixtral-8x7B-Instruct-v0.1-GGUF"), 47 | filename=env.get("MODEL_FILE", "mixtral-8x7b-instruct-v0.1.Q4_K_M.gguf"), 48 | ) 49 | print(f"Model downloaded to {LLM_MODEL_PATH}") 50 | if LLM_WORKER == "http" or LLM_WORKER == "in_memory": 51 | from llama_cpp import Llama, LlamaGrammar, json_schema_to_gbnf 52 | 53 | if in_memory_llm is None and LLM_WORKER == "in_memory": 54 | print("Loading model into memory. If you didn't want this, set the USE_HTTP_SERVER environment variable to 'true'.") 55 | in_memory_llm = Llama(model_path=LLM_MODEL_PATH, n_ctx=CONTEXT_SIZE, n_gpu_layers=N_GPU_LAYERS, verbose=True) 56 | 57 | def llm_streaming( 58 | prompt: str, pydantic_model_class, return_pydantic_object=False 59 | ) -> Union[str, Dict[str, Any]]: 60 | schema = pydantic_model_class.model_json_schema() 61 | 62 | # Optional example field from schema, is not needed for the grammar generation 63 | if "example" in schema: 64 | del schema["example"] 65 | 66 | json_schema = json.dumps(schema) 67 | grammar = json_schema_to_gbnf(json_schema) 68 | 69 | payload = { 70 | "stream": True, 71 | "max_tokens": MAX_TOKENS, 72 | "grammar": grammar, 73 | "temperature": TEMPERATURE, 74 | "messages": [{"role": "user", "content": prompt}], 75 | } 76 | headers = { 77 | "Content-Type": "application/json", 78 | } 79 | 80 | response = requests.post( 81 | URL, 82 | headers=headers, 83 | json=payload, 84 | stream=True, 85 | ) 86 | output_text = "" 87 | for chunk in response.iter_lines(): 88 | if chunk: 89 | chunk = chunk.decode("utf-8") 90 | if chunk.startswith("data: "): 91 | chunk = chunk.split("data: ")[1] 92 | if chunk.strip() == "[DONE]": 93 | break 94 | chunk = json.loads(chunk) 95 | new_token = chunk.get("choices")[0].get("delta").get("content") 96 | if new_token: 97 | output_text = output_text + new_token 98 | print(new_token, sep="", end="", flush=True) 99 | print('\n') 100 | 101 | if return_pydantic_object: 102 | model_object = pydantic_model_class.model_validate_json(output_text) 103 | return model_object 104 | else: 105 | json_output = json.loads(output_text) 106 | return json_output 107 | 108 | 109 | def replace_text(template: str, replacements: dict) -> str: 110 | for key, value in replacements.items(): 111 | template = template.replace(f"{{{key}}}", value) 112 | return template 113 | 114 | 115 | 116 | 117 | def calculate_overall_score(faithfulness, spiciness): 118 | baseline_weight = 0.8 119 | overall = faithfulness + (1 - baseline_weight) * spiciness * faithfulness 120 | return overall 121 | 122 | 123 | def llm_stream_sans_network( 124 | prompt: str, pydantic_model_class, return_pydantic_object=False 125 | ) -> Union[str, Dict[str, Any]]: 126 | schema = pydantic_model_class.model_json_schema() 127 | 128 | # Optional example field from schema, is not needed for the grammar generation 129 | if "example" in schema: 130 | del schema["example"] 131 | 132 | json_schema = json.dumps(schema) 133 | grammar = LlamaGrammar.from_json_schema(json_schema) 134 | 135 | stream = in_memory_llm( 136 | prompt, 137 | max_tokens=MAX_TOKENS, 138 | temperature=TEMPERATURE, 139 | grammar=grammar, 140 | stream=True 141 | ) 142 | 143 | output_text = "" 144 | for chunk in stream: 145 | result = chunk["choices"][0] 146 | print(result["text"], end='', flush=True) 147 | output_text = output_text + result["text"] 148 | 149 | print('\n') 150 | 151 | if return_pydantic_object: 152 | model_object = pydantic_model_class.model_validate_json(output_text) 153 | return model_object 154 | else: 155 | json_output = json.loads(output_text) 156 | return json_output 157 | 158 | 159 | def llm_stream_serverless(prompt,model): 160 | RUNPOD_ENDPOINT_ID = env.get("RUNPOD_ENDPOINT_ID") 161 | RUNPOD_API_KEY = env.get("RUNPOD_API_KEY") 162 | assert RUNPOD_ENDPOINT_ID, "RUNPOD_ENDPOINT_ID environment variable not set" 163 | assert RUNPOD_API_KEY, "RUNPOD_API_KEY environment variable not set" 164 | url = f"https://api.runpod.ai/v2/{RUNPOD_ENDPOINT_ID}/runsync" 165 | 166 | headers = { 167 | 'Content-Type': 'application/json', 168 | 'Authorization': f'Bearer {RUNPOD_API_KEY}' 169 | } 170 | 171 | schema = model.schema() 172 | data = { 173 | 'input': { 174 | 'schema': json.dumps(schema), 175 | 'prompt': prompt 176 | } 177 | } 178 | 179 | response = requests.post(url, json=data, headers=headers) 180 | assert response.status_code == 200, f"Unexpected RunPod API status code: {response.status_code} with body: {response.text}" 181 | result = response.json() 182 | print(result) 183 | # TODO: After a 30 second timeout, a job ID is returned in the response instead, 184 | # and the client must poll the job status endpoint to get the result. 185 | output = result['output'].replace("model:mixtral-8x7b-instruct-v0.1.Q4_K_M.gguf\n", "") 186 | # TODO: remove replacement once new version of runpod is deployed 187 | return json.loads(output) 188 | 189 | # Global variables to enforce rate limiting 190 | LAST_REQUEST_TIME = None 191 | REQUEST_INTERVAL = 0.5 # Minimum time interval between requests in seconds 192 | 193 | def llm_stream_mistral_api(prompt: str, pydantic_model_class=None, attempts=0) -> Union[str, Dict[str, Any]]: 194 | global LAST_REQUEST_TIME 195 | current_time = time() 196 | if LAST_REQUEST_TIME is not None: 197 | elapsed_time = current_time - LAST_REQUEST_TIME 198 | if elapsed_time < REQUEST_INTERVAL: 199 | sleep_time = REQUEST_INTERVAL - elapsed_time 200 | sleep(sleep_time) 201 | print(f"Slept for {sleep_time} seconds to enforce rate limit") 202 | LAST_REQUEST_TIME = time() 203 | 204 | MISTRAL_API_URL = env.get("MISTRAL_API_URL", "https://api.mistral.ai/v1/chat/completions") 205 | MISTRAL_API_KEY = env.get("MISTRAL_API_KEY", None) 206 | if not MISTRAL_API_KEY: 207 | raise ValueError("MISTRAL_API_KEY environment variable not set") 208 | headers = { 209 | 'Content-Type': 'application/json', 210 | 'Accept': 'application/json', 211 | 'Authorization': f'Bearer {MISTRAL_API_KEY}' 212 | } 213 | data = { 214 | 'model': 'mistral-small-latest', 215 | 'messages': [ 216 | { 217 | 'role': 'user', 218 | 'response_format': {'type': 'json_object'}, 219 | 'content': prompt 220 | } 221 | ] 222 | } 223 | response = requests.post(MISTRAL_API_URL, headers=headers, json=data) 224 | if response.status_code != 200: 225 | raise ValueError(f"Unexpected Mistral API status code: {response.status_code} with body: {response.text}") 226 | result = response.json() 227 | print(result) 228 | output = result['choices'][0]['message']['content'] 229 | if pydantic_model_class: 230 | # TODO: Use more robust error handling that works for all cases without retrying? 231 | # Maybe APIs that dont have grammar should be avoided? 232 | # Investigate grammar enforcement with open ended generations? 233 | try: 234 | parsed_result = pydantic_model_class.model_validate_json(output) 235 | print(parsed_result) 236 | # This will raise an exception if the model is invalid, 237 | except Exception as e: 238 | print(f"Error validating pydantic model: {e}") 239 | # Let's retry by calling ourselves again if attempts < 3 240 | if attempts == 0: 241 | # We modify the prompt to remind it to output JSON in the required format 242 | prompt = f"{prompt} You must output the JSON in the required format!" 243 | if attempts < 3: 244 | attempts += 1 245 | print(f"Retrying Mistral API call, attempt {attempts}") 246 | return llm_stream_mistral_api(prompt, pydantic_model_class, attempts) 247 | 248 | else: 249 | print("No pydantic model class provided, returning without class validation") 250 | return json.loads(output) 251 | 252 | 253 | def send_anthropic_request(prompt: str): 254 | api_key = env.get("ANTHROPIC_API_KEY") 255 | if not api_key: 256 | print("API key not found. Please set the ANTHROPIC_API_KEY environment variable.") 257 | return 258 | 259 | headers = { 260 | 'x-api-key': api_key, 261 | 'anthropic-version': '2023-06-01', 262 | 'Content-Type': 'application/json', 263 | } 264 | 265 | data = { 266 | "model": "claude-3-opus-20240229", 267 | "max_tokens": 1024, 268 | "messages": [{"role": "user", "content": prompt}] 269 | } 270 | 271 | response = requests.post('https://api.anthropic.com/v1/messages', headers=headers, data=json.dumps(data)) 272 | if response.status_code != 200: 273 | print(f"Unexpected Anthropic API status code: {response.status_code} with body: {response.text}") 274 | raise ValueError(f"Unexpected Anthropic API status code: {response.status_code} with body: {response.text}") 275 | j = response.json() 276 | 277 | text = j['content'][0]["text"] 278 | print(text) 279 | return text 280 | 281 | def llm_anthropic_api(prompt: str, pydantic_model_class=None, attempts=0) -> Union[str, Dict[str, Any]]: 282 | # With no streaming or rate limits, we use the Anthropic API, we have string input and output from send_anthropic_request, 283 | # but we need to convert it to JSON for the pydantic model class like the other APIs. 284 | output = send_anthropic_request(prompt) 285 | if pydantic_model_class: 286 | try: 287 | parsed_result = pydantic_model_class.model_validate_json(output) 288 | print(parsed_result) 289 | # This will raise an exception if the model is invalid. 290 | return json.loads(output) 291 | except Exception as e: 292 | print(f"Error validating pydantic model: {e}") 293 | # Let's retry by calling ourselves again if attempts < 3 294 | if attempts == 0: 295 | # We modify the prompt to remind it to output JSON in the required format 296 | prompt = f"{prompt} You must output the JSON in the required format only, with no remarks or prefacing remarks - JUST JSON!" 297 | if attempts < 3: 298 | attempts += 1 299 | print(f"Retrying Anthropic API call, attempt {attempts}") 300 | return llm_anthropic_api(prompt, pydantic_model_class, attempts) 301 | else: 302 | print("No pydantic model class provided, returning without class validation") 303 | return json.loads(output) 304 | 305 | def query_ai_prompt(prompt, replacements, model_class): 306 | prompt = replace_text(prompt, replacements) 307 | if LLM_WORKER == "anthropic": 308 | result = llm_anthropic_api(prompt, model_class) 309 | if LLM_WORKER == "mistral": 310 | result = llm_stream_mistral_api(prompt, model_class) 311 | if LLM_WORKER == "runpod": 312 | result = llm_stream_serverless(prompt, model_class) 313 | if LLM_WORKER == "http": 314 | result = llm_streaming(prompt, model_class) 315 | if LLM_WORKER == "in_memory": 316 | result = llm_stream_sans_network(prompt, model_class) 317 | 318 | log_entry = { 319 | "uuid": str(uuid.uuid4()), 320 | "timestamp": datetime.datetime.utcnow().isoformat(), 321 | "worker": LLM_WORKER, 322 | "prompt_input": prompt, 323 | "prompt_output": result 324 | } 325 | log_to_jsonl('prompt_inputs_and_outputs.jsonl', log_entry) 326 | 327 | return result 328 | 329 | --------------------------------------------------------------------------------