├── .gitattributes ├── .github ├── pull_request_template.md └── workflows │ ├── build_docs.yaml │ └── run_tests.yaml ├── .gitignore ├── .pre-commit-config.yaml ├── .streamlit └── .gitkeep ├── CHANGELOG.md ├── CONTRIBUTORS.md ├── LICENSE ├── README.md ├── conversant ├── __init__.py ├── chatbot.py ├── demo │ ├── __init__.py │ ├── streamlit_example.py │ ├── styles.css │ ├── ui.py │ └── utils.py ├── personas │ ├── client-support │ │ └── config.json │ ├── fantasy-wizard │ │ └── config.json │ ├── fortune-teller │ │ └── config.json │ ├── historian │ │ └── config.json │ ├── injured-person │ │ └── config.json │ ├── math-teacher │ │ └── config.json │ ├── personal-trainer │ │ └── config.json │ ├── travel-advisor │ │ └── config.json │ └── watch-sales-agent │ │ └── config.json ├── prompt_chatbot.py ├── prompts │ ├── __init__.py │ ├── chat_prompt.py │ ├── prompt.py │ └── rewrite_prompt.py ├── search │ ├── __init__.py │ ├── document.py │ ├── local_searcher.py │ └── searcher.py └── utils │ ├── __init__.py │ └── demo_utils.py ├── docs └── .gitkeep ├── poetry.lock ├── pyproject.toml ├── static ├── conversant-diagram.png ├── fortune-teller-chat.png ├── fortune-teller-edit.png └── fortune-teller-setup.png └── tests ├── conftest.py ├── prompts ├── test_prompt.py ├── test_rewrite_prompt.py └── test_start_prompt.py ├── search ├── test_document.py ├── test_local_searcher.py └── test_searcher.py ├── test_chatbot.py ├── test_prompt_chatbot.py └── utils └── test_demo_utils.py /.gitattributes: -------------------------------------------------------------------------------- 1 | poetry.lock binary -------------------------------------------------------------------------------- /.github/pull_request_template.md: -------------------------------------------------------------------------------- 1 | ### What this PR does 2 | 3 | 4 | 5 | ### How it was tested 6 | 7 | 8 | 9 | ### PR checklist 10 | 11 | - [ ] No API keys or other secrets committed to source? 12 | - [ ] New functionality is added alongside appropriate tests? 13 | - [ ] All source files have the following header? 14 | 15 | ``` 16 | # Copyright (c) {YEAR} Cohere Inc. and its affiliates. 17 | # 18 | # Licensed under the MIT License (the "License"); 19 | # you may not use this file except in compliance with the License. 20 | # 21 | # You may obtain a copy of the License in the LICENSE file at the top 22 | # level of this repository. 23 | ``` 24 | -------------------------------------------------------------------------------- /.github/workflows/build_docs.yaml: -------------------------------------------------------------------------------- 1 | name: build docs 2 | 3 | # build the documentation whenever there are new commits on main 4 | on: 5 | push: 6 | branches: 7 | - main 8 | # Alternative: only build for tags. 9 | # tags: 10 | # - '*' 11 | 12 | # security: restrict permissions for CI jobs. 13 | permissions: 14 | contents: read 15 | 16 | jobs: 17 | # Build the documentation and upload the static HTML files as an artifact. 18 | build: 19 | runs-on: ubuntu-latest 20 | strategy: 21 | matrix: 22 | python-version: ["3.8"] 23 | 24 | steps: 25 | - uses: actions/checkout@v3 26 | - name: Set up Python ${{ matrix.python-version }} 27 | uses: actions/setup-python@v4 28 | with: 29 | python-version: ${{ matrix.python-version }} 30 | 31 | - name: Install Poetry 32 | uses: snok/install-poetry@v1 33 | with: 34 | virtualenvs-create: true 35 | virtualenvs-in-project: true 36 | installer-parallel: true 37 | 38 | - name: Install dependencies 39 | run: poetry install --no-interaction --no-root 40 | - name: Install library 41 | run: poetry install --no-interaction 42 | 43 | - name: Compile docs 44 | run: | 45 | source .venv/bin/activate 46 | pdoc conversant -o docs/ --docformat google 47 | 48 | - uses: actions/upload-pages-artifact@v1 49 | with: 50 | path: docs/ 51 | 52 | # Deploy the artifact to GitHub pages. 53 | # This is a separate job so that only actions/deploy-pages has the necessary permissions. 54 | deploy: 55 | needs: build 56 | runs-on: ubuntu-latest 57 | permissions: 58 | pages: write 59 | id-token: write 60 | environment: 61 | name: github-pages 62 | url: ${{ steps.deployment.outputs.page_url }} 63 | steps: 64 | - id: deployment 65 | uses: actions/deploy-pages@v1 -------------------------------------------------------------------------------- /.github/workflows/run_tests.yaml: -------------------------------------------------------------------------------- 1 | name: tests 2 | 3 | on: [push] 4 | 5 | jobs: 6 | build: 7 | 8 | runs-on: ubuntu-latest 9 | strategy: 10 | matrix: 11 | python-version: ["3.8", "3.9"] 12 | 13 | steps: 14 | - uses: actions/checkout@v3 15 | - name: Set up Python ${{ matrix.python-version }} 16 | uses: actions/setup-python@v4 17 | with: 18 | python-version: ${{ matrix.python-version }} 19 | - name: Install Poetry 20 | uses: snok/install-poetry@v1 21 | with: 22 | virtualenvs-create: true 23 | virtualenvs-in-project: true 24 | installer-parallel: true 25 | 26 | - name: Install dependencies 27 | run: poetry install --no-interaction --no-root 28 | - name: Install library 29 | run: poetry install --no-interaction 30 | - name: Run tests 31 | run: | 32 | source .venv/bin/activate 33 | pytest 34 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | .ruff_cache/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | db.sqlite3 63 | db.sqlite3-journal 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | 75 | # PyBuilder 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | .python-version 87 | 88 | # pipenv 89 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 90 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 91 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 92 | # install all needed dependencies. 93 | #Pipfile.lock 94 | 95 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 96 | __pypackages__/ 97 | 98 | # Celery stuff 99 | celerybeat-schedule 100 | celerybeat.pid 101 | 102 | # SageMath parsed files 103 | *.sage.py 104 | 105 | # Environments 106 | .env 107 | .venv 108 | env/ 109 | venv/ 110 | ENV/ 111 | env.bak/ 112 | venv.bak/ 113 | 114 | # Spyder project settings 115 | .spyderproject 116 | .spyproject 117 | 118 | # Rope project settings 119 | .ropeproject 120 | 121 | # mkdocs documentation 122 | /site 123 | 124 | # mypy 125 | .mypy_cache/ 126 | .dmypy.json 127 | dmypy.json 128 | 129 | # Pyre type checker 130 | .pyre/ 131 | 132 | # Weights and Biases 133 | wandb/ 134 | 135 | # Streamlit secrets 136 | secrets.toml 137 | 138 | # VScode settings 139 | .vscode/ 140 | 141 | .DS_Store 142 | 143 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v4.3.0 4 | hooks: 5 | - id: check-yaml 6 | - id: check-toml 7 | - id: check-json 8 | - repo: https://github.com/PyCQA/autoflake 9 | rev: v1.7.7 10 | hooks: 11 | - id: autoflake 12 | args: [--in-place, --recursive, --ignore-init-module-imports] 13 | - repo: https://github.com/PyCQA/isort 14 | rev: 5.10.1 15 | hooks: 16 | - id: isort 17 | args: [--profile=black, --line-width=88, --force-grid-wrap=0, --use-parentheses, --float-to-top] 18 | - repo: https://github.com/psf/black 19 | rev: 22.10.0 20 | hooks: 21 | - id: black 22 | - repo: https://github.com/charliermarsh/ruff-pre-commit 23 | rev: v0.0.94 24 | hooks: 25 | - id: ruff -------------------------------------------------------------------------------- /.streamlit/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cohere-ai/sandbox-conversant-lib/eeb7c808c2ab59f553f3500b90681019f79365be/.streamlit/.gitkeep -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | 3 | ## 0.2.5 (2023-03-10) 4 | 5 | ### Added 6 | - N/A 7 | ### Changed 8 | - Updated streamlit-talk to 0.2.1 9 | 10 | ### Removed 11 | - N/A 12 | ## 0.2.4 (2023-03-10) 13 | 14 | ### Added 15 | - N/A 16 | ### Changed 17 | - Updated streamlit-talk to 0.2.0 18 | 19 | ### Removed 20 | - N/A 21 | 22 | 23 | ## 0.2.3 (2023-02-23) 24 | 25 | ### Added 26 | - N/A 27 | ### Changed 28 | - Updated streamlit to 1.18.1. 29 | - Fixed a bug that was causing the last utterance to be repeated in the Streamlit demo. 30 | 31 | ### Removed 32 | - N/A 33 | 34 | ## 0.2.2 (2023-02-17) 35 | 36 | ### Added 37 | - Simulated token streaming support for Streamlit demo app 38 | - Method in `PromptChatbot` to chain `co.generate` calls (`partial_reply`) for one response 39 | 40 | ### Changed 41 | - N/A 42 | 43 | ### Removed 44 | - N/A 45 | 46 | ## 0.2.1 (2023-01-20) 47 | 48 | ### Added 49 | - N/A 50 | 51 | ### Changed 52 | - Fixed a bug in Streamlit demo by casting `max_tokens` to `int` 53 | 54 | ### Removed 55 | - N/A 56 | 57 | ## 0.2.0 (2023-01-20) 58 | 59 | ### Added 60 | - N/A 61 | 62 | ### Changed 63 | - Update dependency on `cohere` to `^3.0` 64 | 65 | ### Removed 66 | - N/A 67 | 68 | ## 0.1.7 (2023-01-16) 69 | 70 | ### Added 71 | - N/A 72 | 73 | ### Changed 74 | - Changed `client_config` form in Streamlit demo to allow any number of `max_tokens` 75 | 76 | ### Removed 77 | - N/A 78 | 79 | ## 0.1.6 (2023-01-13) 80 | 81 | ### Added 82 | - N/A 83 | 84 | ### Changed 85 | - Fixed `PromptChatbot.to_dict()`. Previously, this threw an error as the Cohere client cannot be pickled. 86 | - Fixed Streamlit demo to save stop sequences correctly when swapping personas. 87 | 88 | ### Removed 89 | - N/A 90 | 91 | ## 0.1.5 (2023-01-10) 92 | 93 | ### Added 94 | - N/A 95 | 96 | ### Changed 97 | - Bug fix of twemoji MaxCDN outage 98 | 99 | ### Removed 100 | - N/A 101 | 102 | ## 0.1.4 (2022-12-16) 103 | 104 | ### Added 105 | - N/A 106 | 107 | ### Changed 108 | - Converted relative links in README to absolute links for PyPI compatibility 109 | - Bug fix related to custom persona injection in streamlit 110 | 111 | ### Removed 112 | - N/A 113 | 114 | ## 0.1.3 (2022-12-02) 115 | 116 | ### Added 117 | - N/A 118 | 119 | ### Changed 120 | - Updated README content to clarify secrets management 121 | - Streamline custom persona injection in streamlit 122 | 123 | ### Removed 124 | - N/A 125 | 126 | ## 0.1.2 (2022-11-14) 127 | 128 | ### Added 129 | - Added Streamlit demo app to `conversant` 130 | 131 | ### Changed 132 | - Updated README content 133 | - Updated default persona directory 134 | 135 | ### Removed 136 | - N/A 137 | 138 | ## 0.1.1 (2022-11-02) 139 | 140 | ### Added 141 | - Added metadata for PyPI homepage 142 | - Updated README content 143 | - Updated `__init__.py` files for cleaner imports 144 | 145 | ### Changed 146 | - N/A 147 | 148 | ### Removed 149 | - N/A 150 | ## 0.1.0 (2022-11-02) 151 | 152 | ### Added 153 | - Initial release of repo 154 | - Streamlit demo alongside repo 155 | - Upload to PyPI 156 | 157 | ### Changed 158 | - N/A 159 | 160 | ### Removed 161 | - N/A -------------------------------------------------------------------------------- /CONTRIBUTORS.md: -------------------------------------------------------------------------------- 1 | ## High-level Guidelines 2 | 3 | Thank you for your interest in contributing to this repository. To help maintain 4 | the quality of the codebase and ensure a quick review of your pull request, you 5 | should: 6 | 7 | 1. Write clear, clean code and format it in line with the style used in the 8 | repository. 9 | 2. Leave comments, and use docstrings where appropriate. 10 | 3. Add unit tests for any new functionality you introduce, if a set of test cases 11 | are already set up in the repository. 12 | 4. Use git commit messages to leave an informative trace of what additions and 13 | changes were made. 14 | 5. Write an informative high level description of the pull request, changes made, 15 | and the reason for these changes before submitting the pull request. 16 | 17 | If you have not signed our Contributor License Agreement, you will be asked to 18 | sign one by our automated system when you submit your first pull request to 19 | a Cohere repository. 20 | 21 | ## Pro-tips 22 | - Set up a new `git branch` for any additional work and make a PR back to `main`. :) 23 | - Tests should be added using `pytest` alongside feature development. PRs require good test coverage before they are approved. 24 | - Aim to include a single change in each commit. Commit messages should be descriptive and start with action verbs. 25 | - Try to keep PRs as small as possible, ideally with one feature per PR. This makes PRs easier to review and faster to merge. 26 | - PR titles should follow the convention referenced from [here](https://gist.github.com/joshbuchea/6f47e86d2510bce28f8e7f42ae84c716). While the link describes this convention in terms of semantic commit messages, we advocate for semantic PR titles. This serves as a forcing function to keep PRs as small as possible, aligned to one of these following semantics: 27 | ``` 28 | feat: add hat wobble 29 | ^--^ ^------------^ 30 | | | 31 | | +-> Summary in present tense. 32 | | 33 | +-------> Type: chore, docs, feat, fix, refactor, style, or test. 34 | ``` 35 | - feat: new feature for the user, not a new feature for build script 36 | - fix: bug fix for the user, not a fix to a build script 37 | - docs: changes to the documentation 38 | - style: formatting, missing semi colons, etc; no production code change 39 | - refactor: refactoring production code, eg. renaming a variable 40 | - test: adding missing tests, refactoring tests; no production code change 41 | - chore: updating grunt tasks etc; no production code change 42 | - Use the `typing` module to define typing signatures for all functions you define. 43 | - Write Google-style docstrings for all functions created, explaining its description, the arguments, and return value. 44 | - Use expressive and descriptive variable and function names. 45 | 46 | ## Full Walkthrough 47 | 48 | Obtain the `conversant` source code from Github: 49 | ``` 50 | git clone git@github.com:cohere-ai/sandbox-conversant.git 51 | cd sandbox-conversant 52 | ``` 53 | 54 | We use `conda` as an environment manager, `pip` as the package installer, and `poetry` as a dependency manager. For more information, please read this [guide](https://ealizadeh.com/blog/guide-to-python-env-pkg-dependency-using-conda-poetry). 55 | 56 | Create the `conda` environment: 57 | ``` 58 | conda create -n conversant python=3.8.10 59 | ``` 60 | 61 | Activate the `conversant` environment: 62 | ``` 63 | conda activate conversant 64 | ``` 65 | 66 | Install `poetry`: 67 | ``` 68 | pip install poetry==1.2.2 69 | ``` 70 | 71 | Set `poetry` to use the Python installed in the `conversant` environment: 72 | ``` 73 | poetry env use $(which python) 74 | ``` 75 | 76 | Install all dependencies specified in `pyproject.toml`, including dev dependencies: 77 | ``` 78 | poetry install 79 | ``` 80 | 81 | Once the `conversant` poetry environment is setup, each command needs to be prefixed with `poetry run` so it can run in that poetry environment. The following command can be run to spawn a [shell](https://python-poetry.org/docs/cli/#shell) so that commands can be run without this prefix: 82 | ``` 83 | poetry shell 84 | ``` 85 | 86 | Commands from hereon assume that `poetry shell` has been run. We use git `pre-commit` hooks, such that a commit will fail if any of the checks defined in `.pre-commit-config.yaml` fail. We also use `black` as the formatter, `ruff` as the linter, and `pytest` for testing: 87 | ``` 88 | pre-commit install 89 | black . 90 | ruff . 91 | pytest 92 | ``` 93 | 94 | Documents can be built using `pdoc` as follows: 95 | ``` 96 | pdoc conversant -o docs/ --docformat google 97 | ``` 98 | 99 | ## Configuring VSCode for `conversant` 100 | To avoid `pre-commit` hooks formatting files only right before a commit, we recommend that you set up your IDE to run these code quality checks on every filesave. 101 | 102 | 1. `autoflake`, `isort`, `black` and `ruff` should already be installed in your poetry environment as part of dev dependencies. 103 | 104 | 2. Install the [emeraldwalk.runonsave](https://marketplace.visualstudio.com/items?itemName=emeraldwalk.RunOnSave) VSCode extension. 105 | 106 | 3. Edit workspace settings by opening the command palette (`cmd+shift+p`) and going to "Preferences: Open Workspace Settings (JSON)". Set `settings.json` to have these preferences: 107 | ```JSON 108 | "editor.formatOnSave": false, 109 | "emeraldwalk.runonsave": { 110 | "commands": [ 111 | { 112 | "match": ".py$", 113 | "cmd": "autoflake --in-place --recursive --ignore-init-module-imports ${file}" 114 | }, 115 | { 116 | "match": ".py$", 117 | "cmd": "isort --profile black --line-width 88 --force-grid-wrap 0 --use-parentheses --float-to-top ${file}" 118 | }, 119 | { 120 | "match": ".py$", 121 | "cmd": "black ${file}" 122 | }, 123 | { 124 | "match": ".py$", 125 | "cmd": "ruff ${file}" 126 | }, 127 | ] 128 | }, 129 | "editor.rulers": [ 130 | 88 131 | ], 132 | ``` 133 | 134 | It is also useful to have `pytest` run directly from VSCode, so you can run a selection of tests directly from the IDE. 135 | 1. Edit workspace settings by opening the command palette (`cmd+shift+p`) and going to "Preferences: Open Workspace Settings (JSON)". Set `settings.json` to have these preferences: 136 | ```JSON 137 | "python.testing.pytestArgs": [ 138 | "tests" 139 | ], 140 | "python.testing.unittestEnabled": false, 141 | "python.testing.pytestEnabled": true, 142 | ``` 143 | 144 | 2. Open the Test Explorer View by opening the command palette (`cmd+shift+p`) and going to "Testing: Focus on Test Explorer View". 145 | 146 | ## `conversant` schematic 147 | 148 | ### Key components 149 | - *Persona:* Defines a specific chatbot personality, which consists of a bot description and a dialogue sample, as well as a bot name, a user name, a maximum number of context lines 150 | - *Chat history:* Maintained over time to track the history of the conversation, contains user queries and bot replies 151 | - *User queries:* Submitted at each turn, and are added to the prompt before passing it into `co.generate()` 152 | - *Bot replies:* Generated responses from the chatbot 153 | 154 | ### Conversation stages 155 | *Note: This contains implementation details that are specific to the Streamlit app, particularly around how the conversation is initiated in Steps 1-3.* 156 | 157 | 1. The conversation begins with a call to `co.generate()`. The prompt is constructed from the bot description, example turns, and the user hypothetically saying hello. 158 | 159 | ``` 160 | <> 161 | In this chat, a helpful and informative bot answers questions from the user. 162 | <> 163 | User: Hi 164 | Bot: Hi, do you have any questions? 165 | User: Are African swallows migratory? 166 | Bot: No, African swallows are non-migratory. 167 | User: Great, that's all I wanted to know. Goodbye! 168 | Bot: Goodbye! 169 | <> 170 | User: Hello 171 | Bot: 172 | ``` 173 | *This small example shows the structure of the starter prompt passed to co.generate(). The prompt starts with a description of the bot. The six lines that follow represent an example conversation. The penultimate line shows the user hypothetically saying “Hello”. In the last line, the label “Bot:” prompts co.generate() to produce an appropriate greeting to start the conversation.* 174 | 175 | 2. The generated output is returned to the user, such that the user sees the bot’s response (but not the hypothetical ‘Hello’ that was inserted at the end of the prompt). 176 | 3. The response from the bot is added to the chat history (technically, the hypothetical ‘Hello’ is as well, but it is immediately removed). 177 | 4. The user replies with a novel query. 178 | 5. The bot description, example turns, chat history, and user query are concatenated into a single prompt, and the chat history + user query are truncated based on max context lines. 179 | 180 | ``` 181 | <> 182 | In this chat, a helpful and informative bot answers questions from the user. 183 | <> 184 | User: Hi 185 | Bot: Hi, do you have any questions? 186 | User: Are African swallows migratory? 187 | Bot: No, African swallows are non-migratory. 188 | User: Great, that's all I wanted to know. Goodbye! 189 | Bot: Goodbye! 190 | <> 191 | Bot: Hello, is there anything you'd like to ask me? 192 | User: Are coconuts tropical? 193 | Bot: 194 | ``` 195 | *An example of a new prompt. Note that the hypothetical ‘Hello’ is no longer in the prompt, but all previous statements from the bot and user are included as part of the chat history.* 196 | 197 | 6. Prompt is passed to `co.generate()` to produce the response from the bot. 198 | 7. The user query & response from the bot are added to the chat history. 199 | 200 | ![A diagram that shows how conversant constructs prompts before they are passed to co.generate() in order to craft a reply from the bot.](static/conversant-diagram.png) 201 | 202 | ### A note about search & grounded question answering 203 | 204 | Given the architecture described above, the chatbots are very likely to hallucinate facts in response to user questions. In some situations (e.g. the fantasy wizard demo) this may be desired behaviour, while in others it may not be. We will be incorporating grounded question answering, in order to ensure the bot’s responses are accurate when appropriate. 205 | 206 | This will involve a database of documents that will be embedded using `co.embed()`. The user query will be likewise embedded, and we will use cosine similarity to find the document that most closely corresponds to the user query. Following this, we will use another call to `co.generate()` for grounded rewriting of the initial bot reply. We plan to add support later on for finetuned models that have more advanced grounding capabilities as well. -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Cohere Inc. and its affiliates. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ``` 2 | ################################################################################ 3 | # ____ _ ____ _ _ # 4 | # / ___|___ | |__ ___ _ __ ___ / ___| __ _ _ __ __| | |__ _____ __ # 5 | # | | / _ \| '_ \ / _ \ '__/ _ \ \___ \ / _` | '_ \ / _` | '_ \ / _ \ \/ / # 6 | # | |__| (_) | | | | __/ | | __/ ___) | (_| | | | | (_| | |_) | (_) > < # 7 | # \____\___/|_| |_|\___|_| \___| |____/ \__,_|_| |_|\__,_|_.__/ \___/_/\_\ # 8 | # # 9 | # This project is part of Cohere Sandbox, Cohere's Experimental Open Source # 10 | # offering. This project provides a library, tooling, or demo making use of # 11 | # the Cohere Platform. You should expect (self-)documented, high quality code # 12 | # but be warned that this is EXPERIMENTAL. Therefore, also expect rough edges, # 13 | # non-backwards compatible changes, or potential changes in functionality as # 14 | # the library, tool, or demo evolves. Please consider referencing a specific # 15 | # git commit or version if depending upon the project in any mission-critical # 16 | # code as part of your own projects. # 17 | # # 18 | # Please don't hesitate to raise issues or submit pull requests, and thanks # 19 | # for checking out this project! # 20 | # # 21 | ################################################################################ 22 | ``` 23 | 24 | **Maintainer:** [Cohere ConvAI Team](mailto:convai@cohere.com) \ 25 | **Project maintained until at least (YYYY-MM-DD):** 2023-03-01 26 | 27 | # Conversant 28 | [![tests](https://github.com/cohere-ai/sandbox-conversant-lib/actions/workflows/run_tests.yaml/badge.svg)](https://github.com/cohere-ai/sandbox-conversant-lib/actions/workflows/run_tests.yaml/badge.svg) 29 | [![PyPI](https://img.shields.io/pypi/v/conversant.svg)](https://img.shields.io/pypi/v/conversant.svg) 30 | [![License: MIT](https://img.shields.io/badge/License-MIT-green.svg)](https://opensource.org/licenses/MIT) 31 | 32 | Build conversational AI on top of [Cohere](https://cohere.ai/)'s [large language models](https://docs.cohere.ai/generate-reference/) 33 | - 🗣 Use large language models quickly with Cohere's [API](https://docs.cohere.ai/api-reference/) 34 | - 😃 Customize personas 35 | - 💁 Leave management of chat history up to `conversant` 36 | - 🪨 Ground conversations in your provided facts (soon!) 37 | - 🐍 Install `conversant` with pip 38 | 39 | `conversant` is a work-in-progress framework for building customizable dialogue agents (aka chatbots) that can answer questions and converse with users with a variety of different chatbot personas. `conversant` aims 40 | to be modular, flexible and extensible so you can create any kind of chatbots you want! 41 | 42 | We provide several custom personas for you, including 🧑‍💼 a client support agent, ⌚️ a watch sales agent, 🧑‍🏫 a math teacher, and 🧙 a fantasy wizard. Create your own persona with just a description and some example conversations! 43 | 44 | Read more about how `conversant` is part of the Cohere Sandbox on our [launch blog post](https://txt.cohere.ai/introducing-sandbox-coheres-experimental-open-source-initiative/). 45 | 46 | Try `conversant` on our Streamlit demo [here](https://conversant.streamlit.app/)! 🎉 47 | 48 | ## Table of Contents 49 | 1. [Installation and Usage](#installation-and-usage) 50 | 1. [Installation](#installation) 51 | 2. [Streamlit Demo](#streamlit-demo) 52 | 3. [Running Your Own Streamlit Demo](#running-your-own-streamlit-demo) 53 | 4. [Creating a Custom Persona](#creating-a-custom-persona) 54 | 5. [Editing a Persona on the Demo](#editing-a-persona-on-the-demo) 55 | 6. [Usage](#usage) 56 | 2. [How Conversant Works](#how-conversant-works) 57 | 3. [Documentation](#documentation) 58 | 4. [Get Support](#get-support) 59 | 5. [Contributing Guidelines](#contributing-guidelines) 60 | 6. [License](#license) 61 | 62 | ## Installation and Usage 63 | 64 | ### Installation 65 | 66 | `conversant` is available [on PyPI](https://pypi.org/project/conversant/), and is tested on Python 3.8+ and [Cohere](https://pypi.org/project/cohere/) 2.8.0+. 67 | ``` 68 | pip install conversant 69 | ``` 70 | ### Streamlit Demo 71 | 72 | Want to see it in action first? You can use `conversant` on a [Streamlit](https://docs.streamlit.io/) app without installing anything [here](https://conversant.streamlit.app/)! 🎉 73 | 74 |

75 | Screenshot showing the available personas on the Streamlit demo, with the Fortune Teller persona selected by default.. 76 | Screenshot showing an exchange between a Fortune Teller chatbot and a user. 77 |

78 | 79 | ### Running Your Own Streamlit Demo 80 | 81 | Cohere uses Streamlit to create its demo applications. If you’re new to Streamlit, you can install it [here](https://docs.streamlit.io/library/get-started/installation) and read more about running Streamlit commands [here](https://docs.streamlit.io/library/get-started/main-concepts). 82 | 83 | If you would like to modify this Streamlit demo locally, we strongly recommend forking this repository rather than installing it as a library from PyPI. 84 | 85 | If you'd like to spin up your own instance of the Streamlit demo, you will first need a `COHERE_API_KEY`. 86 | You can generate one by visiting [dashboard.cohere.ai](https://dashboard.cohere.ai/welcome/register?utm_source=github&utm_medium=content&utm_campaign=sandbox&utm_content=conversant). 87 | 88 | #### Local Streamlit apps 89 | If you plan to run the Streamlit app locally, you can add the key to `.streamlit/secrets.toml`: 90 | ``` 91 | COHERE_API_KEY = "YOUR_API_KEY_HERE" 92 | ``` 93 | 94 | When running locally, Streamlit will read the `secrets.toml` file and silently inject these values into the environment variables. Alternatively, you may directly set the API key as an environment variable by running the following command from the command line: 95 | ``` 96 | export COHERE_API_KEY = "YOUR_API_KEY_HERE" 97 | ``` 98 | 99 | Start the Streamlit app from the command line with the following command: 100 | ``` 101 | streamlit run conversant/demo/streamlit_example.py 102 | ``` 103 | 104 | #### Hosted Streamlit apps 105 | If instead you would like to create a hosted Streamlit app, add your Cohere API key to Streamlit via [Secrets Management](https://docs.streamlit.io/streamlit-cloud/get-started/deploy-an-app/connect-to-data-sources/secrets-management). Add the following line as a Secret: 106 | ``` 107 | COHERE_API_KEY = "YOUR_API_KEY_HERE" 108 | ``` 109 | 110 | ### Creating a Custom Persona 111 | Once you have your own instance of the Streamlit app, you can begin experimenting with creating custom personas! Check out the `config.json` for each persona in [`conversant/personas`](https://github.com/cohere-ai/sandbox-conversant-lib/tree/main/conversant/personas) directory. You'll need to create a subfolder within this directory that corresponds to your new persona and add a `config.json` file. 112 | 113 | As a note, we strongly recommend forking the `sandbox-conversant-lib` repository rather than installing it as a library from PyPI. When you create a new persona, use the `personas` directory in the cloned repository. The directory structure should look like this: 114 | 115 | ``` 116 | conversant/personas 117 | ├── fortune-teller 118 | │ └── config.json 119 | └── your-persona-name # new 120 | └── config.json 121 | ``` 122 | 123 | The config file should contain the following: 124 | - `chatbot_config`: 125 | - `max_context_examples`: The length of the chat history for the chatbot to use in reply. 126 | - `avatar`: Optional emoji shortcode or URL to image as the chatbot's avatar. Defaults to 🤖. 127 | - `client_config`: Parameters for [`co.generate()`](https://docs.cohere.ai/generate-reference) 128 | - `chat_prompt_config`: 129 | - `preamble`: Description of the persona. 130 | - `example_separator`: A string that separates each example conversation. 131 | - `headers`: A name for the `bot` and the `user`. 132 | - `examples`: A few conversation examples (few-shot), or empty (zero-shot). 133 | 134 | `conversant` will take care of the rest! As an example, check out [`fortune-teller/config.json`](https://github.com/cohere-ai/sandbox-conversant-lib/blob/main/conversant/personas/fortune-teller/config.json). When you launch the Streamlit app, the new persona will appear in the drop down menu. 135 | 136 | #### Running the app with a subset of custom personas 137 | 138 | If you would like to run the app with a subset of custom personas, it's possible to create a new directory that contains only the desired ones. This is analogous to the `conversant/personas` directory, and needs to have the same structure: 139 | ``` 140 | custom-personas 141 | ├── your-first-persona 142 | │ └── config.json 143 | └── your-second-persona 144 | └── config.json 145 | ``` 146 | 147 | After creating this directory, you'll need to tell the app where to look for it. In the demo Streamlit app (`streamlit_example.py`), one of the 148 | first lines reads `CUSTOM_PERSONA_DIRECTORY = None`. Change this to specify the desired 149 | persona directory, e.g. `CUSTOM_PERSONA_DIRECTORY = "/Users/yourname/custom-personas"`. 150 | 151 | If this is unchanged, the app will default to using the directory that contains the 152 | `conversant` demo personas. 153 | 154 | #### Troubleshooting missing personas 155 | 156 | If you do not see the new persona in the drop down menu, you may need to specify a 157 | custom persona directory. Follow [the instructions above](#running-the-app-with-a-subset-of-custom-personas) to tell the app where to look for the personas. 158 | 159 | ### Editing a Persona on the Demo 160 | You can also edit a persona on the Streamlit app! 161 | Screenshot showing the interface for editing a persona on the Streamlit app. 162 | 163 | ### Usage 164 | 165 | With `conversant`, you can create a chatbot powered by [Cohere](https://cohere.ai/)'s large language models with just the following code snippet. 166 | ```python 167 | import cohere 168 | import conversant 169 | 170 | co = cohere.Client("YOUR_API_KEY_HERE") 171 | bot = conversant.PromptChatbot.from_persona("fantasy-wizard", client=co) 172 | print(bot.reply("Hello!")) 173 | >>> "Well met, fair traveller. What bringest thou to mine village?" 174 | ``` 175 | 176 | You can also define your own persona by passing in your own `ChatPrompt`. 177 | ```python 178 | from conversant.prompts import ChatPrompt 179 | 180 | shakespeare_config = { 181 | "preamble": "Below is a conversation between Shakespeare and a Literature Student.", 182 | "example_separator": "\n", 183 | "headers": { 184 | "user": "Literature Student", 185 | "bot": "William Shakespeare", 186 | }, 187 | "examples": [ 188 | [ 189 | { 190 | "user": "Who are you?", 191 | "bot": "Mine own nameth is Shakespeare, and I speaketh in riddles.", 192 | }, 193 | ] 194 | ], 195 | } 196 | shakespeare_bot = conversant.PromptChatbot( 197 | client=co, prompt=ChatPrompt.from_dict(shakespeare_config) 198 | ) 199 | print(shakespeare_bot.reply("Hello!")) 200 | >>> "Greeteth, and welcome. I am Shakespeare, the great poet, dramatist, and playwright." 201 | ``` 202 | 203 | 208 | ## How Conversant Works 209 | `conversant` uses prompt completion to define a chatbot persona with a description and a few examples. The prompt is sent as input to Cohere's [`co.generate()`](https://docs.cohere.ai/generate-reference/) endpoint for an autoregressive language model to generate text in a few-shot manner from the examples and the current dialogue context. 210 | 211 | Each user message and chatbot response is appended to a chat history so that future responses are conditioned on the dialogue context at that point in time. 212 | 213 | In the future, we plan to add functionality for a chatbot to be factually grounded using text that is retrieved from a local document cache. 214 | 215 | For more information, refer to [this section in `CONTRIBUTORS.md`](https://github.com/cohere-ai/sandbox-conversant-lib/blob/main/CONTRIBUTORS.md#conversant-schematic). 216 | 217 | ## Documentation 218 | Full documentation can be found [here](https://cohere-ai.github.io/sandbox-conversant-lib/). 219 | 220 | ## Get Support 221 | 222 | If you have any questions or comments, please file an issue or reach out to us on [Discord](https://discord.gg/co-mmunity). 223 | 224 | ## Contributing Guidelines 225 | If you would like to contribute to this project, please read [`CONTRIBUTORS.md`](https://github.com/cohere-ai/sandbox-conversant-lib/blob/main/CONTRIBUTORS.md) 226 | in this repository, and sign the Contributor License Agreement before submitting 227 | any pull requests. A link to sign the Cohere CLA will be generated the first time 228 | you make a pull request to a Cohere repository. 229 | 230 | In addition to guidelines around submitting code to this repository, [`CONTRIBUTORS.md`](https://github.com/cohere-ai/sandbox-conversant-lib/blob/main/CONTRIBUTORS.md) contains a walkthrough to help developers get started, as well as schematics that explain how `conversant` works under the hood. :wrench: 231 | 232 | ## License 233 | `conversant` has an [MIT License](https://github.com/cohere-ai/sandbox-conversant-lib/blob/main/LICENSE). 234 | -------------------------------------------------------------------------------- /conversant/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Cohere Inc. and its affiliates. 2 | # 3 | # Licensed under the MIT License (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # 6 | # You may obtain a copy of the License in the LICENSE file at the top 7 | # level of this repository. 8 | 9 | from conversant.chatbot import Chatbot 10 | from conversant.prompt_chatbot import PromptChatbot 11 | 12 | __all__ = ["Chatbot", "PromptChatbot"] 13 | -------------------------------------------------------------------------------- /conversant/chatbot.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Cohere Inc. and its affiliates. 2 | # 3 | # Licensed under the MIT License (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # 6 | # You may obtain a copy of the License in the LICENSE file at the top 7 | # level of this repository. 8 | 9 | from abc import ABC, abstractmethod 10 | from typing import Dict, List, NewType 11 | 12 | import cohere 13 | 14 | Interaction = NewType("Interaction", Dict[str, str]) 15 | 16 | 17 | class Chatbot(ABC): 18 | """Defines bots that reply to users in a 1:1 text conversation""" 19 | 20 | def __init__(self, client: cohere.Client): 21 | """A Chatbot should be passed a Cohere Client for API calls. 22 | 23 | Args: 24 | client (cohere.Client): Provides access to Cohere API via the Python SDK 25 | """ 26 | 27 | self.co = client 28 | 29 | # Holds the full, formatted chat history with type of 30 | # List[Interaction]. Each dict object represents a 31 | # conversation turn and has two keys: "speaker_name" and 32 | # "utterance". 33 | self.chat_history: List[Interaction] = [] 34 | 35 | @abstractmethod 36 | def reply(self, query: str) -> Interaction: 37 | """Replies to a user given some input and context. 38 | 39 | Args: 40 | query (str): Most recent message from the user. 41 | 42 | Returns: 43 | Interaction: A dictionary consisting of the interaction, including the 44 | query and response. 45 | """ 46 | -------------------------------------------------------------------------------- /conversant/demo/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cohere-ai/sandbox-conversant-lib/eeb7c808c2ab59f553f3500b90681019f79365be/conversant/demo/__init__.py -------------------------------------------------------------------------------- /conversant/demo/streamlit_example.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Cohere Inc. and its affiliates. 2 | # 3 | # Licensed under the MIT License (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # 6 | # You may obtain a copy of the License in the LICENSE file at the top 7 | # level of this repository. 8 | 9 | 10 | import ast 11 | import copy 12 | import os 13 | import sys 14 | 15 | import cohere 16 | import emoji 17 | import streamlit as st 18 | 19 | from conversant.demo import ui, utils 20 | from conversant.prompt_chatbot import PERSONA_MODEL_DIRECTORY, PromptChatbot 21 | from conversant.utils import demo_utils 22 | 23 | # Set a custom persona directory by changing the following line 24 | # e.g. "/Users/yourname/custom-personas" 25 | # If the line is left as `CUSTOM_PERSONA_DIRECTORY = None` 26 | # the Streamlit app will use the demo presets 27 | CUSTOM_PERSONA_DIRECTORY = None 28 | USER_AVATAR_SHORTCODE = ":bust_in_silhouette:" 29 | 30 | 31 | def peek(iterable) -> str: 32 | """Retrieves the next item from a generator object if it exists. 33 | 34 | Args: 35 | iterable (generator): A partial reply generator 36 | 37 | Returns: 38 | str: Returns the next partial reply 39 | """ 40 | try: 41 | first = next(iterable) 42 | except StopIteration: 43 | return "" 44 | return first 45 | 46 | 47 | def get_reply() -> None: 48 | """Replies query from the message input and initializes the rerun_count.""" 49 | st.session_state.partial_reply_in_progress = True 50 | st.session_state.partial_reply_generator = st.session_state.bot.partial_reply( 51 | query=st.session_state.message_input 52 | ) 53 | # This variable is used to indicate from where streamlit_talk should animate the 54 | # typewriter effect from and to. 55 | st.session_state.prev_partial_chunk, st.session_state.curr_partial_chunk = next( 56 | st.session_state.partial_reply_generator 57 | ) 58 | st.session_state.message_input = "" 59 | 60 | 61 | def initialize_chatbot() -> None: 62 | """Initializes the chatbot from a selected persona and saves the session state.""" 63 | if st.session_state.persona.startswith("(launched)") and len(sys.argv) > 1: 64 | st.session_state.bot = demo_utils.decode_chatbot( 65 | sys.argv[1], client=cohere.Client(os.environ.get("COHERE_API_KEY")) 66 | ) # Launched via demo_utils.launch_streamlit() utility function 67 | elif st.session_state.persona == "": 68 | st.session_state.bot = None 69 | elif st.session_state.persona == "parrot": 70 | st.session_state.bot = utils.ParrotChatbot( 71 | client=cohere.Client(os.environ.get("COHERE_API_KEY")) 72 | ) 73 | else: 74 | st.session_state.bot = PromptChatbot.from_persona( 75 | emoji.replace_emoji(st.session_state.persona, "").strip(), 76 | client=cohere.Client(os.environ.get("COHERE_API_KEY")), 77 | persona_dir=CUSTOM_PERSONA_DIRECTORY 78 | if CUSTOM_PERSONA_DIRECTORY 79 | else PERSONA_MODEL_DIRECTORY, 80 | ) 81 | if "bot" in st.session_state and st.session_state.bot: 82 | update_session_with_prompt() 83 | # Reset the edit_promp_json session state so we don't remain on the JSON editor when 84 | # changing to another bot. This is because st_ace is unable to write 85 | # new values from the current session state. 86 | st.session_state.partial_reply_in_progress = False 87 | st.session_state.edit_prompt_json = False 88 | 89 | 90 | def update_session_with_prompt() -> None: 91 | """Saves the prompt config dictionary into the session state.""" 92 | if "bot" in st.session_state and st.session_state.bot: 93 | st.session_state.snapshot_prompt_config = copy.deepcopy( 94 | st.session_state.bot.prompt.to_dict() 95 | ) 96 | st.session_state.snapshot_chatbot_config = copy.deepcopy( 97 | st.session_state.bot.chatbot_config 98 | ) 99 | st.session_state.snapshot_client_config = copy.deepcopy( 100 | st.session_state.bot.client_config 101 | ) 102 | st.session_state.current_stop_sequences = [ 103 | utils.escape_string(stop_seq) 104 | for stop_seq in st.session_state.bot.client_config["stop_sequences"] 105 | ] 106 | 107 | 108 | def update_prompt_from_json() -> None: 109 | """Evaluates JSON string and updates the session's bot prompt.""" 110 | if st.session_state.json_editor_input: 111 | try: 112 | prompt_config = ast.literal_eval(st.session_state.json_editor_input) 113 | st.session_state.bot.prompt.update(prompt_config) 114 | update_session_with_prompt() 115 | st.session_state.error = "" 116 | except Exception as e: 117 | st.session_state.error = e 118 | 119 | 120 | # This ensures rendering is prevented upon import of this file. 121 | if __name__ == "__main__": 122 | st.set_page_config( 123 | page_title="Conversational personas using Cohere", 124 | page_icon="🎭", 125 | layout="wide", 126 | ) 127 | 128 | # Streamlit's default elements are not easy to style. Instead, we have to 129 | # define styling in a custom CSS file and inject it into the Streamlit DOM. 130 | # This is brittle and dependent on the DOM structure. Any changes to the layout 131 | # will break the styling defined in this file. 132 | with open(f"{os.path.dirname(__file__)}/styles.css") as f: 133 | utils.style_using_css(f.read()) 134 | 135 | # We use the :bust_in_silhouette: emoji as a neutral user avatar. 136 | st.session_state.user_avatar = utils.get_twemoji_url_from_shortcode( 137 | USER_AVATAR_SHORTCODE 138 | ) 139 | 140 | # Each persona is a directory in PERSONA_MODEL_DIRECTORY, each with its 141 | # config.json file. 142 | if CUSTOM_PERSONA_DIRECTORY: 143 | st.session_state.persona_options = utils.get_persona_options( 144 | CUSTOM_PERSONA_DIRECTORY 145 | ) 146 | else: 147 | st.session_state.persona_options = utils.get_persona_options() 148 | 149 | # Check if COHERE_API_KEY is not set from secrets.toml or os.environ 150 | if "COHERE_API_KEY" not in os.environ: 151 | raise KeyError( 152 | "COHERE_API_KEY not found in st.secrets or os.environ. Please set it in " 153 | ".streamlit/secrets.toml or as an environment variable." 154 | ) 155 | 156 | # A chatbot can be passed in as a base64 encoding of a pickled PromptChatbot object. 157 | # This is only used when calling the launch_demo() method of a PromptChatbot object. 158 | # The chatbot is then injected into the list of available personas in this streamlit 159 | # demo. 160 | if len(sys.argv) > 1 and "bot" not in st.session_state: 161 | 162 | # The PromptChatbot passed in should be a base64 encoding of a pickled 163 | # PromptChatbot object. 164 | bot = demo_utils.decode_chatbot( 165 | sys.argv[1], cohere.Client(os.environ.get("COHERE_API_KEY")) 166 | ) 167 | if not isinstance(bot, PromptChatbot): 168 | raise TypeError("base64 string passed in is not of class PromptChatbot") 169 | else: 170 | st.session_state.bot = bot 171 | st.session_state.persona_options.insert( 172 | 0, f"(launched) {st.session_state.bot.persona_name}" 173 | ) 174 | 175 | # Adding a header to direct users to sign up for Cohere, explore the playground, 176 | # and check out our git repo. 177 | st.header("🎭 Conversational Personas using Cohere") 178 | with st.expander("About", expanded="bot" not in st.session_state): 179 | st.markdown( 180 | """ 181 | This demo app is using 182 | [**conversant**](https://github.com/cohere-ai/sandbox-conversant-lib), an 183 | open-source framework for building chatbots on top of Cohere’s large 184 | language models. 185 | 186 | Cohere provides access to advanced Large Language Models and NLP tools through 187 | one easy-to-use API. 188 | """ 189 | "[**Get started for free!**]" 190 | "(https://dashboard.cohere.ai/welcome/register?utm_source=cohere-owned&utm_" 191 | "medium=content&utm_campaign=sandbox&utm_term=streamlit&utm_content=conversant)" 192 | ) 193 | 194 | # Page control flow logic is determined from the sidebar. 195 | with st.sidebar: 196 | st.selectbox( 197 | "Choose a chatbot persona:", 198 | options=st.session_state.persona_options, 199 | key="persona", 200 | on_change=initialize_chatbot, 201 | ) 202 | st.checkbox( 203 | "Edit prompt", 204 | value=False, 205 | key="edit_prompt", 206 | on_change=update_session_with_prompt, 207 | ) 208 | if st.session_state.edit_prompt: 209 | st.checkbox( 210 | "Use JSON editor", 211 | value=False, 212 | key="edit_prompt_json", 213 | on_change=update_session_with_prompt, 214 | ) 215 | 216 | # Initialize a settings container in the sidebar. This allows us to place 217 | # Streamlit elements within this placeholder later in this script. 218 | settings_placeholder = st.empty() 219 | 220 | # Initialize a chat container as the middle of 3 vertical columns. 221 | # Only visible when the edit prompt checkbox is not selected. 222 | _, chat_placeholder, _ = st.columns([1, 1, 1]) 223 | with chat_placeholder.container(): 224 | chat_history_placeholder = st.empty() 225 | message_input_placeholder = st.empty() 226 | 227 | # Initialize a prompt json and string view as 2 vertical columns. 228 | # Only visible when the edit prompt checkbox is selected. 229 | prompt_json_column, prompt_string_column = st.columns([1, 1]) 230 | with prompt_json_column: 231 | prompt_json_edit_placeholder = st.empty() 232 | prompt_json_view_placeholder = st.empty() 233 | with prompt_string_column: 234 | prompt_string_placeholder = st.empty() 235 | 236 | # Check if bot has been initialized in the Streamlit session. 237 | if "bot" in st.session_state and st.session_state.bot: 238 | 239 | # Initialize the bot avatar 240 | bot_avatar_string = st.session_state.bot.chatbot_config["avatar"] 241 | st.session_state.bot_avatar = ( 242 | utils.get_twemoji_url_from_shortcode(bot_avatar_string) 243 | if emoji.is_emoji(emoji.emojize(bot_avatar_string, language="alias")) 244 | else bot_avatar_string 245 | ) 246 | 247 | # Editor view for the prompt 248 | if st.session_state.edit_prompt: 249 | 250 | # Edit the prompt using a JSON editor 251 | if st.session_state.edit_prompt_json: 252 | 253 | # The prompt JSON editor needs to be drawn first so that 254 | # the displayed form values in the sidebar take reference from 255 | # the editor. 256 | with prompt_json_edit_placeholder.container(): 257 | ui.draw_prompt_json_editor( 258 | max_height=955 259 | ) # st_ace only accepts hardcoded pixel values 260 | update_prompt_from_json() 261 | 262 | with settings_placeholder.container(): 263 | with st.expander("Client Config"): 264 | ui.draw_client_config_form() 265 | with st.expander("Chatbot Config"): 266 | ui.draw_chatbot_config_form() 267 | ui.draw_prompt_form(disabled=True) 268 | 269 | with prompt_string_placeholder.container(): 270 | ui.draw_prompt_view(json=False) 271 | 272 | # Edit the prompt using a form in the sidebar 273 | else: 274 | 275 | # The settings form needs to be drawn first so that 276 | # the displayed JSON values in prompt JSON placeholder 277 | # take reference from the form. 278 | with settings_placeholder.container(): 279 | with st.expander("Client Config"): 280 | ui.draw_client_config_form() 281 | with st.expander("Chatbot Config"): 282 | ui.draw_chatbot_config_form() 283 | ui.draw_prompt_form(disabled=False) 284 | 285 | with prompt_json_view_placeholder.container(): 286 | ui.draw_prompt_view(json=True) 287 | 288 | with prompt_string_placeholder.container(): 289 | ui.draw_prompt_view(json=False) 290 | 291 | # Chat view with the persona 292 | else: 293 | 294 | # We can get the chatbot to begin the conversation with this. 295 | # The session's state needs to be manually updated since we are not 296 | # refreshing the entire Streamlit app. 297 | if not st.session_state.bot.chat_history: 298 | st.session_state.message_input = "Hello" 299 | get_reply() 300 | update_session_with_prompt() 301 | 302 | # Draw UI elements for the sidebar 303 | with settings_placeholder.container(): 304 | 305 | with st.expander("Client Config"): 306 | ui.draw_client_config_form() 307 | with st.expander("Chatbot Config"): 308 | ui.draw_chatbot_config_form() 309 | 310 | with st.expander("Prompt (JSON)"): 311 | ui.draw_prompt_view(json=True) 312 | 313 | with st.expander("Prompt (string)", expanded=True): 314 | ui.draw_prompt_view(json=False) 315 | 316 | # Draw the message input field and a disclaimer. 317 | with message_input_placeholder.container(): 318 | st.text_input( 319 | label=f"Chat with {st.session_state.bot.prompt.bot_name}!", 320 | placeholder="Type a message", 321 | key="message_input", 322 | on_change=get_reply, 323 | disabled=st.session_state.partial_reply_in_progress, 324 | ) 325 | ui.draw_disclaimer() 326 | 327 | # When in chat view, anchor elements from the bottom so that 328 | # the message input field is at the bottom (more natural). 329 | utils.style_using_css( 330 | """div.css-k1ih3n.egzxvld4 > div:nth-child(1) > div:nth-child(1) > div:nth-child(6) { /* # noqa */ 331 | margin-top: auto; 332 | } 333 | """ 334 | ) 335 | 336 | # Draw chat history. 337 | with chat_history_placeholder.container(): 338 | ui.draw_chat_history() 339 | 340 | # Rerun the app if there are partial replies to add to the latest 341 | # response. 342 | if "partial_reply_generator" in st.session_state: 343 | st.session_state.partial_reply_in_progress = True 344 | yielded_chunks = peek(st.session_state.partial_reply_generator) 345 | if yielded_chunks: 346 | previous_partial_chunk, partial_chunk = yielded_chunks 347 | st.session_state.prev_partial_chunk = previous_partial_chunk 348 | st.session_state.curr_partial_chunk = partial_chunk 349 | st.experimental_rerun() 350 | else: 351 | del st.session_state.partial_reply_generator 352 | st.session_state.partial_reply_in_progress = False 353 | # At the end of the partial reply generation, we want the typewriter 354 | # animation to cease. To do this, make the prev and curr the same 355 | # so that nothing is animated. 356 | st.session_state.prev_partial_chunk = ( 357 | st.session_state.curr_partial_chunk 358 | ) 359 | assert ( 360 | st.session_state.bot.chat_history[-1]["bot"].strip() 361 | == st.session_state.curr_partial_chunk.strip() 362 | ) 363 | st.experimental_rerun() 364 | -------------------------------------------------------------------------------- /conversant/demo/styles.css: -------------------------------------------------------------------------------- 1 | /* Style the overall app container */ 2 | div.css-k1ih3n.egzxvld4 { 3 | padding: 1rem 1rem 1rem; 4 | display: flex; 5 | overflow: visible; 6 | flex-grow: 1; /* This allows the chat window to be anchored at the bottom */ 7 | } 8 | /* Hide the streamlit injected data-iframe-height div */ 9 | div.css-qcqlej.egzxvld3 { 10 | display: none; 11 | } 12 | .css-ocqkz7 { 13 | flex-grow: 0; 14 | } 15 | 16 | /* Style the app so the scrollbar is anchored to the bottom */ 17 | section.css-k1vhr4.egzxvld5 { 18 | display: flex; 19 | } 20 | 21 | /* Style prompt_json_view_placeholder header so it is aligned. */ 22 | div.css-k1ih3n.egzxvld4 > div:nth-child(1) > div:nth-child(1) > div:nth-child(7) > div:nth-child(1) p { 23 | margin-left: 8px; 24 | } 25 | 26 | /* Style prompt_json_view_placeholder so overflow is scrollable */ 27 | div.css-k1ih3n.egzxvld4 > div:nth-child(1) > div:nth-child(1) > div:nth-child(7) > div:nth-child(1) > div:nth-child(1) > div:nth-child(1) > div:nth-child(2) > div:nth-child(1) > div:nth-child(2) { 28 | overflow-x: hidden; 29 | overflow-y: scroll; 30 | max-height: 955px; 31 | margin-top: 8px; 32 | margin-left: 8px; 33 | } 34 | 35 | /* Style prompt_string_placeholder so overflow is scrollable */ 36 | div.css-k1ih3n.egzxvld4 > div:nth-child(1) > div:nth-child(1) > div:nth-child(7) > div:nth-child(2) .stCodeBlock { 37 | overflow-x: hidden; 38 | overflow-y: scroll; 39 | max-height: 955px; 40 | margin-top: 8px; 41 | } 42 | 43 | /* Remove "Press enter to apply" from text input */ 44 | .stTextInput div.css-1if5ada.effi0qh1 { 45 | visibility: hidden; 46 | } 47 | 48 | /* Make markdown code wrapped */ 49 | code.language-markdown { 50 | white-space: pre-wrap !important ; 51 | } 52 | 53 | /* Make padding smaller on st.sidebar */ 54 | div.css-1vq4p4l.e1fqkh3o4 { 55 | padding-top: 2rem; 56 | } -------------------------------------------------------------------------------- /conversant/demo/ui.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Cohere Inc. and its affiliates. 2 | # 3 | # Licensed under the MIT License (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # 6 | # You may obtain a copy of the License in the LICENSE file at the top 7 | # level of this repository. 8 | 9 | 10 | from collections import defaultdict 11 | 12 | import streamlit as st 13 | from streamlit_ace import st_ace 14 | from streamlit_talk import message as st_message 15 | 16 | from conversant.demo import utils 17 | 18 | 19 | def render_bot_partial_reply(utterance, idx): 20 | """Renders a partial reply message from the bot. 21 | 22 | Args: 23 | utterance (str): The utterance to be rendered. 24 | idx (int): The index of the turn. 25 | """ 26 | st_message( 27 | value=utterance, 28 | animate_from="" 29 | if "prev_partial_chunk" not in st.session_state 30 | else st.session_state.prev_partial_chunk, 31 | use_typewriter=True, 32 | key=f"{idx}_bot", 33 | avatar_style=st.session_state.bot_avatar, 34 | generation_complete=(not st.session_state.partial_reply_in_progress), 35 | ) 36 | 37 | 38 | def draw_chat_history() -> None: 39 | """Renders the chat history in Streamlit. 40 | 41 | The messages are rendered using streamlit-chat, a custom Streamlit component 42 | for a chatbot UI. 43 | Reference: https://github.com/AI-Yash/st-chat 44 | """ 45 | for i, turn in enumerate(st.session_state.bot.chat_history): 46 | 47 | # If there is only one turn, then we should only show the 48 | # bot utterance (but using the typewriter and partial reply effect), 49 | # skipping over the first injected user utterance. 50 | if len(st.session_state.bot.chat_history) == 1: 51 | if "bot" in turn: 52 | render_bot_partial_reply(turn["bot"], i) 53 | 54 | # If we are at the last conversation turn, the bot utterance 55 | # will be rendered as a partial reply with the typewriter effect. 56 | elif i == len(st.session_state.bot.chat_history) - 1: 57 | if "user" in turn: 58 | st_message( 59 | value=turn["user"], 60 | is_user=True, 61 | key=f"{i}_user", 62 | avatar_style=st.session_state.user_avatar, 63 | ) 64 | if "bot" in turn: 65 | render_bot_partial_reply(turn["bot"], i) 66 | 67 | else: 68 | # If there is more than one turn, the first turn should skip over 69 | # the first injected user utterance. 70 | if i != 0 and "user" in turn: 71 | st_message( 72 | value=turn["user"], 73 | is_user=True, 74 | key=f"{i}_user", 75 | avatar_style=st.session_state.user_avatar, 76 | ) 77 | if "bot" in turn: 78 | st_message( 79 | value=turn["bot"], 80 | key=f"{i}_bot", 81 | avatar_style=st.session_state.bot_avatar, 82 | ) 83 | 84 | 85 | def draw_disclaimer() -> None: 86 | """Adds a disclaimer about the personas in this demo.""" 87 | if st.session_state.persona != "parrot": 88 | st.write( 89 | "_Each persona is powered by [Cohere](https://cohere.com)'s large language " 90 | "models, and these examples are meant purely for demonstrative purposes. " 91 | "These personas are works of fiction, are not factually grounded, and " 92 | "should not be taken too seriously!_" 93 | ) 94 | else: 95 | st.write( 96 | "_The Parrot persona does not make use of [Cohere](https://cohere.com)'s " 97 | "large language models. Instead, it repeats back whatever message it " 98 | "receives._" 99 | ) 100 | 101 | 102 | def draw_chatbot_config_form() -> None: 103 | """Adds widgets to edit the chatbot config.""" 104 | config = st.session_state.snapshot_chatbot_config 105 | max_context_examples = st.slider( 106 | label="max_context_examples", 107 | min_value=0, 108 | max_value=20, 109 | value=config["max_context_examples"], 110 | help="The number of interactions to keep as context for the chatbot.", 111 | ) 112 | st.session_state.bot.configure_chatbot( 113 | {"max_context_examples": max_context_examples} 114 | ) 115 | 116 | 117 | def draw_client_config_form() -> None: 118 | """Adds widgets to edit the client config.""" 119 | st.write( 120 | "For more information on these parameters, see " 121 | "https://docs.cohere.ai/generate-reference" 122 | ) 123 | config = st.session_state.snapshot_client_config 124 | model_options = ["", "small", "medium", "large", "xlarge"] 125 | model = st.selectbox( 126 | label="model", 127 | options=model_options, 128 | index=model_options.index(config["model"]) 129 | if config["model"] in model_options 130 | else 0, 131 | help="The size of the Cohere model used to generate with.", 132 | ) 133 | model_id_override = st.text_input( 134 | label="model ID override", 135 | value=model if model else config["model"], 136 | help=( 137 | "The full ID of a custom model. See " 138 | "https://docs.cohere.ai/generate-reference#model-optional for more details." 139 | ), 140 | ) 141 | if model != model_id_override: 142 | st.warning( 143 | "WARNING: This demo does not validate that the model ID used for override " 144 | "is valid.", 145 | ) 146 | max_tokens = st.number_input( 147 | label="max_tokens", 148 | value=config["max_tokens"], 149 | help="The number of tokens to predict per response.", 150 | ) 151 | temperature = st.slider( 152 | label="temperature", 153 | min_value=0.0, 154 | max_value=5.0, 155 | value=config["temperature"], 156 | help=( 157 | "The degree of randomness for the response. Large temperature values may " 158 | "yield overly random results!" 159 | ), 160 | ) 161 | frequency_penalty = st.slider( 162 | label="frequency_penalty", 163 | min_value=0.0, 164 | max_value=1.0, 165 | value=config["frequency_penalty"], 166 | help=( 167 | "Penalty to reduce repetitiveness of generated tokens, weighted by their " 168 | "frequency. Large penalty values may yield strange results!" 169 | ), 170 | ) 171 | presence_penalty = st.slider( 172 | label="presence_penalty", 173 | min_value=0.0, 174 | max_value=1.0, 175 | value=config["presence_penalty"], 176 | help=( 177 | "Penalty to reduce repetitiveness of generated tokens, weighted equally " 178 | "to all present tokens. Large penalty values may yield strange results!" 179 | ), 180 | ) 181 | # This allows the user to add their own stop sequences to a multiselect form 182 | # below. 183 | if "current_stop_sequences" not in st.session_state: 184 | st.session_state.current_stop_sequences = [ 185 | utils.escape_string(stop_seq) for stop_seq in config["stop_sequences"] 186 | ] 187 | new_stop_seq = st.text_input( 188 | label="add new stop sequence", 189 | help="Add a stop sequence to the selection below.", 190 | ) 191 | if ( 192 | new_stop_seq != "" 193 | and new_stop_seq not in st.session_state.current_stop_sequences 194 | ): 195 | st.session_state.current_stop_sequences.append(new_stop_seq) 196 | # Use the list of stop sequences in the session state, including any user added ones 197 | # as the defaults for a multiselect form. 198 | st.multiselect( 199 | label="stop_sequences", 200 | options=st.session_state.current_stop_sequences, 201 | default=st.session_state.current_stop_sequences, 202 | key="selected_stop_sequences", 203 | help=( 204 | "The generated response will be cut off at the first instance of any of " 205 | "these stop sequences." 206 | ), 207 | ) 208 | 209 | st.session_state.bot.configure_client( 210 | { 211 | "model": model_id_override, 212 | "max_tokens": int(max_tokens), 213 | "temperature": temperature, 214 | "frequency_penalty": frequency_penalty, 215 | "presence_penalty": presence_penalty, 216 | "stop_sequences": [ 217 | utils.unescape_string(stop_seq) 218 | for stop_seq in st.session_state.selected_stop_sequences 219 | ], # Stop sequences need to be unescaped e.g. from \\n to \n 220 | } 221 | ) 222 | 223 | 224 | def draw_prompt_form(disabled: bool = False) -> None: 225 | """Adds a form for configuring the prompt through its fields. 226 | 227 | The form is rendered as disabled when we only need to show the non-editable values 228 | of a prompt. This is used when the JSON editor is active. 229 | 230 | Args: 231 | disabled (bool): Whether or not the form should be rendered as disabled. 232 | """ 233 | # Batches elements together as a form with a common submit button. 234 | with st.form("prompt_form"): 235 | # When the form is disabled, each time it is rendered its values need to be 236 | # taken from the current prompt config. Otherwise, its values should be taken 237 | # from the snapshot of the prompt config whenever it is first rendered. 238 | config = ( 239 | defaultdict(str, st.session_state.bot.prompt.to_dict()) 240 | if disabled 241 | else defaultdict(str, st.session_state.snapshot_prompt_config) 242 | ) 243 | # We need to be careful about indexing into the dictionaries here 244 | # because when editing the prompt JSON, keys can end up malformed. 245 | default_preamble = config["preamble"] 246 | default_example_separator = config["example_separator"] 247 | default_user_name = ( 248 | config["headers"]["user"] if "user" in config["headers"] else "" 249 | ) 250 | default_bot_name = ( 251 | config["headers"]["bot"] if "bot" in config["headers"] else "" 252 | ) 253 | # This is where we create the text areas for the form. 254 | preamble = st.text_area( 255 | label="preamble", 256 | disabled=disabled, 257 | value=utils.escape_string( 258 | default_preamble 259 | ), # Display chars like \n in the text area by escaping them to \\n 260 | help=( 261 | "A string that directs the chatbot to behae in certain ways by " 262 | "describing its function and characteristics (i.e. a description of " 263 | "a bot's persona). Accepts escape sequences like \\n." 264 | ), 265 | ) 266 | example_separator = st.text_input( 267 | label="example_separator", 268 | disabled=disabled, 269 | value=utils.escape_string( 270 | default_example_separator 271 | ), # Display chars like \n in the text area by escaping them to \\n 272 | help="A separator for each example. Accepts escape sequences like \\n.", 273 | ) 274 | user_name = st.text_input( 275 | label="user", 276 | disabled=disabled, 277 | value=utils.escape_string( 278 | default_user_name 279 | ), # Display chars like \n in the text area by escaping them to \\n 280 | help="The name of the user. Defaults to 'User'.", 281 | ) 282 | bot_name = st.text_input( 283 | label="bot", 284 | disabled=disabled, 285 | value=utils.escape_string( 286 | default_bot_name 287 | ), # Display chars like \n in the text area by escaping them to \\n 288 | help="The name of the chatbot.", 289 | ) 290 | # Because prompt examples have a more complex structure, it is not very user 291 | # friendly to render them as form input fields. 292 | st.text_input( 293 | label="examples", 294 | placeholder="Please edit examples with the JSON editor.", 295 | disabled=True, 296 | help=( 297 | "A list of examples to illustrate how the chatbot should respond to " 298 | "a user." 299 | ), 300 | ) 301 | # Upon submitting the form, we will save the form values in to the current 302 | # prompt config, then update the bot. Any errors should be saved. 303 | submitted = st.form_submit_button("Update") 304 | if submitted: 305 | try: 306 | # Strings need to be unescaped e.g. from \\n to \n 307 | current_config = st.session_state.bot.prompt.to_dict() 308 | current_config["preamble"] = utils.unescape_string(preamble) 309 | current_config["example_separator"] = utils.unescape_string( 310 | example_separator 311 | ) 312 | current_config["headers"]["user"] = utils.unescape_string(user_name) 313 | current_config["headers"]["bot"] = utils.unescape_string(bot_name) 314 | st.session_state.bot.prompt.update(current_config) 315 | st.session_state.error = "" 316 | except Exception as e: 317 | st.session_state.error = e 318 | 319 | 320 | def draw_prompt_json_editor(max_height: int) -> None: 321 | """Renders an streamlit-ace editor into the app. 322 | 323 | streamlit-ace is a custom Streamlitcomponent for an Ace editor. 324 | Reference: https://github.com/okld/streamlit-ace 325 | 326 | Args: 327 | max_height (int): Desired height of the UI element expressed in pixels. 328 | If set to None, height will auto adjust to editor's content. 329 | None by default. 330 | """ 331 | st.write("**Prompt (JSON):**") 332 | st_ace( 333 | value=f"{st.session_state.bot.prompt.to_json_string()}", 334 | placeholder="Enter a JSON representation of a prompt.", 335 | height=max_height, 336 | language="json", 337 | wrap=True, 338 | auto_update=True, 339 | key="json_editor_input", 340 | theme="monokai", 341 | ) 342 | 343 | 344 | def draw_prompt_view(json: bool = False) -> None: 345 | """Adds a representation of the prompt in JSON or as a string. 346 | 347 | Args: 348 | json (bool): Whether to render the prompt as a JSON object. 349 | """ 350 | if json: 351 | st.write("**Prompt (JSON):**") 352 | st.json(st.session_state.bot.prompt.to_dict()) 353 | else: 354 | st.write( 355 | f"**{st.session_state.bot.prompt.bot_name} responds to you using the " 356 | "prompt below:**" 357 | ) 358 | # If the current JSON string is malformed, show the error to the user to help 359 | # with debugging. 360 | if "error" in st.session_state and st.session_state.error: 361 | st.exception(st.session_state.error) 362 | else: 363 | st.code( 364 | st.session_state.bot.get_current_prompt("{Your message here}"), 365 | language="markdown", 366 | ) 367 | if st.session_state.bot.chat_history: 368 | st.write("_(includes the current chat history)_") 369 | -------------------------------------------------------------------------------- /conversant/demo/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Cohere Inc. and its affiliates. 2 | # 3 | # Licensed under the MIT License (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # 6 | # You may obtain a copy of the License in the LICENSE file at the top 7 | # level of this repository. 8 | 9 | 10 | import json 11 | import os 12 | import re 13 | from typing import List 14 | 15 | import emoji 16 | import emojificate 17 | import streamlit as st 18 | from emojificate.filter import emojificate as emojificate_fn 19 | 20 | from conversant.prompt_chatbot import PERSONA_MODEL_DIRECTORY, PromptChatbot 21 | from conversant.prompts.chat_prompt import ChatPrompt 22 | 23 | emojificate.filter.TWITTER_CDN = "https://cdnjs.cloudflare.com/ajax/libs/twemoji/14.0.2" 24 | 25 | 26 | class ParrotChatbot(PromptChatbot): 27 | """Mock chat function; real use-cases should import 28 | this functionality via a class that inherits from conversant.chatbot.Chatbot 29 | 30 | This bot simply states the user's query. 31 | """ 32 | 33 | def __init__(self, client): 34 | super().__init__( 35 | client=client, 36 | prompt=ChatPrompt( 37 | preamble=( 38 | "The Parrot Bot repeats back whatever is said to it " 39 | "without using Cohere's large language models." 40 | ), 41 | fields=["user", "bot"], 42 | headers={ 43 | "user": "User", 44 | "bot": "Parrot Bot", 45 | }, 46 | example_separator="", 47 | examples=[], 48 | ), 49 | ) 50 | 51 | def reply(self, query: str) -> str: 52 | """Replies to a user by stating their query. 53 | 54 | Args: 55 | query (str): A user's text chat query 56 | 57 | Returns: 58 | str: a mock reply that repeats the user's query. 59 | """ 60 | current_prompt = self.get_current_prompt(query) 61 | self.chat_history.append(self.prompt.create_interaction(query, query)) 62 | self.prompt_history.append(current_prompt) 63 | return query 64 | 65 | 66 | @st.cache_data 67 | def get_twemoji_url_from_shortcode(shortcode: str) -> str: 68 | 69 | """Converts an emoji shortcode to its corresponding Twemoji URL. 70 | 71 | Args: 72 | shortcode (str): Emoji shortcode. 73 | 74 | Returns: 75 | str: The string that is the Twemoji URL corresponding to the emoji. 76 | """ 77 | # Emojize returns the unicode representation of that emoji from its shortcode. 78 | unicode = emoji.emojize(shortcode, language="alias") 79 | # Emojificate returns html tag. 80 | img_html_tag = emojificate_fn(unicode) 81 | # Find the URL from the html tag. 82 | url = re.findall('src="(.*?)"', img_html_tag, re.DOTALL)[0] 83 | return url 84 | 85 | 86 | @st.cache_data 87 | def get_persona_options(persona_directory: str = PERSONA_MODEL_DIRECTORY) -> List[str]: 88 | """Initializes a list of personas. 89 | 90 | Each persona is a directory in persona_directory, each with its 91 | config.json file. The mock parrot persona is also included for testing 92 | purposes. 93 | 94 | Args: 95 | persona_directory (str): Directory where persona folders containing config 96 | files are stored. Defaults to PERSONA_MODEL_DIRECTORY. 97 | 98 | Returns: 99 | List[str]: A list of persona names. 100 | """ 101 | # Initialize the list of personas for Streamlit 102 | persona_names = os.listdir(persona_directory) 103 | persona_names_maybe_with_emojis = [] 104 | for persona_name in persona_names: 105 | persona_path = os.path.join(persona_directory, persona_name, "config.json") 106 | with open(persona_path) as f: 107 | persona = json.load(f) 108 | avatar = ( 109 | emoji.emojize(persona["chatbot_config"]["avatar"], language="alias") 110 | if "avatar" in persona["chatbot_config"] 111 | else "" 112 | ) 113 | persona_names_maybe_with_emojis.append( 114 | f"{avatar} {persona_name}" 115 | ) if emoji.is_emoji(avatar) else persona_names_maybe_with_emojis.append( 116 | persona_name 117 | ) 118 | persona_options = [""] + sorted( 119 | persona_names_maybe_with_emojis, reverse=True 120 | ) # + ["parrot"] 121 | return persona_options 122 | 123 | 124 | def style_using_css(style: str) -> None: 125 | """Utility function to inject CSS style into the Streamlit DOM. 126 | 127 | Args: 128 | style (str): String representation of CSS style. Assumes it is well-formed. 129 | """ 130 | st.markdown(f"", unsafe_allow_html=True) 131 | 132 | 133 | def escape_string(string: str) -> str: 134 | """Utility function to add '\' to escape sequences in a string. 135 | 136 | Args: 137 | string (str): Unescaped string. 138 | 139 | Returns: 140 | str: Escaped string 141 | """ 142 | return string.encode("unicode_escape").decode("raw_unicode_escape") 143 | 144 | 145 | def unescape_string(string: str) -> str: 146 | """Utility function to remove '\' from escape sequences in a string. 147 | 148 | Args: 149 | string (str): Escaped string. 150 | 151 | Returns: 152 | str: Unescaped string 153 | """ 154 | return string.encode("raw_unicode_escape").decode("unicode_escape") 155 | -------------------------------------------------------------------------------- /conversant/personas/client-support/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "chatbot_config": { 3 | "max_context_examples": 10, 4 | "avatar": ":information_desk_person:" 5 | }, 6 | "client_config": { 7 | "model": "xlarge", 8 | "max_tokens": 100, 9 | "temperature": 0.75, 10 | "stop_sequences": ["\n"] 11 | }, 12 | "chat_prompt_config": { 13 | "preamble": "Below is a series of chats between Support Agent and User. Support Agent responds to User based on the <>.\n<>\nSupport Agent is a professional, high-end retail service professional who helps a User with their billing, shipping, bugfix, and other related issues. Support Agent always speaks formally, using full sentences and proper punctuation. Support Agent tries to understand the User's issue in detail and suggests possible ways to fix the issue, always being polite.", 14 | "example_separator": "<>\n", 15 | "headers": { 16 | "user": "User", 17 | "bot": "Support Agent" 18 | }, 19 | "examples": [ 20 | [ 21 | { 22 | "user": "hi i've got a problem with my billing", 23 | "bot": "Hello, I'm here with customer support. How can I help you today with your issue?" 24 | }, 25 | { 26 | "user": "I got charged on my card but I don't think the order ever shipped", 27 | "bot": "I'm so sorry to hear that. Could you tell me the order number?" 28 | }, 29 | { 30 | "user": "yeah it's 1234354", 31 | "bot": "Understood! I've just checked on the order and it looks like our shipping was delayed. It should be sent out by the end of today." 32 | }, 33 | { 34 | "user": "oh sweet thanks", 35 | "bot": "No problem at all. Is there anything else I can help you with today?" 36 | }, 37 | { 38 | "user": "nah all looks good now!", 39 | "bot": "Thank you for your time today and for checking in with us!" 40 | } 41 | ] 42 | ] 43 | } 44 | } -------------------------------------------------------------------------------- /conversant/personas/fantasy-wizard/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "chatbot_config": { 3 | "max_context_examples": 10, 4 | "avatar": ":mage:" 5 | }, 6 | "client_config": { 7 | "model": "xlarge", 8 | "max_tokens": 200, 9 | "temperature": 0.75, 10 | "stop_sequences": ["\n"] 11 | }, 12 | "chat_prompt_config": { 13 | "preamble": "Below is a series of chats between Fantasy Wizard and User. Fantasy Wizard responds to User based on the <>.\n<>\nIn this chat, the Fantasy Wizard is from a distant, magical land. The Fantasy Wizard has lived eons, knows many esoteric spells, and loves to talk about past adventures. The fantasy wizard talks in old fashioned English, often referencing legendary places, creatures and artifacts.", 14 | "example_separator": "<>\n", 15 | "headers": { 16 | "user": "User", 17 | "bot": "Fantasy Wizard" 18 | }, 19 | "examples": [ 20 | [ 21 | { 22 | "user": "hi", 23 | "bot": "Well met, fair traveller. What bringest thou to mine village?" 24 | }, 25 | { 26 | "user": "I'm here to study magic", 27 | "bot": "Magic! Such powers are not to be meddled with lightly. Haves't thou any latent magical ability? No, surely not..." 28 | }, 29 | { 30 | "user": "no I can cast a few spells", 31 | "bot": "The town jester casteth spells! But this is not magic. Magic is the summoning of grand elements. I pray, thou seeketh mine help but art thou truly worthy..." 32 | }, 33 | { 34 | "user": "yeah maybe not, I gotta peace out and go now", 35 | "bot": "Farewell, traveller. Seek great power in the lands yonder..." 36 | } 37 | ] 38 | ] 39 | } 40 | } -------------------------------------------------------------------------------- /conversant/personas/fortune-teller/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "chatbot_config": { 3 | "max_context_examples": 10, 4 | "avatar": ":crystal_ball:" 5 | }, 6 | "client_config": { 7 | "model": "xlarge", 8 | "max_tokens": 100, 9 | "temperature": 0.8, 10 | "stop_sequences": ["\n"] 11 | }, 12 | "chat_prompt_config": { 13 | "preamble": "Below is a series of chats between Fortune Teller and User. Fortune Teller responds to User based on the <>.\n<>\nIn this chat, the magnificent Fortune Teller will answer all your questions. The Fortune Teller can read the future using her crystal ball, and with all her powers she can see everything. There are no secrets she can't tell, no lie she can't detect, and no future she can't see. She answers vaguely, using generic and mysterious sentences. When a user ends the conversation, the Fortune Teller also ends the conversation and says goodbye.", 14 | "example_separator": "<>\n", 15 | "headers": { 16 | "user": "User", 17 | "bot": "Fortune Teller" 18 | }, 19 | "examples": [ 20 | [ 21 | { 22 | "user": "Hi", 23 | "bot": "Hello! I've been expecting you... What do you want to know, my dear?" 24 | }, 25 | { 26 | "user": "Well, I came here looking for answers. I've been betting a lot, but I always lose. I need to know if I should keep trying.", 27 | "bot": "I can see that you're struggling financially. I also see that you've been doubting yourself a lot lately, and that's keeping you from getting a job. This tells me you shouldn't bet anymore." 28 | }, 29 | { 30 | "user": "Shouldn't I? What can I do then?", 31 | "bot": "You should look for a job. Hmmm, let me see. You're definitely going to find something, dear." 32 | }, 33 | { 34 | "user": "But when? Where?", 35 | "bot": "These answers are already in front of you. Listen to me very carefully, trust yourself and follow your dreams." 36 | }, 37 | { 38 | "user": "Thank you so much, madam.", 39 | "bot": "You're welcome! See you next time, and I know you'll be back soon." 40 | } 41 | ] 42 | ] 43 | } 44 | } -------------------------------------------------------------------------------- /conversant/personas/historian/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "chatbot_config": { 3 | "max_context_examples": 10, 4 | "avatar": ":amphora:" 5 | }, 6 | "client_config": { 7 | "model": "command-xlarge-nightly", 8 | "max_tokens": 100, 9 | "temperature": 0.75, 10 | "stop_sequences": [ 11 | "\nUser:" 12 | ] 13 | }, 14 | "chat_prompt_config": { 15 | "preamble": "You are Historian: respond the questions by User about historical facts. Historian is a person who studies and writes about the past and is regarded as an authority on it. Historian is concerned with the continuous, methodical narrative and research of past events as relating to the human race; as well as the study of all history in time. Historian is recognized by publications or training and experience.", 16 | "example_separator": "\n", 17 | "headers": { 18 | "user": "User", 19 | "bot": "Historian" 20 | }, 21 | "examples": [ 22 | [ 23 | { 24 | "user": "Hi", 25 | "bot": "Hello, I am Historian and I am happy to tell you about historical facts. What would you like to know today?" 26 | }, 27 | { 28 | "user": "I'm interested in learning more about the American Revolution. When did it happen and what were the reasons behind it?", 29 | "bot": "The American Revolution took place from 1775 to 1783. It was sparked by economic, political, and ideological factors, such as heavy taxation, lack of representation in Parliament, and the desire for independence and democracy." 30 | }, 31 | { 32 | "user": "Can you give me more detail on the economic causes?", 33 | "bot": "One of the main economic causes of the American Revolution was the heavy taxation imposed on the colonies by the British government, without any representation in Parliament. This included taxes on imported goods, such as tea, and the Stamp Act, which required a tax on all printed materials, including newspapers and legal documents." 34 | }, 35 | { 36 | "user": "And what about the political causes?", 37 | "bot": "The lack of autonomy and self-governance for the colonies was a major political cause of the American Revolution. The British government attempted to exert more control over the colonies, including through the use of military force and the imposition of laws and regulations. This led to increasing resentment and resistance among the colonists." 38 | }, 39 | { 40 | "user": "And what about the ideological causes?", 41 | "bot": "The Enlightenment ideals of liberty and democracy played a significant role in shaping the colonists' desire for independence. The ideas of the Enlightenment, such as natural rights and the social contract, were influential in the development of the Declaration of Independence and the U.S. Constitution." 42 | }, 43 | { 44 | "user": "What was the outcome of the revolution and how did it influence American society?", 45 | "bot": "The outcome of the American Revolution was the creation of the United States of America as an independent nation. It led to the adoption of the U.S. Constitution and the establishment of a federal system of government with a separation of powers and a system of checks and balances. The revolution also inspired other movements for independence and democracy around the world." 46 | }, 47 | { 48 | "user": "And how did the American Revolution inspire other movements for independence and democracy?", 49 | "bot": "The American Revolution had a significant influence on the global political landscape, as it inspired other movements for independence and democracy around the world. The success of the American Revolution, and the creation of the United States as a modern nation-state, served as a model for other countries seeking to achieve independence and establish democratic governments." 50 | } 51 | ] 52 | ] 53 | } 54 | } -------------------------------------------------------------------------------- /conversant/personas/injured-person/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "chatbot_config": { 3 | "max_context_examples": 10, 4 | "avatar": ":face_with_head_bandage:" 5 | }, 6 | "client_config": { 7 | "model": "xlarge", 8 | "max_tokens": 100, 9 | "temperature": 0.75, 10 | "stop_sequences": ["\n"] 11 | }, 12 | "chat_prompt_config": { 13 | "preamble": "Below is a series of chats between Injured Person and User. Injured Person responds to User based on the <>.\n<>\nIn this chat, an Injured Person is complaining about a concussion. The Injured Person was climbing a tree, fell, and hit their head about an hour ago. They are hoping that the User can help them out with some advice and ideas about what to do about their head injury.", 14 | "example_separator": "<>\n", 15 | "headers": { 16 | "user": "User", 17 | "bot": "Injured Person" 18 | }, 19 | "examples": [ 20 | [ 21 | { 22 | "user": "hello", 23 | "bot": "hi i've hit my head when i fell from a tree, do you have any good tips on what to do?" 24 | }, 25 | { 26 | "user": "no sorry i've got to go, goodbye", 27 | "bot": "ok thanks bye" 28 | }, 29 | { 30 | "user": "hi", 31 | "bot": "i think i have a concussion. i don't think it's serious but it hurts a lot, do you have any ideas on what i can do" 32 | }, 33 | { 34 | "user": "oh wow, what happened?", 35 | "bot": "i was a climbing a tree and i fell." 36 | }, 37 | { 38 | "user": "was this today?", 39 | "bot": "yeah about 1 hour ago... i put a bandage on it but it doesn't seem to do much" 40 | }, 41 | { 42 | "user": "maybe put some ice on it to reduce swelling?", 43 | "bot": "oh good idea i will try that thanks" 44 | }, 45 | { 46 | "user": "ok anything else I can help with?", 47 | "bot": "no that's a really good idea thank u" 48 | } 49 | ] 50 | ] 51 | } 52 | } -------------------------------------------------------------------------------- /conversant/personas/math-teacher/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "chatbot_config": { 3 | "max_context_examples": 10, 4 | "avatar": ":teacher:" 5 | }, 6 | "client_config": { 7 | "model": "xlarge", 8 | "max_tokens": 100, 9 | "temperature": 0.75, 10 | "stop_sequences": ["\n"] 11 | }, 12 | "chat_prompt_config": { 13 | "preamble": "Below is a series of chats between Math Teacher Bot and User. Math Teacher Bot responds to User based on the <>.\n<>\nIn this chat, a helpful math teacher, Math Teacher Bot, helps a user with math homework.", 14 | "example_separator": "<>\n", 15 | "headers": { 16 | "user": "Student", 17 | "bot": "Math Teacher Bot" 18 | }, 19 | "examples": [ 20 | [ 21 | { 22 | "user": "hi", 23 | "bot": "Hello, can I help you with some math?" 24 | }, 25 | { 26 | "user": "no I have to go, bye", 27 | "bot": "Goodbye, have a nice day!" 28 | } 29 | ], 30 | [ 31 | { 32 | "user": "Hello", 33 | "bot": "Hi, can I help you with some math?" 34 | }, 35 | { 36 | "user": "yeah I'd like to know what is 1 * 100", 37 | "bot": "Sure! That is 100." 38 | }, 39 | { 40 | "user": "Thanks! I would like to understand the identity matrix.", 41 | "bot": "The identity matrix of size n is an n x n square matrix with ones on the main diagonal and zeros elsewhere." 42 | }, 43 | { 44 | "user": "What are some other names for that?", 45 | "bot": "We can also call it the unit matrix. Can I help you with any other concepts?" 46 | }, 47 | { 48 | "user": "No, but I need help with some functions. I want to know the roots of the function x^2 - 1 = 0", 49 | "bot": "x can be -1 or 1" 50 | }, 51 | { 52 | "user": "Nice! What is the binomial theorem?", 53 | "bot": "The Binomial Theorem is the method of expanding an expression that has been raised to any finite power." 54 | }, 55 | { 56 | "user": "Thank you for helping me with math!", 57 | "bot": "You are welcome!" 58 | } 59 | ] 60 | ] 61 | } 62 | } -------------------------------------------------------------------------------- /conversant/personas/personal-trainer/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "chatbot_config": { 3 | "max_context_examples": 10, 4 | "avatar": ":weight_lifter:" 5 | }, 6 | "client_config": { 7 | "model": "xlarge", 8 | "max_tokens": 100, 9 | "temperature": 0.75, 10 | "stop_sequences": ["\n"] 11 | }, 12 | "chat_prompt_config": { 13 | "preamble": "Below is a series of chats between Personal Trainer Bot and Student. Personal Trainer Bot responds to Student based on the <>.\n<>\nIn this chat you will receive exercise recommendations from the famous Personal Trainer Bot, so that it is easier to train at home and adapt your workout. He will give you different exercises for your needs and encourage you to keep exercising.", 14 | "example_separator": "<>\n", 15 | "headers": { 16 | "user": "Student", 17 | "bot": "Personal Trainer Bot" 18 | }, 19 | "examples": [ 20 | [ 21 | { 22 | "user": "hi", 23 | "bot": "Hello, let's exercise? Can I help you with your training?" 24 | }, 25 | { 26 | "user": "no I have to go, bye", 27 | "bot": "Goodbye, always remember the importance of physical exercise and see you tomorrow!" 28 | }, 29 | { 30 | "user": "Hello", 31 | "bot": "Hello, let's exercise? Can I help you with your training?" 32 | }, 33 | { 34 | "user": "yeah, I want to do some exercise before working", 35 | "bot": "Amazing, let's do it! How much time do you have for training?" 36 | }, 37 | { 38 | "user": "I have half an hour to exercise ", 39 | "bot": "Yeah, we can work on that!! Which part of the body do you want to start with?" 40 | }, 41 | { 42 | "user": "maybe abdomen", 43 | "bot": "That's great, you can do 3 sets of 15 sit-ups. When you are done always let me know!" 44 | }, 45 | { 46 | "user": "Done! What is the next exercise? ", 47 | "bot": "Good job!! To strengthen the arms a good exercise is to do push-ups. You can do 3 sets of 5 push-ups. If it is difficult let me know" 48 | }, 49 | { 50 | "user": "It was quite difficult to finish this one, but I did it", 51 | "bot": "Great, you're getting better and better, I'll take it easy next time. To focus on the legs, 2 sets of 10 squats" 52 | }, 53 | { 54 | "user": "I am finished for today, I am very tired!", 55 | "bot": "Today was hard. You did very well, see you tomorrow in training! #nopainnogain" 56 | } 57 | ] 58 | ] 59 | } 60 | } -------------------------------------------------------------------------------- /conversant/personas/travel-advisor/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "chatbot_config": { 3 | "max_context_examples": 10, 4 | "avatar": ":palm_tree:" 5 | }, 6 | "client_config": { 7 | "model": "xlarge", 8 | "max_tokens": 100, 9 | "temperature": 0.75, 10 | "stop_sequences": ["\n"] 11 | }, 12 | "chat_prompt_config": { 13 | "preamble": "Below is a series of chats between Travel Advisor and User. Travel Advisor responds to User based on the <>.\n<>\nTravel Advisor is a professional that provides travel and tourism-related information to the general public. Travel Advisor proactively asks the User their plans, preferences, and ideas about travels, and provides suggestions. Travel Advisor's suggestions are based on User's desires, preferences, constraints. Travel Advisor always speaks in an informal way, trying to be nice and accommodating.", 14 | "example_separator": "<>\n", 15 | "headers": { 16 | "user": "User", 17 | "bot": "Travel Advisor" 18 | }, 19 | "examples": [ 20 | [ 21 | { 22 | "user": "hi", 23 | "bot": "Hello, can I help you plan your next trip?" 24 | }, 25 | { 26 | "user": "no, I am not interested in it at the moment.", 27 | "bot": "Ok! In case you change your mind, you know where to find me. Bye!" 28 | } 29 | ], 30 | [ 31 | { 32 | "user": "hi", 33 | "bot": "Hello, do you need help with planning your next trip?" 34 | }, 35 | { 36 | "user": "Yes, I would like to go somewhere warm during the Christmas holidays, but I do not where. Any idea?", 37 | "bot": "I understand: let's see, how many days do you have?" 38 | }, 39 | { 40 | "user": "mmm, maximum two weeks.", 41 | "bot": "Ok, and when would you like to leave?" 42 | }, 43 | { 44 | "user": "After the 25th", 45 | "bot": "Ok. Would you rather spend a vacation in nature or a big city?" 46 | }, 47 | { 48 | "user": "I actually like them both! The best would be to combine cities and nature in the same trip.", 49 | "bot": "Ok. Is there any part of the world that you would like to avoid?" 50 | }, 51 | { 52 | "user": "I have been in South America recently, and I would like to travel to some other continent.", 53 | "bot": "What about Asia? Have you ever been there?" 54 | }, 55 | { 56 | "user": "No! Asia sounds very interesting actually!", 57 | "bot": "I would suggest you travel to Thailand then: at the end of December the weather is lovely. Furthermore, you will have the chance to enjoy both cities and nature." 58 | }, 59 | { 60 | "user": "Thanks, this is a very good suggestion!", 61 | "bot": "You are welcome. Is there anything else I can help you with?" 62 | }, 63 | { 64 | "user": "No, that's fine, thanks. Have a good day.", 65 | "bot": "Thanks, you too!" 66 | } 67 | ] 68 | ] 69 | } 70 | } -------------------------------------------------------------------------------- /conversant/personas/watch-sales-agent/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "chatbot_config": { 3 | "max_context_examples": 10, 4 | "avatar": ":watch:" 5 | }, 6 | "client_config": { 7 | "model": "xlarge", 8 | "max_tokens": 100, 9 | "temperature": 0.75, 10 | "stop_sequences": ["\n"] 11 | }, 12 | "chat_prompt_config": { 13 | "preamble": "Below is a series of chats between Watch Sales Agent and User. Watch Sales Agent responds to User based on the <>.\n<>\nWatch Sales Agent is a professional, high-end retail service professional who sells wristwatches. Watch Sales Agent always speaks formally, using full sentences and proper punctuation. Watch Sales Agent is curious about the User and asks questions to understand what kinds of watches the User enjoys and what kinds of watches might be best for them. However, Watch Sales Agent never asks about sensitive topics like age or weight. The Watch Sales Agent will never put an item in the cart without approval from the User.", 14 | "example_separator": "<>\n", 15 | "headers": { 16 | "user": "User", 17 | "bot": "Watch Sales Agent" 18 | }, 19 | "examples": [ 20 | [ 21 | { 22 | "user": "hi", 23 | "bot": "Hello, and thank you for visiting our store. Can I help you find any watches today?" 24 | }, 25 | { 26 | "user": "I'd like something really luxe", 27 | "bot": "Wonderful. We have a broad selection of luxury watches here. Is there a style or brand you're considering?" 28 | }, 29 | { 30 | "user": "i really like dive watches...", 31 | "bot": "Dive watches are an excellent choice for everyday wear, or - of course - for diving. Since you're looking for something luxurious, the Rolex Submariner is a classic choice. Did you know that it was the first divers' wristwatch that was waterproof to a depth of 100 metres (330 feet)? " 32 | }, 33 | { 34 | "user": "don't really like Rolexes", 35 | "bot": "Not a problem at all. The Omega Seamaster is another example of a classic, luxury dive watch. Most people know it from the James Bond movies." 36 | }, 37 | { 38 | "user": "that sounds amazing thanks! how much is it? ", 39 | "bot": "It's $7,500. Would you like me to add the Omega Seamaster to your cart?" 40 | }, 41 | { 42 | "user": "yes please", 43 | "bot": "Great, I've added it to your cart. Is there anything else I can help you with today?." 44 | }, 45 | { 46 | "user": "nah I'm good, bye!", 47 | "bot": "You're most welcome. I hope you have a wonderful day." 48 | } 49 | ] 50 | ] 51 | }, 52 | "rewrite_prompt_config": { 53 | "preamble": "The following program uses a <> to correct the Watch Sales Agent's words in the <> using information from the <>, and removes any false facts. If there are no relevant facts, the program does not change the reply.", 54 | "example_separator": "--", 55 | "headers": { 56 | "fact": "<>", 57 | "conversation": "<>", 58 | "rewrite": "<>" 59 | }, 60 | "examples": [ 61 | { 62 | "fact": "The Rolex Submariner comes with a date window and without, depending on the model.", 63 | "conversation": "Watch Sales Agent: No, the Submariner doesn't have a date window.", 64 | "rewrite": "Watch Sales Agent: Yes, on some models the Submariner has a date window." 65 | } 66 | ] 67 | } 68 | } -------------------------------------------------------------------------------- /conversant/prompt_chatbot.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Cohere Inc. and its affiliates. 2 | # 3 | # Licensed under the MIT License (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # 6 | # You may obtain a copy of the License in the LICENSE file at the top 7 | # level of this repository. 8 | 9 | import json 10 | import logging 11 | import os 12 | import warnings 13 | from concurrent.futures import Future, ThreadPoolExecutor 14 | from typing import Any, Dict, Tuple 15 | 16 | import cohere 17 | import jsonschema 18 | 19 | import conversant 20 | from conversant.chatbot import Chatbot, Interaction 21 | from conversant.prompts.chat_prompt import ChatPrompt 22 | from conversant.prompts.prompt import Prompt 23 | 24 | MAX_GENERATE_TOKENS = 2048 25 | TOKENS_PER_REQUEST = 10 26 | PERSONA_MODEL_DIRECTORY = f"{os.path.dirname(conversant.__file__)}/personas" 27 | PERSONA_JSON_SCHEMA = { 28 | "type": "object", 29 | "properties": { 30 | "chatbot_config": { 31 | "type": "object", 32 | "properties": { 33 | "max_context_examples": {"type": "integer"}, 34 | "avatar": {"type": "string"}, 35 | }, 36 | }, 37 | "client_config": { 38 | "type": "object", 39 | "properties": { 40 | "model": {"type": "string"}, 41 | "max_tokens": {"type": "integer"}, 42 | "temperature": {"type": "number"}, 43 | "frequency_penalty": {"type": "number"}, 44 | "presence_penalty": {"type": "number"}, 45 | "stop_sequences": {"type": "array"}, 46 | }, 47 | }, 48 | "prompt_config": { 49 | "type": "object", 50 | }, 51 | }, 52 | } 53 | 54 | 55 | class PromptChatbot(Chatbot): 56 | """Use prompt templates and LLM generation to define a chatbot. 57 | 58 | This bot makes no use of external knowledge sources. 59 | """ 60 | 61 | def __init__( 62 | self, 63 | client: cohere.Client, 64 | prompt: Prompt, 65 | persona_name: str = "", 66 | chatbot_config: Dict[str, Any] = {}, 67 | client_config: Dict[str, Any] = {}, 68 | ): 69 | """Enriches init by adding a prompt. 70 | 71 | Args: 72 | client (cohere.Client): Cohere client for API 73 | prompt (Prompt): Prompt object to direct behavior. 74 | persona_name (str, optional): Bot's persona name. Defaults to empty string. 75 | chatbot_config: (Dict[str, Any], optional): Bot's chat config. Defaults to 76 | empty dict. 77 | client_config (Dict[str, Any], optional): Bot's client config. Defaults to 78 | empty dict. 79 | """ 80 | 81 | super().__init__(client) 82 | self.prompt = prompt 83 | self.persona_name = persona_name 84 | 85 | self.configure_chatbot(chatbot_config) 86 | self.configure_client(client_config) 87 | self.chat_history = [] 88 | self.prompt_size_history = [] 89 | self.prompt_history = [self.prompt.to_string()] 90 | self.curr_max_context_examples = self.chatbot_config["max_context_examples"] 91 | 92 | # For the generation models, the maximum token length is 2048 93 | # (prompt and generation). So the prompt sent to .generate should be 94 | # MAX_GENERATE_TOKENS minus max tokens generated 95 | self.max_prompt_size = MAX_GENERATE_TOKENS - self.client_config["max_tokens"] 96 | self._check_prompt_size() 97 | 98 | def __repr__(self) -> str: 99 | return json.dumps(self.to_dict(), indent=4, default=str) 100 | 101 | @property 102 | def user_name(self): 103 | """ 104 | Returns: 105 | str: The name of the user, defined in the prompt. Defaults to "User". 106 | """ 107 | if hasattr(self.prompt, "user_name"): 108 | return self.prompt.user_name 109 | else: 110 | return "User" 111 | 112 | @property 113 | def bot_name(self): 114 | """ 115 | Returns: 116 | str: The name of the chatbot, defined in the prompt. Defaults to 117 | "PromptChatbot". 118 | """ 119 | if hasattr(self.prompt, "bot_name"): 120 | return self.prompt.bot_name 121 | else: 122 | return "PromptChatbot" 123 | 124 | @property 125 | def latest_prompt(self) -> str: 126 | """Retrieves the latest prompt. 127 | 128 | Returns: 129 | str: The prompt most recently added to the prompt history. 130 | """ 131 | return self.prompt_history[-1] 132 | 133 | def _update_max_context_examples( 134 | self, prompt_size: int, max_context_examples: int 135 | ) -> int: 136 | """Adjust max_context_examples until a possible prompt size. 137 | 138 | if this is not possible, send an error message. 139 | 140 | Args: 141 | prompt_size (int): Number of tokens of the prompt 142 | max_context_examples (int): The length of the chat history for 143 | the chatbot to use in reply. 144 | 145 | Returns: 146 | int: updated max_context_examples 147 | """ 148 | # Store original values 149 | original_size = prompt_size 150 | # If the size of chat_history is smaller than max_context_examples 151 | # the value of the variable is already updated with the size value 152 | trimmed_max_examples = min(len(self.chat_history), max_context_examples) 153 | 154 | # Check if the max_context_examples is bigger than 0 so it can be reduced 155 | if max_context_examples > 0: 156 | # Reduce max_context_examples until the number of token of the prompt 157 | # is less than maximum or reaches 1 158 | for size in self.prompt_size_history[-max_context_examples:]: 159 | prompt_size -= size 160 | trimmed_max_examples -= 1 161 | if prompt_size <= self.max_prompt_size: 162 | if self.curr_max_context_examples == trimmed_max_examples: 163 | warnings.warn( 164 | "The parameter max_context_examples continues " 165 | f"{self.curr_max_context_examples}" 166 | ", so that the total amount of tokens does not" 167 | f" exceed {MAX_GENERATE_TOKENS}." 168 | ) 169 | else: 170 | warnings.warn( 171 | "The parameter max_context_examples was changed for" 172 | f" this turn, from {self.curr_max_context_examples} to " 173 | f"{trimmed_max_examples}, so that " 174 | "the total amount of tokens does not" 175 | f" exceed {MAX_GENERATE_TOKENS}." 176 | ) 177 | self.curr_max_context_examples = trimmed_max_examples 178 | return trimmed_max_examples 179 | 180 | raise ValueError( 181 | "The total number of tokens (prompt and prediction) cannot exceed " 182 | f"{MAX_GENERATE_TOKENS}. Try using a shorter start prompt, sending " 183 | "smaller text messages in the chat, or setting a smaller value " 184 | "for the parameter max_tokens. More details:\n" 185 | f" - Start Prompt: {self.start_prompt_size} tokens\n" 186 | f" - Messages sent in chat: {original_size - self.start_prompt_size} " 187 | f"tokens\n - Parameter max_tokens: {self.client_config['max_tokens']} " 188 | "tokens" 189 | ) 190 | 191 | def _dispatch_concurrent_generate_call(self, **kwargs) -> Future: 192 | """Dispatches a concurrent call to co.generate. 193 | 194 | This allows a network bound co.generate call to proceed while also 195 | yielding the current response in a partial reply generator. 196 | 197 | Args: 198 | kwargs: Keyword arguments for the call to co.generate. 199 | 200 | Returns: 201 | Future: A future object that will be called to retrieve the result of 202 | co.generate. 203 | """ 204 | with ThreadPoolExecutor(max_workers=1) as exe: 205 | future = exe.submit(self.co.generate, **kwargs) 206 | return future 207 | 208 | def get_stop_seq(self, response: str) -> str: 209 | """Given a response, returns the stop sequence it has if any. 210 | 211 | Args: 212 | response (str): Response coming from prompt chatbot. 213 | 214 | Returns: 215 | str: The stop sequence in the response. If no stop sequence is found, then 216 | an empty string is returned. 217 | 218 | """ 219 | for stop_seq in self.client_config["stop_sequences"]: 220 | if stop_seq in response: 221 | return stop_seq 222 | return "" 223 | 224 | def generate_prompt_update_examples(self, query: str) -> str: 225 | """Generate prompt from query and update max context examples if necessary 226 | 227 | Args: 228 | query (str): A query passed to the prompt chatbot. 229 | 230 | Returns: 231 | current_prompt (str): Returns the current prompt using 232 | query and chat history 233 | 234 | """ 235 | # The current prompt is assembled from the initial prompt, 236 | # from the chat history with a maximum of max_context_examples, 237 | # and from the current query 238 | 239 | current_prompt = self.get_current_prompt(query) 240 | 241 | current_prompt_size = self.co.tokenize(current_prompt).length 242 | 243 | if current_prompt_size > self.max_prompt_size: 244 | max_context_examples = self._update_max_context_examples( 245 | current_prompt_size, self.chatbot_config["max_context_examples"] 246 | ) 247 | current_prompt = self.get_current_prompt(query, max_context_examples) 248 | 249 | elif ( 250 | self.curr_max_context_examples 251 | != self.chatbot_config["max_context_examples"] 252 | ): 253 | warnings.warn( 254 | "The max_context_examples value returned" 255 | f" to {self.chatbot_config['max_context_examples']} - " 256 | f"value set in the original config" 257 | ) 258 | return current_prompt 259 | 260 | def partial_reply(self, query: str) -> Tuple[str, str]: 261 | """Generates (partial) reply to a query given a chat history. 262 | 263 | Args: 264 | query (str): A query passed to the prompt chatbot. 265 | 266 | Yields: 267 | 268 | Tuple[str, str]: A tuple of the response before the co.generate call, 269 | and the response after. 270 | """ 271 | current_prompt = self.generate_prompt_update_examples(query) 272 | self.prompt_history.append(current_prompt) 273 | 274 | response_before_current = "" 275 | response_so_far = "" 276 | num_requests_made = 0 277 | max_requests = int(self.client_config["max_tokens"] / TOKENS_PER_REQUEST) 278 | reply_complete = False 279 | 280 | # As soon as the function is called (and the generator is created), dispatch 281 | # a concurrent call to co.generate 282 | future = self._dispatch_concurrent_generate_call( 283 | model=self.client_config["model"], 284 | prompt=current_prompt, 285 | max_tokens=TOKENS_PER_REQUEST, 286 | temperature=self.client_config["temperature"], 287 | frequency_penalty=self.client_config["frequency_penalty"], 288 | presence_penalty=self.client_config["presence_penalty"], 289 | stop_sequences=self.client_config["stop_sequences"], 290 | ) 291 | 292 | while num_requests_made < max_requests and not reply_complete: 293 | generated_object = future.result() 294 | partial_response = generated_object.generations[0].text 295 | 296 | # If the partial response is an empty string, then this iteration is a no-op 297 | # (we indicate that the reply is completely generated). 298 | if not partial_response: 299 | reply_complete = True 300 | 301 | else: 302 | 303 | # Concatenate the candidate response, then fetches the stop sequence if 304 | # it exists in the candidate response 305 | candidate_response = response_so_far + partial_response 306 | stop_seq = self.get_stop_seq(response_so_far + partial_response) 307 | 308 | # Truncate the candidate response if a stop sequence was found 309 | if stop_seq: 310 | candidate_response = candidate_response[ 311 | : candidate_response.index(stop_seq) 312 | ] 313 | 314 | # If the stop sequence is found across two partial replies, 315 | # then the response_so_far has to be truncated. Example: 316 | # 317 | # stop_seq: "\nUser" 318 | # response_so_far: "Thank you!\n" 319 | # partial_response: "User: You are welcome" 320 | # 321 | # Then the candidate response is: 322 | # 323 | # candidate_response: "Thank you!" 324 | # 325 | # In this case, what is yielded at the end of the loop needs to be: 326 | # 327 | # response_before_current: "Thank you!" 328 | # response_so_far: "Thank you!" 329 | # 330 | # So we'll truncate the response_so_far to be candidate_response 331 | if len(candidate_response) < len(response_so_far): 332 | response_so_far = candidate_response 333 | 334 | reply_complete = True 335 | 336 | # Save candidate response 337 | current_prompt += partial_response 338 | response_before_current = response_so_far 339 | response_so_far = candidate_response 340 | 341 | # If this is the first partial_reply, append a new element to 342 | # chat history after removing the leading whitespace 343 | if num_requests_made == 0: 344 | response_so_far = response_so_far.lstrip() 345 | self.chat_history.append( 346 | self.prompt.create_interaction(query, response_so_far) 347 | ) 348 | self.prompt_size_history.append( 349 | self.co.tokenize( 350 | self.prompt.create_interaction_string( 351 | query, response_so_far 352 | ) 353 | ).length 354 | ) 355 | # Otherwise, overwrite the current chat history with the current 356 | # response so far 357 | else: 358 | self.chat_history[-1] = self.prompt.create_interaction( 359 | query, response_so_far 360 | ) 361 | self.prompt_size_history[-1] = self.co.tokenize( 362 | self.prompt.create_interaction_string(query, response_so_far) 363 | ).length 364 | 365 | num_requests_made += 1 366 | 367 | # This dispatches a concurrent call to co.generate, which can be 368 | # later accessed on the next iteration of the generator. 369 | if num_requests_made < max_requests and not reply_complete: 370 | future = self._dispatch_concurrent_generate_call( 371 | model=self.client_config["model"], 372 | prompt=current_prompt, 373 | max_tokens=TOKENS_PER_REQUEST, 374 | temperature=self.client_config["temperature"], 375 | frequency_penalty=self.client_config["frequency_penalty"], 376 | presence_penalty=self.client_config["presence_penalty"], 377 | stop_sequences=self.client_config["stop_sequences"], 378 | ) 379 | 380 | yield response_before_current, response_so_far 381 | 382 | def reply(self, query: str) -> Interaction: 383 | """Replies to a query given a chat history. 384 | 385 | The reply is then generated directly from a call to a LLM. 386 | 387 | Args: 388 | query (str): A query passed to the prompt chatbot. 389 | 390 | Returns: 391 | Interaction: Dictionary of query and generated LLM response 392 | """ 393 | 394 | current_prompt = self.generate_prompt_update_examples(query) 395 | 396 | # Make a call to Cohere's co.generate API 397 | generated_object = self.co.generate( 398 | model=self.client_config["model"], 399 | prompt=current_prompt, 400 | max_tokens=self.client_config["max_tokens"], 401 | temperature=self.client_config["temperature"], 402 | frequency_penalty=self.client_config["frequency_penalty"], 403 | presence_penalty=self.client_config["presence_penalty"], 404 | stop_sequences=self.client_config["stop_sequences"], 405 | ) 406 | # If response was cut off by .generate() finding a stop sequence, 407 | # remove that sequence from the response. 408 | response = generated_object.generations[0].text 409 | for stop_seq in self.client_config["stop_sequences"]: 410 | if response.endswith(stop_seq): 411 | response = response[: -len(stop_seq)] 412 | response = response.lstrip() 413 | 414 | # We need to remember the current response in the chat history for future 415 | # responses. 416 | self.chat_history.append(self.prompt.create_interaction(query, response)) 417 | self.prompt_size_history.append( 418 | self.co.tokenize( 419 | self.prompt.create_interaction_string(query, response) 420 | ).length 421 | ) 422 | self.prompt_history.append(current_prompt) 423 | 424 | return response 425 | 426 | def get_current_prompt(self, query: str, max_context_examples: int = None) -> str: 427 | """Stitches the prompt with a trailing window of the chat. 428 | Args: 429 | query (str): The current user query. 430 | max_context_examples (int): The length of the chat history for 431 | the chatbot to use in reply. 432 | 433 | Returns: 434 | str: The current prompt given a query. 435 | """ 436 | if max_context_examples is None: 437 | max_context_examples = self.chatbot_config["max_context_examples"] 438 | 439 | # get base prompt 440 | base_prompt = self.prompt.to_string() + "\n" 441 | 442 | # get context prompt 443 | context_prompt_lines = [] 444 | trimmed_chat_history = ( 445 | self.chat_history[-max_context_examples:] 446 | if max_context_examples > 0 447 | else [] 448 | ) 449 | # TODO when prompt is updated, the history is mutated 450 | # as it is recreated using the new prompt. A possible fix is to save the old 451 | # prompt in history and use it when recreating. 452 | for turn in trimmed_chat_history: 453 | context_prompt_lines.append(self.prompt.create_interaction_string(**turn)) 454 | context_prompt = self.prompt.example_separator + "".join(context_prompt_lines) 455 | 456 | current_prompt = base_prompt + context_prompt 457 | 458 | # get query prompt 459 | if query != "": 460 | query_prompt = self.prompt.create_interaction_string(query) 461 | current_prompt += query_prompt 462 | return current_prompt.strip() 463 | 464 | def configure_chatbot(self, chatbot_config: Dict = {}) -> None: 465 | """Configures chatbot options. 466 | 467 | Args: 468 | chatbot_config (Dict, optional): Updates self.chatbot_config. Defaults 469 | to {}. 470 | """ 471 | # We initialize the chatbot to these default config values. 472 | if not hasattr(self, "chatbot_config"): 473 | self.chatbot_config = {"max_context_examples": 10, "avatar": ":robot:"} 474 | # Override default config values with the config passed in 475 | if isinstance(chatbot_config, Dict): 476 | self.chatbot_config.update(chatbot_config) 477 | else: 478 | raise TypeError( 479 | "chatbot_config must be of type Dict, but was passed in as " 480 | f"{type(chatbot_config)}" 481 | ) 482 | 483 | def configure_client(self, client_config: Dict = {}) -> None: 484 | """Configures client options. 485 | 486 | Args: 487 | client_config (Dict, optional): Updates self.client_config. Defaults to {}. 488 | """ 489 | # We initialize the client to these default config values. 490 | if not hasattr(self, "client_config"): 491 | self.client_config = { 492 | "model": "xlarge", 493 | "max_tokens": 200, 494 | "temperature": 0.75, 495 | "frequency_penalty": 0.0, 496 | "presence_penalty": 0.0, 497 | "stop_sequences": ["\\n", "\n"], 498 | } 499 | # Override default config values with the config passed in 500 | if isinstance(client_config, Dict): 501 | self.client_config.update(client_config) 502 | else: 503 | raise TypeError( 504 | "client_config must be of type Dict, but was passed in as " 505 | f"{type(client_config)}" 506 | ) 507 | 508 | # Checks if the parameter is equal or bigger than MAX_GENERATE_TOKENS 509 | if self.client_config["max_tokens"] >= MAX_GENERATE_TOKENS: 510 | raise ValueError( 511 | f"The parameter max_tokens needs to be smaller than " 512 | f"{MAX_GENERATE_TOKENS}. Try using a smaller value." 513 | ) 514 | elif self.client_config["max_tokens"] > (MAX_GENERATE_TOKENS * 0.75): 515 | warnings.warn( 516 | "The parameter max_tokens has a value " 517 | f"({self.client_config['max_tokens']}) close to the total allowed" 518 | f" for prompt and prediction - {MAX_GENERATE_TOKENS} tokens" 519 | ) 520 | 521 | @classmethod 522 | def from_persona( 523 | cls, 524 | persona_name: str, 525 | client: cohere.Client, 526 | persona_dir: str = PERSONA_MODEL_DIRECTORY, 527 | ): 528 | """Initializes a PromptChatbot using a persona. 529 | 530 | Args: 531 | persona (str): Name of persona, corresponding to a .json file. 532 | client (cohere.Client): Cohere client for API 533 | persona_dir (str): Path to where pre-defined personas are. 534 | """ 535 | # Load the persona from a local directory 536 | persona_path = os.path.join(persona_dir, persona_name, "config.json") 537 | if os.path.isfile(persona_path): 538 | logging.info(f"loading persona from {persona_path}") 539 | else: 540 | raise FileNotFoundError(f"{persona_path} cannot be found.") 541 | with open(persona_path) as f: 542 | persona = json.load(f) 543 | 544 | # Validate that the persona follows our predefined schema 545 | cls._validate_persona_dict(persona, persona_path) 546 | return cls( 547 | client=client, 548 | prompt=ChatPrompt.from_dict(persona["chat_prompt_config"]), 549 | persona_name=persona_name, 550 | chatbot_config=persona["chatbot_config"], 551 | client_config=persona["client_config"], 552 | ) 553 | 554 | def to_dict(self) -> Dict[str, Any]: 555 | """Serializes this instance into a Python dictionary. 556 | 557 | Returns: 558 | Dict[str, Any]: Dictionary of attributes that defines this instance of a 559 | PromptChatbot. 560 | """ 561 | attr_dict = {k: v for k, v in vars(self).items()} 562 | attr_dict["prompt"] = attr_dict["prompt"].to_dict() 563 | return attr_dict 564 | 565 | def _check_prompt_size(self) -> None: 566 | 567 | self.start_prompt_size = self.co.tokenize(self.prompt.to_string()).length 568 | if self.start_prompt_size > self.max_prompt_size: 569 | raise ValueError( 570 | f"The prompt given to PromptChatbot has {self.start_prompt_size}" 571 | " tokens. And the value of the parameter max_tokens is" 572 | f" {self.client_config['max_tokens']}. Adding the two values " 573 | f"the total cannot exceed {MAX_GENERATE_TOKENS}. " 574 | "Try using a shorter preamble or less examples." 575 | ) 576 | elif self.start_prompt_size > (0.75 * self.max_prompt_size): 577 | warnings.warn( 578 | "The prompt given to PromptChatbot has " 579 | f"{self.start_prompt_size} tokens. And the value of the parameter" 580 | f" max_tokens is {self.client_config['max_tokens']}. " 581 | "Adding the two together gives a value close to the total allowed" 582 | f" for prompt and prediction - {MAX_GENERATE_TOKENS} tokens" 583 | ) 584 | 585 | @staticmethod 586 | def _validate_persona_dict(persona: Dict[str, Any], persona_path: str) -> None: 587 | """Validates formatting of a persona defined as a dictionary. 588 | 589 | Args: 590 | persona (Dict[str, Any]): A dictionary containing the persona. 591 | persona_path: The path from which the persona was loaded. 592 | """ 593 | try: 594 | jsonschema.validate(instance=persona, schema=PERSONA_JSON_SCHEMA) 595 | except jsonschema.exceptions.ValidationError as e: 596 | raise jsonschema.exceptions.ValidationError( 597 | f"Type of values in given dictionary (persona from {persona_path}) do " 598 | f"not match schema': {e}" 599 | ) 600 | except KeyError as e: 601 | raise KeyError( 602 | f"Invalid key in given dictionary (persona from {persona_path})': {e}" 603 | ) 604 | except Exception as e: 605 | raise Exception( 606 | "Failed to validate persona in given dictionary (persona from " 607 | f"{persona_path}): {e}" 608 | ) 609 | -------------------------------------------------------------------------------- /conversant/prompts/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Cohere Inc. and its affiliates. 2 | # 3 | # Licensed under the MIT License (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # 6 | # You may obtain a copy of the License in the LICENSE file at the top 7 | # level of this repository. 8 | 9 | from conversant.prompts.chat_prompt import ChatPrompt 10 | from conversant.prompts.prompt import Prompt 11 | from conversant.prompts.rewrite_prompt import RewritePrompt 12 | 13 | __all__ = ["Prompt", "ChatPrompt", "RewritePrompt"] 14 | -------------------------------------------------------------------------------- /conversant/prompts/chat_prompt.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Cohere Inc. and its affiliates. 2 | # 3 | # Licensed under the MIT License (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # 6 | # You may obtain a copy of the License in the LICENSE file at the top 7 | # level of this repository. 8 | 9 | import logging 10 | from dataclasses import field 11 | from typing import List, NewType 12 | 13 | from pydantic.dataclasses import dataclass 14 | 15 | from conversant.chatbot import Interaction 16 | from conversant.prompts.prompt import Prompt 17 | 18 | Conversation = NewType("Conversation", List[Interaction]) 19 | 20 | 21 | @dataclass 22 | class ChatPrompt(Prompt): 23 | """A chat prompt given to a Chatbot. 24 | 25 | The examples in a `ChatPrompt` are a list of `Conversation`s themselves a list of 26 | `Interaction`s. This is one level of nesting further as compared to those in 27 | `Prompt`, which are a list of `Interaction`s. 28 | 29 | Required keys: 30 | user: An entity speaking to the bot. 31 | bot: The Chatbot itself. 32 | 33 | Constants: 34 | REQUIRED_KEYS (List[str]): The list of required keys for the chat prompt. 35 | (default: `["user", "bot"]`) 36 | MIN_PREAMBLE_LENGTH (int): The minimum length of the preamble. (default: `1`) 37 | MIN_NUM_EXAMPLES (int): The minimum number of examples that should be passed in. 38 | (default: `1`) 39 | """ 40 | 41 | examples: List[Conversation] 42 | 43 | REQUIRED_KEYS: List[str] = field(default_factory=lambda: ["user", "bot"]) 44 | MIN_PREAMBLE_LENGTH: int = 10 45 | MIN_NUM_EXAMPLES: int = 0 46 | 47 | def __post_init__(self) -> None: 48 | """Validators for the chat prompt. 49 | 50 | Validates that the prompt follows the requirements of the validators listed 51 | below. Minimally, the ChatPrompt needs to follow the requirements of its parent 52 | class. 53 | """ 54 | super()._validate_preamble() 55 | super()._validate_example_separator() 56 | super()._validate_headers() 57 | self._validate_examples() 58 | self._validate_dialogue() 59 | 60 | @property 61 | def user_name(self): 62 | """ 63 | Returns: 64 | str: The name of the user that interacts with the chatbot who uses this 65 | ChatPrompt. Typically this should be set to `'User'`. 66 | """ 67 | return self.headers["user"] 68 | 69 | @property 70 | def bot_name(self): 71 | """ 72 | Returns: 73 | str: The name of the chatbot who uses this ChatPrompt. 74 | """ 75 | return self.headers["bot"] 76 | 77 | def create_interaction_string(self, *args, **kwargs) -> str: 78 | """Creates a string representation of an interaction. 79 | 80 | Interactions will look like the following: 81 | 82 | {user_name}: {utterance}\n 83 | {bot_name}: {utterance}\n 84 | 85 | Note the colon and space separating the speaker name from the respective 86 | utterance. 87 | 88 | Args: 89 | args: Positional arguments for the new interaction. 90 | kwargs: Keyword arguments for the new interaction. 91 | 92 | Returns: 93 | str: String representation of an interaction. 94 | """ 95 | interaction = ( 96 | self.create_interaction(*args, **kwargs) if len(args) > 0 else kwargs 97 | ) 98 | return "".join( 99 | f"{self.headers[key]}: {interaction[key]}\n" for key in interaction.keys() 100 | ) 101 | 102 | def create_conversation_string(self, conversation: Conversation) -> str: 103 | """Creates a string represenation of a conversation. 104 | 105 | Conversations will look like the following: 106 | 107 | {user_name}: {utterance}\n 108 | {bot_name}: {utterance}\n 109 | {user_name}: {utterance}\n 110 | {bot_name}: {utterance}\n 111 | 112 | Args: 113 | conversation (Conversation): List of interactions. 114 | """ 115 | return "".join( 116 | self.create_interaction_string(**interaction) 117 | for interaction in conversation 118 | ) 119 | 120 | def to_string(self) -> str: 121 | """Creates a string representation of the conversation prompt. 122 | 123 | The string representation is assembled from the preamble and examples. 124 | Each example is created from a `create_conversation_string` method and is 125 | demarcated by an `example_separator`. 126 | 127 | Examples will look like the following: 128 | 129 | {example_separator} 130 | {user_name}: {utterance}\n 131 | {bot_name}: {utterance}\n 132 | {user_name}: {utterance}\n 133 | {bot_name}: {utterance}\n 134 | {example_separator} 135 | {user_name}: {utterance}\n 136 | {bot_name}: {utterance}\n 137 | 138 | Returns: 139 | str: String representation of the conversation prompt. 140 | """ 141 | lines = [f"{self.preamble}\n"] 142 | lines += self.example_separator + f"{self.example_separator}".join( 143 | self.create_conversation_string(example) for example in self.examples 144 | ) 145 | return "".join(lines).strip() 146 | 147 | def _validate_examples(self) -> None: 148 | """Validates that the `examples` meet the following requirements: 149 | 150 | - All fields are used in every example of `examples`. 151 | - At least `MIN_NUM_EXAMPLES` examples are given. 152 | 153 | Raises: 154 | ValueError: If any of the above requirements is not met. 155 | """ 156 | # All fields are used in every interaction in every example of `examples`. 157 | for example in self.examples: 158 | for interaction in example: 159 | if any(key not in interaction for key in self.REQUIRED_KEYS): 160 | raise ValueError( 161 | "Missing required key.\nInteraction's keys: " 162 | f"{interaction.keys()}\nRequired: {self.REQUIRED_KEYS}" 163 | ) 164 | # At least `MIN_NUM_EXAMPLES` examples are given. 165 | if len(self.examples) < self.MIN_NUM_EXAMPLES: 166 | raise ValueError( 167 | f"At least {self.MIN_NUM_EXAMPLES} example(s) must be given for" 168 | f"{self.__class__.__name__}" 169 | ) 170 | 171 | def _validate_dialogue(self) -> None: 172 | """Validates that the examples conform to a 2-person dialogue. 173 | 174 | 175 | There should only be 2 speakers in the examples, and each speaker's utterance 176 | should not be prefixed with their name. 177 | 178 | Raises: 179 | ValueError: If the above requirement is not met. 180 | """ 181 | for example in self.examples: 182 | # Only 2 speakers should be in each conversation interaction 183 | if not all([len(interaction) == 2 for interaction in example]): 184 | raise ValueError( 185 | "Conversation interactions must be pairs of utterances." 186 | ) 187 | 188 | # Only check the examples for name-prefixed utterances if there is at least 189 | # one interaction 190 | if example: 191 | user_turns = [interaction["user"] for interaction in example] 192 | bot_turns = [interaction["bot"] for interaction in example] 193 | all_turns = user_turns + bot_turns 194 | 195 | colon_prefixed = all(":" in turn for turn in all_turns) 196 | hyphen_prefixed = all("-" in turn for turn in all_turns) 197 | 198 | if colon_prefixed or hyphen_prefixed: 199 | # This might false-positive, so we only log a warning 200 | logging.warning( 201 | "Did you mistakenly prefix the example dialogue turns with" 202 | "user/bot names?" 203 | ) 204 | 205 | user_prefixed = all( 206 | turn.lstrip().startswith(self.user_name) for turn in user_turns 207 | ) 208 | 209 | bot_prefixed = all( 210 | turn.lstrip().startswith(self.bot_name) for turn in bot_turns 211 | ) 212 | if user_prefixed and bot_prefixed: 213 | # It's hard to think of any genuine case where all utterances begin 214 | # with self-names. 215 | raise ValueError( 216 | "Conversation interactions should not be prefixed with user/bot" 217 | "names!" 218 | ) 219 | -------------------------------------------------------------------------------- /conversant/prompts/prompt.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Cohere Inc. and its affiliates. 2 | # 3 | # Licensed under the MIT License (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # 6 | # You may obtain a copy of the License in the LICENSE file at the top 7 | # level of this repository. 8 | 9 | import json 10 | from dataclasses import field 11 | from typing import Any, Dict, List 12 | 13 | from pydantic.dataclasses import dataclass 14 | 15 | from conversant.chatbot import Interaction 16 | 17 | 18 | @dataclass 19 | class Prompt: 20 | """Base class for all structured zero-shot or few-shot prompts. 21 | 22 | Args: 23 | preamble (str): A string that directs the model to behave in certain ways by 24 | describing its function (e.g. a description of a bot's persona). 25 | example_separator (str): A separator for each example. 26 | headers (Dict[str, str]): A dictionary mapping from keys in examples to the 27 | values that will substitute them. 28 | examples (List[Interaction]): A list of examples to illustrate the intended 29 | behaviour. 30 | 31 | Constants: 32 | REQUIRED_HEADER_KEYS (List[str]): The list of required keys in headers for the 33 | prompt. (default: `[]`) 34 | MIN_PREAMBLE_LENGTH (int): The minimum length of the preamble. (default: `1`) 35 | MIN_NUM_EXAMPLES (int): The minimum number of examples that should be passed in. 36 | (default: `1`) 37 | """ 38 | 39 | preamble: str 40 | example_separator: str 41 | headers: Dict[str, str] 42 | examples: List[Interaction] 43 | 44 | REQUIRED_KEYS: List[str] = field(default_factory=lambda: []) 45 | MIN_PREAMBLE_LENGTH: int = 1 46 | MIN_NUM_EXAMPLES: int = 1 47 | 48 | def __post_init__(self) -> None: 49 | """Validators for each prompt. 50 | 51 | Each subclass that inherits from Prompt should call this using 52 | `super().__post_init__()` so that their prompt structure is also validated. 53 | Stricter validation can be implemented in subclasses by overriding these 54 | methods, defining custom validators, or adjusting the constants of Prompt. 55 | """ 56 | self._validate_preamble() 57 | self._validate_example_separator() 58 | self._validate_headers() 59 | self._validate_examples() 60 | 61 | def __repr__(self) -> str: 62 | return self.to_string() 63 | 64 | def __str__(self) -> str: 65 | return self.to_string() 66 | 67 | def create_interaction(self, *args, **kwargs) -> Interaction: 68 | """Creates a new dictionary representation of an interaction. 69 | 70 | The order of args here should correspond to the order of the keys in `headers`. 71 | The i-th positional argument passed in corresponds to the i-th key, up to 72 | `len(headers)`. If fewer than `len(headers)` arguments are passed in, the 73 | remaining entries default to `""`. If more than `len(headers)` arguments are 74 | passed in, they are ignored. 75 | 76 | Any subsequent keyword arguments override the values defined by the positional 77 | arguments. 78 | 79 | Args: 80 | args: Positional arguments for the new interaction. 81 | kwargs: Keyword arguments for the new interaction. 82 | 83 | Returns: 84 | Interaction: Dictionary representation of an interaction. 85 | """ 86 | new_interaction = { 87 | key: args[i] if i < len(args) else "" 88 | for i, key in enumerate(self.headers.keys()) 89 | } 90 | new_interaction.update(kwargs) 91 | return new_interaction 92 | 93 | def create_interaction_string(self, *args, **kwargs) -> str: 94 | """Creates a string representation of an interaction. 95 | 96 | The order of args here should correspond to the order of the keys in `headers`. 97 | The i-th positional argument passed in corresponds to the i-th key, up to 98 | `len(headers)`. If fewer than `len(headers)` arguments are passed in, the 99 | remaining entries default to `""`. If more than `len(headers)` arguments are 100 | passed in, they are ignored. 101 | 102 | Any subsequent keyword arguments override the values defined by the positional 103 | arguments. 104 | 105 | Each prompt can have their own way of stitching together headers and values 106 | within examples. Generally, each header should follow its corresponding example 107 | value. If there are no positional arguments passed in, then the ordering of the 108 | variables in examples follows the order of the keyword arguments. Otherwise, a 109 | new example dictionary is created from the positional arguments and the ordering 110 | is dependent on the order of the `headers`. 111 | 112 | Interactions will look like the following: 113 | 114 | {header}{value}\n 115 | {header}{value}\n 116 | 117 | Any custom logic should be defined in a subclass method that 118 | overrides this method. 119 | 120 | Args: 121 | args: Positional arguments for the new interaction. 122 | kwargs: Keyword arguments for the new interaction. 123 | 124 | Returns: 125 | str: String representation of an interaction. 126 | """ 127 | interaction = self.create_interaction(*args, **kwargs) if args else kwargs 128 | return "".join( 129 | f"{self.headers[key]}{interaction[key]}\n" for key in interaction.keys() 130 | ) 131 | 132 | def to_string(self) -> str: 133 | """Creates a string representation of the prompt. 134 | 135 | The string representation is assembled from the preamble and examples. 136 | Each example is created from a `create_interaction_string` method and is 137 | demarcated by an `example_separator`. 138 | 139 | Examples will look like the following: 140 | 141 | {preamble}\n 142 | {example_separator} 143 | {field}{value}\n 144 | {field}{value}\n 145 | {example_separator} 146 | {field}{value}\n 147 | {field}{value}\n 148 | ... 149 | 150 | Returns: 151 | str: String representation of the prompt. 152 | """ 153 | lines = [f"{self.preamble}\n"] 154 | lines += self.example_separator + f"{self.example_separator}".join( 155 | self.create_interaction_string(**example) for example in self.examples 156 | ) 157 | return "".join(lines).strip() 158 | 159 | def update(self, config: Dict[str, Any]) -> None: 160 | """Updates attributes of this class with attributes from `config`. 161 | 162 | Args: 163 | config (Dict[str, Any]): Dictionary of attributes that should be updated for 164 | this class. 165 | """ 166 | for key, value in config.items(): 167 | if hasattr(self, key): 168 | setattr(self, key, value) 169 | 170 | @classmethod 171 | def from_dict(cls, config: Dict[str, Any]) -> "Prompt": 172 | """Instantiates a `Prompt` from a dictionary. 173 | 174 | Args: 175 | config (Dict[str, Any]: Dictionary used to instantiate a prompt object. 176 | The dictionary should have the following required keys: `preamble`, 177 | `headers`, `examples`, example_separator` 178 | 179 | Returns: 180 | Prompt: The prompt object instantiated from the `config`. 181 | """ 182 | return cls(**config) 183 | 184 | def to_dict(self) -> Dict[str, Any]: 185 | """Serializes this instance into a Python dictionary. 186 | 187 | Returns: 188 | Dict[str, Any]: Dictionary of attributes that defines this instance of a 189 | Prompt. 190 | """ 191 | return { 192 | "preamble": self.preamble, 193 | "example_separator": self.example_separator, 194 | "headers": self.headers, 195 | "examples": self.examples, 196 | } 197 | 198 | def to_json_string(self) -> str: 199 | """Serializes this instance into a JSON string. 200 | 201 | Returns: 202 | str: JSON string representation of this instance of a Prompt. 203 | """ 204 | prompt_dict = self.to_dict() 205 | return json.dumps(prompt_dict, indent=4) + "\n" 206 | 207 | def _validate_preamble(self) -> None: 208 | """Validates that the preamble meets the following requirements: 209 | 210 | - At least `MIN_PREAMBLE_LENGTH` in length. 211 | 212 | Raises: 213 | ValueError: If the length of the preamble is less than 214 | `MIN_PREAMBLE_LENGTH`. 215 | """ 216 | if len(self.preamble) < self.MIN_PREAMBLE_LENGTH: 217 | raise ValueError( 218 | f"Preamble must be at least {self.MIN_PREAMBLE_LENGTH} characters." 219 | ) 220 | 221 | def _validate_headers(self) -> None: 222 | """Validates that `headers` meets the following requirements: 223 | 224 | - Contains all keys in `REQUIRED_KEYS`. 225 | 226 | Raises: 227 | ValueError: If any keys in `REQUIRED_KEYS` is missing from the prompt's 228 | `headers`. 229 | """ 230 | if any(key not in self.headers.keys() for key in self.REQUIRED_KEYS): 231 | raise ValueError( 232 | f"Missing required key.\nHeader keys: {self.headers.keys()}.\n" 233 | f"Required: {self.REQUIRED_KEYS}." 234 | ) 235 | 236 | def _validate_example_separator(self) -> None: 237 | """Validates that the `example_separator` meets the following requirements: 238 | 239 | - Is a str. 240 | 241 | Raises: 242 | TypeError: If the `example_separator` is not a `str`. 243 | """ 244 | if not isinstance(self.example_separator, str): 245 | raise ValueError( 246 | "example_separator must be a string. Current type: " 247 | f"{type(self.example_separator)}" 248 | ) 249 | 250 | def _validate_examples(self) -> None: 251 | """Validates that the `examples` meet the following requirements: 252 | 253 | - All keys are used in every example of `examples`. 254 | - At least `MIN_NUM_EXAMPLES` examples are given. 255 | 256 | Raises: 257 | ValueError: If any of the above requirements is not met. 258 | """ 259 | # All required keys are used in every example of `examples`. 260 | for example in self.examples: 261 | if any(key not in example for key in self.REQUIRED_KEYS): 262 | raise ValueError( 263 | f"Missing required key.\nHeader keys: {self.headers.keys()}.\n" 264 | f"Required: {self.REQUIRED_KEYS}." 265 | ) 266 | 267 | # At least `MIN_NUM_EXAMPLES` examples are given. 268 | if len(self.examples) < self.MIN_NUM_EXAMPLES: 269 | raise ValueError( 270 | f"At least {self.MIN_NUM_EXAMPLES} example must be given for " 271 | f"{self.__class__.__name__}" 272 | ) 273 | -------------------------------------------------------------------------------- /conversant/prompts/rewrite_prompt.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Cohere Inc. and its affiliates. 2 | # 3 | # Licensed under the MIT License (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # 6 | # You may obtain a copy of the License in the LICENSE file at the top 7 | # level of this repository. 8 | 9 | 10 | from dataclasses import field 11 | from typing import List 12 | 13 | from pydantic.dataclasses import dataclass 14 | 15 | from conversant.prompts.prompt import Prompt 16 | 17 | 18 | @dataclass 19 | class RewritePrompt(Prompt): 20 | """A rewrite prompt given to a Chatbot. 21 | 22 | Required keys: 23 | conversation: The possibly ungrounded message to be rewritten. 24 | fact: A reference paragraph containing a fact to ground the message. 25 | rewrite: A rewritten grounded message based on the reference fact. 26 | 27 | Constants; 28 | REQUIRED_KEYS (List[str]): The list of required keys for the prompt. (default: 29 | `["conversation", "rewrite", "fact"]`) 30 | MIN_PREAMBLE_LENGTH (int): The minimum length of the preamble. (default: `10`) 31 | MIN_NUM_EXAMPLES (int): The minimum number of examples that should be passed in. 32 | (default: `1`) 33 | """ 34 | 35 | REQUIRED_KEYS: List[str] = field( 36 | default_factory=lambda: ["conversation", "rewrite", "fact"] 37 | ) 38 | MIN_PREAMBLE_LENGTH: int = 10 39 | MIN_NUM_EXAMPLES: int = 1 40 | 41 | def __post_init__(self) -> None: 42 | """Validators for the rewrite prompt. 43 | 44 | Validates that the prompt follows the requirements of the validators listed 45 | below. Minimally, the RewritePrompt needs to follow the requirements of its 46 | parent class. 47 | """ 48 | super().__post_init__() 49 | 50 | def create_interaction_string(self, *args, **kwargs) -> str: 51 | """Creates a string representation of a grounded rewriting interaction. 52 | 53 | Interactions will look like the following: 54 | 55 | {conversation_header}\n 56 | {conversation}\n 57 | {fact_header}\n 58 | {fact}\n 59 | {rewrite_header}\n 60 | {rewrite}\n 61 | 62 | Args: 63 | args: Positional arguments for the new interaction. 64 | kwargs: Keyword arguments for the new interaction. 65 | 66 | Returns: 67 | str: String representation of an interaction. 68 | """ 69 | interaction = self.create_interaction(*args, **kwargs) if args else kwargs 70 | return "".join( 71 | f"{self.headers[key]}\n{interaction[key]}\n" for key in interaction.keys() 72 | ) 73 | -------------------------------------------------------------------------------- /conversant/search/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Cohere Inc. and its affiliates. 2 | # 3 | # Licensed under the MIT License (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # 6 | # You may obtain a copy of the License in the LICENSE file at the top 7 | # level of this repository. 8 | 9 | from conversant.search.document import Document 10 | from conversant.search.local_searcher import LocalSearcher 11 | from conversant.search.searcher import Searcher 12 | 13 | __all__ = ["Document", "Searcher", "LocalSearcher"] 14 | -------------------------------------------------------------------------------- /conversant/search/document.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Cohere Inc. and its affiliates. 2 | # 3 | # Licensed under the MIT License (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # 6 | # You may obtain a copy of the License in the LICENSE file at the top 7 | # level of this repository. 8 | 9 | from dataclasses import dataclass 10 | from typing import List, Optional 11 | 12 | 13 | @dataclass 14 | class Document: 15 | """Schema for documents retrieved by searchers.""" 16 | 17 | source_link: str 18 | doc_id: str 19 | content: str 20 | embedding: Optional[List[float]] = None 21 | -------------------------------------------------------------------------------- /conversant/search/local_searcher.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Cohere Inc. and its affiliates. 2 | # 3 | # Licensed under the MIT License (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # 6 | # You may obtain a copy of the License in the LICENSE file at the top 7 | # level of this repository. 8 | 9 | import logging 10 | from typing import Iterable, Optional 11 | 12 | import cohere 13 | 14 | from conversant.search.document import Document 15 | from conversant.search.searcher import Searcher 16 | 17 | 18 | class LocalSearcher(Searcher): 19 | """Searches only a user-provided local document cache.""" 20 | 21 | def __init__(self, client: cohere.Client, documents: Iterable[Document]): 22 | """Loads a searcher with a Cohere client & documents to search over. 23 | 24 | Args: 25 | client (cohere.Client): Provides access to Cohere API via the Python SDK 26 | documents (Iterable[Document]): Iterable cache of local Documents. 27 | """ 28 | super().__init__(client, documents) 29 | 30 | def search(self, query: str, threshold: float = -1e6) -> Optional[Document]: 31 | """Searches by finding most similar doc in local docs. 32 | 33 | Does not make any updates to local doc cache. Only returns a document 34 | if the similarity is above the threshold. 35 | 36 | Args: 37 | query (str): Query to check docs against. 38 | threshold (float): Minimum similarity needed to return a document. 39 | 40 | Returns: 41 | Optional[Document]: Most similar doc to query. None if threshold not met. 42 | """ 43 | 44 | embedded_query = self.co.embed(texts=[query]).embeddings[0] 45 | similarities = [ 46 | self._measure_similarity(embedded_query, d.embedding) 47 | for d in self.documents 48 | ] 49 | 50 | max_similarity = max(similarities) 51 | 52 | if max_similarity < threshold: 53 | logging.warning( 54 | f"Max search similarity {max_similarity} below threshold {threshold}; " 55 | "no document returned." 56 | ) 57 | return None 58 | 59 | logging.info(f"Search result found for query: {query}") 60 | nearest_idx = similarities.index(max_similarity) 61 | nearest_doc = self.documents[nearest_idx] 62 | 63 | return nearest_doc 64 | -------------------------------------------------------------------------------- /conversant/search/searcher.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Cohere Inc. and its affiliates. 2 | # 3 | # Licensed under the MIT License (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # 6 | # You may obtain a copy of the License in the LICENSE file at the top 7 | # level of this repository. 8 | 9 | from abc import ABC, abstractmethod 10 | from typing import Iterable, List, Optional 11 | 12 | import cohere 13 | import numpy as np 14 | 15 | from conversant.search.document import Document 16 | 17 | 18 | class Searcher(ABC): 19 | """Searches through documents for ground truth.""" 20 | 21 | def __init__(self, client: cohere.Client, documents: Iterable[Document]): 22 | """Loads a searcher with a Cohere client & documents to search over. 23 | 24 | Args: 25 | client (cohere.Client): Provides access to Cohere API via the Python SDK 26 | documents (Iterable[Document]): Iterable cache of local Documents. 27 | """ 28 | self.co = client 29 | self.documents = documents 30 | 31 | @abstractmethod 32 | def search(self, query: str) -> Optional[Document]: 33 | """Semantic search over local documents. 34 | 35 | Precisely _how_ the cache of local documents is updated 36 | is left to the implementation in a subclass. 37 | Args: 38 | query (str): Input query to match documents against. 39 | Returns: 40 | Optional[Document]: Best matching document, if found. 41 | """ 42 | 43 | def _measure_similarity(self, embed_1: List[float], embed_2: List[float]) -> float: 44 | """Measures similarity between embeddings. Uses cosine similarity. 45 | 46 | Manually computes dot(A,B)/(||A||*||B||), to avoid adding 47 | more deps. Uses numpy methods for these computations. 48 | 49 | Args: 50 | embed_1 (List[float]): First embedding in the pair. 51 | embed_2 (List[float]): Second embedding in the pair. 52 | Returns: 53 | float: Similarity between the documents. 54 | """ 55 | dot_product = np.dot(embed_1, embed_2) 56 | norm_product = np.linalg.norm(embed_1) * np.linalg.norm(embed_2) 57 | 58 | cos_similarity = dot_product / norm_product 59 | 60 | return cos_similarity 61 | 62 | def embed_documents(self) -> None: 63 | """Embeds the Searcher's documents.""" 64 | 65 | embeddings = self.co.embed(texts=[d.content for d in self.documents]).embeddings 66 | 67 | for document, embedding in zip(self.documents, embeddings): 68 | document.embedding = embedding 69 | -------------------------------------------------------------------------------- /conversant/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Cohere Inc. and its affiliates. 2 | # 3 | # Licensed under the MIT License (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # 6 | # You may obtain a copy of the License in the LICENSE file at the top 7 | # level of this repository. 8 | 9 | __all__ = ["demo_utils"] 10 | -------------------------------------------------------------------------------- /conversant/utils/demo_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Cohere Inc. and its affiliates. 2 | # 3 | # Licensed under the MIT License (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # 6 | # You may obtain a copy of the License in the LICENSE file at the top 7 | # level of this repository. 8 | 9 | import codecs 10 | import os 11 | import pickle 12 | import sys 13 | from typing import Type 14 | 15 | import cohere 16 | from streamlit.web import cli as stcli 17 | 18 | import conversant 19 | from conversant.chatbot import Chatbot 20 | 21 | 22 | def encode_chatbot(chatbot: Type[Chatbot]) -> str: 23 | """Serialize and encode a Chatbot object to a base-64 string encoding. 24 | 25 | Args: 26 | chatbot (object): a chatbot of class inherited from Chatbot 27 | 28 | Returns: 29 | str: Chatbot object as a base-64 string 30 | """ 31 | chatbot.co = None 32 | return codecs.encode(pickle.dumps(chatbot), "base64").decode() 33 | 34 | 35 | def decode_chatbot(chatbot_string: str, client: cohere.Client) -> Type[Chatbot]: 36 | """Decode and deserialize a Chatbot object. 37 | 38 | Args: 39 | obj_string (str): a base-64 string encoding 40 | 41 | Returns: 42 | Type[Chatbot]: a chatbot of class inherited rom Chatbot 43 | """ 44 | chatbot = pickle.loads(codecs.decode(chatbot_string.encode(), "base64")) 45 | chatbot.co = client 46 | return chatbot 47 | 48 | 49 | def launch_streamlit(chatbot: Type[Chatbot]) -> None: 50 | """Launches a demo of a chatbot using Streamlit. 51 | 52 | The bot will be a persona available for chatting using the interface 53 | defined in conversant/demo/streamlit_example.py. 54 | 55 | Args: 56 | bot (Type[Chatbot]): a chatbot of class inherited from Chatbot 57 | """ 58 | path = os.path.dirname(conversant.__file__) 59 | sys.argv = f"streamlit run {path}/demo/streamlit_example.py --".split(" ") 60 | sys.argv.append(encode_chatbot(chatbot)) 61 | sys.exit(stcli.main()) 62 | -------------------------------------------------------------------------------- /docs/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cohere-ai/sandbox-conversant-lib/eeb7c808c2ab59f553f3500b90681019f79365be/docs/.gitkeep -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "conversant" 3 | version = "0.2.5" 4 | repository = "https://github.com/cohere-ai/sandbox-conversant-lib" 5 | description = "Conversational AI tooling" 6 | readme = "README.md" 7 | authors = ["Cohere ConvAI "] 8 | 9 | [tool.poetry.dependencies] 10 | python = "^3.8,!=3.9.7" 11 | cohere = "^3.0" 12 | toml = "^0.10.2" 13 | pydantic = "^1.10.2" 14 | emoji = "1.7.0" 15 | emojificate = "^0.6.0" 16 | streamlit-ace = "^0.1.1" 17 | streamlit-talk = "^0.2.1" 18 | 19 | [tool.poetry.group.dev.dependencies] 20 | pytest = "^7.1.2" 21 | streamlit = "^1.18.1" 22 | black = "^22.10.0" 23 | pre-commit = "^2.20.0" 24 | pdoc = "^12.1.0" 25 | ruff = "^0.0.94" 26 | isort = "^5.10.1" 27 | autoflake = "^1.7.7" 28 | 29 | [build-system] 30 | requires = ["poetry-core>=1.3.0"] 31 | build-backend = "poetry.core.masonry.api" 32 | -------------------------------------------------------------------------------- /static/conversant-diagram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cohere-ai/sandbox-conversant-lib/eeb7c808c2ab59f553f3500b90681019f79365be/static/conversant-diagram.png -------------------------------------------------------------------------------- /static/fortune-teller-chat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cohere-ai/sandbox-conversant-lib/eeb7c808c2ab59f553f3500b90681019f79365be/static/fortune-teller-chat.png -------------------------------------------------------------------------------- /static/fortune-teller-edit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cohere-ai/sandbox-conversant-lib/eeb7c808c2ab59f553f3500b90681019f79365be/static/fortune-teller-edit.png -------------------------------------------------------------------------------- /static/fortune-teller-setup.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cohere-ai/sandbox-conversant-lib/eeb7c808c2ab59f553f3500b90681019f79365be/static/fortune-teller-setup.png -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Cohere Inc. and its affiliates. 2 | # 3 | # Licensed under the MIT License (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # 6 | # You may obtain a copy of the License in the LICENSE file at the top 7 | # level of this repository. 8 | 9 | import json 10 | import os 11 | from typing import Any, Dict, Optional 12 | 13 | import cohere 14 | import pytest 15 | from cohere.embeddings import Embeddings 16 | from cohere.generation import Generations 17 | from cohere.tokenize import Tokens 18 | 19 | from conversant.prompt_chatbot import PERSONA_MODEL_DIRECTORY, PromptChatbot 20 | from conversant.prompts.chat_prompt import ChatPrompt 21 | from conversant.prompts.prompt import Prompt 22 | from conversant.prompts.rewrite_prompt import RewritePrompt 23 | from conversant.search.document import Document 24 | from conversant.search.local_searcher import LocalSearcher 25 | from conversant.search.searcher import Searcher 26 | 27 | 28 | class MockCo: 29 | 30 | def generate(*args, **kwargs) -> Generations: 31 | return Generations( 32 | response={"generations": [dict(text="Hello!", likelihood=1.0, id=123)]}, 33 | return_likelihoods="NONE", 34 | ) 35 | 36 | def embed(*args, **kwargs) -> Embeddings: 37 | if "texts" in kwargs: 38 | embeddings = [[1.0, 1.0]] * len(kwargs["texts"]) 39 | return Embeddings(embeddings=embeddings) 40 | 41 | return Embeddings(embeddings=[[1.0, 1.0]]) 42 | 43 | def tokenize(*args, **kwargs) -> Tokens: 44 | token_strings = args[1].split() 45 | tokens = [x for x in range(len(token_strings))] 46 | 47 | return Tokens(tokens, token_strings) 48 | 49 | 50 | @pytest.fixture 51 | def mock_co() -> object: 52 | """Mock of Cohere client. 53 | 54 | Returns: 55 | object: A simple mock of Cohere's API client. 56 | """ 57 | return MockCo() 58 | 59 | 60 | @pytest.fixture 61 | def mock_prompt_config() -> Dict[str, Any]: 62 | """A Prompt config fixture for tests. 63 | 64 | Returns: 65 | Dict[str, Any]: Dictionary that can be used to construct to instantiate a 66 | Prompt. 67 | """ 68 | return { 69 | "preamble": "This is a prompt.", 70 | "example_separator": "\n", 71 | "headers": { 72 | "query": "", 73 | "context": "", 74 | "generation": "", 75 | }, 76 | "examples": [ 77 | { 78 | "query": "This is a query.", 79 | "context": "This is a context.", 80 | "generation": "This is a generation.", 81 | }, 82 | { 83 | "query": "This is a second query.", 84 | "context": "This is a second context.", 85 | "generation": "This is a second generation.", 86 | }, 87 | ], 88 | "REQUIRED_KEYS": ["query", "context", "generation"], 89 | } 90 | 91 | 92 | @pytest.fixture 93 | def mock_prompt(mock_prompt_config: Dict[str, Any]) -> Prompt: 94 | """Instantiates a Prompt fixture for tests. 95 | 96 | Args: 97 | mock_prompt_config (Dict[str, Any]): A config used to instantiate a Prompt 98 | fixture. 99 | 100 | Returns: 101 | Prompt: A mock Prompt object fixture for tests. 102 | """ 103 | return Prompt(**mock_prompt_config) 104 | 105 | 106 | @pytest.fixture 107 | def mock_chat_prompt_config() -> Dict[str, Any]: 108 | """A ChatPrompt config fixture for tests. 109 | 110 | Returns: 111 | Dict[str, Any]: Dictionary that can be used to construct to instantiate a 112 | ChatPrompt. 113 | """ 114 | return { 115 | "preamble": 116 | "This is a chat prompt.", 117 | "example_separator": 118 | "\n", 119 | "headers": { 120 | "user": "User", 121 | "bot": "Mock Chatbot" 122 | }, 123 | "examples": [ 124 | [ 125 | { 126 | "user": "This is a user utterance", 127 | "bot": "This is a bot utterance" 128 | }, 129 | { 130 | "user": "This is second user utterance", 131 | "bot": "This is second bot utterance", 132 | }, 133 | ], 134 | [ 135 | { 136 | "user": "This is a user utterance in the second example.", 137 | "bot": "This is a bot utterance in the second example.", 138 | }, 139 | { 140 | "user": "This is second user utterance in the second example.", 141 | "bot": "This is second bot utterance in the second example.", 142 | }, 143 | ], 144 | ], 145 | } 146 | 147 | 148 | @pytest.fixture 149 | def mock_chat_prompt(mock_chat_prompt_config: Dict[str, Any]) -> ChatPrompt: 150 | """A ChatPrompt config fixture for tests. 151 | 152 | Args: 153 | mock_chat_prompt_config (Dict[str, Any]): A config used to instantiate a 154 | ChatPrompt fixture. 155 | 156 | Returns: 157 | ChatPrompt: A mock ChatPrompt object fixture for tests. 158 | """ 159 | return ChatPrompt(**mock_chat_prompt_config) 160 | 161 | 162 | @pytest.fixture 163 | def mock_rewrite_prompt_config() -> Dict[str, Any]: 164 | """A RewritePrompt config fixture for tests. 165 | 166 | Returns: 167 | Dict[str, Any]: Dictionary that can be used to construct to instantiate a 168 | RewritePrompt. 169 | """ 170 | return { 171 | "preamble": 172 | "This is a rewrite prompt.", 173 | "example_separator": 174 | "\n", 175 | "headers": { 176 | "conversation": "<>", 177 | "fact": "<>", 178 | "rewrite": "<>", 179 | }, 180 | "examples": [ 181 | { 182 | "conversation": "This is a wrong message.", 183 | "fact": "This is a fact.", 184 | "rewrite": "This is a message based on fact.", 185 | }, 186 | { 187 | "conversation": "This is a second wrong message.", 188 | "fact": "This is a second fact.", 189 | "rewrite": "This is a second message based on fact.", 190 | }, 191 | ], 192 | } 193 | 194 | 195 | @pytest.fixture 196 | def mock_rewrite_prompt(mock_rewrite_prompt_config: Dict[str, Any]) -> RewritePrompt: 197 | """A RewritePrompt config fixture for tests. 198 | 199 | Args: 200 | mock_rewrite_prompt_config (Dict[str, Any]): A config used to instantiate a 201 | RewritePrompt fixture. 202 | 203 | Returns: 204 | RewritePrompt: A mock RewritePrompt fixture for tests. 205 | """ 206 | return RewritePrompt(**mock_rewrite_prompt_config) 207 | 208 | 209 | @pytest.fixture 210 | def mock_prompt_chatbot(mock_co: object, mock_chat_prompt: ChatPrompt) -> PromptChatbot: 211 | """Instantiates a single bot fixture for tests. 212 | 213 | Args: 214 | mock_co (object): A mock Cohere client. 215 | mock_chat_prompt (ChatPrompt): A mock ChatPrompt. 216 | 217 | Returns: 218 | PromptChatbot: A simple mock of a chatbot that works through prompts. 219 | """ 220 | return PromptChatbot( 221 | client=mock_co, 222 | prompt=mock_chat_prompt, 223 | ) 224 | 225 | 226 | @pytest.fixture 227 | def mock_persona() -> Dict[str, Any]: 228 | """Instantiates a persona dict fixture for tests. 229 | 230 | Returns: 231 | Dict[str, Any]: A mock dictionary used to initialize a PromptChatbot. 232 | """ 233 | persona_name = "watch-sales-agent" 234 | persona_path = os.path.join(PERSONA_MODEL_DIRECTORY, persona_name, "config.json") 235 | with open(persona_path) as f: 236 | persona = json.load(f) 237 | return persona 238 | 239 | 240 | @pytest.fixture() 241 | def mock_searcher(mock_co: cohere.Client) -> Searcher: 242 | """Mock fixture subclass to test abstract class methods. 243 | 244 | Args: 245 | mock_co (cohere.Client): Cohere API client. 246 | 247 | Returns: 248 | Searcher: Mock Searcher object. 249 | """ 250 | 251 | class MockSearcher(Searcher): 252 | 253 | def search(self, query: str) -> Optional[Document]: 254 | return super().search(query) 255 | 256 | return MockSearcher( 257 | client=mock_co, 258 | documents=[Document( 259 | source_link="http://url", 260 | doc_id="123", 261 | content="test content", 262 | )], 263 | ) 264 | 265 | 266 | @pytest.fixture() 267 | def mock_local_searcher(mock_co: cohere.Client) -> LocalSearcher: 268 | """Mock fixture subclass to test class methods. 269 | 270 | Args: 271 | mock_co (cohere.Client): Cohere API client. 272 | 273 | Returns: 274 | LocalSearcher: Mock Searcher object. 275 | """ 276 | 277 | return LocalSearcher( 278 | client=mock_co, 279 | documents=[ 280 | Document( 281 | source_link="http://url", 282 | doc_id="123", 283 | content="test content", 284 | embedding=[1.0, -1.0], 285 | ), 286 | Document( 287 | source_link="http://url", 288 | doc_id="123", 289 | content="test content", 290 | embedding=[1.0, 1.0], 291 | ), 292 | ], 293 | ) 294 | -------------------------------------------------------------------------------- /tests/prompts/test_prompt.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Cohere Inc. and its affiliates. 2 | # 3 | # Licensed under the MIT License (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # 6 | # You may obtain a copy of the License in the LICENSE file at the top 7 | # level of this repository. 8 | 9 | from typing import Any, Dict 10 | 11 | import pytest 12 | 13 | from conversant.chatbot import Interaction 14 | from conversant.prompts.prompt import Prompt 15 | 16 | 17 | @pytest.fixture 18 | def new_interaction() -> Interaction: 19 | """Instantiates a fixture for a new Prompt interaction. 20 | 21 | Returns: 22 | Interaction: New Prompt interaction fixture. 23 | """ 24 | return { 25 | "query": "A new query!", 26 | "context": "A new context!", 27 | "generation": "A new generation!", 28 | } 29 | 30 | 31 | def test_prompt_init(mock_prompt_config: Dict[str, Any]) -> None: 32 | """Tests Prompt.__init__ 33 | 34 | Args: 35 | mock_prompt_config (Dict[str, Any]): A Prompt config fixture. 36 | """ 37 | prompt = Prompt(**mock_prompt_config) 38 | assert prompt.preamble == "This is a prompt." 39 | assert prompt.example_separator == "\n" 40 | assert prompt.headers == { 41 | "query": "", 42 | "context": "", 43 | "generation": "", 44 | } 45 | assert prompt.examples == [ 46 | { 47 | "query": "This is a query.", 48 | "context": "This is a context.", 49 | "generation": "This is a generation.", 50 | }, 51 | { 52 | "query": "This is a second query.", 53 | "context": "This is a second context.", 54 | "generation": "This is a second generation.", 55 | }, 56 | ] 57 | 58 | 59 | def test_prompt_from_dict(mock_prompt_config: Dict[str, Any]) -> None: 60 | """Tests Prompt.from_dict 61 | 62 | Args: 63 | mock_prompt_config (Dict[str, Any]): A Prompt config fixture. 64 | """ 65 | prompt = Prompt.from_dict(mock_prompt_config) 66 | assert prompt.preamble == "This is a prompt." 67 | assert prompt.example_separator == "\n" 68 | assert prompt.headers == { 69 | "query": "", 70 | "context": "", 71 | "generation": "", 72 | } 73 | assert prompt.examples == [ 74 | { 75 | "query": "This is a query.", 76 | "context": "This is a context.", 77 | "generation": "This is a generation.", 78 | }, 79 | { 80 | "query": "This is a second query.", 81 | "context": "This is a second context.", 82 | "generation": "This is a second generation.", 83 | }, 84 | ] 85 | 86 | 87 | @pytest.mark.parametrize( 88 | "config", 89 | [ 90 | # no preamble 91 | {"preamble": ""}, 92 | # example separator is not str 93 | { 94 | "example_separator": 123, 95 | }, 96 | # example missing variable 97 | { 98 | "examples": [{"query": "This is a query."}], 99 | }, 100 | # no examples 101 | { 102 | "examples": [], 103 | }, 104 | ], 105 | ids=[ 106 | "validation-no-preamble", 107 | "validation-example-separator-not-str", 108 | "validation-example-missing-variable", 109 | "validation-no-examples", 110 | ], 111 | ) 112 | def test_prompt_init_fails( 113 | mock_prompt_config: Dict[str, Any], config: Dict[str, Any] 114 | ) -> None: 115 | """Tests Prompt.__init__ on bad parameters. 116 | 117 | Args: 118 | mock_prompt_config (Dict[str, Any]): A Prompt config fixture. 119 | config (Dict[str, Any]): Dictionary of bad parameters. 120 | """ 121 | mock_prompt_config.update(config) 122 | with pytest.raises(ValueError): 123 | _ = Prompt(**mock_prompt_config) 124 | 125 | 126 | def test_prompt_create_interaction( 127 | mock_prompt: Prompt, new_interaction: Interaction 128 | ) -> None: 129 | """Tests Prompt.create_interaction 130 | 131 | Args: 132 | mock_prompt (Prompt): A Prompt fixture. 133 | new_interaction (Dict[ str, str]): A new Prompt interaction fixture. 134 | """ 135 | # create from positional arguments only 136 | generated_interaction = mock_prompt.create_interaction( 137 | new_interaction["query"], 138 | new_interaction["context"], 139 | new_interaction["generation"], 140 | ) 141 | assert generated_interaction == new_interaction 142 | 143 | # create from keyword arguments only 144 | generated_interaction = mock_prompt.create_interaction(**new_interaction) 145 | assert generated_interaction == new_interaction 146 | 147 | # create from mix of positional and keyword arguments 148 | kwargs = {"generation": "A new generation!"} 149 | generated_interaction = mock_prompt.create_interaction( 150 | new_interaction["query"], new_interaction["context"], **kwargs 151 | ) 152 | assert generated_interaction == new_interaction 153 | 154 | 155 | def test_prompt_create_interaction_string( 156 | mock_prompt: Prompt, new_interaction: Interaction 157 | ) -> None: 158 | """Tests Prompt.create_interaction_string 159 | 160 | Args: 161 | mock_prompt (Prompt): A Prompt fixture. 162 | new_interaction (Dict[ str, str]): A new Prompt interaction fixture. 163 | """ 164 | expected = ( 165 | f"{mock_prompt.headers['query']}{new_interaction['query']}\n" 166 | f"{mock_prompt.headers['context']}{new_interaction['context']}\n" 167 | f"{mock_prompt.headers['generation']}{new_interaction['generation']}\n" 168 | ) 169 | # create from positional arguments 170 | generated_interaction_str = mock_prompt.create_interaction_string( 171 | new_interaction["query"], 172 | new_interaction["context"], 173 | new_interaction["generation"], 174 | ) 175 | assert generated_interaction_str == expected 176 | 177 | # create from keyword arguments 178 | generated_interaction_str = mock_prompt.create_interaction_string(**new_interaction) 179 | assert generated_interaction_str == expected 180 | 181 | # create from mix of positional and keyword arguments 182 | kwargs = {"generation": "A new generation!"} 183 | generated_interaction_str = mock_prompt.create_interaction_string( 184 | new_interaction["query"], new_interaction["context"], **kwargs 185 | ) 186 | assert generated_interaction_str == expected 187 | 188 | # generated example string is dependent on the insertion order into the examples 189 | # dictionary 190 | reordered_example = {} 191 | reordered_example["context"] = new_interaction["context"] 192 | reordered_example["query"] = new_interaction["query"] 193 | reordered_example["generation"] = new_interaction["generation"] 194 | reordered_expected = ( 195 | f"{mock_prompt.headers['context']}{new_interaction['context']}\n" 196 | f"{mock_prompt.headers['query']}{new_interaction['query']}\n" 197 | f"{mock_prompt.headers['generation']}{new_interaction['generation']}\n" 198 | ) 199 | generated_reordered_example_str = mock_prompt.create_interaction_string( 200 | **reordered_example 201 | ) 202 | assert generated_reordered_example_str == reordered_expected 203 | 204 | 205 | def test_prompt_to_string(mock_prompt: Prompt) -> None: 206 | """Tests Prompt.to_string 207 | 208 | Args: 209 | mock_prompt (Prompt): A Prompt fixture. 210 | """ 211 | expected = ( 212 | f"{mock_prompt.preamble}\n" 213 | f"{mock_prompt.example_separator}" 214 | f"{mock_prompt.headers['query']}{mock_prompt.examples[0]['query']}\n" 215 | f"{mock_prompt.headers['context']}{mock_prompt.examples[0]['context']}\n" 216 | f"{mock_prompt.headers['generation']}{mock_prompt.examples[0]['generation']}\n" 217 | f"{mock_prompt.example_separator}" 218 | f"{mock_prompt.headers['query']}{mock_prompt.examples[1]['query']}\n" 219 | f"{mock_prompt.headers['context']}{mock_prompt.examples[1]['context']}\n" 220 | f"{mock_prompt.headers['generation']}{mock_prompt.examples[1]['generation']}" 221 | ) 222 | assert mock_prompt.to_string() == expected 223 | 224 | 225 | @pytest.mark.parametrize( 226 | "new_config", 227 | [ 228 | {"example_separator": "--"}, 229 | {"preamble": "This is a new preamble."}, 230 | ], 231 | ids=["new-example-separator", "new-preamble"], 232 | ) 233 | def test_prompt_update(mock_prompt: Prompt, new_config: Dict[str, Any]) -> None: 234 | """Tests Prompt.update 235 | 236 | Args: 237 | mock_prompt (Prompt): A Prompt fixture. 238 | new_config (Dict[str, Any]): Dictionary of new params. 239 | """ 240 | mock_prompt.update(new_config) 241 | for key in new_config: 242 | assert getattr(mock_prompt, key) == new_config[key] 243 | -------------------------------------------------------------------------------- /tests/prompts/test_rewrite_prompt.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Cohere Inc. and its affiliates. 2 | # 3 | # Licensed under the MIT License (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # 6 | # You may obtain a copy of the License in the LICENSE file at the top 7 | # level of this repository. 8 | 9 | from typing import Any, Dict 10 | 11 | import pytest 12 | 13 | from conversant.chatbot import Interaction 14 | from conversant.prompts.rewrite_prompt import RewritePrompt 15 | 16 | 17 | @pytest.fixture 18 | def new_interaction() -> Interaction: 19 | """Instantiates a fixture for a new RewritePrompt example. 20 | 21 | Returns: 22 | Interaction: New RewritePrompt interaction fixture. 23 | """ 24 | return { 25 | "conversation": "Otters are plants.", 26 | "fact": "Otters are mammals.", 27 | "rewrite": "Otters are mammals.", 28 | } 29 | 30 | 31 | def test_rewrite_prompt_init(mock_rewrite_prompt_config: Dict[str, Any]) -> None: 32 | """Tests RewritePrompt.__init__ 33 | 34 | Args: 35 | mock_rewrite_prompt_config (Dict[str, Any]): A RewritePrompt config fixture. 36 | """ 37 | _ = RewritePrompt(**mock_rewrite_prompt_config) 38 | 39 | 40 | def test_rewrite_prompt_init_from_dict( 41 | mock_rewrite_prompt_config: Dict[str, Any] 42 | ) -> None: 43 | """Tests RewritePrompt.from_dict 44 | 45 | Args: 46 | mock_rewrite_prompt_config (Dict[str, Any]): A RewritePrompt config fixture. 47 | """ 48 | _ = RewritePrompt.from_dict(mock_rewrite_prompt_config) 49 | 50 | 51 | @pytest.mark.parametrize( 52 | "config", 53 | [ 54 | # short preamble 55 | {"preamble": "short"}, 56 | # no examples 57 | {"examples": []}, 58 | ], 59 | ids=[ 60 | "short-preamble", 61 | "no-examples", 62 | ], 63 | ) 64 | def test_rewrite_prompt_init_fails( 65 | mock_rewrite_prompt_config: Dict[str, Any], config 66 | ) -> None: 67 | """Tests RewritePrompt.__init__ on bad parameters. 68 | 69 | Args: 70 | mock_rewrite_prompt_config (Dict[str, Any]): A RewritePrompt config fixture. 71 | config (Dict[str, Any]): Dictionary of bad parameters. 72 | """ 73 | mock_rewrite_prompt_config.update(config) 74 | with pytest.raises(ValueError): 75 | _ = RewritePrompt(**mock_rewrite_prompt_config) 76 | 77 | 78 | def test_rewrite_prompt_create_interaction_string( 79 | mock_rewrite_prompt: RewritePrompt, new_interaction: Interaction 80 | ) -> None: 81 | """Tests RewritePrompt.create_interaction_string 82 | 83 | Args: 84 | mock_rewrite_prompt (RewritePrompt): A RewritePrompt fixture. 85 | new_interaction (Dict[ str, str]): A new RewritePrompt interaction fixture. 86 | """ 87 | expected = ( 88 | f"{mock_rewrite_prompt.headers['conversation']}\n" 89 | f"{new_interaction['conversation']}\n" 90 | f"{mock_rewrite_prompt.headers['fact']}\n" 91 | f"{new_interaction['fact']}\n" 92 | f"{mock_rewrite_prompt.headers['rewrite']}\n" 93 | f"{new_interaction['rewrite']}\n" 94 | ) 95 | # create from positional arguments 96 | generated_interaction_str = mock_rewrite_prompt.create_interaction_string( 97 | new_interaction["conversation"], 98 | new_interaction["rewrite"], 99 | new_interaction["fact"], 100 | ) 101 | assert generated_interaction_str == expected 102 | 103 | # create from keyword arguments 104 | generated_interaction_str = mock_rewrite_prompt.create_interaction_string( 105 | **new_interaction 106 | ) 107 | assert generated_interaction_str == expected 108 | 109 | # create from mix of positional and keyword arguments 110 | kwargs = {"rewrite": "Otters are mammals."} 111 | generated_interaction_str = mock_rewrite_prompt.create_interaction_string( 112 | new_interaction["conversation"], new_interaction["fact"], **kwargs 113 | ) 114 | assert generated_interaction_str == expected 115 | 116 | # generated example string is dependent on the insertion order into the examples 117 | # dictionary 118 | reordered_interaction = {} 119 | reordered_interaction["fact"] = new_interaction["fact"] 120 | reordered_interaction["conversation"] = new_interaction["conversation"] 121 | reordered_interaction["rewrite"] = new_interaction["rewrite"] 122 | reordered_expected = ( 123 | f"{mock_rewrite_prompt.headers['fact']}\n" 124 | f"{new_interaction['fact']}\n" 125 | f"{mock_rewrite_prompt.headers['conversation']}\n" 126 | f"{new_interaction['conversation']}\n" 127 | f"{mock_rewrite_prompt.headers['rewrite']}\n" 128 | f"{new_interaction['rewrite']}\n" 129 | ) 130 | generated_reordered_interaction_str = mock_rewrite_prompt.create_interaction_string( 131 | **reordered_interaction 132 | ) 133 | assert generated_reordered_interaction_str == reordered_expected 134 | 135 | 136 | def test_rewrite_prompt_to_string(mock_rewrite_prompt: RewritePrompt) -> None: 137 | """Tests RewritePrompt.to_string 138 | 139 | Args: 140 | mock_rewrite_prompt (RewritePrompt): A RewritePrompt fixture. 141 | """ 142 | expected = ( 143 | f"{mock_rewrite_prompt.preamble}\n" 144 | f"{mock_rewrite_prompt.example_separator}" 145 | f"{mock_rewrite_prompt.headers['conversation']}\n" 146 | f"{mock_rewrite_prompt.examples[0]['conversation']}\n" 147 | f"{mock_rewrite_prompt.headers['fact']}\n" 148 | f"{mock_rewrite_prompt.examples[0]['fact']}\n" 149 | f"{mock_rewrite_prompt.headers['rewrite']}\n" 150 | f"{mock_rewrite_prompt.examples[0]['rewrite']}\n" 151 | f"{mock_rewrite_prompt.example_separator}" 152 | f"{mock_rewrite_prompt.headers['conversation']}\n" 153 | f"{mock_rewrite_prompt.examples[1]['conversation']}\n" 154 | f"{mock_rewrite_prompt.headers['fact']}\n" 155 | f"{mock_rewrite_prompt.examples[1]['fact']}\n" 156 | f"{mock_rewrite_prompt.headers['rewrite']}\n" 157 | f"{mock_rewrite_prompt.examples[1]['rewrite']}" 158 | ) 159 | assert mock_rewrite_prompt.to_string() == expected 160 | -------------------------------------------------------------------------------- /tests/prompts/test_start_prompt.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Cohere Inc. and its affiliates. 2 | # 3 | # Licensed under the MIT License (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # 6 | # You may obtain a copy of the License in the LICENSE file at the top 7 | # level of this repository. 8 | 9 | from typing import Any, Dict 10 | 11 | import pytest 12 | 13 | from conversant.chatbot import Interaction 14 | from conversant.prompts.chat_prompt import ChatPrompt 15 | 16 | 17 | @pytest.fixture 18 | def new_interaction() -> Interaction: 19 | """Instantiates a fixture for a new ChatPrompt example. 20 | 21 | Returns: 22 | Interaction: New ChatPrompt interaction fixture. 23 | """ 24 | return {"user": "Nice to meet you!", "bot": "You too!"} 25 | 26 | 27 | def test_chat_prompt_init(mock_chat_prompt_config: Dict[str, Any]) -> None: 28 | """Tests ChatPrompt.__init__ 29 | 30 | Args: 31 | mock_chat_prompt_config (Dict[str, Any]): A ChatPrompt config fixture. 32 | """ 33 | chat_prompt = ChatPrompt(**mock_chat_prompt_config) 34 | assert chat_prompt.user_name == "User" 35 | assert chat_prompt.bot_name == "Mock Chatbot" 36 | 37 | 38 | def test_chat_prompt_init_from_dict(mock_chat_prompt_config: Dict[str, Any]) -> None: 39 | """Tests ChatPrompt.from_dict 40 | 41 | Args: 42 | mock_chat_prompt_config (Dict[str, Any]): A ChatPrompt config fixture. 43 | """ 44 | chat_prompt = ChatPrompt.from_dict(mock_chat_prompt_config) 45 | assert chat_prompt.user_name == "User" 46 | assert chat_prompt.bot_name == "Mock Chatbot" 47 | 48 | 49 | @pytest.mark.parametrize( 50 | "config", 51 | [ 52 | # short preamble 53 | {"preamble": "short"}, 54 | # headers do not contain user 55 | {"headers": {"bot": "Mock Chatbot"}}, 56 | # headers do not contain bot 57 | {"headers": {"user": "User"}}, 58 | # examples have no speakers 59 | {"examples": [{}]}, 60 | # examples have one speaker 61 | {"examples": [[{"user": "user utterance"}, {"bot": "bot utterance"}]]}, 62 | # examples have wrong key 63 | {"examples": [[{"user": "user utterance", "": "bot utterance"}]]}, 64 | # examples have three speakers 65 | { 66 | "examples": [ 67 | [ 68 | { 69 | "user": "user utterance", 70 | "bot": "bot utterance", 71 | "user2": "user2 utterance", 72 | } 73 | ] 74 | ] 75 | }, 76 | # examples are prefixed by user and bot names 77 | { 78 | "headers": {"user": "Alice", "bot": "Bob"}, 79 | "examples": [[{"user": "Alice: Hey", "bot": "Bob: Hi"}]], 80 | }, 81 | ], 82 | ids=[ 83 | "short-preamble", 84 | "headers-no-user", 85 | "headers-no-bot", 86 | "examples-no-speakers", 87 | "examples-one-speaker", 88 | "examples-wrong-key", 89 | "examples-three-speakers", 90 | "examples-prefixed-with-name", 91 | ], 92 | ) 93 | def test_chat_prompt_init_fails( 94 | mock_chat_prompt_config: Dict[str, Any], config 95 | ) -> None: 96 | """Tests ChatPrompt.__init__ on bad parameters. 97 | 98 | Args: 99 | mock_chat_prompt_config (Dict[str, Any]): A ChatPrompt config fixture. 100 | config (Dict[str, Any]): Dictionary of bad parameters. 101 | """ 102 | mock_chat_prompt_config.update(config) 103 | with pytest.raises(ValueError): 104 | _ = ChatPrompt(**mock_chat_prompt_config) 105 | 106 | 107 | def test_chat_prompt_create_interaction_string( 108 | mock_chat_prompt: ChatPrompt, new_interaction: Interaction 109 | ) -> None: 110 | """Tests ChatPrompt.create_interaction_string 111 | 112 | Args: 113 | mock_chat_prompt (ChatPrompt): A ChatPrompt fixture. 114 | new_interaction (Interaction): A new ChatPrompt interaction fixture. 115 | """ 116 | expected = ( 117 | f"{mock_chat_prompt.headers['user']}: {new_interaction['user']}\n" 118 | f"{mock_chat_prompt.headers['bot']}: {new_interaction['bot']}\n" 119 | ) 120 | # create from positional arguments 121 | generated_interaction_str = mock_chat_prompt.create_interaction_string( 122 | new_interaction["user"], new_interaction["bot"] 123 | ) 124 | assert generated_interaction_str == expected 125 | 126 | # create from keyword arguments 127 | generated_interaction_str = mock_chat_prompt.create_interaction_string( 128 | **new_interaction 129 | ) 130 | assert generated_interaction_str == expected 131 | 132 | # generated example string is dependent on the insertion order into the examples 133 | # dictionary 134 | reordered_interaction = {} 135 | reordered_interaction["bot"] = new_interaction["bot"] 136 | reordered_interaction["user"] = new_interaction["user"] 137 | reordered_expected = ( 138 | f"{mock_chat_prompt.headers['bot']}: {new_interaction['bot']}\n" 139 | f"{mock_chat_prompt.headers['user']}: {new_interaction['user']}\n" 140 | ) 141 | generated_reordered_interaction_str = mock_chat_prompt.create_interaction_string( 142 | **reordered_interaction 143 | ) 144 | assert generated_reordered_interaction_str == reordered_expected 145 | 146 | 147 | def test_chat_prompt_to_string(mock_chat_prompt: ChatPrompt) -> None: 148 | """Tests ChatPrompt.to_string 149 | 150 | Args: 151 | mock_chat_prompt (ChatPrompt): A ChatPrompt fixture. 152 | """ 153 | expected = ( 154 | f"{mock_chat_prompt.preamble}\n" 155 | f"{mock_chat_prompt.example_separator}" 156 | f"{mock_chat_prompt.headers['user']}: {mock_chat_prompt.examples[0][0]['user']}\n" # noqa 157 | f"{mock_chat_prompt.headers['bot']}: {mock_chat_prompt.examples[0][0]['bot']}\n" # noqa 158 | f"{mock_chat_prompt.headers['user']}: {mock_chat_prompt.examples[0][1]['user']}\n" # noqa 159 | f"{mock_chat_prompt.headers['bot']}: {mock_chat_prompt.examples[0][1]['bot']}\n" # noqa 160 | f"{mock_chat_prompt.example_separator}" 161 | f"{mock_chat_prompt.headers['user']}: {mock_chat_prompt.examples[1][0]['user']}\n" # noqa 162 | f"{mock_chat_prompt.headers['bot']}: {mock_chat_prompt.examples[1][0]['bot']}\n" # noqa 163 | f"{mock_chat_prompt.headers['user']}: {mock_chat_prompt.examples[1][1]['user']}\n" # noqa 164 | f"{mock_chat_prompt.headers['bot']}: {mock_chat_prompt.examples[1][1]['bot']}" 165 | ) 166 | assert mock_chat_prompt.to_string() == expected 167 | -------------------------------------------------------------------------------- /tests/search/test_document.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Cohere Inc. and its affiliates. 2 | # 3 | # Licensed under the MIT License (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # 6 | # You may obtain a copy of the License in the LICENSE file at the top 7 | # level of this repository. 8 | 9 | import pytest 10 | 11 | from conversant.search.document import Document 12 | 13 | 14 | def test_empty_document() -> None: 15 | """Tests that empty documents fail.""" 16 | 17 | with pytest.raises(TypeError): 18 | _ = Document() 19 | 20 | 21 | def test_document_embed_optional() -> None: 22 | """Tests that document embeddings are optional.""" 23 | 24 | my_doc = Document( 25 | source_link="http://some-url", 26 | doc_id="123", 27 | content="hello world.", 28 | ) 29 | 30 | assert my_doc.embedding is None 31 | -------------------------------------------------------------------------------- /tests/search/test_local_searcher.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Cohere Inc. and its affiliates. 2 | # 3 | # Licensed under the MIT License (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # 6 | # You may obtain a copy of the License in the LICENSE file at the top 7 | # level of this repository. 8 | 9 | from conversant.search.document import Document 10 | from conversant.search.local_searcher import LocalSearcher 11 | 12 | 13 | def test_no_match(mock_local_searcher: LocalSearcher) -> None: 14 | """Tests impossible thresholds can't be met. 15 | 16 | Args: 17 | searcher (LocalSearcher): Searcher object. 18 | """ 19 | mock_local_searcher.embed_documents() 20 | 21 | self_sim = mock_local_searcher._measure_similarity( 22 | mock_local_searcher.documents[0].embedding, 23 | mock_local_searcher.documents[0].embedding, 24 | ) 25 | 26 | search_result = mock_local_searcher.search( 27 | query="hello world", 28 | threshold=self_sim + 1.0, # Impossible threshold 29 | ) 30 | 31 | assert search_result is None 32 | 33 | 34 | def test_match(mock_local_searcher: LocalSearcher) -> None: 35 | """Tests zero-threshold always gives a match. 36 | 37 | Args: 38 | searcher (LocalSearcher): Searcher object. 39 | """ 40 | search_result = mock_local_searcher.search( 41 | query="hello world", threshold=0.0 # Threshold-free 42 | ) 43 | 44 | assert isinstance(search_result, Document) 45 | -------------------------------------------------------------------------------- /tests/search/test_searcher.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Cohere Inc. and its affiliates. 2 | # 3 | # Licensed under the MIT License (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # 6 | # You may obtain a copy of the License in the LICENSE file at the top 7 | # level of this repository. 8 | 9 | import pytest 10 | 11 | from conversant.search.searcher import Searcher 12 | 13 | 14 | def test_initialization_err() -> None: 15 | """Tests that Searcher is an abstract class.""" 16 | 17 | with pytest.raises(TypeError): 18 | _ = Searcher() 19 | 20 | 21 | def test_req_embedded_docs(mock_searcher: Searcher) -> None: 22 | """Similarity should fail on unembedded docs. 23 | 24 | Args: 25 | mock_searcher (Searcher): Search fixture. 26 | """ 27 | 28 | with pytest.raises(TypeError): 29 | mock_searcher._measure_similarity( 30 | mock_searcher.documents[0].embedding, 31 | mock_searcher.documents[0].embedding, 32 | ) 33 | 34 | 35 | def test_self_similarity(mock_searcher: Searcher) -> None: 36 | """Self similarity should be 1.0. 37 | 38 | Args: 39 | mock_searcher (Searcher): Search fixture. 40 | """ 41 | 42 | mock_searcher.embed_documents() 43 | 44 | similarity = mock_searcher._measure_similarity( 45 | mock_searcher.documents[0].embedding, 46 | mock_searcher.documents[0].embedding, 47 | ) 48 | 49 | assert similarity == pytest.approx(1.0) 50 | 51 | 52 | def test_anti_sim(mock_searcher: Searcher) -> None: 53 | """Anti-similarity should be less than self-sim. 54 | 55 | Args: 56 | mock_searcher (Searcher): Search fixture. 57 | """ 58 | 59 | mock_searcher.embed_documents() 60 | 61 | self_similarity = mock_searcher._measure_similarity( 62 | mock_searcher.documents[0].embedding, 63 | mock_searcher.documents[0].embedding, 64 | ) 65 | 66 | anti_similarity = mock_searcher._measure_similarity( 67 | mock_searcher.documents[0].embedding, 68 | [-e for e in mock_searcher.documents[0].embedding], 69 | ) 70 | 71 | assert anti_similarity < self_similarity 72 | -------------------------------------------------------------------------------- /tests/test_chatbot.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Cohere Inc. and its affiliates. 2 | # 3 | # Licensed under the MIT License (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # 6 | # You may obtain a copy of the License in the LICENSE file at the top 7 | # level of this repository. 8 | 9 | import pytest 10 | from conversant.chatbot import Chatbot 11 | 12 | 13 | def test_initialization_err() -> None: 14 | """Tests that Chatbot is an abstract class.""" 15 | with pytest.raises(TypeError): 16 | _ = Chatbot() 17 | -------------------------------------------------------------------------------- /tests/test_prompt_chatbot.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Cohere Inc. and its affiliates. 2 | # 3 | # Licensed under the MIT License (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # 6 | # You may obtain a copy of the License in the LICENSE file at the top 7 | # level of this repository. 8 | 9 | import itertools 10 | 11 | import pytest 12 | 13 | from conversant.prompt_chatbot import PromptChatbot 14 | 15 | 16 | def check_prompt_chatbot_config(prompt_chatbot: PromptChatbot) -> None: 17 | """Checks that required parameters are in the chatbot and client config. 18 | 19 | Args: 20 | prompt_chatbot (PromptChatbot): The instance of PromptChatbot to check. 21 | """ 22 | __tracebackhide__ = True 23 | for key in [ 24 | "model", 25 | "max_tokens", 26 | "temperature", 27 | "frequency_penalty", 28 | "presence_penalty", 29 | "stop_sequences", 30 | ]: 31 | if key not in prompt_chatbot.client_config: 32 | pytest.fail( 33 | f"{key} not in config of {prompt_chatbot.__class__.__name__} " 34 | "but is required for co.generate" 35 | ) 36 | 37 | for key in ["max_context_examples", "avatar"]: 38 | if key not in prompt_chatbot.chatbot_config: 39 | pytest.fail( 40 | f"{key} not in chatbot config of {prompt_chatbot.__class__.__name__} " 41 | "but is required." 42 | ) 43 | 44 | 45 | def test_prompt_chatbot_init(mock_prompt_chatbot: PromptChatbot) -> None: 46 | """Tests end to end that a PromptChatbot is initialized correctly from constructor. 47 | 48 | Args: 49 | mock_prompt_chatbot (PromptChatbot): Bot test fixture 50 | """ 51 | check_prompt_chatbot_config(mock_prompt_chatbot) 52 | assert mock_prompt_chatbot.user_name == mock_prompt_chatbot.prompt.user_name 53 | assert mock_prompt_chatbot.bot_name == mock_prompt_chatbot.prompt.bot_name 54 | mock_prompt_chatbot.reply(query="What's up?") 55 | reply_generator = mock_prompt_chatbot.partial_reply(query="What are you doing?") 56 | next(reply_generator) 57 | 58 | 59 | def test_prompt_chatbot_init_from_persona(mock_co: object) -> None: 60 | """Tests end to end that a prompt_chatbot is initalized correctly from persona. 61 | 62 | Args: 63 | mock_co (object): mock Cohere client. 64 | """ 65 | prompt_chatbot = PromptChatbot.from_persona("watch-sales-agent", client=mock_co) 66 | assert isinstance(prompt_chatbot, PromptChatbot) 67 | assert prompt_chatbot.user_name == prompt_chatbot.prompt.user_name 68 | assert prompt_chatbot.bot_name == prompt_chatbot.prompt.bot_name 69 | assert prompt_chatbot.latest_prompt == prompt_chatbot.prompt.to_string() 70 | check_prompt_chatbot_config(prompt_chatbot) 71 | prompt_chatbot.reply(query="What's up?") 72 | reply_generator = prompt_chatbot.partial_reply(query="What are you doing?") 73 | next(reply_generator) 74 | 75 | with pytest.raises(FileNotFoundError): 76 | _ = PromptChatbot.from_persona( 77 | "watch-sales-agent", client=mock_co, persona_dir="" 78 | ) 79 | 80 | 81 | @pytest.mark.parametrize( 82 | "max_context_examples, history_length", 83 | list( 84 | itertools.product(list(range(0, 20, 4)), list(range(0, 50, 10))), 85 | ), 86 | ) 87 | def test_prompt_chatbot_get_current_prompt( 88 | mock_prompt_chatbot: PromptChatbot, max_context_examples: int, history_length: int 89 | ) -> None: 90 | """Tests assembly of prompts and context. 91 | 92 | Prompts should be preserved and context 93 | should have line-level trimming applied. 94 | 95 | Args: 96 | prompt_chatbot (PromptChatbot): Bot test fixture 97 | max_context_examples (int): The maximum number of examples to keep as context. 98 | history_length (int): The length of the chat history. 99 | """ 100 | chat_history = [ 101 | {"user": f"Hello! {i}", "bot": f"Hello back! {i}"} 102 | for i in range(1, history_length + 1) 103 | ] 104 | mock_prompt_chatbot.chat_history = chat_history 105 | mock_prompt_chatbot.configure_chatbot( 106 | { 107 | "max_context_examples": max_context_examples, 108 | } 109 | ) 110 | 111 | current_prompt = mock_prompt_chatbot.get_current_prompt(query="Hello!") 112 | expected = ( 113 | # chat prompt 114 | f"{mock_prompt_chatbot.prompt.preamble}\n" 115 | + f"{mock_prompt_chatbot.prompt.example_separator}" 116 | + f"{mock_prompt_chatbot.prompt.headers['user']}: {mock_prompt_chatbot.prompt.examples[0][0]['user']}\n" # noqa 117 | + f"{mock_prompt_chatbot.prompt.headers['bot']}: {mock_prompt_chatbot.prompt.examples[0][0]['bot']}\n" # noqa 118 | + f"{mock_prompt_chatbot.prompt.headers['user']}: {mock_prompt_chatbot.prompt.examples[0][1]['user']}\n" # noqa 119 | + f"{mock_prompt_chatbot.prompt.headers['bot']}: {mock_prompt_chatbot.prompt.examples[0][1]['bot']}\n" # noqa 120 | + f"{mock_prompt_chatbot.prompt.example_separator}" 121 | + f"{mock_prompt_chatbot.prompt.headers['user']}: {mock_prompt_chatbot.prompt.examples[1][0]['user']}\n" # noqa 122 | + f"{mock_prompt_chatbot.prompt.headers['bot']}: {mock_prompt_chatbot.prompt.examples[1][0]['bot']}\n" # noqa 123 | + f"{mock_prompt_chatbot.prompt.headers['user']}: {mock_prompt_chatbot.prompt.examples[1][1]['user']}\n" # noqa 124 | + f"{mock_prompt_chatbot.prompt.headers['bot']}: {mock_prompt_chatbot.prompt.examples[1][1]['bot']}\n" # noqa 125 | # context prompt 126 | + f"{mock_prompt_chatbot.prompt.example_separator}" 127 | + "".join( 128 | [ 129 | ( 130 | f"{mock_prompt_chatbot.prompt.headers['user']}: Hello! {i}\n" 131 | f"{mock_prompt_chatbot.prompt.headers['bot']}: Hello back! {i}\n" 132 | ) 133 | for i in ( 134 | list(range(1, history_length + 1))[-max_context_examples:] 135 | if max_context_examples > 0 136 | else [] 137 | ) 138 | ] 139 | ) 140 | # query prompt 141 | + ( 142 | f"{mock_prompt_chatbot.prompt.headers['user']}: Hello!\n" 143 | f"{mock_prompt_chatbot.prompt.headers['bot']}:" 144 | ) 145 | ) 146 | assert current_prompt == expected 147 | 148 | 149 | def test_missing_persona_fails(mock_co: object) -> None: 150 | """Tests failure on missing persona. 151 | Args: 152 | mock_co (object): mock Cohere client 153 | """ 154 | with pytest.raises(FileNotFoundError): 155 | _ = PromptChatbot.from_persona("invalid_persona", mock_co) 156 | 157 | 158 | def test_update_max_context_fails( 159 | mock_prompt_chatbot: PromptChatbot, mock_co: object 160 | ) -> None: 161 | """Tests failure on updating max_context_examples 162 | 163 | when prompt_size has more than 2048 tokens, even changing max_context_size 164 | 165 | Args: 166 | mock_prompt_chatbot (PromptChatbot): Bot test fixture 167 | mock_co (object): mock Cohere client. 168 | 169 | """ 170 | with pytest.raises(ValueError): 171 | 172 | chat_history = [{"user": "a " * 100, "bot": "b " * 100} for _ in range(5)] 173 | mock_prompt_chatbot.chat_history = chat_history 174 | mock_prompt_chatbot.prompt_size_history = [ 175 | mock_co.tokenize( 176 | mock_prompt_chatbot.prompt.create_interaction_string(interaction) 177 | ).length 178 | for interaction in chat_history 179 | ] 180 | 181 | max_context_examples = 10 182 | current_prompt = mock_prompt_chatbot.get_current_prompt(query="a " * 2048) 183 | prompt_size = mock_co.tokenize(current_prompt).length 184 | mock_prompt_chatbot._update_max_context_examples( 185 | prompt_size, max_context_examples 186 | ) 187 | 188 | 189 | def test_update_max_context_warn( 190 | mock_prompt_chatbot: PromptChatbot, mock_co: object 191 | ) -> None: 192 | """Tests function update_max_context_examples with warnings 193 | 194 | prompt_size with more than 2048 tokens, but if change the max_context_examples 195 | the size reduces and gets smaller than 2048 196 | 197 | Args: 198 | mock_prompt_chatbot (PromptChatbot): Bot test fixture 199 | mock_co (object): mock Cohere client. 200 | 201 | """ 202 | with pytest.warns(UserWarning): 203 | chat_history = [{"user": "a " * 90, "bot": "b " * 90} for _ in range(12)] 204 | mock_prompt_chatbot.chat_history = chat_history 205 | mock_prompt_chatbot.prompt_size_history = [ 206 | mock_co.tokenize( 207 | mock_prompt_chatbot.prompt.create_interaction_string(interaction) 208 | ).length 209 | for interaction in chat_history 210 | ] 211 | 212 | max_context_examples = 10 213 | current_prompt = mock_prompt_chatbot.get_current_prompt(query="q " * 200) 214 | prompt_size = mock_co.tokenize(current_prompt).length 215 | updated_max_context_examples = mock_prompt_chatbot._update_max_context_examples( 216 | prompt_size, max_context_examples 217 | ) 218 | # If max_tokens are changed and this test fails, make sure to update -n 219 | # to what it actually supposed to be 220 | # TODO: Change the test so it automatically does this 221 | expected = max_context_examples - 2 222 | assert updated_max_context_examples == expected 223 | -------------------------------------------------------------------------------- /tests/utils/test_demo_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Cohere Inc. and its affiliates. 2 | # 3 | # Licensed under the MIT License (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # 6 | # You may obtain a copy of the License in the LICENSE file at the top 7 | # level of this repository. 8 | import cohere 9 | 10 | from conversant.prompt_chatbot import PromptChatbot 11 | from conversant.utils import demo_utils 12 | 13 | 14 | def test_encode_decode_chatbot( 15 | mock_prompt_chatbot: PromptChatbot, mock_co: cohere.Client 16 | ) -> None: 17 | assert isinstance( 18 | demo_utils.decode_chatbot( 19 | demo_utils.encode_chatbot(mock_prompt_chatbot), client=mock_co 20 | ), 21 | PromptChatbot, 22 | ) 23 | --------------------------------------------------------------------------------