├── .github └── workflows │ ├── ci.yaml │ └── python-publish.yml ├── .gitignore ├── LICENSE ├── README.md ├── doc ├── class-diagram.dia ├── class-diagram.png └── images │ ├── cmp_candidates-1.png │ ├── cmp_candidates-2.png │ └── concept.png ├── examples ├── antonyms │ ├── antonyms.db │ ├── antonyms.db.cfg │ ├── antonyms.db.cfg.default │ ├── antonyms.db.default │ └── dataset.json └── words.json ├── ppromptor ├── __init__.py ├── agent.py ├── analyzers │ └── __init__.py ├── base │ ├── __init__.py │ ├── command.py │ └── schemas.py ├── config.py ├── db.py ├── evaluators │ └── __init__.py ├── job_queues │ └── __init__.py ├── llms │ ├── __init__.py │ └── wizardlm.py ├── loggers │ └── __init__.py ├── proposers │ └── __init__.py ├── scorefuncs │ └── __init__.py └── utils.py ├── requirements.txt ├── requirements_local_model.txt ├── requirements_test.txt ├── scripts └── ppromptor-cli.py ├── setup.py ├── tests ├── fake_llms.py ├── test_agents.py ├── test_schemas │ ├── test_analysis.py │ ├── test_db.py │ ├── test_eval_result.py │ ├── test_iopair.py │ ├── test_prompt_candidate.py │ └── test_recommendation.py └── test_scorefuncs.py └── ui ├── app.py ├── components.py ├── config.py └── utils.py /.github/workflows/ci.yaml: -------------------------------------------------------------------------------- 1 | name: Python package 2 | 3 | on: [push] 4 | 5 | jobs: 6 | build: 7 | 8 | runs-on: ubuntu-latest 9 | strategy: 10 | matrix: 11 | python-version: ["3.8", "3.9", "3.10", "3.11"] 12 | 13 | steps: 14 | - uses: actions/checkout@v3 15 | - name: Set up Python ${{ matrix.python-version }} 16 | uses: actions/setup-python@v4 17 | with: 18 | python-version: ${{ matrix.python-version }} 19 | - name: Install dependencies 20 | run: | 21 | python -m pip install --upgrade pip 22 | pip install -r requirements.txt 23 | pip install -r requirements_test.txt 24 | 25 | - name: Test with pytest 26 | run: | 27 | echo "PYTHONPATH=/home/runner/work/ppromptor/ppromptor" >> $GITHUB_ENV 28 | export PYTHONPATH=/home/runner/work/ppromptor/ppromptor 29 | pytest -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries 3 | 4 | # This workflow uses actions that are not certified by GitHub. 5 | # They are provided by a third-party and are governed by 6 | # separate terms of service, privacy policy, and support 7 | # documentation. 8 | 9 | name: Upload Python Package 10 | 11 | on: 12 | release: 13 | types: [published] 14 | 15 | permissions: 16 | contents: read 17 | 18 | jobs: 19 | deploy: 20 | 21 | runs-on: ubuntu-latest 22 | 23 | steps: 24 | - uses: actions/checkout@v3 25 | - name: Set up Python 26 | uses: actions/setup-python@v3 27 | with: 28 | python-version: '3.x' 29 | - name: Install dependencies 30 | run: | 31 | python -m pip install --upgrade pip 32 | pip install build 33 | - name: Build package 34 | run: python -m build 35 | - name: Publish package 36 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 37 | with: 38 | user: __token__ 39 | password: ${{ secrets.PYPI_API_TOKEN }} 40 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Editors 2 | .vscode/ 3 | .idea/ 4 | 5 | # Vagrant 6 | .vagrant/ 7 | 8 | # Mac/OSX 9 | .DS_Store 10 | 11 | # Windows 12 | Thumbs.db 13 | 14 | # Source for the following rules: https://raw.githubusercontent.com/github/gitignore/master/Python.gitignore 15 | # Byte-compiled / optimized / DLL files 16 | __pycache__/ 17 | *.py[cod] 18 | *$py.class 19 | 20 | # C extensions 21 | *.so 22 | 23 | # Distribution / packaging 24 | .Python 25 | build/ 26 | develop-eggs/ 27 | dist/ 28 | downloads/ 29 | eggs/ 30 | .eggs/ 31 | lib/ 32 | lib64/ 33 | parts/ 34 | sdist/ 35 | var/ 36 | wheels/ 37 | *.egg-info/ 38 | .installed.cfg 39 | *.egg 40 | MANIFEST 41 | 42 | # PyInstaller 43 | # Usually these files are written by a python script from a template 44 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 45 | *.manifest 46 | *.spec 47 | 48 | # Installer logs 49 | pip-log.txt 50 | pip-delete-this-directory.txt 51 | 52 | # Unit test / coverage reports 53 | htmlcov/ 54 | .tox/ 55 | .nox/ 56 | .coverage 57 | .coverage.* 58 | .cache 59 | nosetests.xml 60 | coverage.xml 61 | *.cover 62 | .hypothesis/ 63 | .pytest_cache/ 64 | 65 | # Translations 66 | *.mo 67 | *.pot 68 | 69 | # Django stuff: 70 | *.log 71 | local_settings.py 72 | db.sqlite3 73 | 74 | # Flask stuff: 75 | instance/ 76 | .webassets-cache 77 | 78 | # Scrapy stuff: 79 | .scrapy 80 | 81 | # Sphinx documentation 82 | docs/_build/ 83 | 84 | # PyBuilder 85 | target/ 86 | 87 | # Jupyter Notebook 88 | .ipynb_checkpoints 89 | 90 | # IPython 91 | profile_default/ 92 | ipython_config.py 93 | 94 | # pyenv 95 | .python-version 96 | 97 | # celery beat schedule file 98 | celerybeat-schedule 99 | 100 | # SageMath parsed files 101 | *.sage.py 102 | 103 | # Environments 104 | .env 105 | .venv 106 | env/ 107 | venv/ 108 | ENV/ 109 | env.bak/ 110 | venv.bak/ 111 | 112 | 113 | .pytest_cache/ 114 | *_private_llms/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 pikho 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Prompt-Promptor: An Autonomous Agent Framework for Prompt Engineering 2 | 3 | Prompt-Promptor(or shorten for ppromptor) is a Python library designed to automatically generate and improve prompts for LLMs. It draws inspiration from autonomous agents like AutoGPT and consists of three agents: Proposer, Evaluator, and Analyzer. These agents work together with human experts to continuously improve the generated prompts. 4 | 5 | ## 🚀 Features: 6 | 7 | - 🤖 The use of LLMs to prompt themself by giving few samples. 8 | 9 | - 💪 Guidance for OSS LLMs(eg, LLaMA) by more powerful LLMs(eg, GPT4) 10 | 11 | - 📈 Continuously improvement. 12 | 13 | - 👨‍👨‍👧‍👦 Collaboration with human experts. 14 | 15 | - 💼 Experiment management for prompt engineering. 16 | 17 | - 🖼 Web GUI interface. 18 | 19 | - 🏳️‍🌈 Open Source. 20 | 21 | ## Warning 22 | - This project is currently in its earily stage, and it is anticipated that there will be major design changes in the future. 23 | 24 | - The main function utilizes an infinite loop to enhance the generation of prompts. If you opt for OpenAI's ChatGPT as Target/Analysis LLMs, kindly ensure that you set a usage limit. 25 | 26 | ## Concept 27 | 28 | ![Compare Prompts](https://github.com/pikho/ppromptor/blob/main/doc/images/concept.png?raw=true) 29 | 30 | A more detailed class diagram could be found in [doc](https://github.com/pikho/ppromptor/tree/main/doc) 31 | 32 | ## Installations 33 | 34 | ### From Github 35 | 36 | 1. Install Package 37 | ``` 38 | pip install ppromptor --upgrade 39 | ``` 40 | 41 | 2. Clone Repository from Github 42 | ``` 43 | git clone https://github.com/pikho/ppromptor.git 44 | ``` 45 | 46 | 3. Start Web UI 47 | ``` 48 | cd ppromptor 49 | streamlit run ui/app.py 50 | ``` 51 | 52 | ### Running Local Model(WizardLM) 53 | 1. Install Required Packages 54 | ``` 55 | pip install requirements_local_model.txt 56 | ``` 57 | 58 | 2. Test if WizardLM can run correctly 59 | ``` 60 | cd /ppromptor/llms 61 | python wizardlm.py 62 | ``` 63 | 64 | ## Usage 65 | 66 | 1. Start the Web App 67 | ``` 68 | cd 69 | streamlit run ui/app.py 70 | ``` 71 | 72 | 2. Load the Demo Project 73 | Load `examples/antonyms.db`(default) for demo purposes. This demonstrates how to use ChatGPT to guide WizardLM to generate antonyms for given inputs. 74 | 75 | 3. Configuration 76 | In the Configuration tab, set `Target LLM` as `wizardlm` if you can infer this model locally. Or choose both `Target LLM` and `Analysis LLM` as `chatgpt`. If chatgpt is used, please provide the OpenAI API Key. 77 | 78 | 4. Load the dataset 79 | The demo project has already loaded 5 records. You can add your own dataset.(Optional) 80 | 81 | 5. Start the Workload 82 | Press the `Start` button to activate the workflow. 83 | 84 | 5. Prompt Candidates 85 | Generated prompts can be found in the `Prompt Candidates` tab. Users can modify generated prompts by selecting only 1 Candidate, then modifying the prompt, then `Create Prompt`. This new prompt will be evaluated by Evaluator agent and then keep improving by Analyzer agent. By selecting 2 prompts, we can compare these prompts side by side. 86 | 87 | ![Compare Prompts](https://github.com/pikho/ppromptor/blob/main/doc/images/cmp_candidates-1.png?raw=true) 88 | 89 | ![Compare Prompts](https://github.com/pikho/ppromptor/blob/main/doc/images/cmp_candidates-2.png?raw=true) 90 | 91 | ## Contribution 92 | We welcome all kinds of contributions, including new feature requests, bug fixes, new feature implementation, examples, and documentation updates. If you have a specific request, please use the "Issues" section. For other contributions, simply create a pull request (PR). Your participation is highly valued in improving our project. Thank you! -------------------------------------------------------------------------------- /doc/class-diagram.dia: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pikho/ppromptor/d80782c3b755e54c6333d85bcb88513a28868833/doc/class-diagram.dia -------------------------------------------------------------------------------- /doc/class-diagram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pikho/ppromptor/d80782c3b755e54c6333d85bcb88513a28868833/doc/class-diagram.png -------------------------------------------------------------------------------- /doc/images/cmp_candidates-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pikho/ppromptor/d80782c3b755e54c6333d85bcb88513a28868833/doc/images/cmp_candidates-1.png -------------------------------------------------------------------------------- /doc/images/cmp_candidates-2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pikho/ppromptor/d80782c3b755e54c6333d85bcb88513a28868833/doc/images/cmp_candidates-2.png -------------------------------------------------------------------------------- /doc/images/concept.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pikho/ppromptor/d80782c3b755e54c6333d85bcb88513a28868833/doc/images/concept.png -------------------------------------------------------------------------------- /examples/antonyms/antonyms.db: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pikho/ppromptor/d80782c3b755e54c6333d85bcb88513a28868833/examples/antonyms/antonyms.db -------------------------------------------------------------------------------- /examples/antonyms/antonyms.db.cfg: -------------------------------------------------------------------------------- 1 | { 2 | "openai_key": "", 3 | "target_llm": "chatgpt", 4 | "analysis_llm": "chatgpt" 5 | } -------------------------------------------------------------------------------- /examples/antonyms/antonyms.db.cfg.default: -------------------------------------------------------------------------------- 1 | { 2 | "openai_key": "", 3 | "target_llm": "chatgpt", 4 | "analysis_llm": "chatgpt" 5 | } -------------------------------------------------------------------------------- /examples/antonyms/antonyms.db.default: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pikho/ppromptor/d80782c3b755e54c6333d85bcb88513a28868833/examples/antonyms/antonyms.db.default -------------------------------------------------------------------------------- /examples/antonyms/dataset.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "input": "small", 4 | "output": "big" 5 | }, 6 | { 7 | "input": "beautiful", 8 | "output": "ugly" 9 | }, 10 | { 11 | "input": "good", 12 | "output": "bad" 13 | }, 14 | { 15 | "input": "long", 16 | "output": "short" 17 | }, 18 | { 19 | "input": "near", 20 | "output": "far" 21 | } 22 | ] 23 | -------------------------------------------------------------------------------- /examples/words.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "input": "EdgeDB 3.0 is now available as a free and open source product. With it comes a host of new features and enhancements, including support for Java, Elixir, Clojure, and TypeScript. We'll spend the rest of this chapter describing some of these new features in greater detail, but first, let's take a look at the UI. The UI of EdgeDB 2.0 was great, but we didn't have the tools to make it better; with EdgeDB 3, we've improved it dramatically. Triggers and Mutation Rewrites were two of the most requested features by our customers when EdgeDB first came out, but since they weren't very well-known, no one had much use for them until recently. In order to make their database more approachable, and easier to use, people wanted EdgeDB to have built-in triggers and mutation rewrites. These allow you to quickly and easily change the types of data in the database without ever having to explicitly switch to another database. Now that's just not possible. To solve this problem, they've added support for \"mutation\" flows, which allows you to write queries that periodically update the underlying database. More on this in a minute. While all of this is really useful, we also want to give you an overview of how it stacks up against other databases like PostgreSQL. Here are some other big differences between PostgreSQL and EdgeDB: You can run your queries directly from the command line instead of having to go through a bunch of intermediate-level libraries; you can nest multiple modules together without worrying about backwards compatibility ; you can usefully combine multiple different programming languages into a single query; and finally, you can generate dashboards using various statistical tools such as Metabase, Cluvio, etc. Finally, EdgeDB now supports Java and Elixir. Our goal is to have clients for all major programming languages and runtimes, so we've created a set of plugins for those. They're called \" EdgeDB Cloud\" and they're designed to be extremely easy to use and extensible.", 4 | "output": "['mutation', 'dramatically', 'approachable', 'Enhancements']" 5 | }, 6 | { 7 | "input": "In this chapter, the authors discuss two approaches to fine-tuning an LLM: knowledge graph analysis and large language models. Knowledge graphs can be used to answer complex graph questions because they allow the end-user to ask questions that a trained expert might not be able to answer directly. Large language models like OpenAI's GPT LLM can also be used for this purpose. In this case, the author uses an RDF knowledge graph of a process flow sheet to help him understand how Valve-104 interacts with Reflux-401.", 8 | "output": "['Reflux', 'Expert', 'Interacts', 'Concisely']" 9 | }, 10 | { 11 | "input": "GitHub Copilot for Enterprises vs. Codeium for Enterprises: On-Prem GitHub Copilot.TL;DR You shouldn't have to choose between best security practices and improving developer productivity with AI-powered code assistance. Our solution is purpose built to run on-prem or in your VPC - no data or telemetry ever leaves. No matter where you put your Codeium instance, you'll always have control over it. It's completely self-hosted, which means that you don't need to worry about third-party vendors taking advantage of you. This makes it much easier for enterprises to set up and maintain. The only downside to this deployment process is that it can be very time-consuming. However, because it's so well-designed and well-tested, enterprises are willing to fork over large portions of their IT budgets to support it.", 12 | "output": "['telemetry']" 13 | } 14 | ] 15 | -------------------------------------------------------------------------------- /ppromptor/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.1.4" 2 | -------------------------------------------------------------------------------- /ppromptor/agent.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from copy import copy, deepcopy 3 | from typing import Dict, List 4 | 5 | import sqlalchemy 6 | from langchain.chains.llm import LLMChain 7 | from ppromptor.analyzers import Analyzer 8 | from ppromptor.base.command import CommandExecutor 9 | from ppromptor.base.schemas import EvalSet, IOPair, PromptCandidate 10 | from ppromptor.config import DEFAULT_PRIORITY 11 | from ppromptor.db import create_engine, get_session, reset_running_cmds 12 | from ppromptor.evaluators import Evaluator 13 | from ppromptor.job_queues import BaseJobQueue, ORMJobQueue, PriorityJobQueue 14 | from ppromptor.loggers import logger 15 | from ppromptor.proposers import Proposer 16 | from ppromptor.scorefuncs import SequenceMatcherScore 17 | 18 | 19 | class BaseAgent: 20 | def __init__(self, eval_llm, analysis_llm, db=None): 21 | self.eval_llm = eval_llm 22 | self.analysis_llm = analysis_llm 23 | if isinstance(db, str): 24 | engine = create_engine(db) 25 | self.db_sess = get_session(engine) 26 | elif isinstance(db, sqlalchemy.orm.session.Session): 27 | self.db_sess = db 28 | else: 29 | self.db_sess = None 30 | self._agent_state = 0 # 0: stoped, 1: running 2: waiting for stopping 31 | 32 | @property 33 | def state(self): 34 | return self._agent_state 35 | 36 | @state.setter 37 | def state(self, state: int): 38 | assert state in [0, 1, 2] 39 | 40 | self._agent_state = state 41 | 42 | def stop(self): 43 | self.state = 2 44 | logger.info(f" agent call stop, state {self.state}") 45 | 46 | def run(self, dataset) -> None: 47 | pass 48 | 49 | 50 | class SimpleAgent(BaseAgent): 51 | def run(self, dataset) -> None: 52 | candidates: List[PromptCandidate] = [] 53 | 54 | while True: 55 | 56 | # 1. Propose Candidates 57 | if len(candidates) <= 0: 58 | proposer = Proposer(self.analysis_llm) 59 | candidate = proposer.propose(dataset) 60 | else: 61 | candidate = candidates.pop() 62 | 63 | if self.db_sess: 64 | self.db_sess.add(candidate) 65 | self.db_sess.commit() 66 | 67 | evaluatee_prompt = candidate.prompt 68 | print("evaluatee_prompt", evaluatee_prompt) 69 | 70 | # 2. Evaluate Candidates and Generate EvalResults 71 | evaluator = Evaluator(self.eval_llm) 72 | evaluator.add_score_func(SequenceMatcherScore(llm=None)) 73 | 74 | eval_set = evaluator.evaluate(dataset, candidate) 75 | 76 | if self.db_sess: 77 | self.db_sess.add(eval_set) 78 | for res in eval_set.results: 79 | self.db_sess.add(res) 80 | self.db_sess.commit() 81 | 82 | logger.info(f"Final score: {eval_set.final_score}") 83 | 84 | # 3. Analyze EvalResults and Generate Analysis and Recommendation 85 | reports = [] 86 | analyzer = Analyzer(llm=self.analysis_llm) 87 | 88 | analysis = analyzer.analyze(candidate, [eval_set]) 89 | reports.append(analysis) 90 | 91 | print("\n*** Role ***") 92 | for rp in reports: 93 | print(rp.recommendation.role) 94 | 95 | print("\n*** Goal ***") 96 | for rp in reports: 97 | print(rp.recommendation.goal) 98 | 99 | for i in range(len(reports)): 100 | idx = (i + 1) * -1 101 | report = reports[idx] 102 | if report.recommendation: 103 | revised_candidate = PromptCandidate( 104 | report.recommendation.role, 105 | report.recommendation.goal, 106 | report.recommendation.guidelines, 107 | report.recommendation.constraints, 108 | examples=[], 109 | output_format="" 110 | ) 111 | candidates.append(revised_candidate) 112 | break 113 | 114 | if self.db_sess: 115 | self.db_sess.add(analysis) 116 | self.db_sess.commit() 117 | 118 | 119 | class JobQueueAgent(BaseAgent): 120 | def __init__(self, eval_llm, analysis_llm, db=None) -> None: 121 | super().__init__(eval_llm, analysis_llm, db) 122 | self._queue: BaseJobQueue = ORMJobQueue(session=self.db_sess) 123 | 124 | self._cmds: Dict[str, CommandExecutor] = { 125 | "Evaluator": Evaluator(self.eval_llm, 126 | [SequenceMatcherScore(llm=None)]), 127 | "Analyzer": Analyzer(self.analysis_llm), 128 | "Proposer": Proposer(self.analysis_llm) 129 | } 130 | 131 | self._cmd_output = { 132 | "Evaluator": "eval_sets", 133 | "Analyzer": "analysis", 134 | "Proposer": "candidate" 135 | } 136 | 137 | self._next_action = { 138 | "Evaluator": "Analyzer", 139 | "Analyzer": "Proposer", 140 | "Proposer": "Evaluator" 141 | } 142 | 143 | def get_runner(self, cmd_s: str): 144 | return self._cmds[cmd_s] 145 | 146 | def add_command(self, cmd_s: str, data: dict, priority: int): 147 | self._queue.put({ 148 | "cmd": cmd_s, 149 | "data": data 150 | }, priority) 151 | 152 | def run(self, dataset, epochs=-1) -> None: 153 | 154 | self.state = 1 155 | 156 | # FIXME: This is a workround for incompleted commands 157 | # Background: When a command is popped from the queue, 158 | # its state is set to 1 (running). However, if the process 159 | # is interupted before the task finishes, it will 160 | # be ignored. This workaround will reset all 161 | # tasks with a stat of 1 to 0 at the begining of run(), which 162 | # assume that only one agent can access the queue. 163 | # Using Context Manager should be a better way to fix this. 164 | 165 | reset_running_cmds(self.db_sess) 166 | 167 | data = { 168 | "candidate": None, 169 | "dataset": dataset, 170 | "eval_sets": None, 171 | "analysis": None 172 | } 173 | 174 | if self._queue.empty(): 175 | self.add_command("Proposer", data, DEFAULT_PRIORITY) 176 | 177 | acc_epochs = 0 178 | 179 | while self.state == 1 and (not self._queue.empty()): 180 | priority, task = self._queue.get() 181 | 182 | task_id = task["id"] 183 | cmd_s = task["cmd"] 184 | logger.info(f"Execute Command(cmd={cmd_s}, id={task_id})") 185 | 186 | for k, v in task["data"].items(): 187 | if v: 188 | data[k] = v 189 | 190 | runner = self.get_runner(cmd_s) 191 | result = runner.run_cmd(**task["data"]) 192 | logger.info(f"Result: {result}") 193 | 194 | if self.db_sess: 195 | self.db_sess.add(result) 196 | self.db_sess.commit() 197 | 198 | data = copy(data) 199 | # Cannot use deepcopy here since the copied elements are not 200 | # associatiated with any ORM session, which causes error. 201 | # Shallow copy() is suitable in this use case to prevent 202 | # different tasks accessing the same data object 203 | 204 | data[self._cmd_output[cmd_s]] = result 205 | 206 | self.add_command(self._next_action[cmd_s], 207 | data, 208 | DEFAULT_PRIORITY) 209 | self._queue.done(task, 2) 210 | 211 | acc_epochs += 1 212 | 213 | if acc_epochs == epochs: 214 | self.state = 0 215 | return None 216 | 217 | self.state = 0 218 | -------------------------------------------------------------------------------- /ppromptor/analyzers/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | import re 3 | import textwrap 4 | from abc import abstractmethod 5 | from typing import List, Union 6 | 7 | from langchain.chains.llm import LLMChain 8 | from langchain.llms.base import BaseLLM 9 | from langchain.prompts import PromptTemplate 10 | from ppromptor.base.command import CommandExecutor 11 | from ppromptor.base.schemas import (Analysis, EvalResult, EvalSet, 12 | PromptCandidate, Recommendation) 13 | from ppromptor.config import PP_VERBOSE 14 | from ppromptor.loggers import logger 15 | from ppromptor.scorefuncs import score_func_selector 16 | from ppromptor.utils import bulletpointize, get_llm_params 17 | 18 | 19 | class BaseAnalyzer(CommandExecutor): 20 | def __init__(self, llm: BaseLLM) -> None: 21 | self._prompt: Union[PromptTemplate, None] = None 22 | self.template: PromptCandidate 23 | self.llm = llm 24 | self._prompt_str: str 25 | self._validate_prompt() 26 | 27 | @property 28 | def prompt(self): 29 | if self._prompt is None: 30 | self._prompt = PromptTemplate( 31 | template=textwrap.dedent(self._prompt_str), 32 | input_variables=["role", 33 | "goal", 34 | "guidelines", 35 | "constraints", 36 | "prediction_anwsers", 37 | "evaluation_scores", 38 | "score_funcs"]) 39 | return self._prompt 40 | 41 | def _validate_prompt(self): 42 | assert isinstance(self._prompt_str, str) 43 | assert "{role}" in self._prompt_str 44 | assert "{goal}" in self._prompt_str 45 | assert "{guidelines}" in self._prompt_str 46 | 47 | @abstractmethod 48 | def analyze(self, candidate, eval_sets: List[EvalSet], **kwargs): 49 | pass 50 | 51 | def run_cmd(self, **kwargs): 52 | return self.analyze(**kwargs) 53 | 54 | def _select_results(self, eval_sets: List[EvalSet]): 55 | pass 56 | 57 | 58 | class Analyzer(BaseAnalyzer): 59 | def __init__(self, llm): 60 | self._prompt_str = """ 61 | I create an LLM AI robot that work as a {role} to {goal}. This AI robot 62 | is equipped with a LLM to generate output and is expected to follow 63 | below GUIDELINES and CONSTRAINTS. I expect to get get below answers, 64 | however, the AI robot outputs the prediction as provided. 65 | 66 | ROLE: 67 | {role} 68 | 69 | GOAL: 70 | {goal} 71 | 72 | GUIDELINES: 73 | {guidelines} 74 | 75 | CONSTRAINTS: 76 | {constraints} 77 | 78 | Input, Prediction and Expected Answer triplets: 79 | {prediction_anwsers} 80 | 81 | Evaluation Scores: 82 | {evaluation_scores} 83 | 84 | Description of Evaluation Functions: 85 | {score_funcs} 86 | 87 | Please refer above evaluation scores and describe the difference 88 | between preditions and expected answers. And explain why the given 89 | role, goal, guidelines and constraints produce the predictions 90 | instead of expected answers. Write down in "THOUGHTS" Section. 91 | 92 | Then, Base on above thoughts, please provide revised ROLE, GOAL, 93 | GUIDELINES and CONSTRAINTS that maximize evaluation 94 | scores. Write down in "REVISION" section in below format: 95 | 96 | REVISION: 97 | 98 | revised ROLE: ... 99 | revised GOAL: ... 100 | revised GUIDELINES: 101 | 1. ... 102 | 2. ... 103 | ... 104 | revised CONSTRAINTS: 105 | 1. ... 106 | 2. ... 107 | ... 108 | 109 | Ok, now, lets think step by step. 110 | """ 111 | super().__init__(llm) 112 | 113 | def _select_results(self, eval_sets: List[EvalSet]): 114 | return eval_sets 115 | 116 | def analyze(self, candidate, eval_sets: List[EvalSet], **kwargs): 117 | if isinstance(eval_sets, EvalSet): 118 | eval_sets = [eval_sets] 119 | 120 | results = self._select_results(eval_sets) 121 | 122 | chain = LLMChain(llm=self.llm, prompt=self.prompt, verbose=PP_VERBOSE) 123 | 124 | pred = "\n".join(["\n".join([str(x) for x in r_set.results]) for r_set in results]) 125 | eval_scores = "\n".join(["\n".join([str(x.scores) for x in r_set.results]) for r_set in results]) 126 | 127 | used_scorefunc_names = set() 128 | for r_set in results: 129 | for result in r_set.results: 130 | for key in result.scores.keys(): 131 | used_scorefunc_names.add(key) 132 | 133 | score_funcs = score_func_selector(list(used_scorefunc_names)) 134 | score_funcc_desc = [f"{func.__name__}: {func().description}" for func in score_funcs] # type: ignore[attr-defined, operator] 135 | value = { 136 | "role": candidate.role, 137 | "goal": candidate.goal, 138 | "guidelines": bulletpointize(candidate.guidelines), 139 | "constraints": bulletpointize(candidate.constraints), 140 | "prediction_anwsers": pred, 141 | "evaluation_scores": eval_scores, 142 | "score_funcs": bulletpointize(score_funcc_desc) 143 | } 144 | 145 | res = chain(value) 146 | 147 | recommendation = self.parse_output(res["text"]) 148 | 149 | return Analysis(self.__class__.__name__, 150 | results, recommendation) 151 | 152 | def parse_output(self, output): 153 | 154 | logger.info(f"Output: {output}") 155 | 156 | try: 157 | thoughts = re.findall('(.*)REVISION', 158 | output, 159 | re.DOTALL | re.IGNORECASE)[0] 160 | except IndexError: 161 | thoughts = output 162 | 163 | try: 164 | res = re.findall('REVISION(.*)', 165 | output, 166 | re.DOTALL | re.IGNORECASE)[0] 167 | revision = res 168 | 169 | except IndexError: 170 | revision = "" 171 | 172 | try: 173 | role = re.findall('ROLE:(.*?)\n', output, re.IGNORECASE)[0] 174 | except IndexError: 175 | role = "" 176 | 177 | try: 178 | goal = re.findall('GOAL:(.*?)\n', output, re.IGNORECASE)[0] 179 | except IndexError: 180 | goal = "" 181 | 182 | try: 183 | guidelines = re.findall('GUIDELINES:(.*?)CONSTRAINTS', 184 | output, 185 | re.DOTALL | re.IGNORECASE)[0] 186 | 187 | guidelines = re.findall('\d\.(.*?)\n', 188 | guidelines) 189 | except IndexError: 190 | guidelines = [] 191 | 192 | try: 193 | constraints = re.findall('CONSTRAINTS:(.*)', 194 | output, 195 | re.DOTALL | re.IGNORECASE)[0] 196 | constraints = re.findall('\d\.(.*?)\n', 197 | constraints) 198 | except IndexError: 199 | constraints = [] 200 | 201 | return Recommendation(thoughts, revision, role, goal, 202 | guidelines, constraints, 203 | examples=[], output_format="") 204 | -------------------------------------------------------------------------------- /ppromptor/base/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pikho/ppromptor/d80782c3b755e54c6333d85bcb88513a28868833/ppromptor/base/__init__.py -------------------------------------------------------------------------------- /ppromptor/base/command.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | 3 | 4 | class CommandExecutor: 5 | @abstractmethod 6 | def run_cmd(self, **kwargs): 7 | pass 8 | -------------------------------------------------------------------------------- /ppromptor/base/schemas.py: -------------------------------------------------------------------------------- 1 | import textwrap 2 | from dataclasses import dataclass, field 3 | from typing import Any, Dict, List, Optional, Union 4 | 5 | from dataclasses_json import dataclass_json 6 | from langchain.prompts import PromptTemplate 7 | from ppromptor.utils import bulletpointize 8 | from sqlalchemy import JSON, Column, ForeignKey, Table 9 | from sqlalchemy.orm import (DeclarativeBase, Mapped, MappedAsDataclass, 10 | mapped_column, relationship) 11 | 12 | 13 | class Base(MappedAsDataclass, DeclarativeBase): 14 | """subclasses will be converted to dataclasses""" 15 | 16 | 17 | @dataclass_json 18 | class PromptCandidate(Base): 19 | __tablename__ = "prompt_candidate" 20 | 21 | id: Mapped[int] = mapped_column(init=False, primary_key=True) 22 | 23 | role: Mapped[str] = mapped_column() 24 | goal: Mapped[str] = mapped_column() 25 | guidelines: Mapped[List[str]] = Column(JSON) 26 | constraints: Mapped[List[str]] = Column(JSON) 27 | examples: Mapped[List[str]] = Column(JSON) 28 | output_format: Mapped[str] = mapped_column(default="") 29 | 30 | @property 31 | def prompt(self): 32 | guidelines = bulletpointize(self.guidelines) 33 | constraints = bulletpointize(self.constraints) 34 | 35 | prompt_str = (f"You are a {self.role}. Your job is to {self.goal}.", 36 | "Always follow below guidelines:", 37 | "", 38 | "Guideline:", 39 | f"{guidelines}", 40 | "", 41 | "Strickly follow below constraints:", 42 | "", 43 | "Constraints:", 44 | f"{constraints}", 45 | "", 46 | "Input:", 47 | "{input}", 48 | "", 49 | "Now, generate output accordingly:") 50 | 51 | print(prompt_str) 52 | return PromptTemplate(template=textwrap.dedent("\n".join(prompt_str)), 53 | input_variables=["input"]) 54 | 55 | 56 | @dataclass_json 57 | class IOPair(Base): 58 | __tablename__ = "io_pair" 59 | 60 | id: Mapped[int] = mapped_column(init=False, primary_key=True) 61 | 62 | input: Mapped[str] = mapped_column() 63 | output: Mapped[str] = mapped_column() 64 | 65 | def __str__(self): 66 | return f"Input: {self.input}; Output: {self.output}" 67 | 68 | 69 | association_result_set = Table( 70 | "association_result_set", 71 | Base.metadata, 72 | Column("eval_set_id", ForeignKey("eval_set.id")), 73 | Column("eval_result_id", ForeignKey("eval_result.id")), 74 | ) 75 | 76 | 77 | @dataclass_json 78 | class EvalResult(Base): 79 | __tablename__ = "eval_result" 80 | 81 | id: Mapped[int] = mapped_column(init=False, primary_key=True) 82 | evaluator_name: Mapped[str] = mapped_column() 83 | 84 | candidate: Mapped["PromptCandidate"] = relationship() 85 | data: Mapped["IOPair"] = relationship() 86 | 87 | prediction: Mapped[str] = mapped_column() 88 | 89 | scores: Mapped[Dict[str, float]] = Column(JSON) 90 | llm_params: Mapped[Dict[str, Any]] = Column(JSON) 91 | 92 | candidate_id: Mapped[int] = mapped_column( 93 | ForeignKey("prompt_candidate.id"), default=None) 94 | data_id: Mapped[int] = mapped_column( 95 | ForeignKey("io_pair.id"), default=None) 96 | 97 | def __str__(self): 98 | return (f"Input: [{self.data.input}]," 99 | f" Prediction: [{self.prediction}]," 100 | f" Answer: [{self.data.output}]") 101 | 102 | 103 | @dataclass_json 104 | class EvalSet(Base): 105 | __tablename__ = "eval_set" 106 | 107 | id: Mapped[int] = mapped_column(init=False, primary_key=True) 108 | candidate: Mapped["PromptCandidate"] = relationship() 109 | results: Mapped[List[EvalResult]] = relationship( 110 | secondary=association_result_set, default_factory=list) 111 | scores: Mapped[Dict[str, float]] = Column(JSON, default={}) 112 | final_score: Mapped[float] = mapped_column(default=None) 113 | candidate_id: Mapped[int] = mapped_column( 114 | ForeignKey("prompt_candidate.id"), default=None) 115 | 116 | 117 | @dataclass_json 118 | class Recommendation(Base): 119 | __tablename__ = "recommendation" 120 | 121 | id: Mapped[int] = mapped_column(init=False, primary_key=True) 122 | 123 | thoughts: Mapped[str] = mapped_column() 124 | revision: Mapped[str] = mapped_column() 125 | 126 | role: Mapped[str] = mapped_column() 127 | goal: Mapped[str] = mapped_column() 128 | guidelines: Mapped[List[str]] = Column(JSON) 129 | constraints: Mapped[List[str]] = Column(JSON) 130 | examples: Mapped[List[str]] = Column(JSON) 131 | output_format: Mapped[Optional[str]] = mapped_column(default=None) 132 | 133 | 134 | association_resultset_analysis = Table( 135 | "association_resultset_analysis", 136 | Base.metadata, 137 | Column("analysis_id", ForeignKey("analysis.id")), 138 | Column("eval_set_id", ForeignKey("eval_set.id")), 139 | ) 140 | 141 | 142 | @dataclass_json 143 | class Analysis(Base): 144 | __tablename__ = "analysis" 145 | 146 | id: Mapped[int] = mapped_column(init=False, primary_key=True) 147 | 148 | analyzer_name: Mapped[str] = mapped_column() 149 | 150 | eval_sets: Mapped[List[EvalSet]] = relationship( 151 | secondary=association_resultset_analysis) 152 | 153 | recommendation: Mapped["Recommendation"] = relationship() 154 | 155 | rcm_id: Mapped[int] = mapped_column(ForeignKey("recommendation.id"), 156 | default=None) 157 | 158 | 159 | @dataclass_json 160 | class Command(Base): 161 | __tablename__ = "command" 162 | 163 | id: Mapped[int] = mapped_column(init=False, primary_key=True) 164 | 165 | cmd: Mapped[str] = Column(JSON) 166 | """ 167 | { 168 | "cls": str, 169 | "params": { 170 | "key": value 171 | } 172 | } 173 | """ 174 | data: Mapped[str] = Column(JSON) 175 | 176 | """ 177 | { 178 | "param_key": { 179 | "data_cls": str, 180 | "data_id": List[int] 181 | } 182 | } 183 | """ 184 | owner: Mapped[str] = mapped_column(default=None) 185 | priority: Mapped[int] = mapped_column(default=0) 186 | state: Mapped[int] = mapped_column(default=0) 187 | # 0: waiting, 1: running, 2: successful 3: failed 188 | 189 | @property 190 | def data_obj(self): 191 | return None 192 | 193 | @property 194 | def cmd_obj(self): 195 | return None 196 | 197 | 198 | TABLE_MAP = { 199 | "PromptCandidate": PromptCandidate, 200 | "IOPair": IOPair, 201 | "EvalResult": EvalResult, 202 | "EvalSet": EvalSet, 203 | "Recommendation": Recommendation, 204 | "Analysis": Analysis 205 | } 206 | -------------------------------------------------------------------------------- /ppromptor/config.py: -------------------------------------------------------------------------------- 1 | PP_VERBOSE = False 2 | DEFAULT_PRIORITY = 0 3 | -------------------------------------------------------------------------------- /ppromptor/db.py: -------------------------------------------------------------------------------- 1 | from ppromptor.base.schemas import (Analysis, Command, EvalResult, EvalSet, 2 | IOPair, PromptCandidate, Recommendation, 3 | association_result_set, 4 | association_resultset_analysis) 5 | from sqlalchemy import create_engine as slc_create_engine 6 | from sqlalchemy.orm import Session 7 | 8 | CMD_STATE_CODE = { 9 | 0: "W", 10 | 1: "R", 11 | 2: "S", 12 | 3: "F" 13 | } 14 | 15 | def create_engine(db_path, echo=False): 16 | engine = slc_create_engine(f"sqlite+pysqlite:///{db_path}", echo=echo) 17 | 18 | Analysis.__table__.create(engine, checkfirst=True) 19 | EvalResult.__table__.create(engine, checkfirst=True) 20 | EvalSet.__table__.create(engine, checkfirst=True) 21 | IOPair.__table__.create(engine, checkfirst=True) 22 | PromptCandidate.__table__.create(engine, checkfirst=True) 23 | Recommendation.__table__.create(engine, checkfirst=True) 24 | Command.__table__.create(engine, checkfirst=True) 25 | association_result_set.create(engine, checkfirst=True) 26 | association_resultset_analysis.create(engine, checkfirst=True) 27 | 28 | return engine 29 | 30 | 31 | def get_session(engine): 32 | session = Session(engine) 33 | return session 34 | 35 | 36 | def get_dataset(sess): 37 | return sess.query(IOPair).all() 38 | 39 | 40 | def get_iopair_by_id(sess, id): 41 | return sess.query(IOPair).filter_by(id=id).one() 42 | 43 | 44 | def add_iopair(sess, input_, output_): 45 | iopair = IOPair(input=input_, output=output_) 46 | sess.add(iopair) 47 | sess.commit() 48 | 49 | 50 | def update_iopair(sess, id, input_, output_): 51 | iopair = get_iopair_by_id(sess, id) 52 | iopair.input = input_ 53 | iopair.output = output_ 54 | sess.add(iopair) 55 | sess.commit() 56 | 57 | def get_candidates(sess): 58 | return sess.query(PromptCandidate).all() 59 | 60 | 61 | def get_candidate_by_id(sess, id): 62 | return sess.query(PromptCandidate).filter_by(id=id).one() 63 | 64 | 65 | def get_results(sess): 66 | return sess.query(EvalResult).all() 67 | 68 | 69 | def get_result_by_id(sess, id): 70 | return sess.query(EvalResult).filter_by(id=id).one() 71 | 72 | 73 | def get_eval_sets(sess): 74 | return sess.query(EvalSet).all() 75 | 76 | 77 | def get_analysis(sess): 78 | return sess.query(Analysis).all() 79 | 80 | 81 | def get_analysis_by_id(sess, id): 82 | return sess.query(Analysis).filter_by(id=id).one() 83 | 84 | 85 | def get_analysis_by_candidate_id(sess, candidate_id): 86 | res = sess.query(Analysis).join(EvalSet, Analysis.eval_sets).filter( 87 | EvalSet.candidate_id == candidate_id).one() 88 | return res 89 | 90 | 91 | def get_candidates_with_score(sess): 92 | return sess.query(PromptCandidate.id, 93 | EvalSet.final_score, 94 | PromptCandidate.role, 95 | PromptCandidate.goal) \ 96 | .join(EvalSet) \ 97 | .order_by(EvalSet.final_score.desc())\ 98 | .all() 99 | 100 | 101 | def get_commands_as_dict(sess, limit=10): 102 | cmds = (sess.query(Command) 103 | # .order_by(Command.priority.asc()) 104 | .order_by(Command.id.desc()) 105 | .limit(limit) 106 | .all() 107 | ) 108 | cmds = [{"id": x.id, 109 | "cmd": x.cmd["cmd"], 110 | "state": CMD_STATE_CODE[x.state], 111 | "priority": x.priority, 112 | # "owner": x.owner 113 | } for x in cmds] 114 | return cmds 115 | 116 | 117 | def reset_running_cmds(sess): 118 | """ 119 | Reset state of running cmds to 0 (waiting) 120 | """ 121 | cmds = sess.query(Command).filter_by(state=1).all() 122 | 123 | for cmd in cmds: 124 | cmd.state = 0 125 | sess.add(cmd) 126 | logger.debug(f"Command(id={cmd.id}, state={cmd.state}) state reseted") 127 | 128 | sess.commit() 129 | 130 | 131 | if __name__ == '__main__': 132 | engine = create_engine('test3.db') 133 | sess = get_session(engine) 134 | dataset = get_dataset(sess) 135 | breakpoint() 136 | print(dataset[0]) 137 | -------------------------------------------------------------------------------- /ppromptor/evaluators/__init__.py: -------------------------------------------------------------------------------- 1 | import textwrap 2 | from abc import abstractmethod 3 | from typing import List, Optional 4 | 5 | from langchain.chains.llm import LLMChain 6 | from langchain.chat_models import ChatOpenAI 7 | from langchain.llms.base import BaseLLM 8 | from langchain.prompts import PromptTemplate 9 | from ppromptor.base.command import CommandExecutor 10 | from ppromptor.base.schemas import EvalResult, EvalSet, IOPair, PromptCandidate 11 | from ppromptor.config import PP_VERBOSE 12 | from ppromptor.loggers import logger 13 | from ppromptor.scorefuncs import BaseScoreFunc 14 | from ppromptor.utils import bulletpointize, get_llm_params 15 | 16 | 17 | class BaseEvaluator(CommandExecutor): 18 | def __init__(self, 19 | llm: BaseLLM, 20 | score_funcs: Optional[List[BaseScoreFunc]] = None) -> None: 21 | if score_funcs is None: 22 | self.score_funcs = [] 23 | else: 24 | self.score_funcs = score_funcs 25 | 26 | self._prompt_str: str 27 | self._validate_prompt() 28 | self.llm = llm 29 | self._prompt = None 30 | 31 | @property 32 | def prompt(self): 33 | if self._prompt is None: 34 | self._prompt = PromptTemplate( 35 | template=textwrap.dedent(self._prompt_str), 36 | input_variables=["role", 37 | "goal", 38 | "input", 39 | "guidelines", 40 | "constraints"] 41 | ) 42 | return self._prompt 43 | 44 | def _validate_prompt(self): 45 | assert isinstance(self._prompt_str, str) 46 | assert "{input}" in self._prompt_str 47 | assert "{guidelines}" in self._prompt_str 48 | assert "{constraints}" in self._prompt_str 49 | 50 | def add_score_func(self, score_func): 51 | self.score_funcs.append(score_func) 52 | 53 | @abstractmethod 54 | def evaluate(self, dataset: List[IOPair], # type: ignore[empty-body] 55 | candidate: PromptCandidate, 56 | **kwargs) -> EvalSet: 57 | pass 58 | 59 | def run_cmd(self, **kwargs): 60 | return self.evaluate(**kwargs) 61 | 62 | 63 | class Evaluator(BaseEvaluator): 64 | def __init__(self, 65 | llm: BaseLLM, 66 | score_funcs: Optional[List[BaseScoreFunc]] = None) -> None: 67 | 68 | self._prompt_str = ("You are a {role}." 69 | " Base on below INPUT, {goal}\n") + """ 70 | INPUT: 71 | '{input}' 72 | 73 | GUIDELINES: 74 | {guidelines} 75 | 76 | CONSTRAINTS: 77 | {constraints} 78 | 79 | Please strictly follow above guidelines and constraints. 80 | Answer: 81 | """ 82 | super().__init__(llm, score_funcs) 83 | 84 | def _get_scores(self, results) -> dict: 85 | res = {} 86 | 87 | for res in results: 88 | for key, value in res.items(): 89 | if key not in res: 90 | res[key] = value 91 | else: 92 | res[key] += value 93 | return res 94 | 95 | def _get_final_score(self, results) -> float: 96 | score = 0.0 97 | 98 | for res in results: 99 | for key, value in res.items(): 100 | score += value 101 | return score 102 | 103 | def evaluate(self, 104 | dataset: List[IOPair], 105 | candidate: PromptCandidate, 106 | **kwargs) -> EvalSet: 107 | 108 | chain = LLMChain(llm=self.llm, prompt=self.prompt, verbose=PP_VERBOSE) 109 | 110 | results = [] 111 | 112 | for record in dataset: 113 | data = { 114 | "role": candidate.role, 115 | "goal": candidate.goal, 116 | "input": record.input, 117 | "guidelines": bulletpointize(candidate.guidelines), 118 | "constraints": bulletpointize(candidate.constraints) 119 | } 120 | 121 | pred = chain(data)["text"].strip() 122 | 123 | rec_scores = {} 124 | for sf in self.score_funcs: 125 | rec_scores[sf.name] = sf.score(candidate, 126 | record, 127 | pred) 128 | 129 | logger.debug(f"Evaluator Prediction: {pred}") 130 | logger.debug(f"Evaluator Answer: {record.output}") 131 | logger.debug(f"Score: {rec_scores}") 132 | 133 | res = EvalResult(self.__class__.__name__, 134 | candidate, 135 | record, 136 | pred, 137 | rec_scores, 138 | llm_params=get_llm_params(self.llm)) 139 | results.append(res) 140 | 141 | scores = [x.scores for x in results] 142 | res_set = EvalSet(candidate=candidate, 143 | results=results, 144 | scores=self._get_scores(scores), 145 | final_score=self._get_final_score(scores) 146 | ) 147 | return res_set 148 | -------------------------------------------------------------------------------- /ppromptor/job_queues/__init__.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from queue import PriorityQueue 3 | from typing import Any, Tuple 4 | 5 | from ppromptor.base.schemas import TABLE_MAP, Command 6 | from sqlalchemy import column 7 | 8 | 9 | class BaseJobQueue: 10 | def __init__(self): 11 | self._queue 12 | 13 | @abstractmethod 14 | def put(self, job, priority): 15 | pass 16 | 17 | @abstractmethod 18 | def get(self) -> Tuple[int, Any]: 19 | pass 20 | 21 | @abstractmethod 22 | def empty(self) -> bool: 23 | pass 24 | 25 | @abstractmethod 26 | def done(self, cmd, state_code: int) -> None: 27 | pass 28 | 29 | 30 | class PriorityJobQueue(BaseJobQueue): 31 | def __init__(self): 32 | self._queue = PriorityQueue() 33 | 34 | def put(self, job, priority): 35 | self._queue.put((priority, job)) 36 | 37 | def get(self) -> Tuple[int, Any]: 38 | return self._queue.get() 39 | 40 | def empty(self) -> bool: 41 | return self._queue.empty() 42 | 43 | def done(self, cmd, state_code: int) -> None: 44 | return None 45 | 46 | 47 | class ORMJobQueue(BaseJobQueue): 48 | def __init__(self, session): 49 | self._sess = session 50 | 51 | def _serialize_data(self, data): 52 | res = {} 53 | 54 | for key, value in data.items(): 55 | if value is None: 56 | rec = None 57 | elif isinstance(value, list): 58 | rec = { 59 | "cls": value[0].__class__.__name__, 60 | "value": [x.id for x in value] 61 | } 62 | else: 63 | rec = { 64 | "cls": value.__class__.__name__, 65 | "value": value.id 66 | } 67 | 68 | res[key] = rec 69 | 70 | return res 71 | 72 | def _deserialize_data(self, data): 73 | res = {} 74 | # breakpoint() 75 | for key, value in data.items(): 76 | if value is None: 77 | rec = None 78 | elif isinstance(value["value"], list): 79 | table = TABLE_MAP[value["cls"]] 80 | rec = self._sess.query(table).filter(table.id.in_(value["value"])).all() 81 | else: 82 | table = TABLE_MAP[value["cls"]] 83 | rec = self._sess.query(table).filter(table.id == value["value"]).first() 84 | res[key] = rec 85 | return res 86 | 87 | def put(self, job, priority): 88 | _job = Command(cmd={"cmd": job["cmd"]}, 89 | data=self._serialize_data(job["data"]), 90 | owner="user", 91 | priority=priority, 92 | state=0) 93 | self._sess.add(_job) 94 | self._sess.commit() 95 | 96 | def get(self) -> Tuple[int, Any]: 97 | cmd = (self._sess.query(Command) 98 | .filter_by(state=0) 99 | .order_by(Command.priority.asc()) 100 | .first()) 101 | cmd.state = 1 102 | self._sess.add(cmd) 103 | self._sess.commit() 104 | 105 | job = { 106 | "id": cmd.id, 107 | "cmd": cmd.cmd["cmd"], 108 | "data": self._deserialize_data(cmd.data), 109 | "orig_obj": cmd 110 | } 111 | return (cmd.priority, job) 112 | 113 | def empty(self) -> bool: 114 | count = (self._sess.query(Command) 115 | .filter_by(state=0) 116 | .order_by(Command.priority.asc()) 117 | .count()) 118 | return count == 0 119 | 120 | def done(self, cmd, state_code: int) -> None: 121 | cmd = cmd["orig_obj"] 122 | cmd.state = state_code 123 | self._sess.add(cmd) 124 | self._sess.commit() 125 | -------------------------------------------------------------------------------- /ppromptor/llms/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pikho/ppromptor/d80782c3b755e54c6333d85bcb88513a28868833/ppromptor/llms/__init__.py -------------------------------------------------------------------------------- /ppromptor/llms/wizardlm.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module is adapted from 3 | https://huggingface.co/mosaicml/mpt-7b/discussions/16 4 | """ 5 | 6 | from functools import partial 7 | from threading import Thread 8 | from typing import Any, Dict, List, Mapping, Optional, Set 9 | 10 | import torch 11 | from auto_gptq import AutoGPTQForCausalLM 12 | from langchain.callbacks.manager import CallbackManagerForLLMRun 13 | from langchain.llms.base import LLM 14 | from pydantic import Extra, Field, root_validator 15 | from transformers import AutoTokenizer 16 | 17 | 18 | class WizardLM(LLM): 19 | model_name: str = Field("TheBloke/WizardLM-30B-GPTQ", 20 | alias='model_name') 21 | """The name of the model to use.""" 22 | 23 | model_basename: str = Field("wizardlm-30b-GPTQ-4bit.act.order", 24 | alias="model_basename") 25 | 26 | tokenizer_name: str = Field("TheBloke/WizardLM-30B-GPTQ", 27 | alias='tokenizer_name') 28 | """The name of the sentence tokenizer to use.""" 29 | 30 | config: Any = None #: :meta private: 31 | """The reference to the loaded configuration.""" 32 | 33 | tokenizer: Any = None #: :meta private: 34 | """The reference to the loaded tokenizer.""" 35 | 36 | model: Any = None #: :meta private: 37 | """The reference to the loaded model.""" 38 | 39 | stop: Optional[List[str]] = [] 40 | """A list of strings to stop generation when encountered.""" 41 | 42 | temperature: Optional[float] = Field(0.8, alias='temperature') 43 | """The temperature to use for sampling.""" 44 | 45 | max_new_tokens: Optional[int] = Field(512, alias='max_new_tokens') 46 | """The maximum number of tokens to generate.""" 47 | 48 | top_p: Optional[float] = Field(0.95, alias='top_p') 49 | repetition_penalty: Optional[float] = Field(1.15, 50 | alias='repetition_penalty') 51 | skip_special_tokens: Optional[bool] = Field(True, 52 | alias='skip_special_tokens') 53 | 54 | class Config: 55 | """Configuration for this pydantic object.""" 56 | 57 | extra = Extra.forbid 58 | 59 | def _wizard_default_params(self) -> Dict[str, Any]: 60 | """Get the default parameters.""" 61 | return { 62 | "max_new_tokens": self.max_new_tokens, 63 | "temperature": self.temperature, 64 | "top_p": self.top_p, 65 | "repetition_penalty": self.repetition_penalty, 66 | "skip_special_tokens": self.skip_special_tokens 67 | } 68 | 69 | @staticmethod 70 | def _wizard_param_names() -> Set[str]: 71 | """Get the identifying parameters.""" 72 | return { 73 | "max_new_tokens", 74 | "temperature", 75 | "top_p", 76 | "repetition_penalty", 77 | "skip_special_tokens" 78 | } 79 | 80 | @staticmethod 81 | def _model_param_names(model_name: str) -> Set[str]: 82 | """Get the identifying parameters.""" 83 | # TODO: fork for different parameters for different model variants. 84 | return WizardLM._wizard_param_names() 85 | 86 | @property 87 | def _default_params(self) -> Dict[str, Any]: 88 | """Get the default parameters.""" 89 | return self._wizard_default_params() 90 | 91 | @root_validator() 92 | def validate_environment(cls, values: Dict) -> Dict: 93 | """Validate the environment.""" 94 | try: 95 | model = AutoGPTQForCausalLM.from_quantized(values["model_name"], 96 | model_basename=values["model_basename"], 97 | use_safetensors=True, 98 | trust_remote_code=True, 99 | device="cuda:0", 100 | use_triton=False, 101 | quantize_config=None) 102 | tokenizer = AutoTokenizer.from_pretrained(values["model_name"], 103 | use_fast=True) 104 | 105 | values["model"] = model 106 | values["tokenizer"] = tokenizer 107 | 108 | except Exception as e: 109 | raise Exception(f"WizardLM failed to load with error: {e}") 110 | return values 111 | 112 | @property 113 | def _identifying_params(self) -> Mapping[str, Any]: 114 | """Get the identifying parameters.""" 115 | return { 116 | "model": self.model_name, 117 | **self._default_params, 118 | **{ 119 | k: v 120 | for k, v in self.__dict__.items() 121 | if k in self._model_param_names(self.model_name) 122 | }, 123 | } 124 | 125 | @property 126 | def _llm_type(self) -> str: 127 | """Return the type of llm.""" 128 | return "wizardlm" 129 | 130 | def _call( 131 | self, 132 | prompt: str, 133 | stop: Optional[List[str]] = None, 134 | run_manager: Optional[CallbackManagerForLLMRun] = None, 135 | ) -> str: 136 | r"""Call out to WizardLM's generate method via transformers. 137 | 138 | Args: 139 | prompt: The prompt to pass into the model. 140 | stop: A list of strings to stop generation when encountered. 141 | 142 | Returns: 143 | The string generated by the model. 144 | 145 | Example: 146 | .. code-block:: python 147 | 148 | prompt = "This is a story about a big sabre tooth tiger: " 149 | response = model(prompt) 150 | """ 151 | text_callback = None 152 | if run_manager: 153 | text_callback = partial(run_manager.on_llm_new_token, 154 | verbose=self.verbose) 155 | text = "" 156 | model = self.model 157 | tokenizer = self.tokenizer 158 | 159 | prompt_template = '''A chat between a curious user and an artificial intelligence assistant. 160 | The assistant gives helpful, detailed, and polite answers to the user's questions. 161 | USER: {prompt} 162 | ASSISTANT:''' 163 | 164 | inputs = tokenizer([prompt_template.format(prompt=prompt)], 165 | return_tensors="pt").input_ids.cuda() 166 | 167 | gen_ids = model.generate(inputs=inputs, 168 | temperature=self.temperature, 169 | max_new_tokens=self.max_new_tokens, 170 | top_p=self.top_p, 171 | repetition_penalty=self.repetition_penalty) 172 | 173 | output = tokenizer.batch_decode(gen_ids, 174 | skip_special_tokens=self.skip_special_tokens) 175 | 176 | return [x[x.find('ASSISTANT:')+10:] for x in output][0] 177 | 178 | 179 | if __name__ == '__main__': 180 | # llm = MLegoLLM() 181 | llm = WizardLM() 182 | 183 | text = """ 184 | Summarize below text into 15 words: 185 | 186 | 187 | Truly unprecedented. An AI discovery that, weirdly enough, will make Elon Musk as much excited as you… while leaving Hollywood extremely worried. 188 | 189 | In my opinion, this is potentially more significant than GPT-4, considering the novelty of the field and the potential for disruption. 190 | 191 | Cutting to the chase, NVIDIA and the creators of Stable Diffusion have presented VideoLDM, a new state-of-the-art video synthesis model that proves, once again, that the world will never be the same after AI. 192 | 193 | It can go for minutes long with entirely made-up scenery and interactions. 194 | It can recreate multiple scenarios at wish, with no human interference. 195 | Entirely by itself. 196 | 197 | The first text-to-video high-quality generator. 198 | 199 | Let’s understand how humans managed to create this… “thing”, how does it work by showing examples, and, above all… should we be afraid? 200 | 201 | """ 202 | print(llm(text)) 203 | -------------------------------------------------------------------------------- /ppromptor/loggers/__init__.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | from loguru import logger as loguru_logger 4 | 5 | 6 | class Logger: 7 | def __init__(self): 8 | self._logger = loguru_logger 9 | try: 10 | self._logger.remove(0) 11 | except ValueError: 12 | self._logger.warning("Unable to remove previous logger") 13 | 14 | self._logger.add(sys.stdout, 15 | colorize=True, 16 | format="{level} {message}") 17 | 18 | def trace(self, message): 19 | self._logger.trace(message) 20 | 21 | def debug(self, message): 22 | self._logger.debug(message) 23 | 24 | def info(self, message): 25 | self._logger.info(message) 26 | 27 | def success(self, message): 28 | self._logger.success(message) 29 | 30 | def warning(self, message): 31 | self._logger.warning(message) 32 | 33 | def error(self, message): 34 | self._logger.error(message) 35 | 36 | def critical(self, message): 37 | self._logger.critical(message) 38 | 39 | 40 | logger = Logger() 41 | -------------------------------------------------------------------------------- /ppromptor/proposers/__init__.py: -------------------------------------------------------------------------------- 1 | import re 2 | import textwrap 3 | from abc import abstractmethod 4 | 5 | from langchain.chains.llm import LLMChain 6 | from langchain.prompts import PromptTemplate 7 | from ppromptor.base.command import CommandExecutor 8 | from ppromptor.base.schemas import PromptCandidate 9 | from ppromptor.config import PP_VERBOSE 10 | from ppromptor.utils import gen_prompt 11 | 12 | 13 | class BaseProposer(CommandExecutor): 14 | def __init__(self, llm) -> None: 15 | self.llm = llm 16 | self.prompt: PromptTemplate 17 | 18 | @abstractmethod 19 | def propose(self, dataset, analysis=None, **kwargs) -> PromptCandidate: # type: ignore[empty-body] 20 | pass 21 | 22 | def run_cmd(self, **kwargs): 23 | return self.propose(**kwargs) 24 | 25 | 26 | class Proposer(BaseProposer): 27 | def __init__(self, llm) -> None: 28 | goal = """Your job is to design an LLM AI robot to generate 29 | the below ouputs, according to the given inputs. The LLM AI robot 30 | is equipped with a pretrained LLM to generate outputs. You need to 31 | speicify the ROLE, GOAL, GUIDELINES, CONSTRAINTS to guide the AI 32 | robot to generate correct answers(outputs). Here is the 33 | definitation of each propertiy: 34 | 35 | 1. ROLE: A job occupation or profession this robot has/is. For example, 36 | a senior editor in fashion magzine, a senior product manager 37 | in software company. 38 | 2. GOAL: One sentence to describe what should be done by the AI robot 39 | to generate correct answer. For example, to proof read an 40 | article, to create a travel plan or to summarize an article. 41 | 3. GUIDELINES: A list of precise rules that the AI robot should follow 42 | so that it can generate better fit to the answer. Usually 43 | they describe the properties of the answer, the 44 | characteristics of the convertion function or the relation 45 | between inputs and outputs. 46 | 4. CONSTRAINTS: A list of precise limitations that the AI robot should 47 | strictly follow so that it can generate better fit to the answer. 48 | """ 49 | guidelines = [ 50 | "Do not instruct the AI robot to use external resource", 51 | "Content of examples must NOT appear in ROLE, GOAL, GUIDELINES and CONSTRAINTS you designed" 52 | ] 53 | 54 | examples = """ 55 | Below is a list of example input-output pairs: 56 | 57 | {examples} 58 | """ 59 | examples_prompt = PromptTemplate( 60 | template=examples, input_variables=["examples"]) 61 | 62 | instrutions = """ 63 | Please provide the ROLE, GOAL, GUIDELINES, CONSTRAINTS of the AI robot 64 | that can generate above pairs. 65 | """ 66 | 67 | super().__init__(llm) 68 | 69 | self.prompt = gen_prompt(goal=goal, 70 | instrutions=instrutions, 71 | guidelines=guidelines, 72 | examples=examples_prompt) 73 | 74 | def propose(self, dataset, analysis=None, **kwargs): 75 | if analysis is None: 76 | chain = LLMChain(llm=self.llm, prompt=self.prompt, verbose=PP_VERBOSE) 77 | 78 | res = chain({"examples": "\n".join([f"{v+1}. {str(x)}" for v, x in enumerate(dataset)])}) 79 | 80 | prompt_proposal = res["text"] # .replace("INSTRUCTION:", "").strip() 81 | 82 | return self._parse(prompt_proposal) 83 | else: 84 | return PromptCandidate( 85 | role=analysis.recommendation.role, 86 | goal=analysis.recommendation.goal, 87 | guidelines=analysis.recommendation.guidelines, 88 | constraints=analysis.recommendation.constraints, 89 | examples=analysis.recommendation.examples, 90 | output_format=analysis.recommendation.output_format 91 | ) 92 | 93 | def _parse(self, prompt_proposal): 94 | role = re.findall("role:(.*?)\n", 95 | prompt_proposal, 96 | re.IGNORECASE)[0].strip() 97 | goal = re.findall("goal:(.*?)\n", 98 | prompt_proposal, 99 | re.IGNORECASE)[0].strip() 100 | guidelines = re.findall("guidelines:(.*?)constraints", 101 | prompt_proposal, 102 | re.IGNORECASE | re.DOTALL)[0] 103 | guidelines = [x.strip() for x in guidelines.split("\n")] 104 | guidelines = list(filter(lambda x: x != "", guidelines)) 105 | guidelines = [re.findall("[\d|\-|\*]\.(.*)", 106 | x)[0].strip() for x in guidelines] 107 | 108 | constraints = re.findall("constraints:(.*?)$", 109 | prompt_proposal, 110 | re.IGNORECASE | re.DOTALL)[0] 111 | constraints = [x.strip() for x in constraints.split("\n")] 112 | constraints = list(filter(lambda x: x != "", constraints)) 113 | constraints = [re.findall("[\d|\-|\*]\.(.*)", 114 | x)[0].strip() for x in constraints] 115 | 116 | res = PromptCandidate( 117 | role=role, 118 | goal=goal, 119 | guidelines=guidelines, 120 | constraints=constraints, 121 | examples=[], 122 | output_format="" 123 | ) 124 | return res 125 | -------------------------------------------------------------------------------- /ppromptor/scorefuncs/__init__.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from typing import List 3 | 4 | 5 | class ScoreFuncRegistry(type): 6 | 7 | REGISTRY: List = [] 8 | 9 | def __new__(cls, name, bases, attrs): 10 | 11 | new_cls = type.__new__(cls, name, bases, attrs) 12 | cls.REGISTRY.append(new_cls) 13 | return new_cls 14 | 15 | 16 | class BaseScoreFunc(metaclass=ScoreFuncRegistry): 17 | def __init__(self, llm=None): 18 | self.llm = llm 19 | 20 | @property 21 | def name(self): 22 | return self.__class__.__name__ 23 | 24 | @property 25 | @abstractmethod 26 | def description(self): 27 | pass 28 | 29 | @classmethod 30 | @abstractmethod 31 | def score(cls, candidate, record, prediction) -> float: # type: ignore[empty-body] 32 | pass 33 | 34 | @classmethod 35 | def is_me(cls, query): 36 | return query == cls.__name__ 37 | 38 | 39 | def score_func_selector(names: List[str]) -> List[BaseScoreFunc]: 40 | score_funcs = [] 41 | for func in BaseScoreFunc.REGISTRY: 42 | if func.is_me(names): 43 | score_funcs.append(func) 44 | return score_funcs 45 | 46 | 47 | class SequenceMatcherScore(BaseScoreFunc): 48 | 49 | def __init__(self, llm=None): 50 | self.llm = llm 51 | 52 | @property 53 | def description(self): 54 | return """score is calculated by a string similarity algorithm which 55 | compare prediction and expected answer word by word and calculate the 56 | edit distance. The highest score is 1.0, whcih means prediction is 57 | exactly the same as the expected answer; the lowest score is 0.0 which 58 | means prediction is far away closed to expected answer.""" 59 | 60 | @classmethod 61 | def score(cls, candidate, record, prediction) -> float: 62 | import difflib 63 | seq = difflib.SequenceMatcher(a=record.output.lower(), 64 | b=prediction.lower()) 65 | return seq.ratio() 66 | 67 | 68 | # class SentenceEmbeddingScore(BaseScoreFunc): 69 | # pass 70 | 71 | 72 | # class WordEmbeddingScore(BaseScoreFunc): 73 | # pass 74 | -------------------------------------------------------------------------------- /ppromptor/utils.py: -------------------------------------------------------------------------------- 1 | import textwrap 2 | from typing import Dict, Optional, Union 3 | 4 | from langchain.prompts import PromptTemplate 5 | from loguru import logger 6 | 7 | 8 | def get_llm_params(llm) -> Dict: 9 | res = {} 10 | res["llm_name"] = llm.__class__.__name__ 11 | res.update(llm._default_params) 12 | return res 13 | 14 | 15 | def bulletpointize(lst): 16 | res = [] 17 | for idx, v in enumerate(lst): 18 | res.append(f"{str(idx+1)}. {v}") 19 | return res 20 | 21 | 22 | def evaluate_guideline_contribution(evaluator, 23 | dataset, 24 | prompt, 25 | guidelines, 26 | llm): 27 | for idx in range(len(guidelines)): 28 | gl = guidelines[:] 29 | del gl[idx] 30 | logger.debug(f"Removed guideline idenx: {idx}") 31 | evaluator.eval(dataset, prompt, gl, llm) 32 | 33 | 34 | def gen_prompt(goal: Union[str, PromptTemplate], 35 | instrutions: Union[str, PromptTemplate], 36 | guidelines: Optional[Union[str, PromptTemplate, list]] = None, 37 | examples: Optional[Union[str, PromptTemplate]] = None, 38 | input_variables: Optional[list] = None): 39 | 40 | if input_variables is None: 41 | input_variables = [] 42 | 43 | if isinstance(goal, str): 44 | goal = PromptTemplate(template=goal, input_variables=[]) 45 | 46 | if isinstance(instrutions, str): 47 | instrutions = PromptTemplate(template=instrutions, input_variables=[]) 48 | 49 | if (guidelines is None) or (isinstance(guidelines, list) and len(guidelines) == 0): 50 | guidelines = PromptTemplate(template="\n", input_variables=[]) 51 | elif isinstance(guidelines, list): 52 | _gls = "\nGuidelines:\n" + "\n".join(bulletpointize(guidelines)) 53 | guidelines = PromptTemplate(template=_gls, input_variables=[]) 54 | elif isinstance(guidelines, str): 55 | guidelines = PromptTemplate(template=guidelines, input_variables=[]) 56 | 57 | if (examples is None) or (isinstance(examples, list) and len(examples) == 0): 58 | examples = "\n" 59 | elif isinstance(examples, list): 60 | _exs = "Examples:\n" + "\n".join(bulletpointize(examples)) 61 | examples = PromptTemplate(template=_exs, input_variables=[]) 62 | elif isinstance(examples, str): 63 | examples = PromptTemplate(template=examples, input_variables=[]) 64 | 65 | components = [goal, guidelines, examples, instrutions] 66 | prompt_str = "\n".join([textwrap.dedent(x.template) for x in components]) # type: ignore[union-attr] 67 | 68 | for component in components: 69 | for i in component.input_variables: # type: ignore[union-attr] 70 | assert i not in input_variables 71 | input_variables.append(i) 72 | 73 | return PromptTemplate(template=prompt_str, 74 | input_variables=input_variables) 75 | 76 | 77 | def load_lm(name): 78 | if name == "mlego_wizardlm": 79 | from ppromptor._private_llms.mlego_llm import WizardLLM 80 | 81 | return WizardLLM() 82 | elif name == "wizardlm": 83 | from ppromptor.llms.wizardlm import WizardLLM 84 | 85 | return WizardLLM() 86 | elif name == "chatgpt": 87 | from langchain.chat_models import ChatOpenAI 88 | 89 | return ChatOpenAI(model_name='gpt-3.5-turbo', 90 | temperature=0.1) 91 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | loguru 2 | langchain 3 | dataclasses_json 4 | sqlalchemy 5 | 6 | requests 7 | openai 8 | 9 | beautifulsoup4 10 | streamlit 11 | streamlit-diff-viewer 12 | streamlit-autorefresh 13 | streamlit-aggrid -------------------------------------------------------------------------------- /requirements_local_model.txt: -------------------------------------------------------------------------------- 1 | pydantic 2 | transformers 3 | torch 4 | auto_gptq==0.2.2 5 | langchain -------------------------------------------------------------------------------- /requirements_test.txt: -------------------------------------------------------------------------------- 1 | pytest -------------------------------------------------------------------------------- /scripts/ppromptor-cli.py: -------------------------------------------------------------------------------- 1 | #!python3 2 | import argparse 3 | import os 4 | from typing import List 5 | 6 | from ppromptor.agent import JobQueueAgent, SimpleAgent 7 | from ppromptor.base.schemas import IOPair 8 | from ppromptor.db import create_engine, get_session 9 | from ppromptor.loggers import logger 10 | from ppromptor.utils import load_lm 11 | 12 | if __name__ == '__main__': 13 | 14 | def parse_args(): 15 | parser = argparse.ArgumentParser( 16 | description='ArgumentParser') 17 | parser.add_argument( 18 | '--data', 19 | required=True, 20 | help='Path to dataset.') 21 | 22 | parser.add_argument( 23 | '--eval_llm', 24 | required=True, 25 | choices=('wizardlm', 'chatgpt', 'mlego_wizardlm'), 26 | help='Name of LLM to use as evaluator') 27 | 28 | parser.add_argument( 29 | '--analysis_llm', 30 | required=True, 31 | choices=('wizardlm', 'chatgpt', 'mlego_wizardlm'), 32 | help='Name of LLM to use as analyzer') 33 | 34 | parser.add_argument( 35 | '--database_name', 36 | default=None, 37 | help='Path or name of databse') 38 | 39 | return parser.parse_args() 40 | 41 | args = parse_args() 42 | 43 | if args.database_name and os.path.exists(args.database_name): 44 | engine = create_engine(args.database_name) 45 | sess = get_session(engine) 46 | dataset = sess.query(IOPair).all() 47 | 48 | logger.success(f"Data loaded from db: {args.database_name}") 49 | else: 50 | with open(args.data, 'r') as f: 51 | jstr = f.read() 52 | dataset = IOPair.schema().loads(jstr, many=True) # type: ignore[attr-defined] 53 | 54 | logger.success(f"Data loaded from file: {args.data}") 55 | 56 | engine = create_engine(args.database_name) 57 | sess = get_session(engine) 58 | 59 | for d in dataset: 60 | sess.add(d) 61 | 62 | sess.commit() 63 | logger.success(f"Data successfully inserted into db") 64 | 65 | agent = JobQueueAgent(load_lm(args.eval_llm), 66 | load_lm(args.analysis_llm), 67 | db=sess) 68 | agent.run(dataset) 69 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import codecs 2 | import os 3 | import re 4 | from pathlib import Path 5 | 6 | from setuptools import find_packages, setup 7 | 8 | here = os.path.abspath(os.path.dirname(__file__)) 9 | 10 | 11 | def read(*parts): 12 | with codecs.open(os.path.join(here, *parts), "r") as fp: 13 | return fp.read() 14 | 15 | 16 | def find_version(*file_paths): 17 | version_file = read(*file_paths) 18 | version_match = re.search( 19 | r"^__version__ = ['\"]([^'\"]*)['\"]", version_file, re.M) 20 | if version_match: 21 | return version_match.group(1) 22 | raise RuntimeError("Unable to find version string.") 23 | 24 | 25 | this_directory = Path(__file__).parent 26 | long_description = (this_directory / "README.md").read_text() 27 | 28 | setup( 29 | name="ppromptor", 30 | version=find_version("ppromptor", "__init__.py"), 31 | url="https://github.com/pikho/ppromptor", 32 | author="Pikho Tan", 33 | author_email="pikho.tan@gmail.com", 34 | description="An autonomous agent framework for prompt engineering", 35 | long_description=long_description, 36 | long_description_content_type='text/markdown', 37 | packages=find_packages(exclude=["*_private_llms*"]), 38 | package_data={"ppromptor": ["examples/*"]}, 39 | scripts=['scripts/ppromptor-cli.py'], 40 | install_requires=[ 41 | "langchain", "loguru", "dataclasses_json", 42 | "requests", "openai", "sqlalchemy", 43 | "beautifulsoup4", 44 | "streamlit", "streamlit-diff-viewer", "streamlit-autorefresh", 45 | "streamlit-aggrid" 46 | ], 47 | include_package_data=True, 48 | extras_require={ 49 | 'docs': ['sphinx'], 50 | 'test': ['pytest'], 51 | 'local-llms': ["torch", "auto_gptq", "transformers"] 52 | }, 53 | python_requires=">=3.8", 54 | ) 55 | -------------------------------------------------------------------------------- /tests/fake_llms.py: -------------------------------------------------------------------------------- 1 | """Fake LLM wrapper for testing purposes.""" 2 | from typing import Any, List, Mapping, Optional 3 | 4 | from langchain.callbacks.manager import CallbackManagerForLLMRun 5 | from langchain.llms.base import LLM 6 | 7 | 8 | class FakeListLLM(LLM): 9 | """Fake LLM wrapper for testing purposes.""" 10 | 11 | responses: List 12 | i: int = 0 13 | 14 | @property 15 | def _llm_type(self) -> str: 16 | """Return type of llm.""" 17 | return "fake-list" 18 | 19 | def _call( 20 | self, 21 | prompt: str, 22 | stop: Optional[List[str]] = None, 23 | run_manager: Optional[CallbackManagerForLLMRun] = None, 24 | ) -> str: 25 | """First try to lookup in queries, else return 'foo' or 'bar'.""" 26 | response = self.responses[self.i] 27 | self.i += 1 28 | if self.i >= len(self.responses): 29 | self.i = 0 30 | 31 | return response 32 | 33 | @property 34 | def _identifying_params(self) -> Mapping[str, Any]: 35 | return {} 36 | 37 | @property 38 | def _default_params(self): 39 | return {} 40 | -------------------------------------------------------------------------------- /tests/test_agents.py: -------------------------------------------------------------------------------- 1 | from fake_llms import FakeListLLM 2 | from ppromptor.analyzers import Analyzer 3 | from ppromptor.base.schemas import (Analysis, EvalResult, EvalSet, IOPair, 4 | PromptCandidate, Recommendation) 5 | from ppromptor.evaluators import Evaluator 6 | from ppromptor.proposers import Proposer 7 | from ppromptor.scorefuncs import SequenceMatcherScore, score_func_selector 8 | 9 | dataset = [IOPair(input="1", output="2"), 10 | IOPair(input="2", output="4")] 11 | 12 | 13 | def test_proposer(): 14 | proposal = """ 15 | ROLE: test role 16 | 17 | GOAL: test goal 18 | 19 | GUIDELINES: 20 | 1. guideline1 21 | 2. guideline2 22 | 23 | CONSTRAINTS: 24 | 1. constraint1 25 | 2. constraint2 26 | """ 27 | llm = FakeListLLM(responses=[proposal]) 28 | proposer = Proposer(llm) 29 | res = proposer.propose(dataset) 30 | assert isinstance(res, PromptCandidate) 31 | assert res.role == 'test role' 32 | assert res.goal == 'test goal' 33 | assert res.guidelines == ['guideline1', 'guideline2'] 34 | assert res.constraints == ['constraint1', 'constraint2'] 35 | 36 | return res 37 | 38 | 39 | def test_evaluator(): 40 | candidate = PromptCandidate(role='test', goal='test', 41 | guidelines=['test'], constraints=['test']) 42 | llm = FakeListLLM(responses=['2', '1']) 43 | evaluator = Evaluator(llm, [SequenceMatcherScore(None)]) 44 | res = evaluator.evaluate(dataset, candidate) 45 | 46 | assert isinstance(res, EvalSet) 47 | assert res.final_score == 2 48 | 49 | return res 50 | 51 | def test_analyzer(): 52 | report = """ 53 | THOUGHTS: 54 | test thoughts 55 | 56 | REVISION: 57 | revised ROLE: test role 58 | 59 | revised GOAL: test goal 60 | 61 | revised GUIDELINES: 62 | 1. guideline1 63 | 2. guideline2 64 | 65 | revised CONSTRAINTS: 66 | 1. constraint1 67 | 2. constraint2 68 | """ 69 | candidate = test_proposer() 70 | eval_set = test_evaluator() 71 | 72 | llm = FakeListLLM(responses=[report]) 73 | test_analyzer = Analyzer(llm) 74 | report = test_analyzer.analyze(candidate=candidate, 75 | eval_sets=[eval_set]) 76 | assert isinstance(report, Analysis) 77 | 78 | 79 | # if __name__ == '__main__': 80 | # res = test_evaluator() 81 | # print(res) 82 | -------------------------------------------------------------------------------- /tests/test_schemas/test_analysis.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from ppromptor.base.schemas import (Analysis, EvalResult, EvalSet, IOPair, 3 | PromptCandidate, Recommendation) 4 | 5 | 6 | def test_Analysis_1(): 7 | a = Analysis( 8 | analyzer_name="test1", 9 | eval_sets=[], 10 | recommendation=None 11 | ) 12 | 13 | 14 | def test_Analysis_req_params(): 15 | with pytest.raises(TypeError): 16 | a = Analysis( 17 | analyzer_name="test1", 18 | eval_sets=[] 19 | ) 20 | 21 | with pytest.raises(TypeError): 22 | a = Analysis( 23 | analyzer_name="test1" 24 | ) 25 | 26 | 27 | def test_Analysis_3(): 28 | res = EvalResult( 29 | evaluator_name="evaluator", 30 | candidate=None, 31 | data=[], 32 | prediction="", 33 | scores={}, 34 | llm_params={} 35 | ) 36 | 37 | eset = EvalSet(candidate=None, 38 | results=[res], 39 | scores={}, 40 | final_score=0.1) 41 | 42 | recomm = Recommendation( 43 | thoughts="", 44 | revision="", 45 | role="", 46 | goal="", 47 | guidelines=[], 48 | constraints=[], 49 | examples=[], 50 | output_format="" 51 | ) 52 | 53 | a = Analysis( 54 | analyzer_name="test1", 55 | eval_sets=[res], 56 | recommendation=recomm 57 | ) 58 | -------------------------------------------------------------------------------- /tests/test_schemas/test_db.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tempfile 3 | 4 | from ppromptor.base.schemas import (Analysis, EvalResult, EvalSet, IOPair, 5 | PromptCandidate, Recommendation) 6 | from ppromptor.db import create_engine, get_session 7 | 8 | 9 | def test_create_engine(): 10 | with tempfile.TemporaryDirectory() as tmp: 11 | db_path = os.path.join(tmp, 'test1.db') 12 | engine = create_engine(db_path) 13 | 14 | 15 | def test_sess(): 16 | with tempfile.TemporaryDirectory() as tmp: 17 | db_path = os.path.join(tmp, 'test1.db') 18 | engine = create_engine(db_path) 19 | 20 | sess = get_session(engine) 21 | 22 | 23 | def test_objs(): 24 | with tempfile.TemporaryDirectory() as tmp: 25 | db_path = os.path.join(tmp, 'test1.db') 26 | engine = create_engine(db_path) 27 | 28 | sess = get_session(engine) 29 | 30 | iopair = IOPair( 31 | input="1", 32 | output="2" 33 | ) 34 | 35 | candidate = PromptCandidate( 36 | role="", 37 | goal="", 38 | guidelines=["1"], 39 | constraints=["1"], 40 | examples=["1"], 41 | output_format="" 42 | ) 43 | 44 | eval_result = EvalResult( 45 | evaluator_name="evaluator", 46 | candidate=candidate, 47 | data=iopair, 48 | prediction="", 49 | scores={}, 50 | llm_params={} 51 | ) 52 | 53 | eval_set = EvalSet(candidate, [eval_result], 54 | scores={}, final_score=0.1) 55 | 56 | recomm = Recommendation( 57 | thoughts="", 58 | revision="", 59 | role="", 60 | goal="", 61 | guidelines=[], 62 | constraints=[], 63 | examples=[], 64 | output_format="" 65 | ) 66 | 67 | analysis = Analysis( 68 | analyzer_name="test1", 69 | eval_sets=[eval_set], 70 | recommendation=recomm 71 | ) 72 | 73 | sess.add(iopair) 74 | sess.commit() 75 | 76 | sess.add(candidate) 77 | sess.commit() 78 | 79 | sess.add(eval_result) 80 | 81 | sess.add(recomm) 82 | sess.commit() 83 | 84 | sess.add(analysis) 85 | sess.commit() 86 | -------------------------------------------------------------------------------- /tests/test_schemas/test_eval_result.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from ppromptor.base.schemas import (Analysis, EvalResult, EvalSet, IOPair, 3 | PromptCandidate, Recommendation) 4 | 5 | 6 | def test_EvalResult(): 7 | res = EvalResult( 8 | evaluator_name="evaluator", 9 | candidate=None, 10 | data="", 11 | prediction="", 12 | scores={}, 13 | llm_params={} 14 | ) 15 | 16 | 17 | def test_EvalSet(): 18 | res1 = EvalResult( 19 | evaluator_name="evaluator", 20 | candidate=None, 21 | data="", 22 | prediction="", 23 | scores={}, 24 | llm_params={} 25 | ) 26 | 27 | res2 = EvalResult( 28 | evaluator_name="evaluator", 29 | candidate=None, 30 | data="", 31 | prediction="", 32 | scores={}, 33 | llm_params={} 34 | ) 35 | 36 | eval_set = EvalSet(candidate=None, 37 | results=[res1, res2], 38 | scores={}, 39 | final_score=0.0) 40 | -------------------------------------------------------------------------------- /tests/test_schemas/test_iopair.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from ppromptor.base.schemas import (Analysis, EvalResult, IOPair, 3 | PromptCandidate, Recommendation) 4 | 5 | 6 | def test_IOPair(): 7 | res = IOPair( 8 | input="1", 9 | output="2" 10 | ) 11 | -------------------------------------------------------------------------------- /tests/test_schemas/test_prompt_candidate.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from ppromptor.base.schemas import (Analysis, EvalResult, IOPair, 3 | PromptCandidate, Recommendation) 4 | 5 | 6 | def test_PromptCandidate(): 7 | res = PromptCandidate( 8 | role="", 9 | goal="", 10 | guidelines=[], 11 | constraints=[], 12 | examples=[], 13 | output_format="" 14 | ) 15 | -------------------------------------------------------------------------------- /tests/test_schemas/test_recommendation.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from ppromptor.base.schemas import (Analysis, EvalResult, IOPair, 3 | PromptCandidate, Recommendation) 4 | 5 | 6 | def test_Recommendation(): 7 | recomm = Recommendation( 8 | thoughts="", 9 | revision="", 10 | role="", 11 | goal="", 12 | guidelines=[], 13 | constraints=[], 14 | examples=[], 15 | output_format="" 16 | ) 17 | -------------------------------------------------------------------------------- /tests/test_scorefuncs.py: -------------------------------------------------------------------------------- 1 | from ppromptor.base.schemas import IOPair 2 | from ppromptor.scorefuncs import SequenceMatcherScore, score_func_selector 3 | 4 | 5 | def test_SequenceMatcherScore(): 6 | 7 | rec = IOPair(None, "aaa bbb ccc") 8 | pred = "aaa bbb ccc" 9 | s = SequenceMatcherScore.score(None, rec, pred) 10 | assert s == 1.0 11 | 12 | 13 | def test_selector(): 14 | res = score_func_selector('SequenceMatcherScore') 15 | assert len(res) == 1 16 | 17 | 18 | if __name__ == '__main__': 19 | test_SequenceMatcherScore() 20 | -------------------------------------------------------------------------------- /ui/app.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import time 4 | from copy import copy 5 | from datetime import datetime 6 | from io import StringIO 7 | from threading import Thread 8 | 9 | import pandas as pd 10 | import sqlalchemy 11 | import streamlit as st 12 | from components import (render_candidate, render_config, render_data_as_table, 13 | render_iopair, render_report, render_result, 14 | show_candidte_comp) 15 | from config import DEFAULT_CONFIG 16 | from ppromptor.agent import JobQueueAgent 17 | from ppromptor.base.schemas import IOPair, PromptCandidate 18 | from ppromptor.config import DEFAULT_PRIORITY 19 | from ppromptor.db import (add_iopair, create_engine, get_analysis, 20 | get_analysis_by_candidate_id, get_analysis_by_id, 21 | get_candidate_by_id, get_candidates, 22 | get_candidates_with_score, get_commands_as_dict, 23 | get_dataset, get_result_by_id, get_results, 24 | get_session, update_iopair) 25 | from ppromptor.loggers import logger 26 | from ppromptor.utils import load_lm 27 | from streamlit_autorefresh import st_autorefresh 28 | 29 | st.set_page_config(layout="wide") 30 | 31 | global engine 32 | global sess 33 | global db_path 34 | 35 | os.environ["OPENAI_API_KEY"] = "NONE" 36 | 37 | 38 | def save_config(cfg_data, path): 39 | with open(path, 'w') as f: 40 | f.write(json.dumps(cfg_data, indent=4)) 41 | 42 | 43 | def load_config(path): 44 | if os.path.exists(path): 45 | with open(path, 'r') as f: 46 | cfg_data = json.loads(f.read()) 47 | else: 48 | cfg_data = DEFAULT_CONFIG 49 | return cfg_data 50 | 51 | 52 | def enable_config(cfg_data): 53 | if cfg_data["analysis_llm"]: 54 | analysis_llm = load_lm(cfg_data["analysis_llm"]) 55 | st.session_state['analysis_llm'] = analysis_llm 56 | 57 | if cfg_data["target_llm"]: 58 | target_llm = load_lm(cfg_data["target_llm"]) 59 | st.session_state['target_llm'] = target_llm 60 | 61 | if cfg_data["openai_key"]: 62 | os.environ["OPENAI_API_KEY"] = cfg_data["openai_key"] 63 | 64 | 65 | def render_db_cfg(): 66 | db_path = st.text_input('Database Name/Path', 'examples/antonyms/antonyms.db') 67 | if os.path.exists(db_path): 68 | btn_db = st.button('Load') 69 | else: 70 | btn_db = st.button('Create') 71 | 72 | if btn_db: 73 | st.session_state['db_path'] = db_path 74 | st.session_state['engine'] = create_engine(db_path) 75 | st.session_state['sess'] = get_session(st.session_state['engine']) 76 | 77 | if 'target_llm' in st.session_state: 78 | st.experimental_rerun() 79 | 80 | if 'sess' in st.session_state and st.session_state['sess']: 81 | st.divider() 82 | 83 | cfg_data = load_config(st.session_state['db_path']+".cfg") 84 | cfg_data = render_config(cfg_data) 85 | enable_config(cfg_data) 86 | 87 | save_cfg = st.button('Save') 88 | 89 | if save_cfg: 90 | save_config(cfg_data, st.session_state['db_path']+'.cfg') 91 | st.success("Config was saved successfully") 92 | 93 | return True 94 | 95 | return False 96 | 97 | 98 | if not ('sess' in st.session_state and st.session_state['sess']): 99 | st.header("Welcome to Prompt-Promptor") 100 | if render_db_cfg(): 101 | st.experimental_rerun() 102 | 103 | else: 104 | with st.sidebar: 105 | 106 | db_path = st.session_state['db_path'] 107 | target_llm = st.session_state['target_llm'] 108 | analysis_llm = st.session_state['analysis_llm'] 109 | 110 | if "agent" not in st.session_state: 111 | agent = JobQueueAgent(target_llm, analysis_llm, db_path) 112 | st.session_state['agent'] = agent 113 | 114 | if 'run_state' not in st.session_state: 115 | st.session_state['run_state'] = False 116 | 117 | refresh_btn = st.sidebar.button('Refresh Page') 118 | 119 | if refresh_btn: 120 | st.experimental_rerun() 121 | 122 | ckb_auto = st.checkbox("Auto-Refesh") 123 | if ckb_auto: 124 | st_autorefresh(interval=2000, key="dataframerefresh") 125 | 126 | st.divider() 127 | 128 | st.write("Command Queue") 129 | 130 | if st.session_state['agent'].state == 1: 131 | run_btn = st.sidebar.button('Stop') 132 | if run_btn: 133 | st.session_state['run_state'] = False 134 | st.session_state['agent'].stop() 135 | with st.spinner(text="Wait for agent to stop..."): 136 | while True: 137 | if st.session_state['agent'].state == 0: 138 | break 139 | else: 140 | time.sleep(0.5) 141 | 142 | st.experimental_rerun() 143 | 144 | elif st.session_state['agent'].state == 0: 145 | run_btn = st.sidebar.button('Start') 146 | if run_btn: 147 | dataset = get_dataset(st.session_state['sess']) 148 | st.session_state['run_state'] = True 149 | 150 | agent = st.session_state['agent'] 151 | agent_thread = Thread(target=agent.run, 152 | kwargs={"dataset": dataset}) 153 | agent_thread.start() 154 | st.session_state['agent_thread'] = agent_thread 155 | st.experimental_rerun() 156 | 157 | cmds = get_commands_as_dict(st.session_state['sess']) 158 | df = pd.DataFrame(cmds) 159 | st.dataframe(df, use_container_width=True, hide_index=True) 160 | # selected_rows = render_data_as_table(cmds) 161 | 162 | if not st.session_state['sess']: 163 | st.text("Please load or create database") 164 | 165 | else: 166 | sess = st.session_state['sess'] 167 | 168 | tab_candidate, tab_analysis, tab_data, tab_config = st.tabs([ 169 | "Prompt Candidates", "Analysis", "Dataset", "Configuration"]) 170 | 171 | with tab_candidate: 172 | st.header("Prompt Candidates") 173 | 174 | # candidates = get_candidates(sess) 175 | candidates = get_candidates_with_score(sess) 176 | 177 | selected_rows = render_data_as_table(candidates, 178 | multiple_selection=True) 179 | selected_candidates = copy(selected_rows) 180 | st.divider() 181 | 182 | if len(selected_rows) == 1: 183 | 184 | id_ = selected_rows[0]["id"] 185 | # input_ = selected_rows[0]["input"] 186 | # output_ = selected_rows[0]["output"] 187 | 188 | cdd_data = render_candidate(id_, sess) 189 | 190 | cdd_add_btn = st.button('Create Candidate') 191 | if cdd_add_btn: 192 | candidate = PromptCandidate(**cdd_data) 193 | 194 | sess.add(candidate) 195 | sess.commit() 196 | 197 | __data = { 198 | "candidate": candidate, 199 | "dataset": get_dataset(sess), 200 | "eval_sets": None, 201 | "analysis": None 202 | } 203 | 204 | st.session_state['agent'].add_command("Evaluator", 205 | __data, 206 | DEFAULT_PRIORITY-100) 207 | 208 | st.success("PromptCandidate was created successfully") 209 | st.experimental_rerun() 210 | elif len(selected_rows) == 2: 211 | id_1 = selected_rows[1]["id"] 212 | id_2 = selected_rows[0]["id"] 213 | 214 | score_1 = selected_rows[1]["final_score"] 215 | score_2 = selected_rows[0]["final_score"] 216 | 217 | # score of id_2 is always lower than id_1 218 | # for UX reason, order id_2 on left and id_1 on the right 219 | show_candidte_comp([id_1, score_1], [id_2, score_2], sess) 220 | 221 | elif len(selected_rows) > 2: 222 | st.write("Comparison only availabe for two candidates") 223 | 224 | with tab_analysis: 225 | 226 | st.header("Analysis Reports") 227 | 228 | analysis_mode = st.radio( 229 | "What analysis reports to show?", 230 | ('Selected', 'All')) 231 | 232 | if analysis_mode == 'All': 233 | reports = get_analysis(sess) 234 | 235 | selected_rows = render_data_as_table(reports) 236 | 237 | if len(selected_rows) > 1: 238 | id_ = selected_rows[0]["id"] 239 | render_report(id_, sess) 240 | 241 | elif analysis_mode == 'Selected': 242 | cdd_ids = [x["id"] for x in selected_candidates] 243 | 244 | if len(cdd_ids) == 0: 245 | st.write("Please select Prompt Candidate(s) first") 246 | elif len(cdd_ids) == 1: 247 | try: 248 | report = get_analysis_by_candidate_id(sess, cdd_ids[0]) 249 | render_report(report.id, sess) 250 | except sqlalchemy.exc.NoResultFound: 251 | st.text(("Report not found. (Maybe still " 252 | "waiting for analyzing?)")) 253 | 254 | elif len(cdd_ids) > 1: 255 | st.text(("Warning: Multiple candidates are selected, " 256 | f"below only show candidate id=={cdd_ids[0]}")) 257 | 258 | # for cdd_id in cdd_ids: 259 | try: 260 | report = get_analysis_by_candidate_id(sess, cdd_ids[0]) 261 | render_report(report.id, sess) 262 | except sqlalchemy.exc.NoResultFound: 263 | st.text("Report not found. (Maybe still analyzing?)") 264 | 265 | with tab_data: 266 | st.header("Dataset") 267 | 268 | uploaded_file = st.file_uploader("Import Dataset") 269 | if uploaded_file: 270 | stringio = StringIO(uploaded_file.getvalue().decode("utf-8")) 271 | jstr = stringio.read() 272 | dataset = IOPair.schema().loads(jstr, many=True) # type: ignore[attr-defined] 273 | 274 | st.write(dataset) 275 | 276 | btn_import = st.button("Import") 277 | if btn_import: 278 | for d in dataset: 279 | sess.add(d) 280 | sess.commit() 281 | 282 | uploaded_file = None 283 | st.experimental_rerun() 284 | 285 | dataset = get_dataset(sess) 286 | selected_rows = render_data_as_table(dataset) 287 | 288 | if len(selected_rows) > 0: 289 | id_ = selected_rows[0]["id"] 290 | input_ = selected_rows[0]["input"] 291 | output_ = selected_rows[0]["output"] 292 | 293 | render_iopair([id_, input_, output_], sess) 294 | else: 295 | render_iopair([None, "", ""], sess) 296 | 297 | with tab_config: 298 | render_db_cfg() 299 | -------------------------------------------------------------------------------- /ui/components.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pandas as pd 4 | import sqlalchemy 5 | import streamlit as st 6 | from diff_viewer import diff_viewer 7 | from ppromptor.db import (add_iopair, create_engine, get_analysis, 8 | get_analysis_by_id, get_candidate_by_id, 9 | get_result_by_id, update_iopair) 10 | from st_aggrid import (AgGrid, ColumnsAutoSizeMode, GridOptionsBuilder, 11 | GridUpdateMode) 12 | 13 | 14 | def render_data_as_table(records, multiple_selection=False): 15 | if len(records) == 0: 16 | st.text("No data") 17 | return [] 18 | 19 | if isinstance(records[0], sqlalchemy.engine.row.Row): 20 | data = [dict(x._mapping) for x in records] 21 | elif isinstance(records[0], dict): 22 | data = records 23 | else: 24 | data = [x.to_dict() for x in records] 25 | 26 | df = pd.DataFrame(data) 27 | 28 | gb = GridOptionsBuilder.from_dataframe(df) 29 | # configure selection 30 | if multiple_selection: 31 | gb.configure_selection(selection_mode="multiple", 32 | use_checkbox=True) 33 | else: 34 | gb.configure_selection(selection_mode="single") 35 | 36 | gb.configure_pagination(enabled=True, 37 | paginationAutoPageSize=False, 38 | paginationPageSize=10) 39 | gridOptions = gb.build() 40 | 41 | data = AgGrid(df, 42 | edit=False, 43 | gridOptions=gridOptions, 44 | allow_unsafe_jscode=True, 45 | fit_columns_on_grid_load=False, 46 | # update_mode=GridUpdateMode.SELECTION_CHANGED, 47 | columns_auto_size_mode=ColumnsAutoSizeMode.FIT_CONTENTS) 48 | 49 | selected_rows = data["selected_rows"] 50 | 51 | return selected_rows 52 | 53 | 54 | def render_config(cfg_data): 55 | 56 | target_llm = st.selectbox( 57 | "Target LLM", 58 | (cfg_data["target_llm"], "mlego_wizardlm", "wizardlm", "chatgpt") 59 | ) 60 | 61 | analysis_llm = st.selectbox( 62 | "Analysis LLM", 63 | (cfg_data["analysis_llm"], "mlego_wizardlm", "wizardlm", "chatgpt") 64 | ) 65 | 66 | openai_key = st.text_input("OpenAI Key", value=cfg_data["openai_key"]) 67 | 68 | if "chatgpt" in (analysis_llm, target_llm) and not openai_key: 69 | st.warning("Please set OpenAI Key for using ChatGPT") 70 | 71 | return {"openai_key": openai_key, 72 | "target_llm": target_llm, 73 | "analysis_llm": analysis_llm 74 | } 75 | 76 | 77 | 78 | 79 | def render_candidate(id_, sess): 80 | rec = get_candidate_by_id(sess, id_) 81 | 82 | cdd_role = st.text_input("Role", value=rec.role, key=f"cdd_role{id_}") 83 | cdd_goal = st.text_input('Goal', value=rec.goal, key=f"cdd_goal{id_}") 84 | cdd_guidelines = st.text_area('Guidelines', 85 | value="\n\n".join(rec.guidelines), 86 | key=f"cdd_guide{id_}") 87 | cdd_constraints = st.text_area('Constraints', 88 | value="\n\n".join(rec.constraints), 89 | key=f"cdd_cons{id_}") 90 | cdd_examples = st.text_input('Examples', value=rec.examples, 91 | key=f"cdd_examp{id_}") 92 | cdd_output_format = st.text_input('Output Format', 93 | value=rec.output_format, 94 | key=f"cdd_output{id_}") 95 | 96 | return { 97 | "role": cdd_role, 98 | "goal": cdd_goal, 99 | "guidelines": cdd_guidelines.split("\n\n"), 100 | "constraints": cdd_constraints.split("\n\n"), 101 | "examples": cdd_examples, 102 | "output_format": cdd_output_format 103 | } 104 | 105 | 106 | def show_candidte_comp(value1, value2, sess): 107 | id_1, score_1 = value1 108 | id_2, score_2 = value2 109 | 110 | rec_1 = get_candidate_by_id(sess, id_1) 111 | rec_2 = get_candidate_by_id(sess, id_2) 112 | 113 | col_cmp1, col_cmp2 = st.columns(2) 114 | 115 | with col_cmp1: 116 | cdd_id = st.text_input('ID', 117 | value=id_1, 118 | key=f"cdd_id{id_1}") 119 | cdd_score = st.text_input('Score', 120 | value=score_1, 121 | key=f"cdd_score{id_1}") 122 | with col_cmp2: 123 | cdd_id = st.text_input('ID', 124 | value=id_2, 125 | key=f"cdd_id{id_2}") 126 | cdd_score = st.text_input('Score', 127 | value=score_2, 128 | key=f"cdd_score{id_2}") 129 | if rec_2.role != rec_1.role: 130 | st.write("Role") 131 | diff_viewer(old_text=rec_1.role, 132 | new_text=rec_2.role, 133 | lang="none") 134 | else: 135 | cdd_role = st.text_input("Role", 136 | value=rec_2.role, 137 | key=f"cdd_role{id_1}") 138 | if rec_2.goal != rec_1.goal: 139 | st.write("Goal") 140 | diff_viewer(old_text=rec_1.goal, 141 | new_text=rec_2.goal, 142 | lang="none") 143 | else: 144 | cdd_goal = st.text_input('Goal', 145 | value=rec_2.goal, 146 | key=f"cdd_goal{id_1}") 147 | if rec_2.guidelines != rec_1.guidelines: 148 | st.write("Gguidelines") 149 | diff_viewer(old_text="\n".join(rec_1.guidelines), 150 | new_text="\n".join(rec_2.guidelines), 151 | lang="none") 152 | else: 153 | cdd_guidelines = st.text_area('Guidelines', 154 | value="\n\n".join(rec_2.guidelines), 155 | key=f"cdd_guide{id_1}") 156 | if rec_2.constraints != rec_1.constraints: 157 | st.write("Constraints") 158 | diff_viewer(old_text="\n".join(rec_1.constraints), 159 | new_text="\n".join(rec_2.constraints), 160 | lang="none") 161 | else: 162 | cdd_constraints = st.text_area('Constraints', 163 | value="\n\n".join(rec_2.constraints), 164 | key=f"cdd_cons{id_1}") 165 | if rec_2.examples != rec_1.examples: 166 | st.write("Examples") 167 | diff_viewer(old_text=rec_1.examples, 168 | new_text=rec_2.examples, 169 | lang="none") 170 | else: 171 | 172 | cdd_examples = st.text_input('Examples', value=rec_2.examples, 173 | key=f"cdd_examp{id_1}") 174 | 175 | if rec_2.output_format != rec_1.output_format: 176 | st.write("Output Format") 177 | diff_viewer(old_text=rec_1.output_format, 178 | new_text=rec_2.output_format, 179 | lang="none") 180 | else: 181 | cdd_output_format = st.text_input('Output Format', 182 | value=rec_2.output_format, 183 | key=f"cdd_output{id_1}") 184 | 185 | 186 | def render_result(id_, sess): 187 | rec = get_result_by_id(sess, id_) 188 | 189 | rst_input = st.text_area("Input", value=rec.data.input) 190 | 191 | col1, col2 = st.columns(2) 192 | 193 | with col1: 194 | rst_output = st.text_area("Correct Answer", value=rec.data.output) 195 | 196 | with col2: 197 | rst_prediction = st.text_area("LLM Prediction", value=rec.prediction) 198 | 199 | st.text("Evaluation Scores") 200 | st.json(rec.scores) 201 | 202 | st.text("Evaluator LLM") 203 | st.json(rec.llm_params) 204 | 205 | 206 | def render_report(report_id, sess): 207 | try: 208 | report = get_analysis_by_id(sess, report_id) 209 | __show_report_continue = True 210 | except sqlalchemy.exc.NoResultFound: 211 | __show_report_continue = False 212 | st.write(f"Unable to get report: {report_id}") 213 | 214 | if __show_report_continue: 215 | st.divider() 216 | st.subheader("Recommendation") 217 | 218 | rcmm = report.recommendation 219 | 220 | rst_thoughts = st.text_area("Thoughts", value=rcmm.thoughts) 221 | rst_revision = st.text_area("Revision", value=rcmm.revision) 222 | 223 | st.divider() 224 | st.subheader("Evaluation Results") 225 | selected_rows = render_data_as_table(report.eval_sets[0].results) 226 | 227 | if len(selected_rows) > 0: 228 | id_ = selected_rows[0]["id"] 229 | 230 | render_result(id_, sess) 231 | 232 | 233 | def render_iopair(data, sess): 234 | io_input = st.text_area("Input", 235 | value=data[1], 236 | key="io_input_1") 237 | io_output = st.text_area("Output", 238 | value=data[2], 239 | key="io_output_1") 240 | 241 | col_io_1, col_io_2, col_io_3, col_io_4 = st.columns(4) 242 | with col_io_1: 243 | io_new_btn = st.button("Create") 244 | if io_new_btn: 245 | add_iopair(sess, io_input, io_output) 246 | st.experimental_rerun() 247 | 248 | with col_io_2: 249 | if data[0]: 250 | io_update_btn = st.button("Update") 251 | if io_update_btn: 252 | update_iopair(sess, data[0], io_input, io_output) 253 | st.success("Record was updated successfully") 254 | -------------------------------------------------------------------------------- /ui/config.py: -------------------------------------------------------------------------------- 1 | DEFAULT_CONFIG = { 2 | "openai_key": "", 3 | "target_llm": "chatgpt", 4 | "analysis_llm": "chatgpt" 5 | } 6 | -------------------------------------------------------------------------------- /ui/utils.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import sqlalchemy 3 | from st_aggrid import (AgGrid, ColumnsAutoSizeMode, GridOptionsBuilder, 4 | GridUpdateMode) 5 | 6 | 7 | def _render_data_as_table(records, multiple_selection=False): 8 | if len(records) == 0: 9 | st.text("No data") 10 | return [] 11 | 12 | if isinstance(records[0], sqlalchemy.engine.row.Row): 13 | data = [dict(x._mapping) for x in records] 14 | elif isinstance(records[0], dict): 15 | data = records 16 | else: 17 | data = [x.to_dict() for x in records] 18 | 19 | df = pd.DataFrame(data) 20 | 21 | gb = GridOptionsBuilder.from_dataframe(df) 22 | # configure selection 23 | if multiple_selection: 24 | gb.configure_selection(selection_mode="multiple", 25 | use_checkbox=True) 26 | else: 27 | gb.configure_selection(selection_mode="single") 28 | 29 | gb.configure_pagination(enabled=True, 30 | paginationAutoPageSize=False, 31 | paginationPageSize=10) 32 | gridOptions = gb.build() 33 | 34 | data = AgGrid(df, 35 | edit=False, 36 | gridOptions=gridOptions, 37 | allow_unsafe_jscode=True, 38 | fit_columns_on_grid_load=False, 39 | # update_mode=GridUpdateMode.SELECTION_CHANGED, 40 | columns_auto_size_mode=ColumnsAutoSizeMode.FIT_CONTENTS) 41 | 42 | selected_rows = data["selected_rows"] 43 | 44 | return selected_rows 45 | --------------------------------------------------------------------------------